rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use http_client::HttpClient;
  39use isahc_http_client::IsahcHttpClient;
  40use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  41use sha2::Digest;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    channel::oneshot,
  46    future::{self, BoxFuture},
  47    stream::FuturesUnordered,
  48    FutureExt, SinkExt, StreamExt, TryStreamExt,
  49};
  50use prometheus::{register_int_gauge, IntGauge};
  51use rpc::{
  52    proto::{
  53        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  54        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  55    },
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57};
  58use semantic_version::SemanticVersion;
  59use serde::{Serialize, Serializer};
  60use std::{
  61    any::TypeId,
  62    future::Future,
  63    marker::PhantomData,
  64    mem,
  65    net::SocketAddr,
  66    ops::{Deref, DerefMut},
  67    rc::Rc,
  68    sync::{
  69        atomic::{AtomicBool, Ordering::SeqCst},
  70        Arc, OnceLock,
  71    },
  72    time::{Duration, Instant},
  73};
  74use time::OffsetDateTime;
  75use tokio::sync::{watch, MutexGuard, Semaphore};
  76use tower::ServiceBuilder;
  77use tracing::{
  78    field::{self},
  79    info_span, instrument, Instrument,
  80};
  81
  82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  83
  84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  86
  87const MESSAGE_COUNT_PER_PAGE: usize = 100;
  88const MAX_MESSAGE_LEN: usize = 1024;
  89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  93
  94struct Response<R> {
  95    peer: Arc<Peer>,
  96    receipt: Receipt<R>,
  97    responded: Arc<AtomicBool>,
  98}
  99
 100impl<R: RequestMessage> Response<R> {
 101    fn send(self, payload: R::Response) -> Result<()> {
 102        self.responded.store(true, SeqCst);
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108#[derive(Clone, Debug)]
 109pub enum Principal {
 110    User(User),
 111    Impersonated { user: User, admin: User },
 112    DevServer(dev_server::Model),
 113}
 114
 115impl Principal {
 116    fn update_span(&self, span: &tracing::Span) {
 117        match &self {
 118            Principal::User(user) => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121            }
 122            Principal::Impersonated { user, admin } => {
 123                span.record("user_id", user.id.0);
 124                span.record("login", &user.github_login);
 125                span.record("impersonator", &admin.github_login);
 126            }
 127            Principal::DevServer(dev_server) => {
 128                span.record("dev_server_id", dev_server.id.0);
 129            }
 130        }
 131    }
 132}
 133
 134#[derive(Clone)]
 135struct Session {
 136    principal: Principal,
 137    connection_id: ConnectionId,
 138    db: Arc<tokio::sync::Mutex<DbHandle>>,
 139    peer: Arc<Peer>,
 140    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 141    app_state: Arc<AppState>,
 142    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 143    http_client: Arc<dyn HttpClient>,
 144    /// The GeoIP country code for the user.
 145    #[allow(unused)]
 146    geoip_country_code: Option<String>,
 147    _executor: Executor,
 148}
 149
 150impl Session {
 151    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.db.lock().await;
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        guard
 158    }
 159
 160    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        let guard = self.connection_pool.lock();
 164        ConnectionPoolGuard {
 165            guard,
 166            _not_send: PhantomData,
 167        }
 168    }
 169
 170    fn for_user(self) -> Option<UserSession> {
 171        UserSession::new(self)
 172    }
 173
 174    fn for_dev_server(self) -> Option<DevServerSession> {
 175        DevServerSession::new(self)
 176    }
 177
 178    fn user_id(&self) -> Option<UserId> {
 179        match &self.principal {
 180            Principal::User(user) => Some(user.id),
 181            Principal::Impersonated { user, .. } => Some(user.id),
 182            Principal::DevServer(_) => None,
 183        }
 184    }
 185
 186    fn is_staff(&self) -> bool {
 187        match &self.principal {
 188            Principal::User(user) => user.admin,
 189            Principal::Impersonated { .. } => true,
 190            Principal::DevServer(_) => false,
 191        }
 192    }
 193
 194    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 195        if self.is_staff() {
 196            return Ok(proto::Plan::ZedPro);
 197        }
 198
 199        let Some(user_id) = self.user_id() else {
 200            return Ok(proto::Plan::Free);
 201        };
 202
 203        if db.has_active_billing_subscription(user_id).await? {
 204            Ok(proto::Plan::ZedPro)
 205        } else {
 206            Ok(proto::Plan::Free)
 207        }
 208    }
 209
 210    fn dev_server_id(&self) -> Option<DevServerId> {
 211        match &self.principal {
 212            Principal::User(_) | Principal::Impersonated { .. } => None,
 213            Principal::DevServer(dev_server) => Some(dev_server.id),
 214        }
 215    }
 216
 217    fn principal_id(&self) -> PrincipalId {
 218        match &self.principal {
 219            Principal::User(user) => PrincipalId::UserId(user.id),
 220            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 221            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 222        }
 223    }
 224}
 225
 226impl Debug for Session {
 227    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 228        let mut result = f.debug_struct("Session");
 229        match &self.principal {
 230            Principal::User(user) => {
 231                result.field("user", &user.github_login);
 232            }
 233            Principal::Impersonated { user, admin } => {
 234                result.field("user", &user.github_login);
 235                result.field("impersonator", &admin.github_login);
 236            }
 237            Principal::DevServer(dev_server) => {
 238                result.field("dev_server", &dev_server.id);
 239            }
 240        }
 241        result.field("connection_id", &self.connection_id).finish()
 242    }
 243}
 244
 245struct UserSession(Session);
 246
 247impl UserSession {
 248    pub fn new(s: Session) -> Option<Self> {
 249        s.user_id().map(|_| UserSession(s))
 250    }
 251    pub fn user_id(&self) -> UserId {
 252        self.0.user_id().unwrap()
 253    }
 254
 255    pub fn email(&self) -> Option<String> {
 256        match &self.0.principal {
 257            Principal::User(user) => user.email_address.clone(),
 258            Principal::Impersonated { user, .. } => user.email_address.clone(),
 259            Principal::DevServer(..) => None,
 260        }
 261    }
 262}
 263
 264impl Deref for UserSession {
 265    type Target = Session;
 266
 267    fn deref(&self) -> &Self::Target {
 268        &self.0
 269    }
 270}
 271impl DerefMut for UserSession {
 272    fn deref_mut(&mut self) -> &mut Self::Target {
 273        &mut self.0
 274    }
 275}
 276
 277struct DevServerSession(Session);
 278
 279impl DevServerSession {
 280    pub fn new(s: Session) -> Option<Self> {
 281        s.dev_server_id().map(|_| DevServerSession(s))
 282    }
 283    pub fn dev_server_id(&self) -> DevServerId {
 284        self.0.dev_server_id().unwrap()
 285    }
 286
 287    fn dev_server(&self) -> &dev_server::Model {
 288        match &self.0.principal {
 289            Principal::DevServer(dev_server) => dev_server,
 290            _ => unreachable!(),
 291        }
 292    }
 293}
 294
 295impl Deref for DevServerSession {
 296    type Target = Session;
 297
 298    fn deref(&self) -> &Self::Target {
 299        &self.0
 300    }
 301}
 302impl DerefMut for DevServerSession {
 303    fn deref_mut(&mut self) -> &mut Self::Target {
 304        &mut self.0
 305    }
 306}
 307
 308fn user_handler<M: RequestMessage, Fut>(
 309    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 310) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 311where
 312    Fut: Send + Future<Output = Result<()>>,
 313{
 314    let handler = Arc::new(handler);
 315    move |message, response, session| {
 316        let handler = handler.clone();
 317        Box::pin(async move {
 318            if let Some(user_session) = session.for_user() {
 319                Ok(handler(message, response, user_session).await?)
 320            } else {
 321                Err(Error::Internal(anyhow!(
 322                    "must be a user to call {}",
 323                    M::NAME
 324                )))
 325            }
 326        })
 327    }
 328}
 329
 330fn dev_server_handler<M: RequestMessage, Fut>(
 331    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 332) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 333where
 334    Fut: Send + Future<Output = Result<()>>,
 335{
 336    let handler = Arc::new(handler);
 337    move |message, response, session| {
 338        let handler = handler.clone();
 339        Box::pin(async move {
 340            if let Some(dev_server_session) = session.for_dev_server() {
 341                Ok(handler(message, response, dev_server_session).await?)
 342            } else {
 343                Err(Error::Internal(anyhow!(
 344                    "must be a dev server to call {}",
 345                    M::NAME
 346                )))
 347            }
 348        })
 349    }
 350}
 351
 352fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 353    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 354) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 355where
 356    InnertRetFut: Send + Future<Output = Result<()>>,
 357{
 358    let handler = Arc::new(handler);
 359    move |message, session| {
 360        let handler = handler.clone();
 361        Box::pin(async move {
 362            if let Some(user_session) = session.for_user() {
 363                Ok(handler(message, user_session).await?)
 364            } else {
 365                Err(Error::Internal(anyhow!(
 366                    "must be a user to call {}",
 367                    M::NAME
 368                )))
 369            }
 370        })
 371    }
 372}
 373
 374struct DbHandle(Arc<Database>);
 375
 376impl Deref for DbHandle {
 377    type Target = Database;
 378
 379    fn deref(&self) -> &Self::Target {
 380        self.0.as_ref()
 381    }
 382}
 383
 384pub struct Server {
 385    id: parking_lot::Mutex<ServerId>,
 386    peer: Arc<Peer>,
 387    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 388    app_state: Arc<AppState>,
 389    handlers: HashMap<TypeId, MessageHandler>,
 390    teardown: watch::Sender<bool>,
 391}
 392
 393pub(crate) struct ConnectionPoolGuard<'a> {
 394    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 395    _not_send: PhantomData<Rc<()>>,
 396}
 397
 398#[derive(Serialize)]
 399pub struct ServerSnapshot<'a> {
 400    peer: &'a Peer,
 401    #[serde(serialize_with = "serialize_deref")]
 402    connection_pool: ConnectionPoolGuard<'a>,
 403}
 404
 405pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 406where
 407    S: Serializer,
 408    T: Deref<Target = U>,
 409    U: Serialize,
 410{
 411    Serialize::serialize(value.deref(), serializer)
 412}
 413
 414impl Server {
 415    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 416        let mut server = Self {
 417            id: parking_lot::Mutex::new(id),
 418            peer: Peer::new(id.0 as u32),
 419            app_state: app_state.clone(),
 420            connection_pool: Default::default(),
 421            handlers: Default::default(),
 422            teardown: watch::channel(false).0,
 423        };
 424
 425        server
 426            .add_request_handler(ping)
 427            .add_request_handler(user_handler(create_room))
 428            .add_request_handler(user_handler(join_room))
 429            .add_request_handler(user_handler(rejoin_room))
 430            .add_request_handler(user_handler(leave_room))
 431            .add_request_handler(user_handler(set_room_participant_role))
 432            .add_request_handler(user_handler(call))
 433            .add_request_handler(user_handler(cancel_call))
 434            .add_message_handler(user_message_handler(decline_call))
 435            .add_request_handler(user_handler(update_participant_location))
 436            .add_request_handler(user_handler(share_project))
 437            .add_message_handler(unshare_project)
 438            .add_request_handler(user_handler(join_project))
 439            .add_request_handler(user_handler(join_hosted_project))
 440            .add_request_handler(user_handler(rejoin_dev_server_projects))
 441            .add_request_handler(user_handler(create_dev_server_project))
 442            .add_request_handler(user_handler(update_dev_server_project))
 443            .add_request_handler(user_handler(delete_dev_server_project))
 444            .add_request_handler(user_handler(create_dev_server))
 445            .add_request_handler(user_handler(regenerate_dev_server_token))
 446            .add_request_handler(user_handler(rename_dev_server))
 447            .add_request_handler(user_handler(delete_dev_server))
 448            .add_request_handler(user_handler(list_remote_directory))
 449            .add_request_handler(dev_server_handler(share_dev_server_project))
 450            .add_request_handler(dev_server_handler(shutdown_dev_server))
 451            .add_request_handler(dev_server_handler(reconnect_dev_server))
 452            .add_message_handler(user_message_handler(leave_project))
 453            .add_request_handler(update_project)
 454            .add_request_handler(update_worktree)
 455            .add_message_handler(start_language_server)
 456            .add_message_handler(update_language_server)
 457            .add_message_handler(update_diagnostic_summary)
 458            .add_message_handler(update_worktree_settings)
 459            .add_request_handler(user_handler(
 460                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 461            ))
 462            .add_request_handler(user_handler(
 463                forward_project_request_for_owner::<proto::TaskTemplates>,
 464            ))
 465            .add_request_handler(user_handler(
 466                forward_read_only_project_request::<proto::GetHover>,
 467            ))
 468            .add_request_handler(user_handler(
 469                forward_read_only_project_request::<proto::GetDefinition>,
 470            ))
 471            .add_request_handler(user_handler(
 472                forward_read_only_project_request::<proto::GetTypeDefinition>,
 473            ))
 474            .add_request_handler(user_handler(
 475                forward_read_only_project_request::<proto::GetReferences>,
 476            ))
 477            .add_request_handler(user_handler(forward_find_search_candidates_request))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetProjectSymbols>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::OpenBufferById>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_read_only_project_request::<proto::InlayHints>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_read_only_project_request::<proto::ResolveInlayHint>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_read_only_project_request::<proto::OpenBufferByPath>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::GetCompletions>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::OpenNewBuffer>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::GetCodeActions>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ApplyCodeAction>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::PrepareRename>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::PerformRename>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::ReloadBuffers>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::FormatBuffers>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::CreateProjectEntry>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::RenameProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::CopyProjectEntry>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::OnTypeFormatting>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::SaveBuffer>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::BlameBuffer>,
 555            ))
 556            .add_request_handler(user_handler(
 557                forward_mutating_project_request::<proto::MultiLspQuery>,
 558            ))
 559            .add_request_handler(user_handler(
 560                forward_mutating_project_request::<proto::RestartLanguageServers>,
 561            ))
 562            .add_request_handler(user_handler(
 563                forward_mutating_project_request::<proto::LinkedEditingRange>,
 564            ))
 565            .add_message_handler(create_buffer_for_peer)
 566            .add_request_handler(update_buffer)
 567            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 568            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 569            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 571            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 572            .add_request_handler(get_users)
 573            .add_request_handler(user_handler(fuzzy_search_users))
 574            .add_request_handler(user_handler(request_contact))
 575            .add_request_handler(user_handler(remove_contact))
 576            .add_request_handler(user_handler(respond_to_contact_request))
 577            .add_message_handler(subscribe_to_channels)
 578            .add_request_handler(user_handler(create_channel))
 579            .add_request_handler(user_handler(delete_channel))
 580            .add_request_handler(user_handler(invite_channel_member))
 581            .add_request_handler(user_handler(remove_channel_member))
 582            .add_request_handler(user_handler(set_channel_member_role))
 583            .add_request_handler(user_handler(set_channel_visibility))
 584            .add_request_handler(user_handler(rename_channel))
 585            .add_request_handler(user_handler(join_channel_buffer))
 586            .add_request_handler(user_handler(leave_channel_buffer))
 587            .add_message_handler(user_message_handler(update_channel_buffer))
 588            .add_request_handler(user_handler(rejoin_channel_buffers))
 589            .add_request_handler(user_handler(get_channel_members))
 590            .add_request_handler(user_handler(respond_to_channel_invite))
 591            .add_request_handler(user_handler(join_channel))
 592            .add_request_handler(user_handler(join_channel_chat))
 593            .add_message_handler(user_message_handler(leave_channel_chat))
 594            .add_request_handler(user_handler(send_channel_message))
 595            .add_request_handler(user_handler(remove_channel_message))
 596            .add_request_handler(user_handler(update_channel_message))
 597            .add_request_handler(user_handler(get_channel_messages))
 598            .add_request_handler(user_handler(get_channel_messages_by_id))
 599            .add_request_handler(user_handler(get_notifications))
 600            .add_request_handler(user_handler(mark_notification_as_read))
 601            .add_request_handler(user_handler(move_channel))
 602            .add_request_handler(user_handler(follow))
 603            .add_message_handler(user_message_handler(unfollow))
 604            .add_message_handler(user_message_handler(update_followers))
 605            .add_request_handler(user_handler(get_private_user_info))
 606            .add_request_handler(user_handler(get_llm_api_token))
 607            .add_request_handler(user_handler(accept_terms_of_service))
 608            .add_message_handler(user_message_handler(acknowledge_channel_message))
 609            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 610            .add_request_handler(user_handler(get_supermaven_api_key))
 611            .add_request_handler(user_handler(
 612                forward_mutating_project_request::<proto::OpenContext>,
 613            ))
 614            .add_request_handler(user_handler(
 615                forward_mutating_project_request::<proto::CreateContext>,
 616            ))
 617            .add_request_handler(user_handler(
 618                forward_mutating_project_request::<proto::SynchronizeContexts>,
 619            ))
 620            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 621            .add_message_handler(update_context)
 622            .add_request_handler({
 623                let app_state = app_state.clone();
 624                move |request, response, session| {
 625                    let app_state = app_state.clone();
 626                    async move {
 627                        count_language_model_tokens(request, response, session, &app_state.config)
 628                            .await
 629                    }
 630                }
 631            })
 632            .add_request_handler({
 633                user_handler(move |request, response, session| {
 634                    get_cached_embeddings(request, response, session)
 635                })
 636            })
 637            .add_request_handler({
 638                let app_state = app_state.clone();
 639                user_handler(move |request, response, session| {
 640                    compute_embeddings(
 641                        request,
 642                        response,
 643                        session,
 644                        app_state.config.openai_api_key.clone(),
 645                    )
 646                })
 647            });
 648
 649        Arc::new(server)
 650    }
 651
 652    pub async fn start(&self) -> Result<()> {
 653        let server_id = *self.id.lock();
 654        let app_state = self.app_state.clone();
 655        let peer = self.peer.clone();
 656        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 657        let pool = self.connection_pool.clone();
 658        let live_kit_client = self.app_state.live_kit_client.clone();
 659
 660        let span = info_span!("start server");
 661        self.app_state.executor.spawn_detached(
 662            async move {
 663                tracing::info!("waiting for cleanup timeout");
 664                timeout.await;
 665                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 666                if let Some((room_ids, channel_ids)) = app_state
 667                    .db
 668                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 669                    .await
 670                    .trace_err()
 671                {
 672                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 673                    tracing::info!(
 674                        stale_channel_buffer_count = channel_ids.len(),
 675                        "retrieved stale channel buffers"
 676                    );
 677
 678                    for channel_id in channel_ids {
 679                        if let Some(refreshed_channel_buffer) = app_state
 680                            .db
 681                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 682                            .await
 683                            .trace_err()
 684                        {
 685                            for connection_id in refreshed_channel_buffer.connection_ids {
 686                                peer.send(
 687                                    connection_id,
 688                                    proto::UpdateChannelBufferCollaborators {
 689                                        channel_id: channel_id.to_proto(),
 690                                        collaborators: refreshed_channel_buffer
 691                                            .collaborators
 692                                            .clone(),
 693                                    },
 694                                )
 695                                .trace_err();
 696                            }
 697                        }
 698                    }
 699
 700                    for room_id in room_ids {
 701                        let mut contacts_to_update = HashSet::default();
 702                        let mut canceled_calls_to_user_ids = Vec::new();
 703                        let mut live_kit_room = String::new();
 704                        let mut delete_live_kit_room = false;
 705
 706                        if let Some(mut refreshed_room) = app_state
 707                            .db
 708                            .clear_stale_room_participants(room_id, server_id)
 709                            .await
 710                            .trace_err()
 711                        {
 712                            tracing::info!(
 713                                room_id = room_id.0,
 714                                new_participant_count = refreshed_room.room.participants.len(),
 715                                "refreshed room"
 716                            );
 717                            room_updated(&refreshed_room.room, &peer);
 718                            if let Some(channel) = refreshed_room.channel.as_ref() {
 719                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 720                            }
 721                            contacts_to_update
 722                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 723                            contacts_to_update
 724                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 725                            canceled_calls_to_user_ids =
 726                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 727                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 728                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 729                        }
 730
 731                        {
 732                            let pool = pool.lock();
 733                            for canceled_user_id in canceled_calls_to_user_ids {
 734                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 735                                    peer.send(
 736                                        connection_id,
 737                                        proto::CallCanceled {
 738                                            room_id: room_id.to_proto(),
 739                                        },
 740                                    )
 741                                    .trace_err();
 742                                }
 743                            }
 744                        }
 745
 746                        for user_id in contacts_to_update {
 747                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 748                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 749                            if let Some((busy, contacts)) = busy.zip(contacts) {
 750                                let pool = pool.lock();
 751                                let updated_contact = contact_for_user(user_id, busy, &pool);
 752                                for contact in contacts {
 753                                    if let db::Contact::Accepted {
 754                                        user_id: contact_user_id,
 755                                        ..
 756                                    } = contact
 757                                    {
 758                                        for contact_conn_id in
 759                                            pool.user_connection_ids(contact_user_id)
 760                                        {
 761                                            peer.send(
 762                                                contact_conn_id,
 763                                                proto::UpdateContacts {
 764                                                    contacts: vec![updated_contact.clone()],
 765                                                    remove_contacts: Default::default(),
 766                                                    incoming_requests: Default::default(),
 767                                                    remove_incoming_requests: Default::default(),
 768                                                    outgoing_requests: Default::default(),
 769                                                    remove_outgoing_requests: Default::default(),
 770                                                },
 771                                            )
 772                                            .trace_err();
 773                                        }
 774                                    }
 775                                }
 776                            }
 777                        }
 778
 779                        if let Some(live_kit) = live_kit_client.as_ref() {
 780                            if delete_live_kit_room {
 781                                live_kit.delete_room(live_kit_room).await.trace_err();
 782                            }
 783                        }
 784                    }
 785                }
 786
 787                app_state
 788                    .db
 789                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 790                    .await
 791                    .trace_err();
 792            }
 793            .instrument(span),
 794        );
 795        Ok(())
 796    }
 797
 798    pub fn teardown(&self) {
 799        self.peer.teardown();
 800        self.connection_pool.lock().reset();
 801        let _ = self.teardown.send(true);
 802    }
 803
 804    #[cfg(test)]
 805    pub fn reset(&self, id: ServerId) {
 806        self.teardown();
 807        *self.id.lock() = id;
 808        self.peer.reset(id.0 as u32);
 809        let _ = self.teardown.send(false);
 810    }
 811
 812    #[cfg(test)]
 813    pub fn id(&self) -> ServerId {
 814        *self.id.lock()
 815    }
 816
 817    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 818    where
 819        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 820        Fut: 'static + Send + Future<Output = Result<()>>,
 821        M: EnvelopedMessage,
 822    {
 823        let prev_handler = self.handlers.insert(
 824            TypeId::of::<M>(),
 825            Box::new(move |envelope, session| {
 826                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 827                let received_at = envelope.received_at;
 828                tracing::info!("message received");
 829                let start_time = Instant::now();
 830                let future = (handler)(*envelope, session);
 831                async move {
 832                    let result = future.await;
 833                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 834                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 835                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 836                    let payload_type = M::NAME;
 837
 838                    match result {
 839                        Err(error) => {
 840                            tracing::error!(
 841                                ?error,
 842                                total_duration_ms,
 843                                processing_duration_ms,
 844                                queue_duration_ms,
 845                                payload_type,
 846                                "error handling message"
 847                            )
 848                        }
 849                        Ok(()) => tracing::info!(
 850                            total_duration_ms,
 851                            processing_duration_ms,
 852                            queue_duration_ms,
 853                            "finished handling message"
 854                        ),
 855                    }
 856                }
 857                .boxed()
 858            }),
 859        );
 860        if prev_handler.is_some() {
 861            panic!("registered a handler for the same message twice");
 862        }
 863        self
 864    }
 865
 866    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 867    where
 868        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 869        Fut: 'static + Send + Future<Output = Result<()>>,
 870        M: EnvelopedMessage,
 871    {
 872        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 873        self
 874    }
 875
 876    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 877    where
 878        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 879        Fut: Send + Future<Output = Result<()>>,
 880        M: RequestMessage,
 881    {
 882        let handler = Arc::new(handler);
 883        self.add_handler(move |envelope, session| {
 884            let receipt = envelope.receipt();
 885            let handler = handler.clone();
 886            async move {
 887                let peer = session.peer.clone();
 888                let responded = Arc::new(AtomicBool::default());
 889                let response = Response {
 890                    peer: peer.clone(),
 891                    responded: responded.clone(),
 892                    receipt,
 893                };
 894                match (handler)(envelope.payload, response, session).await {
 895                    Ok(()) => {
 896                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 897                            Ok(())
 898                        } else {
 899                            Err(anyhow!("handler did not send a response"))?
 900                        }
 901                    }
 902                    Err(error) => {
 903                        let proto_err = match &error {
 904                            Error::Internal(err) => err.to_proto(),
 905                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 906                        };
 907                        peer.respond_with_error(receipt, proto_err)?;
 908                        Err(error)
 909                    }
 910                }
 911            }
 912        })
 913    }
 914
 915    #[allow(clippy::too_many_arguments)]
 916    pub fn handle_connection(
 917        self: &Arc<Self>,
 918        connection: Connection,
 919        address: String,
 920        principal: Principal,
 921        zed_version: ZedVersion,
 922        geoip_country_code: Option<String>,
 923        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 924        executor: Executor,
 925    ) -> impl Future<Output = ()> {
 926        let this = self.clone();
 927        let span = info_span!("handle connection", %address,
 928            connection_id=field::Empty,
 929            user_id=field::Empty,
 930            login=field::Empty,
 931            impersonator=field::Empty,
 932            dev_server_id=field::Empty,
 933            geoip_country_code=field::Empty
 934        );
 935        principal.update_span(&span);
 936        if let Some(country_code) = geoip_country_code.as_ref() {
 937            span.record("geoip_country_code", country_code);
 938        }
 939
 940        let mut teardown = self.teardown.subscribe();
 941        async move {
 942            if *teardown.borrow() {
 943                tracing::error!("server is tearing down");
 944                return
 945            }
 946            let (connection_id, handle_io, mut incoming_rx) = this
 947                .peer
 948                .add_connection(connection, {
 949                    let executor = executor.clone();
 950                    move |duration| executor.sleep(duration)
 951                });
 952            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 953
 954            tracing::info!("connection opened");
 955
 956            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 957            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 958                Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
 959                Err(error) => {
 960                    tracing::error!(?error, "failed to create HTTP client");
 961                    return;
 962                }
 963            };
 964
 965            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 966                    supermaven_admin_api_key.to_string(),
 967                    http_client.clone(),
 968                )));
 969
 970            let session = Session {
 971                principal: principal.clone(),
 972                connection_id,
 973                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 974                peer: this.peer.clone(),
 975                connection_pool: this.connection_pool.clone(),
 976                app_state: this.app_state.clone(),
 977                http_client,
 978                geoip_country_code,
 979                _executor: executor.clone(),
 980                supermaven_client,
 981            };
 982
 983            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 984                tracing::error!(?error, "failed to send initial client update");
 985                return;
 986            }
 987
 988            let handle_io = handle_io.fuse();
 989            futures::pin_mut!(handle_io);
 990
 991            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 992            // This prevents deadlocks when e.g., client A performs a request to client B and
 993            // client B performs a request to client A. If both clients stop processing further
 994            // messages until their respective request completes, they won't have a chance to
 995            // respond to the other client's request and cause a deadlock.
 996            //
 997            // This arrangement ensures we will attempt to process earlier messages first, but fall
 998            // back to processing messages arrived later in the spirit of making progress.
 999            let mut foreground_message_handlers = FuturesUnordered::new();
1000            let concurrent_handlers = Arc::new(Semaphore::new(256));
1001            loop {
1002                let next_message = async {
1003                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1004                    let message = incoming_rx.next().await;
1005                    (permit, message)
1006                }.fuse();
1007                futures::pin_mut!(next_message);
1008                futures::select_biased! {
1009                    _ = teardown.changed().fuse() => return,
1010                    result = handle_io => {
1011                        if let Err(error) = result {
1012                            tracing::error!(?error, "error handling I/O");
1013                        }
1014                        break;
1015                    }
1016                    _ = foreground_message_handlers.next() => {}
1017                    next_message = next_message => {
1018                        let (permit, message) = next_message;
1019                        if let Some(message) = message {
1020                            let type_name = message.payload_type_name();
1021                            // note: we copy all the fields from the parent span so we can query them in the logs.
1022                            // (https://github.com/tokio-rs/tracing/issues/2670).
1023                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1024                                user_id=field::Empty,
1025                                login=field::Empty,
1026                                impersonator=field::Empty,
1027                                dev_server_id=field::Empty
1028                            );
1029                            principal.update_span(&span);
1030                            let span_enter = span.enter();
1031                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1032                                let is_background = message.is_background();
1033                                let handle_message = (handler)(message, session.clone());
1034                                drop(span_enter);
1035
1036                                let handle_message = async move {
1037                                    handle_message.await;
1038                                    drop(permit);
1039                                }.instrument(span);
1040                                if is_background {
1041                                    executor.spawn_detached(handle_message);
1042                                } else {
1043                                    foreground_message_handlers.push(handle_message);
1044                                }
1045                            } else {
1046                                tracing::error!("no message handler");
1047                            }
1048                        } else {
1049                            tracing::info!("connection closed");
1050                            break;
1051                        }
1052                    }
1053                }
1054            }
1055
1056            drop(foreground_message_handlers);
1057            tracing::info!("signing out");
1058            if let Err(error) = connection_lost(session, teardown, executor).await {
1059                tracing::error!(?error, "error signing out");
1060            }
1061
1062        }.instrument(span)
1063    }
1064
1065    async fn send_initial_client_update(
1066        &self,
1067        connection_id: ConnectionId,
1068        principal: &Principal,
1069        zed_version: ZedVersion,
1070        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1071        session: &Session,
1072    ) -> Result<()> {
1073        self.peer.send(
1074            connection_id,
1075            proto::Hello {
1076                peer_id: Some(connection_id.into()),
1077            },
1078        )?;
1079        tracing::info!("sent hello message");
1080        if let Some(send_connection_id) = send_connection_id.take() {
1081            let _ = send_connection_id.send(connection_id);
1082        }
1083
1084        match principal {
1085            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1086                if !user.connected_once {
1087                    self.peer.send(connection_id, proto::ShowContacts {})?;
1088                    self.app_state
1089                        .db
1090                        .set_user_connected_once(user.id, true)
1091                        .await?;
1092                }
1093
1094                update_user_plan(user.id, session).await?;
1095
1096                let (contacts, dev_server_projects) = future::try_join(
1097                    self.app_state.db.get_contacts(user.id),
1098                    self.app_state.db.dev_server_projects_update(user.id),
1099                )
1100                .await?;
1101
1102                {
1103                    let mut pool = self.connection_pool.lock();
1104                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1105                    self.peer.send(
1106                        connection_id,
1107                        build_initial_contacts_update(contacts, &pool),
1108                    )?;
1109                }
1110
1111                if should_auto_subscribe_to_channels(zed_version) {
1112                    subscribe_user_to_channels(user.id, session).await?;
1113                }
1114
1115                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1116
1117                if let Some(incoming_call) =
1118                    self.app_state.db.incoming_call_for_user(user.id).await?
1119                {
1120                    self.peer.send(connection_id, incoming_call)?;
1121                }
1122
1123                update_user_contacts(user.id, session).await?;
1124            }
1125            Principal::DevServer(dev_server) => {
1126                {
1127                    let mut pool = self.connection_pool.lock();
1128                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1129                    {
1130                        self.peer.send(
1131                            stale_connection_id,
1132                            proto::ShutdownDevServer {
1133                                reason: Some(
1134                                    "another dev server connected with the same token".to_string(),
1135                                ),
1136                            },
1137                        )?;
1138                        pool.remove_connection(stale_connection_id)?;
1139                    };
1140                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1141                }
1142
1143                let projects = self
1144                    .app_state
1145                    .db
1146                    .get_projects_for_dev_server(dev_server.id)
1147                    .await?;
1148                self.peer
1149                    .send(connection_id, proto::DevServerInstructions { projects })?;
1150
1151                let status = self
1152                    .app_state
1153                    .db
1154                    .dev_server_projects_update(dev_server.user_id)
1155                    .await?;
1156                send_dev_server_projects_update(dev_server.user_id, status, session).await;
1157            }
1158        }
1159
1160        Ok(())
1161    }
1162
1163    pub async fn invite_code_redeemed(
1164        self: &Arc<Self>,
1165        inviter_id: UserId,
1166        invitee_id: UserId,
1167    ) -> Result<()> {
1168        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1169            if let Some(code) = &user.invite_code {
1170                let pool = self.connection_pool.lock();
1171                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1172                for connection_id in pool.user_connection_ids(inviter_id) {
1173                    self.peer.send(
1174                        connection_id,
1175                        proto::UpdateContacts {
1176                            contacts: vec![invitee_contact.clone()],
1177                            ..Default::default()
1178                        },
1179                    )?;
1180                    self.peer.send(
1181                        connection_id,
1182                        proto::UpdateInviteInfo {
1183                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1184                            count: user.invite_count as u32,
1185                        },
1186                    )?;
1187                }
1188            }
1189        }
1190        Ok(())
1191    }
1192
1193    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1194        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1195            if let Some(invite_code) = &user.invite_code {
1196                let pool = self.connection_pool.lock();
1197                for connection_id in pool.user_connection_ids(user_id) {
1198                    self.peer.send(
1199                        connection_id,
1200                        proto::UpdateInviteInfo {
1201                            url: format!(
1202                                "{}{}",
1203                                self.app_state.config.invite_link_prefix, invite_code
1204                            ),
1205                            count: user.invite_count as u32,
1206                        },
1207                    )?;
1208                }
1209            }
1210        }
1211        Ok(())
1212    }
1213
1214    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1215        ServerSnapshot {
1216            connection_pool: ConnectionPoolGuard {
1217                guard: self.connection_pool.lock(),
1218                _not_send: PhantomData,
1219            },
1220            peer: &self.peer,
1221        }
1222    }
1223}
1224
1225impl<'a> Deref for ConnectionPoolGuard<'a> {
1226    type Target = ConnectionPool;
1227
1228    fn deref(&self) -> &Self::Target {
1229        &self.guard
1230    }
1231}
1232
1233impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1234    fn deref_mut(&mut self) -> &mut Self::Target {
1235        &mut self.guard
1236    }
1237}
1238
1239impl<'a> Drop for ConnectionPoolGuard<'a> {
1240    fn drop(&mut self) {
1241        #[cfg(test)]
1242        self.check_invariants();
1243    }
1244}
1245
1246fn broadcast<F>(
1247    sender_id: Option<ConnectionId>,
1248    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1249    mut f: F,
1250) where
1251    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1252{
1253    for receiver_id in receiver_ids {
1254        if Some(receiver_id) != sender_id {
1255            if let Err(error) = f(receiver_id) {
1256                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1257            }
1258        }
1259    }
1260}
1261
1262pub struct ProtocolVersion(u32);
1263
1264impl Header for ProtocolVersion {
1265    fn name() -> &'static HeaderName {
1266        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1267        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1268    }
1269
1270    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1271    where
1272        Self: Sized,
1273        I: Iterator<Item = &'i axum::http::HeaderValue>,
1274    {
1275        let version = values
1276            .next()
1277            .ok_or_else(axum::headers::Error::invalid)?
1278            .to_str()
1279            .map_err(|_| axum::headers::Error::invalid())?
1280            .parse()
1281            .map_err(|_| axum::headers::Error::invalid())?;
1282        Ok(Self(version))
1283    }
1284
1285    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1286        values.extend([self.0.to_string().parse().unwrap()]);
1287    }
1288}
1289
1290pub struct AppVersionHeader(SemanticVersion);
1291impl Header for AppVersionHeader {
1292    fn name() -> &'static HeaderName {
1293        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1294        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1295    }
1296
1297    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1298    where
1299        Self: Sized,
1300        I: Iterator<Item = &'i axum::http::HeaderValue>,
1301    {
1302        let version = values
1303            .next()
1304            .ok_or_else(axum::headers::Error::invalid)?
1305            .to_str()
1306            .map_err(|_| axum::headers::Error::invalid())?
1307            .parse()
1308            .map_err(|_| axum::headers::Error::invalid())?;
1309        Ok(Self(version))
1310    }
1311
1312    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1313        values.extend([self.0.to_string().parse().unwrap()]);
1314    }
1315}
1316
1317pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1318    Router::new()
1319        .route("/rpc", get(handle_websocket_request))
1320        .layer(
1321            ServiceBuilder::new()
1322                .layer(Extension(server.app_state.clone()))
1323                .layer(middleware::from_fn(auth::validate_header)),
1324        )
1325        .route("/metrics", get(handle_metrics))
1326        .layer(Extension(server))
1327}
1328
1329pub async fn handle_websocket_request(
1330    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1331    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1332    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1333    Extension(server): Extension<Arc<Server>>,
1334    Extension(principal): Extension<Principal>,
1335    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1336    ws: WebSocketUpgrade,
1337) -> axum::response::Response {
1338    if protocol_version != rpc::PROTOCOL_VERSION {
1339        return (
1340            StatusCode::UPGRADE_REQUIRED,
1341            "client must be upgraded".to_string(),
1342        )
1343            .into_response();
1344    }
1345
1346    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1347        return (
1348            StatusCode::UPGRADE_REQUIRED,
1349            "no version header found".to_string(),
1350        )
1351            .into_response();
1352    };
1353
1354    if !version.can_collaborate() {
1355        return (
1356            StatusCode::UPGRADE_REQUIRED,
1357            "client must be upgraded".to_string(),
1358        )
1359            .into_response();
1360    }
1361
1362    let socket_address = socket_address.to_string();
1363    ws.on_upgrade(move |socket| {
1364        let socket = socket
1365            .map_ok(to_tungstenite_message)
1366            .err_into()
1367            .with(|message| async move { to_axum_message(message) });
1368        let connection = Connection::new(Box::pin(socket));
1369        async move {
1370            server
1371                .handle_connection(
1372                    connection,
1373                    socket_address,
1374                    principal,
1375                    version,
1376                    country_code_header.map(|header| header.to_string()),
1377                    None,
1378                    Executor::Production,
1379                )
1380                .await;
1381        }
1382    })
1383}
1384
1385pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1386    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1387    let connections_metric = CONNECTIONS_METRIC
1388        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1389
1390    let connections = server
1391        .connection_pool
1392        .lock()
1393        .connections()
1394        .filter(|connection| !connection.admin)
1395        .count();
1396    connections_metric.set(connections as _);
1397
1398    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1399    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1400        register_int_gauge!(
1401            "shared_projects",
1402            "number of open projects with one or more guests"
1403        )
1404        .unwrap()
1405    });
1406
1407    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1408    shared_projects_metric.set(shared_projects as _);
1409
1410    let encoder = prometheus::TextEncoder::new();
1411    let metric_families = prometheus::gather();
1412    let encoded_metrics = encoder
1413        .encode_to_string(&metric_families)
1414        .map_err(|err| anyhow!("{}", err))?;
1415    Ok(encoded_metrics)
1416}
1417
1418#[instrument(err, skip(executor))]
1419async fn connection_lost(
1420    session: Session,
1421    mut teardown: watch::Receiver<bool>,
1422    executor: Executor,
1423) -> Result<()> {
1424    session.peer.disconnect(session.connection_id);
1425    session
1426        .connection_pool()
1427        .await
1428        .remove_connection(session.connection_id)?;
1429
1430    session
1431        .db()
1432        .await
1433        .connection_lost(session.connection_id)
1434        .await
1435        .trace_err();
1436
1437    futures::select_biased! {
1438        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1439            match &session.principal {
1440                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1441                    let session = session.for_user().unwrap();
1442
1443                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1444                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1445                    leave_channel_buffers_for_session(&session)
1446                        .await
1447                        .trace_err();
1448
1449                    if !session
1450                        .connection_pool()
1451                        .await
1452                        .is_user_online(session.user_id())
1453                    {
1454                        let db = session.db().await;
1455                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1456                            room_updated(&room, &session.peer);
1457                        }
1458                    }
1459
1460                    update_user_contacts(session.user_id(), &session).await?;
1461                },
1462            Principal::DevServer(_) => {
1463                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1464            },
1465        }
1466        },
1467        _ = teardown.changed().fuse() => {}
1468    }
1469
1470    Ok(())
1471}
1472
1473/// Acknowledges a ping from a client, used to keep the connection alive.
1474async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1475    response.send(proto::Ack {})?;
1476    Ok(())
1477}
1478
1479/// Creates a new room for calling (outside of channels)
1480async fn create_room(
1481    _request: proto::CreateRoom,
1482    response: Response<proto::CreateRoom>,
1483    session: UserSession,
1484) -> Result<()> {
1485    let live_kit_room = nanoid::nanoid!(30);
1486
1487    let live_kit_connection_info = util::maybe!(async {
1488        let live_kit = session.app_state.live_kit_client.as_ref();
1489        let live_kit = live_kit?;
1490        let user_id = session.user_id().to_string();
1491
1492        let token = live_kit
1493            .room_token(&live_kit_room, &user_id.to_string())
1494            .trace_err()?;
1495
1496        Some(proto::LiveKitConnectionInfo {
1497            server_url: live_kit.url().into(),
1498            token,
1499            can_publish: true,
1500        })
1501    })
1502    .await;
1503
1504    let room = session
1505        .db()
1506        .await
1507        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1508        .await?;
1509
1510    response.send(proto::CreateRoomResponse {
1511        room: Some(room.clone()),
1512        live_kit_connection_info,
1513    })?;
1514
1515    update_user_contacts(session.user_id(), &session).await?;
1516    Ok(())
1517}
1518
1519/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1520async fn join_room(
1521    request: proto::JoinRoom,
1522    response: Response<proto::JoinRoom>,
1523    session: UserSession,
1524) -> Result<()> {
1525    let room_id = RoomId::from_proto(request.id);
1526
1527    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1528
1529    if let Some(channel_id) = channel_id {
1530        return join_channel_internal(channel_id, Box::new(response), session).await;
1531    }
1532
1533    let joined_room = {
1534        let room = session
1535            .db()
1536            .await
1537            .join_room(room_id, session.user_id(), session.connection_id)
1538            .await?;
1539        room_updated(&room.room, &session.peer);
1540        room.into_inner()
1541    };
1542
1543    for connection_id in session
1544        .connection_pool()
1545        .await
1546        .user_connection_ids(session.user_id())
1547    {
1548        session
1549            .peer
1550            .send(
1551                connection_id,
1552                proto::CallCanceled {
1553                    room_id: room_id.to_proto(),
1554                },
1555            )
1556            .trace_err();
1557    }
1558
1559    let live_kit_connection_info =
1560        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1561            live_kit
1562                .room_token(
1563                    &joined_room.room.live_kit_room,
1564                    &session.user_id().to_string(),
1565                )
1566                .trace_err()
1567                .map(|token| proto::LiveKitConnectionInfo {
1568                    server_url: live_kit.url().into(),
1569                    token,
1570                    can_publish: true,
1571                })
1572        } else {
1573            None
1574        };
1575
1576    response.send(proto::JoinRoomResponse {
1577        room: Some(joined_room.room),
1578        channel_id: None,
1579        live_kit_connection_info,
1580    })?;
1581
1582    update_user_contacts(session.user_id(), &session).await?;
1583    Ok(())
1584}
1585
1586/// Rejoin room is used to reconnect to a room after connection errors.
1587async fn rejoin_room(
1588    request: proto::RejoinRoom,
1589    response: Response<proto::RejoinRoom>,
1590    session: UserSession,
1591) -> Result<()> {
1592    let room;
1593    let channel;
1594    {
1595        let mut rejoined_room = session
1596            .db()
1597            .await
1598            .rejoin_room(request, session.user_id(), session.connection_id)
1599            .await?;
1600
1601        response.send(proto::RejoinRoomResponse {
1602            room: Some(rejoined_room.room.clone()),
1603            reshared_projects: rejoined_room
1604                .reshared_projects
1605                .iter()
1606                .map(|project| proto::ResharedProject {
1607                    id: project.id.to_proto(),
1608                    collaborators: project
1609                        .collaborators
1610                        .iter()
1611                        .map(|collaborator| collaborator.to_proto())
1612                        .collect(),
1613                })
1614                .collect(),
1615            rejoined_projects: rejoined_room
1616                .rejoined_projects
1617                .iter()
1618                .map(|rejoined_project| rejoined_project.to_proto())
1619                .collect(),
1620        })?;
1621        room_updated(&rejoined_room.room, &session.peer);
1622
1623        for project in &rejoined_room.reshared_projects {
1624            for collaborator in &project.collaborators {
1625                session
1626                    .peer
1627                    .send(
1628                        collaborator.connection_id,
1629                        proto::UpdateProjectCollaborator {
1630                            project_id: project.id.to_proto(),
1631                            old_peer_id: Some(project.old_connection_id.into()),
1632                            new_peer_id: Some(session.connection_id.into()),
1633                        },
1634                    )
1635                    .trace_err();
1636            }
1637
1638            broadcast(
1639                Some(session.connection_id),
1640                project
1641                    .collaborators
1642                    .iter()
1643                    .map(|collaborator| collaborator.connection_id),
1644                |connection_id| {
1645                    session.peer.forward_send(
1646                        session.connection_id,
1647                        connection_id,
1648                        proto::UpdateProject {
1649                            project_id: project.id.to_proto(),
1650                            worktrees: project.worktrees.clone(),
1651                        },
1652                    )
1653                },
1654            );
1655        }
1656
1657        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1658
1659        let rejoined_room = rejoined_room.into_inner();
1660
1661        room = rejoined_room.room;
1662        channel = rejoined_room.channel;
1663    }
1664
1665    if let Some(channel) = channel {
1666        channel_updated(
1667            &channel,
1668            &room,
1669            &session.peer,
1670            &*session.connection_pool().await,
1671        );
1672    }
1673
1674    update_user_contacts(session.user_id(), &session).await?;
1675    Ok(())
1676}
1677
1678fn notify_rejoined_projects(
1679    rejoined_projects: &mut Vec<RejoinedProject>,
1680    session: &UserSession,
1681) -> Result<()> {
1682    for project in rejoined_projects.iter() {
1683        for collaborator in &project.collaborators {
1684            session
1685                .peer
1686                .send(
1687                    collaborator.connection_id,
1688                    proto::UpdateProjectCollaborator {
1689                        project_id: project.id.to_proto(),
1690                        old_peer_id: Some(project.old_connection_id.into()),
1691                        new_peer_id: Some(session.connection_id.into()),
1692                    },
1693                )
1694                .trace_err();
1695        }
1696    }
1697
1698    for project in rejoined_projects {
1699        for worktree in mem::take(&mut project.worktrees) {
1700            #[cfg(any(test, feature = "test-support"))]
1701            const MAX_CHUNK_SIZE: usize = 2;
1702            #[cfg(not(any(test, feature = "test-support")))]
1703            const MAX_CHUNK_SIZE: usize = 256;
1704
1705            // Stream this worktree's entries.
1706            let message = proto::UpdateWorktree {
1707                project_id: project.id.to_proto(),
1708                worktree_id: worktree.id,
1709                abs_path: worktree.abs_path.clone(),
1710                root_name: worktree.root_name,
1711                updated_entries: worktree.updated_entries,
1712                removed_entries: worktree.removed_entries,
1713                scan_id: worktree.scan_id,
1714                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1715                updated_repositories: worktree.updated_repositories,
1716                removed_repositories: worktree.removed_repositories,
1717            };
1718            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1719                session.peer.send(session.connection_id, update.clone())?;
1720            }
1721
1722            // Stream this worktree's diagnostics.
1723            for summary in worktree.diagnostic_summaries {
1724                session.peer.send(
1725                    session.connection_id,
1726                    proto::UpdateDiagnosticSummary {
1727                        project_id: project.id.to_proto(),
1728                        worktree_id: worktree.id,
1729                        summary: Some(summary),
1730                    },
1731                )?;
1732            }
1733
1734            for settings_file in worktree.settings_files {
1735                session.peer.send(
1736                    session.connection_id,
1737                    proto::UpdateWorktreeSettings {
1738                        project_id: project.id.to_proto(),
1739                        worktree_id: worktree.id,
1740                        path: settings_file.path,
1741                        content: Some(settings_file.content),
1742                    },
1743                )?;
1744            }
1745        }
1746
1747        for language_server in &project.language_servers {
1748            session.peer.send(
1749                session.connection_id,
1750                proto::UpdateLanguageServer {
1751                    project_id: project.id.to_proto(),
1752                    language_server_id: language_server.id,
1753                    variant: Some(
1754                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1755                            proto::LspDiskBasedDiagnosticsUpdated {},
1756                        ),
1757                    ),
1758                },
1759            )?;
1760        }
1761    }
1762    Ok(())
1763}
1764
1765/// leave room disconnects from the room.
1766async fn leave_room(
1767    _: proto::LeaveRoom,
1768    response: Response<proto::LeaveRoom>,
1769    session: UserSession,
1770) -> Result<()> {
1771    leave_room_for_session(&session, session.connection_id).await?;
1772    response.send(proto::Ack {})?;
1773    Ok(())
1774}
1775
1776/// Updates the permissions of someone else in the room.
1777async fn set_room_participant_role(
1778    request: proto::SetRoomParticipantRole,
1779    response: Response<proto::SetRoomParticipantRole>,
1780    session: UserSession,
1781) -> Result<()> {
1782    let user_id = UserId::from_proto(request.user_id);
1783    let role = ChannelRole::from(request.role());
1784
1785    let (live_kit_room, can_publish) = {
1786        let room = session
1787            .db()
1788            .await
1789            .set_room_participant_role(
1790                session.user_id(),
1791                RoomId::from_proto(request.room_id),
1792                user_id,
1793                role,
1794            )
1795            .await?;
1796
1797        let live_kit_room = room.live_kit_room.clone();
1798        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1799        room_updated(&room, &session.peer);
1800        (live_kit_room, can_publish)
1801    };
1802
1803    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1804        live_kit
1805            .update_participant(
1806                live_kit_room.clone(),
1807                request.user_id.to_string(),
1808                live_kit_server::proto::ParticipantPermission {
1809                    can_subscribe: true,
1810                    can_publish,
1811                    can_publish_data: can_publish,
1812                    hidden: false,
1813                    recorder: false,
1814                },
1815            )
1816            .await
1817            .trace_err();
1818    }
1819
1820    response.send(proto::Ack {})?;
1821    Ok(())
1822}
1823
1824/// Call someone else into the current room
1825async fn call(
1826    request: proto::Call,
1827    response: Response<proto::Call>,
1828    session: UserSession,
1829) -> Result<()> {
1830    let room_id = RoomId::from_proto(request.room_id);
1831    let calling_user_id = session.user_id();
1832    let calling_connection_id = session.connection_id;
1833    let called_user_id = UserId::from_proto(request.called_user_id);
1834    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1835    if !session
1836        .db()
1837        .await
1838        .has_contact(calling_user_id, called_user_id)
1839        .await?
1840    {
1841        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1842    }
1843
1844    let incoming_call = {
1845        let (room, incoming_call) = &mut *session
1846            .db()
1847            .await
1848            .call(
1849                room_id,
1850                calling_user_id,
1851                calling_connection_id,
1852                called_user_id,
1853                initial_project_id,
1854            )
1855            .await?;
1856        room_updated(room, &session.peer);
1857        mem::take(incoming_call)
1858    };
1859    update_user_contacts(called_user_id, &session).await?;
1860
1861    let mut calls = session
1862        .connection_pool()
1863        .await
1864        .user_connection_ids(called_user_id)
1865        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1866        .collect::<FuturesUnordered<_>>();
1867
1868    while let Some(call_response) = calls.next().await {
1869        match call_response.as_ref() {
1870            Ok(_) => {
1871                response.send(proto::Ack {})?;
1872                return Ok(());
1873            }
1874            Err(_) => {
1875                call_response.trace_err();
1876            }
1877        }
1878    }
1879
1880    {
1881        let room = session
1882            .db()
1883            .await
1884            .call_failed(room_id, called_user_id)
1885            .await?;
1886        room_updated(&room, &session.peer);
1887    }
1888    update_user_contacts(called_user_id, &session).await?;
1889
1890    Err(anyhow!("failed to ring user"))?
1891}
1892
1893/// Cancel an outgoing call.
1894async fn cancel_call(
1895    request: proto::CancelCall,
1896    response: Response<proto::CancelCall>,
1897    session: UserSession,
1898) -> Result<()> {
1899    let called_user_id = UserId::from_proto(request.called_user_id);
1900    let room_id = RoomId::from_proto(request.room_id);
1901    {
1902        let room = session
1903            .db()
1904            .await
1905            .cancel_call(room_id, session.connection_id, called_user_id)
1906            .await?;
1907        room_updated(&room, &session.peer);
1908    }
1909
1910    for connection_id in session
1911        .connection_pool()
1912        .await
1913        .user_connection_ids(called_user_id)
1914    {
1915        session
1916            .peer
1917            .send(
1918                connection_id,
1919                proto::CallCanceled {
1920                    room_id: room_id.to_proto(),
1921                },
1922            )
1923            .trace_err();
1924    }
1925    response.send(proto::Ack {})?;
1926
1927    update_user_contacts(called_user_id, &session).await?;
1928    Ok(())
1929}
1930
1931/// Decline an incoming call.
1932async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1933    let room_id = RoomId::from_proto(message.room_id);
1934    {
1935        let room = session
1936            .db()
1937            .await
1938            .decline_call(Some(room_id), session.user_id())
1939            .await?
1940            .ok_or_else(|| anyhow!("failed to decline call"))?;
1941        room_updated(&room, &session.peer);
1942    }
1943
1944    for connection_id in session
1945        .connection_pool()
1946        .await
1947        .user_connection_ids(session.user_id())
1948    {
1949        session
1950            .peer
1951            .send(
1952                connection_id,
1953                proto::CallCanceled {
1954                    room_id: room_id.to_proto(),
1955                },
1956            )
1957            .trace_err();
1958    }
1959    update_user_contacts(session.user_id(), &session).await?;
1960    Ok(())
1961}
1962
1963/// Updates other participants in the room with your current location.
1964async fn update_participant_location(
1965    request: proto::UpdateParticipantLocation,
1966    response: Response<proto::UpdateParticipantLocation>,
1967    session: UserSession,
1968) -> Result<()> {
1969    let room_id = RoomId::from_proto(request.room_id);
1970    let location = request
1971        .location
1972        .ok_or_else(|| anyhow!("invalid location"))?;
1973
1974    let db = session.db().await;
1975    let room = db
1976        .update_room_participant_location(room_id, session.connection_id, location)
1977        .await?;
1978
1979    room_updated(&room, &session.peer);
1980    response.send(proto::Ack {})?;
1981    Ok(())
1982}
1983
1984/// Share a project into the room.
1985async fn share_project(
1986    request: proto::ShareProject,
1987    response: Response<proto::ShareProject>,
1988    session: UserSession,
1989) -> Result<()> {
1990    let (project_id, room) = &*session
1991        .db()
1992        .await
1993        .share_project(
1994            RoomId::from_proto(request.room_id),
1995            session.connection_id,
1996            &request.worktrees,
1997            request.is_ssh_project,
1998            request
1999                .dev_server_project_id
2000                .map(DevServerProjectId::from_proto),
2001        )
2002        .await?;
2003    response.send(proto::ShareProjectResponse {
2004        project_id: project_id.to_proto(),
2005    })?;
2006    room_updated(room, &session.peer);
2007
2008    Ok(())
2009}
2010
2011/// Unshare a project from the room.
2012async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2013    let project_id = ProjectId::from_proto(message.project_id);
2014    unshare_project_internal(
2015        project_id,
2016        session.connection_id,
2017        session.user_id(),
2018        &session,
2019    )
2020    .await
2021}
2022
2023async fn unshare_project_internal(
2024    project_id: ProjectId,
2025    connection_id: ConnectionId,
2026    user_id: Option<UserId>,
2027    session: &Session,
2028) -> Result<()> {
2029    let delete = {
2030        let room_guard = session
2031            .db()
2032            .await
2033            .unshare_project(project_id, connection_id, user_id)
2034            .await?;
2035
2036        let (delete, room, guest_connection_ids) = &*room_guard;
2037
2038        let message = proto::UnshareProject {
2039            project_id: project_id.to_proto(),
2040        };
2041
2042        broadcast(
2043            Some(connection_id),
2044            guest_connection_ids.iter().copied(),
2045            |conn_id| session.peer.send(conn_id, message.clone()),
2046        );
2047        if let Some(room) = room {
2048            room_updated(room, &session.peer);
2049        }
2050
2051        *delete
2052    };
2053
2054    if delete {
2055        let db = session.db().await;
2056        db.delete_project(project_id).await?;
2057    }
2058
2059    Ok(())
2060}
2061
2062/// DevServer makes a project available online
2063async fn share_dev_server_project(
2064    request: proto::ShareDevServerProject,
2065    response: Response<proto::ShareDevServerProject>,
2066    session: DevServerSession,
2067) -> Result<()> {
2068    let (dev_server_project, user_id, status) = session
2069        .db()
2070        .await
2071        .share_dev_server_project(
2072            DevServerProjectId::from_proto(request.dev_server_project_id),
2073            session.dev_server_id(),
2074            session.connection_id,
2075            &request.worktrees,
2076        )
2077        .await?;
2078    let Some(project_id) = dev_server_project.project_id else {
2079        return Err(anyhow!("failed to share remote project"))?;
2080    };
2081
2082    send_dev_server_projects_update(user_id, status, &session).await;
2083
2084    response.send(proto::ShareProjectResponse { project_id })?;
2085
2086    Ok(())
2087}
2088
2089/// Join someone elses shared project.
2090async fn join_project(
2091    request: proto::JoinProject,
2092    response: Response<proto::JoinProject>,
2093    session: UserSession,
2094) -> Result<()> {
2095    let project_id = ProjectId::from_proto(request.project_id);
2096
2097    tracing::info!(%project_id, "join project");
2098
2099    let db = session.db().await;
2100    let (project, replica_id) = &mut *db
2101        .join_project(project_id, session.connection_id, session.user_id())
2102        .await?;
2103    drop(db);
2104    tracing::info!(%project_id, "join remote project");
2105    join_project_internal(response, session, project, replica_id)
2106}
2107
2108trait JoinProjectInternalResponse {
2109    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2110}
2111impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2112    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2113        Response::<proto::JoinProject>::send(self, result)
2114    }
2115}
2116impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2117    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2118        Response::<proto::JoinHostedProject>::send(self, result)
2119    }
2120}
2121
2122fn join_project_internal(
2123    response: impl JoinProjectInternalResponse,
2124    session: UserSession,
2125    project: &mut Project,
2126    replica_id: &ReplicaId,
2127) -> Result<()> {
2128    let collaborators = project
2129        .collaborators
2130        .iter()
2131        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2132        .map(|collaborator| collaborator.to_proto())
2133        .collect::<Vec<_>>();
2134    let project_id = project.id;
2135    let guest_user_id = session.user_id();
2136
2137    let worktrees = project
2138        .worktrees
2139        .iter()
2140        .map(|(id, worktree)| proto::WorktreeMetadata {
2141            id: *id,
2142            root_name: worktree.root_name.clone(),
2143            visible: worktree.visible,
2144            abs_path: worktree.abs_path.clone(),
2145        })
2146        .collect::<Vec<_>>();
2147
2148    let add_project_collaborator = proto::AddProjectCollaborator {
2149        project_id: project_id.to_proto(),
2150        collaborator: Some(proto::Collaborator {
2151            peer_id: Some(session.connection_id.into()),
2152            replica_id: replica_id.0 as u32,
2153            user_id: guest_user_id.to_proto(),
2154        }),
2155    };
2156
2157    for collaborator in &collaborators {
2158        session
2159            .peer
2160            .send(
2161                collaborator.peer_id.unwrap().into(),
2162                add_project_collaborator.clone(),
2163            )
2164            .trace_err();
2165    }
2166
2167    // First, we send the metadata associated with each worktree.
2168    response.send(proto::JoinProjectResponse {
2169        project_id: project.id.0 as u64,
2170        worktrees: worktrees.clone(),
2171        replica_id: replica_id.0 as u32,
2172        collaborators: collaborators.clone(),
2173        language_servers: project.language_servers.clone(),
2174        role: project.role.into(),
2175        dev_server_project_id: project
2176            .dev_server_project_id
2177            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2178    })?;
2179
2180    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2181        #[cfg(any(test, feature = "test-support"))]
2182        const MAX_CHUNK_SIZE: usize = 2;
2183        #[cfg(not(any(test, feature = "test-support")))]
2184        const MAX_CHUNK_SIZE: usize = 256;
2185
2186        // Stream this worktree's entries.
2187        let message = proto::UpdateWorktree {
2188            project_id: project_id.to_proto(),
2189            worktree_id,
2190            abs_path: worktree.abs_path.clone(),
2191            root_name: worktree.root_name,
2192            updated_entries: worktree.entries,
2193            removed_entries: Default::default(),
2194            scan_id: worktree.scan_id,
2195            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2196            updated_repositories: worktree.repository_entries.into_values().collect(),
2197            removed_repositories: Default::default(),
2198        };
2199        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2200            session.peer.send(session.connection_id, update.clone())?;
2201        }
2202
2203        // Stream this worktree's diagnostics.
2204        for summary in worktree.diagnostic_summaries {
2205            session.peer.send(
2206                session.connection_id,
2207                proto::UpdateDiagnosticSummary {
2208                    project_id: project_id.to_proto(),
2209                    worktree_id: worktree.id,
2210                    summary: Some(summary),
2211                },
2212            )?;
2213        }
2214
2215        for settings_file in worktree.settings_files {
2216            session.peer.send(
2217                session.connection_id,
2218                proto::UpdateWorktreeSettings {
2219                    project_id: project_id.to_proto(),
2220                    worktree_id: worktree.id,
2221                    path: settings_file.path,
2222                    content: Some(settings_file.content),
2223                },
2224            )?;
2225        }
2226    }
2227
2228    for language_server in &project.language_servers {
2229        session.peer.send(
2230            session.connection_id,
2231            proto::UpdateLanguageServer {
2232                project_id: project_id.to_proto(),
2233                language_server_id: language_server.id,
2234                variant: Some(
2235                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2236                        proto::LspDiskBasedDiagnosticsUpdated {},
2237                    ),
2238                ),
2239            },
2240        )?;
2241    }
2242
2243    Ok(())
2244}
2245
2246/// Leave someone elses shared project.
2247async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2248    let sender_id = session.connection_id;
2249    let project_id = ProjectId::from_proto(request.project_id);
2250    let db = session.db().await;
2251    if db.is_hosted_project(project_id).await? {
2252        let project = db.leave_hosted_project(project_id, sender_id).await?;
2253        project_left(&project, &session);
2254        return Ok(());
2255    }
2256
2257    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2258    tracing::info!(
2259        %project_id,
2260        "leave project"
2261    );
2262
2263    project_left(project, &session);
2264    if let Some(room) = room {
2265        room_updated(room, &session.peer);
2266    }
2267
2268    Ok(())
2269}
2270
2271async fn join_hosted_project(
2272    request: proto::JoinHostedProject,
2273    response: Response<proto::JoinHostedProject>,
2274    session: UserSession,
2275) -> Result<()> {
2276    let (mut project, replica_id) = session
2277        .db()
2278        .await
2279        .join_hosted_project(
2280            ProjectId(request.project_id as i32),
2281            session.user_id(),
2282            session.connection_id,
2283        )
2284        .await?;
2285
2286    join_project_internal(response, session, &mut project, &replica_id)
2287}
2288
2289async fn list_remote_directory(
2290    request: proto::ListRemoteDirectory,
2291    response: Response<proto::ListRemoteDirectory>,
2292    session: UserSession,
2293) -> Result<()> {
2294    let dev_server_id = DevServerId(request.dev_server_id as i32);
2295    let dev_server_connection_id = session
2296        .connection_pool()
2297        .await
2298        .online_dev_server_connection_id(dev_server_id)?;
2299
2300    session
2301        .db()
2302        .await
2303        .get_dev_server_for_user(dev_server_id, session.user_id())
2304        .await?;
2305
2306    response.send(
2307        session
2308            .peer
2309            .forward_request(session.connection_id, dev_server_connection_id, request)
2310            .await?,
2311    )?;
2312    Ok(())
2313}
2314
2315async fn update_dev_server_project(
2316    request: proto::UpdateDevServerProject,
2317    response: Response<proto::UpdateDevServerProject>,
2318    session: UserSession,
2319) -> Result<()> {
2320    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2321
2322    let (dev_server_project, update) = session
2323        .db()
2324        .await
2325        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2326        .await?;
2327
2328    let projects = session
2329        .db()
2330        .await
2331        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2332        .await?;
2333
2334    let dev_server_connection_id = session
2335        .connection_pool()
2336        .await
2337        .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2338
2339    session.peer.send(
2340        dev_server_connection_id,
2341        proto::DevServerInstructions { projects },
2342    )?;
2343
2344    send_dev_server_projects_update(session.user_id(), update, &session).await;
2345
2346    response.send(proto::Ack {})
2347}
2348
2349async fn create_dev_server_project(
2350    request: proto::CreateDevServerProject,
2351    response: Response<proto::CreateDevServerProject>,
2352    session: UserSession,
2353) -> Result<()> {
2354    let dev_server_id = DevServerId(request.dev_server_id as i32);
2355    let dev_server_connection_id = session
2356        .connection_pool()
2357        .await
2358        .dev_server_connection_id(dev_server_id);
2359    let Some(dev_server_connection_id) = dev_server_connection_id else {
2360        Err(ErrorCode::DevServerOffline
2361            .message("Cannot create a remote project when the dev server is offline".to_string())
2362            .anyhow())?
2363    };
2364
2365    let path = request.path.clone();
2366    //Check that the path exists on the dev server
2367    session
2368        .peer
2369        .forward_request(
2370            session.connection_id,
2371            dev_server_connection_id,
2372            proto::ValidateDevServerProjectRequest { path: path.clone() },
2373        )
2374        .await?;
2375
2376    let (dev_server_project, update) = session
2377        .db()
2378        .await
2379        .create_dev_server_project(
2380            DevServerId(request.dev_server_id as i32),
2381            &request.path,
2382            session.user_id(),
2383        )
2384        .await?;
2385
2386    let projects = session
2387        .db()
2388        .await
2389        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2390        .await?;
2391
2392    session.peer.send(
2393        dev_server_connection_id,
2394        proto::DevServerInstructions { projects },
2395    )?;
2396
2397    send_dev_server_projects_update(session.user_id(), update, &session).await;
2398
2399    response.send(proto::CreateDevServerProjectResponse {
2400        dev_server_project: Some(dev_server_project.to_proto(None)),
2401    })?;
2402    Ok(())
2403}
2404
2405async fn create_dev_server(
2406    request: proto::CreateDevServer,
2407    response: Response<proto::CreateDevServer>,
2408    session: UserSession,
2409) -> Result<()> {
2410    let access_token = auth::random_token();
2411    let hashed_access_token = auth::hash_access_token(&access_token);
2412
2413    if request.name.is_empty() {
2414        return Err(proto::ErrorCode::Forbidden
2415            .message("Dev server name cannot be empty".to_string())
2416            .anyhow())?;
2417    }
2418
2419    let (dev_server, status) = session
2420        .db()
2421        .await
2422        .create_dev_server(
2423            &request.name,
2424            request.ssh_connection_string.as_deref(),
2425            &hashed_access_token,
2426            session.user_id(),
2427        )
2428        .await?;
2429
2430    send_dev_server_projects_update(session.user_id(), status, &session).await;
2431
2432    response.send(proto::CreateDevServerResponse {
2433        dev_server_id: dev_server.id.0 as u64,
2434        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2435        name: request.name,
2436    })?;
2437    Ok(())
2438}
2439
2440async fn regenerate_dev_server_token(
2441    request: proto::RegenerateDevServerToken,
2442    response: Response<proto::RegenerateDevServerToken>,
2443    session: UserSession,
2444) -> Result<()> {
2445    let dev_server_id = DevServerId(request.dev_server_id as i32);
2446    let access_token = auth::random_token();
2447    let hashed_access_token = auth::hash_access_token(&access_token);
2448
2449    let connection_id = session
2450        .connection_pool()
2451        .await
2452        .dev_server_connection_id(dev_server_id);
2453    if let Some(connection_id) = connection_id {
2454        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2455        session.peer.send(
2456            connection_id,
2457            proto::ShutdownDevServer {
2458                reason: Some("dev server token was regenerated".to_string()),
2459            },
2460        )?;
2461        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2462    }
2463
2464    let status = session
2465        .db()
2466        .await
2467        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2468        .await?;
2469
2470    send_dev_server_projects_update(session.user_id(), status, &session).await;
2471
2472    response.send(proto::RegenerateDevServerTokenResponse {
2473        dev_server_id: dev_server_id.to_proto(),
2474        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2475    })?;
2476    Ok(())
2477}
2478
2479async fn rename_dev_server(
2480    request: proto::RenameDevServer,
2481    response: Response<proto::RenameDevServer>,
2482    session: UserSession,
2483) -> Result<()> {
2484    if request.name.trim().is_empty() {
2485        return Err(proto::ErrorCode::Forbidden
2486            .message("Dev server name cannot be empty".to_string())
2487            .anyhow())?;
2488    }
2489
2490    let dev_server_id = DevServerId(request.dev_server_id as i32);
2491    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2492    if dev_server.user_id != session.user_id() {
2493        return Err(anyhow!(ErrorCode::Forbidden))?;
2494    }
2495
2496    let status = session
2497        .db()
2498        .await
2499        .rename_dev_server(
2500            dev_server_id,
2501            &request.name,
2502            request.ssh_connection_string.as_deref(),
2503            session.user_id(),
2504        )
2505        .await?;
2506
2507    send_dev_server_projects_update(session.user_id(), status, &session).await;
2508
2509    response.send(proto::Ack {})?;
2510    Ok(())
2511}
2512
2513async fn delete_dev_server(
2514    request: proto::DeleteDevServer,
2515    response: Response<proto::DeleteDevServer>,
2516    session: UserSession,
2517) -> Result<()> {
2518    let dev_server_id = DevServerId(request.dev_server_id as i32);
2519    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2520    if dev_server.user_id != session.user_id() {
2521        return Err(anyhow!(ErrorCode::Forbidden))?;
2522    }
2523
2524    let connection_id = session
2525        .connection_pool()
2526        .await
2527        .dev_server_connection_id(dev_server_id);
2528    if let Some(connection_id) = connection_id {
2529        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2530        session.peer.send(
2531            connection_id,
2532            proto::ShutdownDevServer {
2533                reason: Some("dev server was deleted".to_string()),
2534            },
2535        )?;
2536        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2537    }
2538
2539    let status = session
2540        .db()
2541        .await
2542        .delete_dev_server(dev_server_id, session.user_id())
2543        .await?;
2544
2545    send_dev_server_projects_update(session.user_id(), status, &session).await;
2546
2547    response.send(proto::Ack {})?;
2548    Ok(())
2549}
2550
2551async fn delete_dev_server_project(
2552    request: proto::DeleteDevServerProject,
2553    response: Response<proto::DeleteDevServerProject>,
2554    session: UserSession,
2555) -> Result<()> {
2556    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2557    let dev_server_project = session
2558        .db()
2559        .await
2560        .get_dev_server_project(dev_server_project_id)
2561        .await?;
2562
2563    let dev_server = session
2564        .db()
2565        .await
2566        .get_dev_server(dev_server_project.dev_server_id)
2567        .await?;
2568    if dev_server.user_id != session.user_id() {
2569        return Err(anyhow!(ErrorCode::Forbidden))?;
2570    }
2571
2572    let dev_server_connection_id = session
2573        .connection_pool()
2574        .await
2575        .dev_server_connection_id(dev_server.id);
2576
2577    if let Some(dev_server_connection_id) = dev_server_connection_id {
2578        let project = session
2579            .db()
2580            .await
2581            .find_dev_server_project(dev_server_project_id)
2582            .await;
2583        if let Ok(project) = project {
2584            unshare_project_internal(
2585                project.id,
2586                dev_server_connection_id,
2587                Some(session.user_id()),
2588                &session,
2589            )
2590            .await?;
2591        }
2592    }
2593
2594    let (projects, status) = session
2595        .db()
2596        .await
2597        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2598        .await?;
2599
2600    if let Some(dev_server_connection_id) = dev_server_connection_id {
2601        session.peer.send(
2602            dev_server_connection_id,
2603            proto::DevServerInstructions { projects },
2604        )?;
2605    }
2606
2607    send_dev_server_projects_update(session.user_id(), status, &session).await;
2608
2609    response.send(proto::Ack {})?;
2610    Ok(())
2611}
2612
2613async fn rejoin_dev_server_projects(
2614    request: proto::RejoinRemoteProjects,
2615    response: Response<proto::RejoinRemoteProjects>,
2616    session: UserSession,
2617) -> Result<()> {
2618    let mut rejoined_projects = {
2619        let db = session.db().await;
2620        db.rejoin_dev_server_projects(
2621            &request.rejoined_projects,
2622            session.user_id(),
2623            session.0.connection_id,
2624        )
2625        .await?
2626    };
2627    response.send(proto::RejoinRemoteProjectsResponse {
2628        rejoined_projects: rejoined_projects
2629            .iter()
2630            .map(|project| project.to_proto())
2631            .collect(),
2632    })?;
2633    notify_rejoined_projects(&mut rejoined_projects, &session)
2634}
2635
2636async fn reconnect_dev_server(
2637    request: proto::ReconnectDevServer,
2638    response: Response<proto::ReconnectDevServer>,
2639    session: DevServerSession,
2640) -> Result<()> {
2641    let reshared_projects = {
2642        let db = session.db().await;
2643        db.reshare_dev_server_projects(
2644            &request.reshared_projects,
2645            session.dev_server_id(),
2646            session.0.connection_id,
2647        )
2648        .await?
2649    };
2650
2651    for project in &reshared_projects {
2652        for collaborator in &project.collaborators {
2653            session
2654                .peer
2655                .send(
2656                    collaborator.connection_id,
2657                    proto::UpdateProjectCollaborator {
2658                        project_id: project.id.to_proto(),
2659                        old_peer_id: Some(project.old_connection_id.into()),
2660                        new_peer_id: Some(session.connection_id.into()),
2661                    },
2662                )
2663                .trace_err();
2664        }
2665
2666        broadcast(
2667            Some(session.connection_id),
2668            project
2669                .collaborators
2670                .iter()
2671                .map(|collaborator| collaborator.connection_id),
2672            |connection_id| {
2673                session.peer.forward_send(
2674                    session.connection_id,
2675                    connection_id,
2676                    proto::UpdateProject {
2677                        project_id: project.id.to_proto(),
2678                        worktrees: project.worktrees.clone(),
2679                    },
2680                )
2681            },
2682        );
2683    }
2684
2685    response.send(proto::ReconnectDevServerResponse {
2686        reshared_projects: reshared_projects
2687            .iter()
2688            .map(|project| proto::ResharedProject {
2689                id: project.id.to_proto(),
2690                collaborators: project
2691                    .collaborators
2692                    .iter()
2693                    .map(|collaborator| collaborator.to_proto())
2694                    .collect(),
2695            })
2696            .collect(),
2697    })?;
2698
2699    Ok(())
2700}
2701
2702async fn shutdown_dev_server(
2703    _: proto::ShutdownDevServer,
2704    response: Response<proto::ShutdownDevServer>,
2705    session: DevServerSession,
2706) -> Result<()> {
2707    response.send(proto::Ack {})?;
2708    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2709    remove_dev_server_connection(session.dev_server_id(), &session).await
2710}
2711
2712async fn shutdown_dev_server_internal(
2713    dev_server_id: DevServerId,
2714    connection_id: ConnectionId,
2715    session: &Session,
2716) -> Result<()> {
2717    let (dev_server_projects, dev_server) = {
2718        let db = session.db().await;
2719        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2720        let dev_server = db.get_dev_server(dev_server_id).await?;
2721        (dev_server_projects, dev_server)
2722    };
2723
2724    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2725        unshare_project_internal(
2726            ProjectId::from_proto(project_id),
2727            connection_id,
2728            None,
2729            session,
2730        )
2731        .await?;
2732    }
2733
2734    session
2735        .connection_pool()
2736        .await
2737        .set_dev_server_offline(dev_server_id);
2738
2739    let status = session
2740        .db()
2741        .await
2742        .dev_server_projects_update(dev_server.user_id)
2743        .await?;
2744    send_dev_server_projects_update(dev_server.user_id, status, session).await;
2745
2746    Ok(())
2747}
2748
2749async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2750    let dev_server_connection = session
2751        .connection_pool()
2752        .await
2753        .dev_server_connection_id(dev_server_id);
2754
2755    if let Some(dev_server_connection) = dev_server_connection {
2756        session
2757            .connection_pool()
2758            .await
2759            .remove_connection(dev_server_connection)?;
2760    }
2761    Ok(())
2762}
2763
2764/// Updates other participants with changes to the project
2765async fn update_project(
2766    request: proto::UpdateProject,
2767    response: Response<proto::UpdateProject>,
2768    session: Session,
2769) -> Result<()> {
2770    let project_id = ProjectId::from_proto(request.project_id);
2771    let (room, guest_connection_ids) = &*session
2772        .db()
2773        .await
2774        .update_project(project_id, session.connection_id, &request.worktrees)
2775        .await?;
2776    broadcast(
2777        Some(session.connection_id),
2778        guest_connection_ids.iter().copied(),
2779        |connection_id| {
2780            session
2781                .peer
2782                .forward_send(session.connection_id, connection_id, request.clone())
2783        },
2784    );
2785    if let Some(room) = room {
2786        room_updated(room, &session.peer);
2787    }
2788    response.send(proto::Ack {})?;
2789
2790    Ok(())
2791}
2792
2793/// Updates other participants with changes to the worktree
2794async fn update_worktree(
2795    request: proto::UpdateWorktree,
2796    response: Response<proto::UpdateWorktree>,
2797    session: Session,
2798) -> Result<()> {
2799    let guest_connection_ids = session
2800        .db()
2801        .await
2802        .update_worktree(&request, session.connection_id)
2803        .await?;
2804
2805    broadcast(
2806        Some(session.connection_id),
2807        guest_connection_ids.iter().copied(),
2808        |connection_id| {
2809            session
2810                .peer
2811                .forward_send(session.connection_id, connection_id, request.clone())
2812        },
2813    );
2814    response.send(proto::Ack {})?;
2815    Ok(())
2816}
2817
2818/// Updates other participants with changes to the diagnostics
2819async fn update_diagnostic_summary(
2820    message: proto::UpdateDiagnosticSummary,
2821    session: Session,
2822) -> Result<()> {
2823    let guest_connection_ids = session
2824        .db()
2825        .await
2826        .update_diagnostic_summary(&message, session.connection_id)
2827        .await?;
2828
2829    broadcast(
2830        Some(session.connection_id),
2831        guest_connection_ids.iter().copied(),
2832        |connection_id| {
2833            session
2834                .peer
2835                .forward_send(session.connection_id, connection_id, message.clone())
2836        },
2837    );
2838
2839    Ok(())
2840}
2841
2842/// Updates other participants with changes to the worktree settings
2843async fn update_worktree_settings(
2844    message: proto::UpdateWorktreeSettings,
2845    session: Session,
2846) -> Result<()> {
2847    let guest_connection_ids = session
2848        .db()
2849        .await
2850        .update_worktree_settings(&message, session.connection_id)
2851        .await?;
2852
2853    broadcast(
2854        Some(session.connection_id),
2855        guest_connection_ids.iter().copied(),
2856        |connection_id| {
2857            session
2858                .peer
2859                .forward_send(session.connection_id, connection_id, message.clone())
2860        },
2861    );
2862
2863    Ok(())
2864}
2865
2866/// Notify other participants that a  language server has started.
2867async fn start_language_server(
2868    request: proto::StartLanguageServer,
2869    session: Session,
2870) -> Result<()> {
2871    let guest_connection_ids = session
2872        .db()
2873        .await
2874        .start_language_server(&request, session.connection_id)
2875        .await?;
2876
2877    broadcast(
2878        Some(session.connection_id),
2879        guest_connection_ids.iter().copied(),
2880        |connection_id| {
2881            session
2882                .peer
2883                .forward_send(session.connection_id, connection_id, request.clone())
2884        },
2885    );
2886    Ok(())
2887}
2888
2889/// Notify other participants that a language server has changed.
2890async fn update_language_server(
2891    request: proto::UpdateLanguageServer,
2892    session: Session,
2893) -> Result<()> {
2894    let project_id = ProjectId::from_proto(request.project_id);
2895    let project_connection_ids = session
2896        .db()
2897        .await
2898        .project_connection_ids(project_id, session.connection_id, true)
2899        .await?;
2900    broadcast(
2901        Some(session.connection_id),
2902        project_connection_ids.iter().copied(),
2903        |connection_id| {
2904            session
2905                .peer
2906                .forward_send(session.connection_id, connection_id, request.clone())
2907        },
2908    );
2909    Ok(())
2910}
2911
2912/// forward a project request to the host. These requests should be read only
2913/// as guests are allowed to send them.
2914async fn forward_read_only_project_request<T>(
2915    request: T,
2916    response: Response<T>,
2917    session: UserSession,
2918) -> Result<()>
2919where
2920    T: EntityMessage + RequestMessage,
2921{
2922    let project_id = ProjectId::from_proto(request.remote_entity_id());
2923    let host_connection_id = session
2924        .db()
2925        .await
2926        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2927        .await?;
2928    let payload = session
2929        .peer
2930        .forward_request(session.connection_id, host_connection_id, request)
2931        .await?;
2932    response.send(payload)?;
2933    Ok(())
2934}
2935
2936async fn forward_find_search_candidates_request(
2937    request: proto::FindSearchCandidates,
2938    response: Response<proto::FindSearchCandidates>,
2939    session: UserSession,
2940) -> Result<()> {
2941    let project_id = ProjectId::from_proto(request.remote_entity_id());
2942    let host_connection_id = session
2943        .db()
2944        .await
2945        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2946        .await?;
2947    let payload = session
2948        .peer
2949        .forward_request(session.connection_id, host_connection_id, request)
2950        .await?;
2951    response.send(payload)?;
2952    Ok(())
2953}
2954
2955/// forward a project request to the dev server. Only allowed
2956/// if it's your dev server.
2957async fn forward_project_request_for_owner<T>(
2958    request: T,
2959    response: Response<T>,
2960    session: UserSession,
2961) -> Result<()>
2962where
2963    T: EntityMessage + RequestMessage,
2964{
2965    let project_id = ProjectId::from_proto(request.remote_entity_id());
2966
2967    let host_connection_id = session
2968        .db()
2969        .await
2970        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2971        .await?;
2972    let payload = session
2973        .peer
2974        .forward_request(session.connection_id, host_connection_id, request)
2975        .await?;
2976    response.send(payload)?;
2977    Ok(())
2978}
2979
2980/// forward a project request to the host. These requests are disallowed
2981/// for guests.
2982async fn forward_mutating_project_request<T>(
2983    request: T,
2984    response: Response<T>,
2985    session: UserSession,
2986) -> Result<()>
2987where
2988    T: EntityMessage + RequestMessage,
2989{
2990    let project_id = ProjectId::from_proto(request.remote_entity_id());
2991
2992    let host_connection_id = session
2993        .db()
2994        .await
2995        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2996        .await?;
2997    let payload = session
2998        .peer
2999        .forward_request(session.connection_id, host_connection_id, request)
3000        .await?;
3001    response.send(payload)?;
3002    Ok(())
3003}
3004
3005/// Notify other participants that a new buffer has been created
3006async fn create_buffer_for_peer(
3007    request: proto::CreateBufferForPeer,
3008    session: Session,
3009) -> Result<()> {
3010    session
3011        .db()
3012        .await
3013        .check_user_is_project_host(
3014            ProjectId::from_proto(request.project_id),
3015            session.connection_id,
3016        )
3017        .await?;
3018    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3019    session
3020        .peer
3021        .forward_send(session.connection_id, peer_id.into(), request)?;
3022    Ok(())
3023}
3024
3025/// Notify other participants that a buffer has been updated. This is
3026/// allowed for guests as long as the update is limited to selections.
3027async fn update_buffer(
3028    request: proto::UpdateBuffer,
3029    response: Response<proto::UpdateBuffer>,
3030    session: Session,
3031) -> Result<()> {
3032    let project_id = ProjectId::from_proto(request.project_id);
3033    let mut capability = Capability::ReadOnly;
3034
3035    for op in request.operations.iter() {
3036        match op.variant {
3037            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3038            Some(_) => capability = Capability::ReadWrite,
3039        }
3040    }
3041
3042    let host = {
3043        let guard = session
3044            .db()
3045            .await
3046            .connections_for_buffer_update(
3047                project_id,
3048                session.principal_id(),
3049                session.connection_id,
3050                capability,
3051            )
3052            .await?;
3053
3054        let (host, guests) = &*guard;
3055
3056        broadcast(
3057            Some(session.connection_id),
3058            guests.clone(),
3059            |connection_id| {
3060                session
3061                    .peer
3062                    .forward_send(session.connection_id, connection_id, request.clone())
3063            },
3064        );
3065
3066        *host
3067    };
3068
3069    if host != session.connection_id {
3070        session
3071            .peer
3072            .forward_request(session.connection_id, host, request.clone())
3073            .await?;
3074    }
3075
3076    response.send(proto::Ack {})?;
3077    Ok(())
3078}
3079
3080async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3081    let project_id = ProjectId::from_proto(message.project_id);
3082
3083    let operation = message.operation.as_ref().context("invalid operation")?;
3084    let capability = match operation.variant.as_ref() {
3085        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3086            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3087                match buffer_op.variant {
3088                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3089                        Capability::ReadOnly
3090                    }
3091                    _ => Capability::ReadWrite,
3092                }
3093            } else {
3094                Capability::ReadWrite
3095            }
3096        }
3097        Some(_) => Capability::ReadWrite,
3098        None => Capability::ReadOnly,
3099    };
3100
3101    let guard = session
3102        .db()
3103        .await
3104        .connections_for_buffer_update(
3105            project_id,
3106            session.principal_id(),
3107            session.connection_id,
3108            capability,
3109        )
3110        .await?;
3111
3112    let (host, guests) = &*guard;
3113
3114    broadcast(
3115        Some(session.connection_id),
3116        guests.iter().chain([host]).copied(),
3117        |connection_id| {
3118            session
3119                .peer
3120                .forward_send(session.connection_id, connection_id, message.clone())
3121        },
3122    );
3123
3124    Ok(())
3125}
3126
3127/// Notify other participants that a project has been updated.
3128async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3129    request: T,
3130    session: Session,
3131) -> Result<()> {
3132    let project_id = ProjectId::from_proto(request.remote_entity_id());
3133    let project_connection_ids = session
3134        .db()
3135        .await
3136        .project_connection_ids(project_id, session.connection_id, false)
3137        .await?;
3138
3139    broadcast(
3140        Some(session.connection_id),
3141        project_connection_ids.iter().copied(),
3142        |connection_id| {
3143            session
3144                .peer
3145                .forward_send(session.connection_id, connection_id, request.clone())
3146        },
3147    );
3148    Ok(())
3149}
3150
3151/// Start following another user in a call.
3152async fn follow(
3153    request: proto::Follow,
3154    response: Response<proto::Follow>,
3155    session: UserSession,
3156) -> Result<()> {
3157    let room_id = RoomId::from_proto(request.room_id);
3158    let project_id = request.project_id.map(ProjectId::from_proto);
3159    let leader_id = request
3160        .leader_id
3161        .ok_or_else(|| anyhow!("invalid leader id"))?
3162        .into();
3163    let follower_id = session.connection_id;
3164
3165    session
3166        .db()
3167        .await
3168        .check_room_participants(room_id, leader_id, session.connection_id)
3169        .await?;
3170
3171    let response_payload = session
3172        .peer
3173        .forward_request(session.connection_id, leader_id, request)
3174        .await?;
3175    response.send(response_payload)?;
3176
3177    if let Some(project_id) = project_id {
3178        let room = session
3179            .db()
3180            .await
3181            .follow(room_id, project_id, leader_id, follower_id)
3182            .await?;
3183        room_updated(&room, &session.peer);
3184    }
3185
3186    Ok(())
3187}
3188
3189/// Stop following another user in a call.
3190async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3191    let room_id = RoomId::from_proto(request.room_id);
3192    let project_id = request.project_id.map(ProjectId::from_proto);
3193    let leader_id = request
3194        .leader_id
3195        .ok_or_else(|| anyhow!("invalid leader id"))?
3196        .into();
3197    let follower_id = session.connection_id;
3198
3199    session
3200        .db()
3201        .await
3202        .check_room_participants(room_id, leader_id, session.connection_id)
3203        .await?;
3204
3205    session
3206        .peer
3207        .forward_send(session.connection_id, leader_id, request)?;
3208
3209    if let Some(project_id) = project_id {
3210        let room = session
3211            .db()
3212            .await
3213            .unfollow(room_id, project_id, leader_id, follower_id)
3214            .await?;
3215        room_updated(&room, &session.peer);
3216    }
3217
3218    Ok(())
3219}
3220
3221/// Notify everyone following you of your current location.
3222async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3223    let room_id = RoomId::from_proto(request.room_id);
3224    let database = session.db.lock().await;
3225
3226    let connection_ids = if let Some(project_id) = request.project_id {
3227        let project_id = ProjectId::from_proto(project_id);
3228        database
3229            .project_connection_ids(project_id, session.connection_id, true)
3230            .await?
3231    } else {
3232        database
3233            .room_connection_ids(room_id, session.connection_id)
3234            .await?
3235    };
3236
3237    // For now, don't send view update messages back to that view's current leader.
3238    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3239        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3240        _ => None,
3241    });
3242
3243    for connection_id in connection_ids.iter().cloned() {
3244        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3245            session
3246                .peer
3247                .forward_send(session.connection_id, connection_id, request.clone())?;
3248        }
3249    }
3250    Ok(())
3251}
3252
3253/// Get public data about users.
3254async fn get_users(
3255    request: proto::GetUsers,
3256    response: Response<proto::GetUsers>,
3257    session: Session,
3258) -> Result<()> {
3259    let user_ids = request
3260        .user_ids
3261        .into_iter()
3262        .map(UserId::from_proto)
3263        .collect();
3264    let users = session
3265        .db()
3266        .await
3267        .get_users_by_ids(user_ids)
3268        .await?
3269        .into_iter()
3270        .map(|user| proto::User {
3271            id: user.id.to_proto(),
3272            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3273            github_login: user.github_login,
3274        })
3275        .collect();
3276    response.send(proto::UsersResponse { users })?;
3277    Ok(())
3278}
3279
3280/// Search for users (to invite) buy Github login
3281async fn fuzzy_search_users(
3282    request: proto::FuzzySearchUsers,
3283    response: Response<proto::FuzzySearchUsers>,
3284    session: UserSession,
3285) -> Result<()> {
3286    let query = request.query;
3287    let users = match query.len() {
3288        0 => vec![],
3289        1 | 2 => session
3290            .db()
3291            .await
3292            .get_user_by_github_login(&query)
3293            .await?
3294            .into_iter()
3295            .collect(),
3296        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3297    };
3298    let users = users
3299        .into_iter()
3300        .filter(|user| user.id != session.user_id())
3301        .map(|user| proto::User {
3302            id: user.id.to_proto(),
3303            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3304            github_login: user.github_login,
3305        })
3306        .collect();
3307    response.send(proto::UsersResponse { users })?;
3308    Ok(())
3309}
3310
3311/// Send a contact request to another user.
3312async fn request_contact(
3313    request: proto::RequestContact,
3314    response: Response<proto::RequestContact>,
3315    session: UserSession,
3316) -> Result<()> {
3317    let requester_id = session.user_id();
3318    let responder_id = UserId::from_proto(request.responder_id);
3319    if requester_id == responder_id {
3320        return Err(anyhow!("cannot add yourself as a contact"))?;
3321    }
3322
3323    let notifications = session
3324        .db()
3325        .await
3326        .send_contact_request(requester_id, responder_id)
3327        .await?;
3328
3329    // Update outgoing contact requests of requester
3330    let mut update = proto::UpdateContacts::default();
3331    update.outgoing_requests.push(responder_id.to_proto());
3332    for connection_id in session
3333        .connection_pool()
3334        .await
3335        .user_connection_ids(requester_id)
3336    {
3337        session.peer.send(connection_id, update.clone())?;
3338    }
3339
3340    // Update incoming contact requests of responder
3341    let mut update = proto::UpdateContacts::default();
3342    update
3343        .incoming_requests
3344        .push(proto::IncomingContactRequest {
3345            requester_id: requester_id.to_proto(),
3346        });
3347    let connection_pool = session.connection_pool().await;
3348    for connection_id in connection_pool.user_connection_ids(responder_id) {
3349        session.peer.send(connection_id, update.clone())?;
3350    }
3351
3352    send_notifications(&connection_pool, &session.peer, notifications);
3353
3354    response.send(proto::Ack {})?;
3355    Ok(())
3356}
3357
3358/// Accept or decline a contact request
3359async fn respond_to_contact_request(
3360    request: proto::RespondToContactRequest,
3361    response: Response<proto::RespondToContactRequest>,
3362    session: UserSession,
3363) -> Result<()> {
3364    let responder_id = session.user_id();
3365    let requester_id = UserId::from_proto(request.requester_id);
3366    let db = session.db().await;
3367    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3368        db.dismiss_contact_notification(responder_id, requester_id)
3369            .await?;
3370    } else {
3371        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3372
3373        let notifications = db
3374            .respond_to_contact_request(responder_id, requester_id, accept)
3375            .await?;
3376        let requester_busy = db.is_user_busy(requester_id).await?;
3377        let responder_busy = db.is_user_busy(responder_id).await?;
3378
3379        let pool = session.connection_pool().await;
3380        // Update responder with new contact
3381        let mut update = proto::UpdateContacts::default();
3382        if accept {
3383            update
3384                .contacts
3385                .push(contact_for_user(requester_id, requester_busy, &pool));
3386        }
3387        update
3388            .remove_incoming_requests
3389            .push(requester_id.to_proto());
3390        for connection_id in pool.user_connection_ids(responder_id) {
3391            session.peer.send(connection_id, update.clone())?;
3392        }
3393
3394        // Update requester with new contact
3395        let mut update = proto::UpdateContacts::default();
3396        if accept {
3397            update
3398                .contacts
3399                .push(contact_for_user(responder_id, responder_busy, &pool));
3400        }
3401        update
3402            .remove_outgoing_requests
3403            .push(responder_id.to_proto());
3404
3405        for connection_id in pool.user_connection_ids(requester_id) {
3406            session.peer.send(connection_id, update.clone())?;
3407        }
3408
3409        send_notifications(&pool, &session.peer, notifications);
3410    }
3411
3412    response.send(proto::Ack {})?;
3413    Ok(())
3414}
3415
3416/// Remove a contact.
3417async fn remove_contact(
3418    request: proto::RemoveContact,
3419    response: Response<proto::RemoveContact>,
3420    session: UserSession,
3421) -> Result<()> {
3422    let requester_id = session.user_id();
3423    let responder_id = UserId::from_proto(request.user_id);
3424    let db = session.db().await;
3425    let (contact_accepted, deleted_notification_id) =
3426        db.remove_contact(requester_id, responder_id).await?;
3427
3428    let pool = session.connection_pool().await;
3429    // Update outgoing contact requests of requester
3430    let mut update = proto::UpdateContacts::default();
3431    if contact_accepted {
3432        update.remove_contacts.push(responder_id.to_proto());
3433    } else {
3434        update
3435            .remove_outgoing_requests
3436            .push(responder_id.to_proto());
3437    }
3438    for connection_id in pool.user_connection_ids(requester_id) {
3439        session.peer.send(connection_id, update.clone())?;
3440    }
3441
3442    // Update incoming contact requests of responder
3443    let mut update = proto::UpdateContacts::default();
3444    if contact_accepted {
3445        update.remove_contacts.push(requester_id.to_proto());
3446    } else {
3447        update
3448            .remove_incoming_requests
3449            .push(requester_id.to_proto());
3450    }
3451    for connection_id in pool.user_connection_ids(responder_id) {
3452        session.peer.send(connection_id, update.clone())?;
3453        if let Some(notification_id) = deleted_notification_id {
3454            session.peer.send(
3455                connection_id,
3456                proto::DeleteNotification {
3457                    notification_id: notification_id.to_proto(),
3458                },
3459            )?;
3460        }
3461    }
3462
3463    response.send(proto::Ack {})?;
3464    Ok(())
3465}
3466
3467fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3468    version.0.minor() < 139
3469}
3470
3471async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3472    let plan = session.current_plan(session.db().await).await?;
3473
3474    session
3475        .peer
3476        .send(
3477            session.connection_id,
3478            proto::UpdateUserPlan { plan: plan.into() },
3479        )
3480        .trace_err();
3481
3482    Ok(())
3483}
3484
3485async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3486    subscribe_user_to_channels(
3487        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3488        &session,
3489    )
3490    .await?;
3491    Ok(())
3492}
3493
3494async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3495    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3496    let mut pool = session.connection_pool().await;
3497    for membership in &channels_for_user.channel_memberships {
3498        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3499    }
3500    session.peer.send(
3501        session.connection_id,
3502        build_update_user_channels(&channels_for_user),
3503    )?;
3504    session.peer.send(
3505        session.connection_id,
3506        build_channels_update(channels_for_user),
3507    )?;
3508    Ok(())
3509}
3510
3511/// Creates a new channel.
3512async fn create_channel(
3513    request: proto::CreateChannel,
3514    response: Response<proto::CreateChannel>,
3515    session: UserSession,
3516) -> Result<()> {
3517    let db = session.db().await;
3518
3519    let parent_id = request.parent_id.map(ChannelId::from_proto);
3520    let (channel, membership) = db
3521        .create_channel(&request.name, parent_id, session.user_id())
3522        .await?;
3523
3524    let root_id = channel.root_id();
3525    let channel = Channel::from_model(channel);
3526
3527    response.send(proto::CreateChannelResponse {
3528        channel: Some(channel.to_proto()),
3529        parent_id: request.parent_id,
3530    })?;
3531
3532    let mut connection_pool = session.connection_pool().await;
3533    if let Some(membership) = membership {
3534        connection_pool.subscribe_to_channel(
3535            membership.user_id,
3536            membership.channel_id,
3537            membership.role,
3538        );
3539        let update = proto::UpdateUserChannels {
3540            channel_memberships: vec![proto::ChannelMembership {
3541                channel_id: membership.channel_id.to_proto(),
3542                role: membership.role.into(),
3543            }],
3544            ..Default::default()
3545        };
3546        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3547            session.peer.send(connection_id, update.clone())?;
3548        }
3549    }
3550
3551    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3552        if !role.can_see_channel(channel.visibility) {
3553            continue;
3554        }
3555
3556        let update = proto::UpdateChannels {
3557            channels: vec![channel.to_proto()],
3558            ..Default::default()
3559        };
3560        session.peer.send(connection_id, update.clone())?;
3561    }
3562
3563    Ok(())
3564}
3565
3566/// Delete a channel
3567async fn delete_channel(
3568    request: proto::DeleteChannel,
3569    response: Response<proto::DeleteChannel>,
3570    session: UserSession,
3571) -> Result<()> {
3572    let db = session.db().await;
3573
3574    let channel_id = request.channel_id;
3575    let (root_channel, removed_channels) = db
3576        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3577        .await?;
3578    response.send(proto::Ack {})?;
3579
3580    // Notify members of removed channels
3581    let mut update = proto::UpdateChannels::default();
3582    update
3583        .delete_channels
3584        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3585
3586    let connection_pool = session.connection_pool().await;
3587    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3588        session.peer.send(connection_id, update.clone())?;
3589    }
3590
3591    Ok(())
3592}
3593
3594/// Invite someone to join a channel.
3595async fn invite_channel_member(
3596    request: proto::InviteChannelMember,
3597    response: Response<proto::InviteChannelMember>,
3598    session: UserSession,
3599) -> Result<()> {
3600    let db = session.db().await;
3601    let channel_id = ChannelId::from_proto(request.channel_id);
3602    let invitee_id = UserId::from_proto(request.user_id);
3603    let InviteMemberResult {
3604        channel,
3605        notifications,
3606    } = db
3607        .invite_channel_member(
3608            channel_id,
3609            invitee_id,
3610            session.user_id(),
3611            request.role().into(),
3612        )
3613        .await?;
3614
3615    let update = proto::UpdateChannels {
3616        channel_invitations: vec![channel.to_proto()],
3617        ..Default::default()
3618    };
3619
3620    let connection_pool = session.connection_pool().await;
3621    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3622        session.peer.send(connection_id, update.clone())?;
3623    }
3624
3625    send_notifications(&connection_pool, &session.peer, notifications);
3626
3627    response.send(proto::Ack {})?;
3628    Ok(())
3629}
3630
3631/// remove someone from a channel
3632async fn remove_channel_member(
3633    request: proto::RemoveChannelMember,
3634    response: Response<proto::RemoveChannelMember>,
3635    session: UserSession,
3636) -> Result<()> {
3637    let db = session.db().await;
3638    let channel_id = ChannelId::from_proto(request.channel_id);
3639    let member_id = UserId::from_proto(request.user_id);
3640
3641    let RemoveChannelMemberResult {
3642        membership_update,
3643        notification_id,
3644    } = db
3645        .remove_channel_member(channel_id, member_id, session.user_id())
3646        .await?;
3647
3648    let mut connection_pool = session.connection_pool().await;
3649    notify_membership_updated(
3650        &mut connection_pool,
3651        membership_update,
3652        member_id,
3653        &session.peer,
3654    );
3655    for connection_id in connection_pool.user_connection_ids(member_id) {
3656        if let Some(notification_id) = notification_id {
3657            session
3658                .peer
3659                .send(
3660                    connection_id,
3661                    proto::DeleteNotification {
3662                        notification_id: notification_id.to_proto(),
3663                    },
3664                )
3665                .trace_err();
3666        }
3667    }
3668
3669    response.send(proto::Ack {})?;
3670    Ok(())
3671}
3672
3673/// Toggle the channel between public and private.
3674/// Care is taken to maintain the invariant that public channels only descend from public channels,
3675/// (though members-only channels can appear at any point in the hierarchy).
3676async fn set_channel_visibility(
3677    request: proto::SetChannelVisibility,
3678    response: Response<proto::SetChannelVisibility>,
3679    session: UserSession,
3680) -> Result<()> {
3681    let db = session.db().await;
3682    let channel_id = ChannelId::from_proto(request.channel_id);
3683    let visibility = request.visibility().into();
3684
3685    let channel_model = db
3686        .set_channel_visibility(channel_id, visibility, session.user_id())
3687        .await?;
3688    let root_id = channel_model.root_id();
3689    let channel = Channel::from_model(channel_model);
3690
3691    let mut connection_pool = session.connection_pool().await;
3692    for (user_id, role) in connection_pool
3693        .channel_user_ids(root_id)
3694        .collect::<Vec<_>>()
3695        .into_iter()
3696    {
3697        let update = if role.can_see_channel(channel.visibility) {
3698            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3699            proto::UpdateChannels {
3700                channels: vec![channel.to_proto()],
3701                ..Default::default()
3702            }
3703        } else {
3704            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3705            proto::UpdateChannels {
3706                delete_channels: vec![channel.id.to_proto()],
3707                ..Default::default()
3708            }
3709        };
3710
3711        for connection_id in connection_pool.user_connection_ids(user_id) {
3712            session.peer.send(connection_id, update.clone())?;
3713        }
3714    }
3715
3716    response.send(proto::Ack {})?;
3717    Ok(())
3718}
3719
3720/// Alter the role for a user in the channel.
3721async fn set_channel_member_role(
3722    request: proto::SetChannelMemberRole,
3723    response: Response<proto::SetChannelMemberRole>,
3724    session: UserSession,
3725) -> Result<()> {
3726    let db = session.db().await;
3727    let channel_id = ChannelId::from_proto(request.channel_id);
3728    let member_id = UserId::from_proto(request.user_id);
3729    let result = db
3730        .set_channel_member_role(
3731            channel_id,
3732            session.user_id(),
3733            member_id,
3734            request.role().into(),
3735        )
3736        .await?;
3737
3738    match result {
3739        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3740            let mut connection_pool = session.connection_pool().await;
3741            notify_membership_updated(
3742                &mut connection_pool,
3743                membership_update,
3744                member_id,
3745                &session.peer,
3746            )
3747        }
3748        db::SetMemberRoleResult::InviteUpdated(channel) => {
3749            let update = proto::UpdateChannels {
3750                channel_invitations: vec![channel.to_proto()],
3751                ..Default::default()
3752            };
3753
3754            for connection_id in session
3755                .connection_pool()
3756                .await
3757                .user_connection_ids(member_id)
3758            {
3759                session.peer.send(connection_id, update.clone())?;
3760            }
3761        }
3762    }
3763
3764    response.send(proto::Ack {})?;
3765    Ok(())
3766}
3767
3768/// Change the name of a channel
3769async fn rename_channel(
3770    request: proto::RenameChannel,
3771    response: Response<proto::RenameChannel>,
3772    session: UserSession,
3773) -> Result<()> {
3774    let db = session.db().await;
3775    let channel_id = ChannelId::from_proto(request.channel_id);
3776    let channel_model = db
3777        .rename_channel(channel_id, session.user_id(), &request.name)
3778        .await?;
3779    let root_id = channel_model.root_id();
3780    let channel = Channel::from_model(channel_model);
3781
3782    response.send(proto::RenameChannelResponse {
3783        channel: Some(channel.to_proto()),
3784    })?;
3785
3786    let connection_pool = session.connection_pool().await;
3787    let update = proto::UpdateChannels {
3788        channels: vec![channel.to_proto()],
3789        ..Default::default()
3790    };
3791    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3792        if role.can_see_channel(channel.visibility) {
3793            session.peer.send(connection_id, update.clone())?;
3794        }
3795    }
3796
3797    Ok(())
3798}
3799
3800/// Move a channel to a new parent.
3801async fn move_channel(
3802    request: proto::MoveChannel,
3803    response: Response<proto::MoveChannel>,
3804    session: UserSession,
3805) -> Result<()> {
3806    let channel_id = ChannelId::from_proto(request.channel_id);
3807    let to = ChannelId::from_proto(request.to);
3808
3809    let (root_id, channels) = session
3810        .db()
3811        .await
3812        .move_channel(channel_id, to, session.user_id())
3813        .await?;
3814
3815    let connection_pool = session.connection_pool().await;
3816    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3817        let channels = channels
3818            .iter()
3819            .filter_map(|channel| {
3820                if role.can_see_channel(channel.visibility) {
3821                    Some(channel.to_proto())
3822                } else {
3823                    None
3824                }
3825            })
3826            .collect::<Vec<_>>();
3827        if channels.is_empty() {
3828            continue;
3829        }
3830
3831        let update = proto::UpdateChannels {
3832            channels,
3833            ..Default::default()
3834        };
3835
3836        session.peer.send(connection_id, update.clone())?;
3837    }
3838
3839    response.send(Ack {})?;
3840    Ok(())
3841}
3842
3843/// Get the list of channel members
3844async fn get_channel_members(
3845    request: proto::GetChannelMembers,
3846    response: Response<proto::GetChannelMembers>,
3847    session: UserSession,
3848) -> Result<()> {
3849    let db = session.db().await;
3850    let channel_id = ChannelId::from_proto(request.channel_id);
3851    let limit = if request.limit == 0 {
3852        u16::MAX as u64
3853    } else {
3854        request.limit
3855    };
3856    let (members, users) = db
3857        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3858        .await?;
3859    response.send(proto::GetChannelMembersResponse { members, users })?;
3860    Ok(())
3861}
3862
3863/// Accept or decline a channel invitation.
3864async fn respond_to_channel_invite(
3865    request: proto::RespondToChannelInvite,
3866    response: Response<proto::RespondToChannelInvite>,
3867    session: UserSession,
3868) -> Result<()> {
3869    let db = session.db().await;
3870    let channel_id = ChannelId::from_proto(request.channel_id);
3871    let RespondToChannelInvite {
3872        membership_update,
3873        notifications,
3874    } = db
3875        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3876        .await?;
3877
3878    let mut connection_pool = session.connection_pool().await;
3879    if let Some(membership_update) = membership_update {
3880        notify_membership_updated(
3881            &mut connection_pool,
3882            membership_update,
3883            session.user_id(),
3884            &session.peer,
3885        );
3886    } else {
3887        let update = proto::UpdateChannels {
3888            remove_channel_invitations: vec![channel_id.to_proto()],
3889            ..Default::default()
3890        };
3891
3892        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3893            session.peer.send(connection_id, update.clone())?;
3894        }
3895    };
3896
3897    send_notifications(&connection_pool, &session.peer, notifications);
3898
3899    response.send(proto::Ack {})?;
3900
3901    Ok(())
3902}
3903
3904/// Join the channels' room
3905async fn join_channel(
3906    request: proto::JoinChannel,
3907    response: Response<proto::JoinChannel>,
3908    session: UserSession,
3909) -> Result<()> {
3910    let channel_id = ChannelId::from_proto(request.channel_id);
3911    join_channel_internal(channel_id, Box::new(response), session).await
3912}
3913
3914trait JoinChannelInternalResponse {
3915    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3916}
3917impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3918    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3919        Response::<proto::JoinChannel>::send(self, result)
3920    }
3921}
3922impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3923    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3924        Response::<proto::JoinRoom>::send(self, result)
3925    }
3926}
3927
3928async fn join_channel_internal(
3929    channel_id: ChannelId,
3930    response: Box<impl JoinChannelInternalResponse>,
3931    session: UserSession,
3932) -> Result<()> {
3933    let joined_room = {
3934        let mut db = session.db().await;
3935        // If zed quits without leaving the room, and the user re-opens zed before the
3936        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3937        // room they were in.
3938        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3939            tracing::info!(
3940                stale_connection_id = %connection,
3941                "cleaning up stale connection",
3942            );
3943            drop(db);
3944            leave_room_for_session(&session, connection).await?;
3945            db = session.db().await;
3946        }
3947
3948        let (joined_room, membership_updated, role) = db
3949            .join_channel(channel_id, session.user_id(), session.connection_id)
3950            .await?;
3951
3952        let live_kit_connection_info =
3953            session
3954                .app_state
3955                .live_kit_client
3956                .as_ref()
3957                .and_then(|live_kit| {
3958                    let (can_publish, token) = if role == ChannelRole::Guest {
3959                        (
3960                            false,
3961                            live_kit
3962                                .guest_token(
3963                                    &joined_room.room.live_kit_room,
3964                                    &session.user_id().to_string(),
3965                                )
3966                                .trace_err()?,
3967                        )
3968                    } else {
3969                        (
3970                            true,
3971                            live_kit
3972                                .room_token(
3973                                    &joined_room.room.live_kit_room,
3974                                    &session.user_id().to_string(),
3975                                )
3976                                .trace_err()?,
3977                        )
3978                    };
3979
3980                    Some(LiveKitConnectionInfo {
3981                        server_url: live_kit.url().into(),
3982                        token,
3983                        can_publish,
3984                    })
3985                });
3986
3987        response.send(proto::JoinRoomResponse {
3988            room: Some(joined_room.room.clone()),
3989            channel_id: joined_room
3990                .channel
3991                .as_ref()
3992                .map(|channel| channel.id.to_proto()),
3993            live_kit_connection_info,
3994        })?;
3995
3996        let mut connection_pool = session.connection_pool().await;
3997        if let Some(membership_updated) = membership_updated {
3998            notify_membership_updated(
3999                &mut connection_pool,
4000                membership_updated,
4001                session.user_id(),
4002                &session.peer,
4003            );
4004        }
4005
4006        room_updated(&joined_room.room, &session.peer);
4007
4008        joined_room
4009    };
4010
4011    channel_updated(
4012        &joined_room
4013            .channel
4014            .ok_or_else(|| anyhow!("channel not returned"))?,
4015        &joined_room.room,
4016        &session.peer,
4017        &*session.connection_pool().await,
4018    );
4019
4020    update_user_contacts(session.user_id(), &session).await?;
4021    Ok(())
4022}
4023
4024/// Start editing the channel notes
4025async fn join_channel_buffer(
4026    request: proto::JoinChannelBuffer,
4027    response: Response<proto::JoinChannelBuffer>,
4028    session: UserSession,
4029) -> Result<()> {
4030    let db = session.db().await;
4031    let channel_id = ChannelId::from_proto(request.channel_id);
4032
4033    let open_response = db
4034        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4035        .await?;
4036
4037    let collaborators = open_response.collaborators.clone();
4038    response.send(open_response)?;
4039
4040    let update = UpdateChannelBufferCollaborators {
4041        channel_id: channel_id.to_proto(),
4042        collaborators: collaborators.clone(),
4043    };
4044    channel_buffer_updated(
4045        session.connection_id,
4046        collaborators
4047            .iter()
4048            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4049        &update,
4050        &session.peer,
4051    );
4052
4053    Ok(())
4054}
4055
4056/// Edit the channel notes
4057async fn update_channel_buffer(
4058    request: proto::UpdateChannelBuffer,
4059    session: UserSession,
4060) -> Result<()> {
4061    let db = session.db().await;
4062    let channel_id = ChannelId::from_proto(request.channel_id);
4063
4064    let (collaborators, epoch, version) = db
4065        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4066        .await?;
4067
4068    channel_buffer_updated(
4069        session.connection_id,
4070        collaborators.clone(),
4071        &proto::UpdateChannelBuffer {
4072            channel_id: channel_id.to_proto(),
4073            operations: request.operations,
4074        },
4075        &session.peer,
4076    );
4077
4078    let pool = &*session.connection_pool().await;
4079
4080    let non_collaborators =
4081        pool.channel_connection_ids(channel_id)
4082            .filter_map(|(connection_id, _)| {
4083                if collaborators.contains(&connection_id) {
4084                    None
4085                } else {
4086                    Some(connection_id)
4087                }
4088            });
4089
4090    broadcast(None, non_collaborators, |peer_id| {
4091        session.peer.send(
4092            peer_id,
4093            proto::UpdateChannels {
4094                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4095                    channel_id: channel_id.to_proto(),
4096                    epoch: epoch as u64,
4097                    version: version.clone(),
4098                }],
4099                ..Default::default()
4100            },
4101        )
4102    });
4103
4104    Ok(())
4105}
4106
4107/// Rejoin the channel notes after a connection blip
4108async fn rejoin_channel_buffers(
4109    request: proto::RejoinChannelBuffers,
4110    response: Response<proto::RejoinChannelBuffers>,
4111    session: UserSession,
4112) -> Result<()> {
4113    let db = session.db().await;
4114    let buffers = db
4115        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4116        .await?;
4117
4118    for rejoined_buffer in &buffers {
4119        let collaborators_to_notify = rejoined_buffer
4120            .buffer
4121            .collaborators
4122            .iter()
4123            .filter_map(|c| Some(c.peer_id?.into()));
4124        channel_buffer_updated(
4125            session.connection_id,
4126            collaborators_to_notify,
4127            &proto::UpdateChannelBufferCollaborators {
4128                channel_id: rejoined_buffer.buffer.channel_id,
4129                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4130            },
4131            &session.peer,
4132        );
4133    }
4134
4135    response.send(proto::RejoinChannelBuffersResponse {
4136        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4137    })?;
4138
4139    Ok(())
4140}
4141
4142/// Stop editing the channel notes
4143async fn leave_channel_buffer(
4144    request: proto::LeaveChannelBuffer,
4145    response: Response<proto::LeaveChannelBuffer>,
4146    session: UserSession,
4147) -> Result<()> {
4148    let db = session.db().await;
4149    let channel_id = ChannelId::from_proto(request.channel_id);
4150
4151    let left_buffer = db
4152        .leave_channel_buffer(channel_id, session.connection_id)
4153        .await?;
4154
4155    response.send(Ack {})?;
4156
4157    channel_buffer_updated(
4158        session.connection_id,
4159        left_buffer.connections,
4160        &proto::UpdateChannelBufferCollaborators {
4161            channel_id: channel_id.to_proto(),
4162            collaborators: left_buffer.collaborators,
4163        },
4164        &session.peer,
4165    );
4166
4167    Ok(())
4168}
4169
4170fn channel_buffer_updated<T: EnvelopedMessage>(
4171    sender_id: ConnectionId,
4172    collaborators: impl IntoIterator<Item = ConnectionId>,
4173    message: &T,
4174    peer: &Peer,
4175) {
4176    broadcast(Some(sender_id), collaborators, |peer_id| {
4177        peer.send(peer_id, message.clone())
4178    });
4179}
4180
4181fn send_notifications(
4182    connection_pool: &ConnectionPool,
4183    peer: &Peer,
4184    notifications: db::NotificationBatch,
4185) {
4186    for (user_id, notification) in notifications {
4187        for connection_id in connection_pool.user_connection_ids(user_id) {
4188            if let Err(error) = peer.send(
4189                connection_id,
4190                proto::AddNotification {
4191                    notification: Some(notification.clone()),
4192                },
4193            ) {
4194                tracing::error!(
4195                    "failed to send notification to {:?} {}",
4196                    connection_id,
4197                    error
4198                );
4199            }
4200        }
4201    }
4202}
4203
4204/// Send a message to the channel
4205async fn send_channel_message(
4206    request: proto::SendChannelMessage,
4207    response: Response<proto::SendChannelMessage>,
4208    session: UserSession,
4209) -> Result<()> {
4210    // Validate the message body.
4211    let body = request.body.trim().to_string();
4212    if body.len() > MAX_MESSAGE_LEN {
4213        return Err(anyhow!("message is too long"))?;
4214    }
4215    if body.is_empty() {
4216        return Err(anyhow!("message can't be blank"))?;
4217    }
4218
4219    // TODO: adjust mentions if body is trimmed
4220
4221    let timestamp = OffsetDateTime::now_utc();
4222    let nonce = request
4223        .nonce
4224        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4225
4226    let channel_id = ChannelId::from_proto(request.channel_id);
4227    let CreatedChannelMessage {
4228        message_id,
4229        participant_connection_ids,
4230        notifications,
4231    } = session
4232        .db()
4233        .await
4234        .create_channel_message(
4235            channel_id,
4236            session.user_id(),
4237            &body,
4238            &request.mentions,
4239            timestamp,
4240            nonce.clone().into(),
4241            request.reply_to_message_id.map(MessageId::from_proto),
4242        )
4243        .await?;
4244
4245    let message = proto::ChannelMessage {
4246        sender_id: session.user_id().to_proto(),
4247        id: message_id.to_proto(),
4248        body,
4249        mentions: request.mentions,
4250        timestamp: timestamp.unix_timestamp() as u64,
4251        nonce: Some(nonce),
4252        reply_to_message_id: request.reply_to_message_id,
4253        edited_at: None,
4254    };
4255    broadcast(
4256        Some(session.connection_id),
4257        participant_connection_ids.clone(),
4258        |connection| {
4259            session.peer.send(
4260                connection,
4261                proto::ChannelMessageSent {
4262                    channel_id: channel_id.to_proto(),
4263                    message: Some(message.clone()),
4264                },
4265            )
4266        },
4267    );
4268    response.send(proto::SendChannelMessageResponse {
4269        message: Some(message),
4270    })?;
4271
4272    let pool = &*session.connection_pool().await;
4273    let non_participants =
4274        pool.channel_connection_ids(channel_id)
4275            .filter_map(|(connection_id, _)| {
4276                if participant_connection_ids.contains(&connection_id) {
4277                    None
4278                } else {
4279                    Some(connection_id)
4280                }
4281            });
4282    broadcast(None, non_participants, |peer_id| {
4283        session.peer.send(
4284            peer_id,
4285            proto::UpdateChannels {
4286                latest_channel_message_ids: vec![proto::ChannelMessageId {
4287                    channel_id: channel_id.to_proto(),
4288                    message_id: message_id.to_proto(),
4289                }],
4290                ..Default::default()
4291            },
4292        )
4293    });
4294    send_notifications(pool, &session.peer, notifications);
4295
4296    Ok(())
4297}
4298
4299/// Delete a channel message
4300async fn remove_channel_message(
4301    request: proto::RemoveChannelMessage,
4302    response: Response<proto::RemoveChannelMessage>,
4303    session: UserSession,
4304) -> Result<()> {
4305    let channel_id = ChannelId::from_proto(request.channel_id);
4306    let message_id = MessageId::from_proto(request.message_id);
4307    let (connection_ids, existing_notification_ids) = session
4308        .db()
4309        .await
4310        .remove_channel_message(channel_id, message_id, session.user_id())
4311        .await?;
4312
4313    broadcast(
4314        Some(session.connection_id),
4315        connection_ids,
4316        move |connection| {
4317            session.peer.send(connection, request.clone())?;
4318
4319            for notification_id in &existing_notification_ids {
4320                session.peer.send(
4321                    connection,
4322                    proto::DeleteNotification {
4323                        notification_id: (*notification_id).to_proto(),
4324                    },
4325                )?;
4326            }
4327
4328            Ok(())
4329        },
4330    );
4331    response.send(proto::Ack {})?;
4332    Ok(())
4333}
4334
4335async fn update_channel_message(
4336    request: proto::UpdateChannelMessage,
4337    response: Response<proto::UpdateChannelMessage>,
4338    session: UserSession,
4339) -> Result<()> {
4340    let channel_id = ChannelId::from_proto(request.channel_id);
4341    let message_id = MessageId::from_proto(request.message_id);
4342    let updated_at = OffsetDateTime::now_utc();
4343    let UpdatedChannelMessage {
4344        message_id,
4345        participant_connection_ids,
4346        notifications,
4347        reply_to_message_id,
4348        timestamp,
4349        deleted_mention_notification_ids,
4350        updated_mention_notifications,
4351    } = session
4352        .db()
4353        .await
4354        .update_channel_message(
4355            channel_id,
4356            message_id,
4357            session.user_id(),
4358            request.body.as_str(),
4359            &request.mentions,
4360            updated_at,
4361        )
4362        .await?;
4363
4364    let nonce = request
4365        .nonce
4366        .clone()
4367        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4368
4369    let message = proto::ChannelMessage {
4370        sender_id: session.user_id().to_proto(),
4371        id: message_id.to_proto(),
4372        body: request.body.clone(),
4373        mentions: request.mentions.clone(),
4374        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4375        nonce: Some(nonce),
4376        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4377        edited_at: Some(updated_at.unix_timestamp() as u64),
4378    };
4379
4380    response.send(proto::Ack {})?;
4381
4382    let pool = &*session.connection_pool().await;
4383    broadcast(
4384        Some(session.connection_id),
4385        participant_connection_ids,
4386        |connection| {
4387            session.peer.send(
4388                connection,
4389                proto::ChannelMessageUpdate {
4390                    channel_id: channel_id.to_proto(),
4391                    message: Some(message.clone()),
4392                },
4393            )?;
4394
4395            for notification_id in &deleted_mention_notification_ids {
4396                session.peer.send(
4397                    connection,
4398                    proto::DeleteNotification {
4399                        notification_id: (*notification_id).to_proto(),
4400                    },
4401                )?;
4402            }
4403
4404            for notification in &updated_mention_notifications {
4405                session.peer.send(
4406                    connection,
4407                    proto::UpdateNotification {
4408                        notification: Some(notification.clone()),
4409                    },
4410                )?;
4411            }
4412
4413            Ok(())
4414        },
4415    );
4416
4417    send_notifications(pool, &session.peer, notifications);
4418
4419    Ok(())
4420}
4421
4422/// Mark a channel message as read
4423async fn acknowledge_channel_message(
4424    request: proto::AckChannelMessage,
4425    session: UserSession,
4426) -> Result<()> {
4427    let channel_id = ChannelId::from_proto(request.channel_id);
4428    let message_id = MessageId::from_proto(request.message_id);
4429    let notifications = session
4430        .db()
4431        .await
4432        .observe_channel_message(channel_id, session.user_id(), message_id)
4433        .await?;
4434    send_notifications(
4435        &*session.connection_pool().await,
4436        &session.peer,
4437        notifications,
4438    );
4439    Ok(())
4440}
4441
4442/// Mark a buffer version as synced
4443async fn acknowledge_buffer_version(
4444    request: proto::AckBufferOperation,
4445    session: UserSession,
4446) -> Result<()> {
4447    let buffer_id = BufferId::from_proto(request.buffer_id);
4448    session
4449        .db()
4450        .await
4451        .observe_buffer_version(
4452            buffer_id,
4453            session.user_id(),
4454            request.epoch as i32,
4455            &request.version,
4456        )
4457        .await?;
4458    Ok(())
4459}
4460
4461async fn count_language_model_tokens(
4462    request: proto::CountLanguageModelTokens,
4463    response: Response<proto::CountLanguageModelTokens>,
4464    session: Session,
4465    config: &Config,
4466) -> Result<()> {
4467    let Some(session) = session.for_user() else {
4468        return Err(anyhow!("user not found"))?;
4469    };
4470    authorize_access_to_legacy_llm_endpoints(&session).await?;
4471
4472    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4473        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4474        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4475    };
4476
4477    session
4478        .app_state
4479        .rate_limiter
4480        .check(&*rate_limit, session.user_id())
4481        .await?;
4482
4483    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4484        Some(proto::LanguageModelProvider::Google) => {
4485            let api_key = config
4486                .google_ai_api_key
4487                .as_ref()
4488                .context("no Google AI API key configured on the server")?;
4489            google_ai::count_tokens(
4490                session.http_client.as_ref(),
4491                google_ai::API_URL,
4492                api_key,
4493                serde_json::from_str(&request.request)?,
4494                None,
4495            )
4496            .await?
4497        }
4498        _ => return Err(anyhow!("unsupported provider"))?,
4499    };
4500
4501    response.send(proto::CountLanguageModelTokensResponse {
4502        token_count: result.total_tokens as u32,
4503    })?;
4504
4505    Ok(())
4506}
4507
4508struct ZedProCountLanguageModelTokensRateLimit;
4509
4510impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4511    fn capacity(&self) -> usize {
4512        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4513            .ok()
4514            .and_then(|v| v.parse().ok())
4515            .unwrap_or(600) // Picked arbitrarily
4516    }
4517
4518    fn refill_duration(&self) -> chrono::Duration {
4519        chrono::Duration::hours(1)
4520    }
4521
4522    fn db_name(&self) -> &'static str {
4523        "zed-pro:count-language-model-tokens"
4524    }
4525}
4526
4527struct FreeCountLanguageModelTokensRateLimit;
4528
4529impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4530    fn capacity(&self) -> usize {
4531        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4532            .ok()
4533            .and_then(|v| v.parse().ok())
4534            .unwrap_or(600 / 10) // Picked arbitrarily
4535    }
4536
4537    fn refill_duration(&self) -> chrono::Duration {
4538        chrono::Duration::hours(1)
4539    }
4540
4541    fn db_name(&self) -> &'static str {
4542        "free:count-language-model-tokens"
4543    }
4544}
4545
4546struct ZedProComputeEmbeddingsRateLimit;
4547
4548impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4549    fn capacity(&self) -> usize {
4550        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4551            .ok()
4552            .and_then(|v| v.parse().ok())
4553            .unwrap_or(5000) // Picked arbitrarily
4554    }
4555
4556    fn refill_duration(&self) -> chrono::Duration {
4557        chrono::Duration::hours(1)
4558    }
4559
4560    fn db_name(&self) -> &'static str {
4561        "zed-pro:compute-embeddings"
4562    }
4563}
4564
4565struct FreeComputeEmbeddingsRateLimit;
4566
4567impl RateLimit for FreeComputeEmbeddingsRateLimit {
4568    fn capacity(&self) -> usize {
4569        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4570            .ok()
4571            .and_then(|v| v.parse().ok())
4572            .unwrap_or(5000 / 10) // Picked arbitrarily
4573    }
4574
4575    fn refill_duration(&self) -> chrono::Duration {
4576        chrono::Duration::hours(1)
4577    }
4578
4579    fn db_name(&self) -> &'static str {
4580        "free:compute-embeddings"
4581    }
4582}
4583
4584async fn compute_embeddings(
4585    request: proto::ComputeEmbeddings,
4586    response: Response<proto::ComputeEmbeddings>,
4587    session: UserSession,
4588    api_key: Option<Arc<str>>,
4589) -> Result<()> {
4590    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4591    authorize_access_to_legacy_llm_endpoints(&session).await?;
4592
4593    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4594        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4595        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4596    };
4597
4598    session
4599        .app_state
4600        .rate_limiter
4601        .check(&*rate_limit, session.user_id())
4602        .await?;
4603
4604    let embeddings = match request.model.as_str() {
4605        "openai/text-embedding-3-small" => {
4606            open_ai::embed(
4607                session.http_client.as_ref(),
4608                OPEN_AI_API_URL,
4609                &api_key,
4610                OpenAiEmbeddingModel::TextEmbedding3Small,
4611                request.texts.iter().map(|text| text.as_str()),
4612            )
4613            .await?
4614        }
4615        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4616    };
4617
4618    let embeddings = request
4619        .texts
4620        .iter()
4621        .map(|text| {
4622            let mut hasher = sha2::Sha256::new();
4623            hasher.update(text.as_bytes());
4624            let result = hasher.finalize();
4625            result.to_vec()
4626        })
4627        .zip(
4628            embeddings
4629                .data
4630                .into_iter()
4631                .map(|embedding| embedding.embedding),
4632        )
4633        .collect::<HashMap<_, _>>();
4634
4635    let db = session.db().await;
4636    db.save_embeddings(&request.model, &embeddings)
4637        .await
4638        .context("failed to save embeddings")
4639        .trace_err();
4640
4641    response.send(proto::ComputeEmbeddingsResponse {
4642        embeddings: embeddings
4643            .into_iter()
4644            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4645            .collect(),
4646    })?;
4647    Ok(())
4648}
4649
4650async fn get_cached_embeddings(
4651    request: proto::GetCachedEmbeddings,
4652    response: Response<proto::GetCachedEmbeddings>,
4653    session: UserSession,
4654) -> Result<()> {
4655    authorize_access_to_legacy_llm_endpoints(&session).await?;
4656
4657    let db = session.db().await;
4658    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4659
4660    response.send(proto::GetCachedEmbeddingsResponse {
4661        embeddings: embeddings
4662            .into_iter()
4663            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4664            .collect(),
4665    })?;
4666    Ok(())
4667}
4668
4669/// This is leftover from before the LLM service.
4670///
4671/// The endpoints protected by this check will be moved there eventually.
4672async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4673    if session.is_staff() {
4674        Ok(())
4675    } else {
4676        Err(anyhow!("permission denied"))?
4677    }
4678}
4679
4680/// Get a Supermaven API key for the user
4681async fn get_supermaven_api_key(
4682    _request: proto::GetSupermavenApiKey,
4683    response: Response<proto::GetSupermavenApiKey>,
4684    session: UserSession,
4685) -> Result<()> {
4686    let user_id: String = session.user_id().to_string();
4687    if !session.is_staff() {
4688        return Err(anyhow!("supermaven not enabled for this account"))?;
4689    }
4690
4691    let email = session
4692        .email()
4693        .ok_or_else(|| anyhow!("user must have an email"))?;
4694
4695    let supermaven_admin_api = session
4696        .supermaven_client
4697        .as_ref()
4698        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4699
4700    let result = supermaven_admin_api
4701        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4702        .await?;
4703
4704    response.send(proto::GetSupermavenApiKeyResponse {
4705        api_key: result.api_key,
4706    })?;
4707
4708    Ok(())
4709}
4710
4711/// Start receiving chat updates for a channel
4712async fn join_channel_chat(
4713    request: proto::JoinChannelChat,
4714    response: Response<proto::JoinChannelChat>,
4715    session: UserSession,
4716) -> Result<()> {
4717    let channel_id = ChannelId::from_proto(request.channel_id);
4718
4719    let db = session.db().await;
4720    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4721        .await?;
4722    let messages = db
4723        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4724        .await?;
4725    response.send(proto::JoinChannelChatResponse {
4726        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4727        messages,
4728    })?;
4729    Ok(())
4730}
4731
4732/// Stop receiving chat updates for a channel
4733async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4734    let channel_id = ChannelId::from_proto(request.channel_id);
4735    session
4736        .db()
4737        .await
4738        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4739        .await?;
4740    Ok(())
4741}
4742
4743/// Retrieve the chat history for a channel
4744async fn get_channel_messages(
4745    request: proto::GetChannelMessages,
4746    response: Response<proto::GetChannelMessages>,
4747    session: UserSession,
4748) -> Result<()> {
4749    let channel_id = ChannelId::from_proto(request.channel_id);
4750    let messages = session
4751        .db()
4752        .await
4753        .get_channel_messages(
4754            channel_id,
4755            session.user_id(),
4756            MESSAGE_COUNT_PER_PAGE,
4757            Some(MessageId::from_proto(request.before_message_id)),
4758        )
4759        .await?;
4760    response.send(proto::GetChannelMessagesResponse {
4761        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4762        messages,
4763    })?;
4764    Ok(())
4765}
4766
4767/// Retrieve specific chat messages
4768async fn get_channel_messages_by_id(
4769    request: proto::GetChannelMessagesById,
4770    response: Response<proto::GetChannelMessagesById>,
4771    session: UserSession,
4772) -> Result<()> {
4773    let message_ids = request
4774        .message_ids
4775        .iter()
4776        .map(|id| MessageId::from_proto(*id))
4777        .collect::<Vec<_>>();
4778    let messages = session
4779        .db()
4780        .await
4781        .get_channel_messages_by_id(session.user_id(), &message_ids)
4782        .await?;
4783    response.send(proto::GetChannelMessagesResponse {
4784        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4785        messages,
4786    })?;
4787    Ok(())
4788}
4789
4790/// Retrieve the current users notifications
4791async fn get_notifications(
4792    request: proto::GetNotifications,
4793    response: Response<proto::GetNotifications>,
4794    session: UserSession,
4795) -> Result<()> {
4796    let notifications = session
4797        .db()
4798        .await
4799        .get_notifications(
4800            session.user_id(),
4801            NOTIFICATION_COUNT_PER_PAGE,
4802            request.before_id.map(db::NotificationId::from_proto),
4803        )
4804        .await?;
4805    response.send(proto::GetNotificationsResponse {
4806        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4807        notifications,
4808    })?;
4809    Ok(())
4810}
4811
4812/// Mark notifications as read
4813async fn mark_notification_as_read(
4814    request: proto::MarkNotificationRead,
4815    response: Response<proto::MarkNotificationRead>,
4816    session: UserSession,
4817) -> Result<()> {
4818    let database = &session.db().await;
4819    let notifications = database
4820        .mark_notification_as_read_by_id(
4821            session.user_id(),
4822            NotificationId::from_proto(request.notification_id),
4823        )
4824        .await?;
4825    send_notifications(
4826        &*session.connection_pool().await,
4827        &session.peer,
4828        notifications,
4829    );
4830    response.send(proto::Ack {})?;
4831    Ok(())
4832}
4833
4834/// Get the current users information
4835async fn get_private_user_info(
4836    _request: proto::GetPrivateUserInfo,
4837    response: Response<proto::GetPrivateUserInfo>,
4838    session: UserSession,
4839) -> Result<()> {
4840    let db = session.db().await;
4841
4842    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4843    let user = db
4844        .get_user_by_id(session.user_id())
4845        .await?
4846        .ok_or_else(|| anyhow!("user not found"))?;
4847    let flags = db.get_user_flags(session.user_id()).await?;
4848
4849    response.send(proto::GetPrivateUserInfoResponse {
4850        metrics_id,
4851        staff: user.admin,
4852        flags,
4853        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4854    })?;
4855    Ok(())
4856}
4857
4858/// Accept the terms of service (tos) on behalf of the current user
4859async fn accept_terms_of_service(
4860    _request: proto::AcceptTermsOfService,
4861    response: Response<proto::AcceptTermsOfService>,
4862    session: UserSession,
4863) -> Result<()> {
4864    let db = session.db().await;
4865
4866    let accepted_tos_at = Utc::now();
4867    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4868        .await?;
4869
4870    response.send(proto::AcceptTermsOfServiceResponse {
4871        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4872    })?;
4873    Ok(())
4874}
4875
4876/// The minimum account age an account must have in order to use the LLM service.
4877const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4878
4879async fn get_llm_api_token(
4880    _request: proto::GetLlmToken,
4881    response: Response<proto::GetLlmToken>,
4882    session: UserSession,
4883) -> Result<()> {
4884    let db = session.db().await;
4885
4886    let flags = db.get_user_flags(session.user_id()).await?;
4887    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4888    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4889
4890    if !session.is_staff() && !has_language_models_feature_flag {
4891        Err(anyhow!("permission denied"))?
4892    }
4893
4894    let user_id = session.user_id();
4895    let user = db
4896        .get_user_by_id(user_id)
4897        .await?
4898        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4899
4900    if user.accepted_tos_at.is_none() {
4901        Err(anyhow!("terms of service not accepted"))?
4902    }
4903
4904    let mut account_created_at = user.created_at;
4905    if let Some(github_created_at) = user.github_user_created_at {
4906        account_created_at = account_created_at.min(github_created_at);
4907    }
4908    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4909        Err(anyhow!("account too young"))?
4910    }
4911    let token = LlmTokenClaims::create(
4912        user.id,
4913        user.github_login.clone(),
4914        session.is_staff(),
4915        has_llm_closed_beta_feature_flag,
4916        session.current_plan(db).await?,
4917        &session.app_state.config,
4918    )?;
4919    response.send(proto::GetLlmTokenResponse { token })?;
4920    Ok(())
4921}
4922
4923fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4924    let message = match message {
4925        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4926        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4927        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4928        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4929        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4930            code: frame.code.into(),
4931            reason: frame.reason,
4932        })),
4933        // We should never receive a frame while reading the message, according
4934        // to the `tungstenite` maintainers:
4935        //
4936        // > It cannot occur when you read messages from the WebSocket, but it
4937        // > can be used when you want to send the raw frames (e.g. you want to
4938        // > send the frames to the WebSocket without composing the full message first).
4939        // >
4940        // > — https://github.com/snapview/tungstenite-rs/issues/268
4941        TungsteniteMessage::Frame(_) => {
4942            bail!("received an unexpected frame while reading the message")
4943        }
4944    };
4945
4946    Ok(message)
4947}
4948
4949fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4950    match message {
4951        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4952        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4953        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4954        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4955        AxumMessage::Close(frame) => {
4956            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4957                code: frame.code.into(),
4958                reason: frame.reason,
4959            }))
4960        }
4961    }
4962}
4963
4964fn notify_membership_updated(
4965    connection_pool: &mut ConnectionPool,
4966    result: MembershipUpdated,
4967    user_id: UserId,
4968    peer: &Peer,
4969) {
4970    for membership in &result.new_channels.channel_memberships {
4971        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4972    }
4973    for channel_id in &result.removed_channels {
4974        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4975    }
4976
4977    let user_channels_update = proto::UpdateUserChannels {
4978        channel_memberships: result
4979            .new_channels
4980            .channel_memberships
4981            .iter()
4982            .map(|cm| proto::ChannelMembership {
4983                channel_id: cm.channel_id.to_proto(),
4984                role: cm.role.into(),
4985            })
4986            .collect(),
4987        ..Default::default()
4988    };
4989
4990    let mut update = build_channels_update(result.new_channels);
4991    update.delete_channels = result
4992        .removed_channels
4993        .into_iter()
4994        .map(|id| id.to_proto())
4995        .collect();
4996    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4997
4998    for connection_id in connection_pool.user_connection_ids(user_id) {
4999        peer.send(connection_id, user_channels_update.clone())
5000            .trace_err();
5001        peer.send(connection_id, update.clone()).trace_err();
5002    }
5003}
5004
5005fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5006    proto::UpdateUserChannels {
5007        channel_memberships: channels
5008            .channel_memberships
5009            .iter()
5010            .map(|m| proto::ChannelMembership {
5011                channel_id: m.channel_id.to_proto(),
5012                role: m.role.into(),
5013            })
5014            .collect(),
5015        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5016        observed_channel_message_id: channels.observed_channel_messages.clone(),
5017    }
5018}
5019
5020fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5021    let mut update = proto::UpdateChannels::default();
5022
5023    for channel in channels.channels {
5024        update.channels.push(channel.to_proto());
5025    }
5026
5027    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5028    update.latest_channel_message_ids = channels.latest_channel_messages;
5029
5030    for (channel_id, participants) in channels.channel_participants {
5031        update
5032            .channel_participants
5033            .push(proto::ChannelParticipants {
5034                channel_id: channel_id.to_proto(),
5035                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5036            });
5037    }
5038
5039    for channel in channels.invited_channels {
5040        update.channel_invitations.push(channel.to_proto());
5041    }
5042
5043    update.hosted_projects = channels.hosted_projects;
5044    update
5045}
5046
5047fn build_initial_contacts_update(
5048    contacts: Vec<db::Contact>,
5049    pool: &ConnectionPool,
5050) -> proto::UpdateContacts {
5051    let mut update = proto::UpdateContacts::default();
5052
5053    for contact in contacts {
5054        match contact {
5055            db::Contact::Accepted { user_id, busy } => {
5056                update.contacts.push(contact_for_user(user_id, busy, pool));
5057            }
5058            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5059            db::Contact::Incoming { user_id } => {
5060                update
5061                    .incoming_requests
5062                    .push(proto::IncomingContactRequest {
5063                        requester_id: user_id.to_proto(),
5064                    })
5065            }
5066        }
5067    }
5068
5069    update
5070}
5071
5072fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5073    proto::Contact {
5074        user_id: user_id.to_proto(),
5075        online: pool.is_user_online(user_id),
5076        busy,
5077    }
5078}
5079
5080fn room_updated(room: &proto::Room, peer: &Peer) {
5081    broadcast(
5082        None,
5083        room.participants
5084            .iter()
5085            .filter_map(|participant| Some(participant.peer_id?.into())),
5086        |peer_id| {
5087            peer.send(
5088                peer_id,
5089                proto::RoomUpdated {
5090                    room: Some(room.clone()),
5091                },
5092            )
5093        },
5094    );
5095}
5096
5097fn channel_updated(
5098    channel: &db::channel::Model,
5099    room: &proto::Room,
5100    peer: &Peer,
5101    pool: &ConnectionPool,
5102) {
5103    let participants = room
5104        .participants
5105        .iter()
5106        .map(|p| p.user_id)
5107        .collect::<Vec<_>>();
5108
5109    broadcast(
5110        None,
5111        pool.channel_connection_ids(channel.root_id())
5112            .filter_map(|(channel_id, role)| {
5113                role.can_see_channel(channel.visibility)
5114                    .then_some(channel_id)
5115            }),
5116        |peer_id| {
5117            peer.send(
5118                peer_id,
5119                proto::UpdateChannels {
5120                    channel_participants: vec![proto::ChannelParticipants {
5121                        channel_id: channel.id.to_proto(),
5122                        participant_user_ids: participants.clone(),
5123                    }],
5124                    ..Default::default()
5125                },
5126            )
5127        },
5128    );
5129}
5130
5131async fn send_dev_server_projects_update(
5132    user_id: UserId,
5133    mut status: proto::DevServerProjectsUpdate,
5134    session: &Session,
5135) {
5136    let pool = session.connection_pool().await;
5137    for dev_server in &mut status.dev_servers {
5138        dev_server.status =
5139            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5140    }
5141    let connections = pool.user_connection_ids(user_id);
5142    for connection_id in connections {
5143        session.peer.send(connection_id, status.clone()).trace_err();
5144    }
5145}
5146
5147async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5148    let db = session.db().await;
5149
5150    let contacts = db.get_contacts(user_id).await?;
5151    let busy = db.is_user_busy(user_id).await?;
5152
5153    let pool = session.connection_pool().await;
5154    let updated_contact = contact_for_user(user_id, busy, &pool);
5155    for contact in contacts {
5156        if let db::Contact::Accepted {
5157            user_id: contact_user_id,
5158            ..
5159        } = contact
5160        {
5161            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5162                session
5163                    .peer
5164                    .send(
5165                        contact_conn_id,
5166                        proto::UpdateContacts {
5167                            contacts: vec![updated_contact.clone()],
5168                            remove_contacts: Default::default(),
5169                            incoming_requests: Default::default(),
5170                            remove_incoming_requests: Default::default(),
5171                            outgoing_requests: Default::default(),
5172                            remove_outgoing_requests: Default::default(),
5173                        },
5174                    )
5175                    .trace_err();
5176            }
5177        }
5178    }
5179    Ok(())
5180}
5181
5182async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5183    log::info!("lost dev server connection, unsharing projects");
5184    let project_ids = session
5185        .db()
5186        .await
5187        .get_stale_dev_server_projects(session.connection_id)
5188        .await?;
5189
5190    for project_id in project_ids {
5191        // not unshare re-checks the connection ids match, so we get away with no transaction
5192        unshare_project_internal(project_id, session.connection_id, None, session).await?;
5193    }
5194
5195    let user_id = session.dev_server().user_id;
5196    let update = session
5197        .db()
5198        .await
5199        .dev_server_projects_update(user_id)
5200        .await?;
5201
5202    send_dev_server_projects_update(user_id, update, session).await;
5203
5204    Ok(())
5205}
5206
5207async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5208    let mut contacts_to_update = HashSet::default();
5209
5210    let room_id;
5211    let canceled_calls_to_user_ids;
5212    let live_kit_room;
5213    let delete_live_kit_room;
5214    let room;
5215    let channel;
5216
5217    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5218        contacts_to_update.insert(session.user_id());
5219
5220        for project in left_room.left_projects.values() {
5221            project_left(project, session);
5222        }
5223
5224        room_id = RoomId::from_proto(left_room.room.id);
5225        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5226        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5227        delete_live_kit_room = left_room.deleted;
5228        room = mem::take(&mut left_room.room);
5229        channel = mem::take(&mut left_room.channel);
5230
5231        room_updated(&room, &session.peer);
5232    } else {
5233        return Ok(());
5234    }
5235
5236    if let Some(channel) = channel {
5237        channel_updated(
5238            &channel,
5239            &room,
5240            &session.peer,
5241            &*session.connection_pool().await,
5242        );
5243    }
5244
5245    {
5246        let pool = session.connection_pool().await;
5247        for canceled_user_id in canceled_calls_to_user_ids {
5248            for connection_id in pool.user_connection_ids(canceled_user_id) {
5249                session
5250                    .peer
5251                    .send(
5252                        connection_id,
5253                        proto::CallCanceled {
5254                            room_id: room_id.to_proto(),
5255                        },
5256                    )
5257                    .trace_err();
5258            }
5259            contacts_to_update.insert(canceled_user_id);
5260        }
5261    }
5262
5263    for contact_user_id in contacts_to_update {
5264        update_user_contacts(contact_user_id, session).await?;
5265    }
5266
5267    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5268        live_kit
5269            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5270            .await
5271            .trace_err();
5272
5273        if delete_live_kit_room {
5274            live_kit.delete_room(live_kit_room).await.trace_err();
5275        }
5276    }
5277
5278    Ok(())
5279}
5280
5281async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5282    let left_channel_buffers = session
5283        .db()
5284        .await
5285        .leave_channel_buffers(session.connection_id)
5286        .await?;
5287
5288    for left_buffer in left_channel_buffers {
5289        channel_buffer_updated(
5290            session.connection_id,
5291            left_buffer.connections,
5292            &proto::UpdateChannelBufferCollaborators {
5293                channel_id: left_buffer.channel_id.to_proto(),
5294                collaborators: left_buffer.collaborators,
5295            },
5296            &session.peer,
5297        );
5298    }
5299
5300    Ok(())
5301}
5302
5303fn project_left(project: &db::LeftProject, session: &UserSession) {
5304    for connection_id in &project.connection_ids {
5305        if project.should_unshare {
5306            session
5307                .peer
5308                .send(
5309                    *connection_id,
5310                    proto::UnshareProject {
5311                        project_id: project.id.to_proto(),
5312                    },
5313                )
5314                .trace_err();
5315        } else {
5316            session
5317                .peer
5318                .send(
5319                    *connection_id,
5320                    proto::RemoveProjectCollaborator {
5321                        project_id: project.id.to_proto(),
5322                        peer_id: Some(session.connection_id.into()),
5323                    },
5324                )
5325                .trace_err();
5326        }
5327    }
5328}
5329
5330pub trait ResultExt {
5331    type Ok;
5332
5333    fn trace_err(self) -> Option<Self::Ok>;
5334}
5335
5336impl<T, E> ResultExt for Result<T, E>
5337where
5338    E: std::fmt::Debug,
5339{
5340    type Ok = T;
5341
5342    #[track_caller]
5343    fn trace_err(self) -> Option<T> {
5344        match self {
5345            Ok(value) => Some(value),
5346            Err(error) => {
5347                tracing::error!("{:?}", error);
5348                None
5349            }
5350        }
5351    }
5352}