rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use http_client::HttpClient;
  39use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  40use reqwest_client::ReqwestClient;
  41use sha2::Digest;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    channel::oneshot,
  46    future::{self, BoxFuture},
  47    stream::FuturesUnordered,
  48    FutureExt, SinkExt, StreamExt, TryStreamExt,
  49};
  50use prometheus::{register_int_gauge, IntGauge};
  51use rpc::{
  52    proto::{
  53        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  54        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  55    },
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57};
  58use semantic_version::SemanticVersion;
  59use serde::{Serialize, Serializer};
  60use std::{
  61    any::TypeId,
  62    future::Future,
  63    marker::PhantomData,
  64    mem,
  65    net::SocketAddr,
  66    ops::{Deref, DerefMut},
  67    rc::Rc,
  68    sync::{
  69        atomic::{AtomicBool, Ordering::SeqCst},
  70        Arc, OnceLock,
  71    },
  72    time::{Duration, Instant},
  73};
  74use time::OffsetDateTime;
  75use tokio::sync::{watch, MutexGuard, Semaphore};
  76use tower::ServiceBuilder;
  77use tracing::{
  78    field::{self},
  79    info_span, instrument, Instrument,
  80};
  81
  82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  83
  84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  86
  87const MESSAGE_COUNT_PER_PAGE: usize = 100;
  88const MAX_MESSAGE_LEN: usize = 1024;
  89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  93
  94struct Response<R> {
  95    peer: Arc<Peer>,
  96    receipt: Receipt<R>,
  97    responded: Arc<AtomicBool>,
  98}
  99
 100impl<R: RequestMessage> Response<R> {
 101    fn send(self, payload: R::Response) -> Result<()> {
 102        self.responded.store(true, SeqCst);
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108#[derive(Clone, Debug)]
 109pub enum Principal {
 110    User(User),
 111    Impersonated { user: User, admin: User },
 112    DevServer(dev_server::Model),
 113}
 114
 115impl Principal {
 116    fn update_span(&self, span: &tracing::Span) {
 117        match &self {
 118            Principal::User(user) => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121            }
 122            Principal::Impersonated { user, admin } => {
 123                span.record("user_id", user.id.0);
 124                span.record("login", &user.github_login);
 125                span.record("impersonator", &admin.github_login);
 126            }
 127            Principal::DevServer(dev_server) => {
 128                span.record("dev_server_id", dev_server.id.0);
 129            }
 130        }
 131    }
 132}
 133
 134#[derive(Clone)]
 135struct Session {
 136    principal: Principal,
 137    connection_id: ConnectionId,
 138    db: Arc<tokio::sync::Mutex<DbHandle>>,
 139    peer: Arc<Peer>,
 140    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 141    app_state: Arc<AppState>,
 142    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 143    http_client: Arc<dyn HttpClient>,
 144    /// The GeoIP country code for the user.
 145    #[allow(unused)]
 146    geoip_country_code: Option<String>,
 147    _executor: Executor,
 148}
 149
 150impl Session {
 151    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.db.lock().await;
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        guard
 158    }
 159
 160    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        let guard = self.connection_pool.lock();
 164        ConnectionPoolGuard {
 165            guard,
 166            _not_send: PhantomData,
 167        }
 168    }
 169
 170    fn for_user(self) -> Option<UserSession> {
 171        UserSession::new(self)
 172    }
 173
 174    fn for_dev_server(self) -> Option<DevServerSession> {
 175        DevServerSession::new(self)
 176    }
 177
 178    fn user_id(&self) -> Option<UserId> {
 179        match &self.principal {
 180            Principal::User(user) => Some(user.id),
 181            Principal::Impersonated { user, .. } => Some(user.id),
 182            Principal::DevServer(_) => None,
 183        }
 184    }
 185
 186    fn is_staff(&self) -> bool {
 187        match &self.principal {
 188            Principal::User(user) => user.admin,
 189            Principal::Impersonated { .. } => true,
 190            Principal::DevServer(_) => false,
 191        }
 192    }
 193
 194    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 195        if self.is_staff() {
 196            return Ok(proto::Plan::ZedPro);
 197        }
 198
 199        let Some(user_id) = self.user_id() else {
 200            return Ok(proto::Plan::Free);
 201        };
 202
 203        if db.has_active_billing_subscription(user_id).await? {
 204            Ok(proto::Plan::ZedPro)
 205        } else {
 206            Ok(proto::Plan::Free)
 207        }
 208    }
 209
 210    fn dev_server_id(&self) -> Option<DevServerId> {
 211        match &self.principal {
 212            Principal::User(_) | Principal::Impersonated { .. } => None,
 213            Principal::DevServer(dev_server) => Some(dev_server.id),
 214        }
 215    }
 216
 217    fn principal_id(&self) -> PrincipalId {
 218        match &self.principal {
 219            Principal::User(user) => PrincipalId::UserId(user.id),
 220            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 221            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 222        }
 223    }
 224}
 225
 226impl Debug for Session {
 227    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 228        let mut result = f.debug_struct("Session");
 229        match &self.principal {
 230            Principal::User(user) => {
 231                result.field("user", &user.github_login);
 232            }
 233            Principal::Impersonated { user, admin } => {
 234                result.field("user", &user.github_login);
 235                result.field("impersonator", &admin.github_login);
 236            }
 237            Principal::DevServer(dev_server) => {
 238                result.field("dev_server", &dev_server.id);
 239            }
 240        }
 241        result.field("connection_id", &self.connection_id).finish()
 242    }
 243}
 244
 245struct UserSession(Session);
 246
 247impl UserSession {
 248    pub fn new(s: Session) -> Option<Self> {
 249        s.user_id().map(|_| UserSession(s))
 250    }
 251    pub fn user_id(&self) -> UserId {
 252        self.0.user_id().unwrap()
 253    }
 254
 255    pub fn email(&self) -> Option<String> {
 256        match &self.0.principal {
 257            Principal::User(user) => user.email_address.clone(),
 258            Principal::Impersonated { user, .. } => user.email_address.clone(),
 259            Principal::DevServer(..) => None,
 260        }
 261    }
 262}
 263
 264impl Deref for UserSession {
 265    type Target = Session;
 266
 267    fn deref(&self) -> &Self::Target {
 268        &self.0
 269    }
 270}
 271impl DerefMut for UserSession {
 272    fn deref_mut(&mut self) -> &mut Self::Target {
 273        &mut self.0
 274    }
 275}
 276
 277struct DevServerSession(Session);
 278
 279impl DevServerSession {
 280    pub fn new(s: Session) -> Option<Self> {
 281        s.dev_server_id().map(|_| DevServerSession(s))
 282    }
 283    pub fn dev_server_id(&self) -> DevServerId {
 284        self.0.dev_server_id().unwrap()
 285    }
 286
 287    fn dev_server(&self) -> &dev_server::Model {
 288        match &self.0.principal {
 289            Principal::DevServer(dev_server) => dev_server,
 290            _ => unreachable!(),
 291        }
 292    }
 293}
 294
 295impl Deref for DevServerSession {
 296    type Target = Session;
 297
 298    fn deref(&self) -> &Self::Target {
 299        &self.0
 300    }
 301}
 302impl DerefMut for DevServerSession {
 303    fn deref_mut(&mut self) -> &mut Self::Target {
 304        &mut self.0
 305    }
 306}
 307
 308fn user_handler<M: RequestMessage, Fut>(
 309    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 310) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 311where
 312    Fut: Send + Future<Output = Result<()>>,
 313{
 314    let handler = Arc::new(handler);
 315    move |message, response, session| {
 316        let handler = handler.clone();
 317        Box::pin(async move {
 318            if let Some(user_session) = session.for_user() {
 319                Ok(handler(message, response, user_session).await?)
 320            } else {
 321                Err(Error::Internal(anyhow!(
 322                    "must be a user to call {}",
 323                    M::NAME
 324                )))
 325            }
 326        })
 327    }
 328}
 329
 330fn dev_server_handler<M: RequestMessage, Fut>(
 331    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 332) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 333where
 334    Fut: Send + Future<Output = Result<()>>,
 335{
 336    let handler = Arc::new(handler);
 337    move |message, response, session| {
 338        let handler = handler.clone();
 339        Box::pin(async move {
 340            if let Some(dev_server_session) = session.for_dev_server() {
 341                Ok(handler(message, response, dev_server_session).await?)
 342            } else {
 343                Err(Error::Internal(anyhow!(
 344                    "must be a dev server to call {}",
 345                    M::NAME
 346                )))
 347            }
 348        })
 349    }
 350}
 351
 352fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 353    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 354) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 355where
 356    InnertRetFut: Send + Future<Output = Result<()>>,
 357{
 358    let handler = Arc::new(handler);
 359    move |message, session| {
 360        let handler = handler.clone();
 361        Box::pin(async move {
 362            if let Some(user_session) = session.for_user() {
 363                Ok(handler(message, user_session).await?)
 364            } else {
 365                Err(Error::Internal(anyhow!(
 366                    "must be a user to call {}",
 367                    M::NAME
 368                )))
 369            }
 370        })
 371    }
 372}
 373
 374struct DbHandle(Arc<Database>);
 375
 376impl Deref for DbHandle {
 377    type Target = Database;
 378
 379    fn deref(&self) -> &Self::Target {
 380        self.0.as_ref()
 381    }
 382}
 383
 384pub struct Server {
 385    id: parking_lot::Mutex<ServerId>,
 386    peer: Arc<Peer>,
 387    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 388    app_state: Arc<AppState>,
 389    handlers: HashMap<TypeId, MessageHandler>,
 390    teardown: watch::Sender<bool>,
 391}
 392
 393pub(crate) struct ConnectionPoolGuard<'a> {
 394    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 395    _not_send: PhantomData<Rc<()>>,
 396}
 397
 398#[derive(Serialize)]
 399pub struct ServerSnapshot<'a> {
 400    peer: &'a Peer,
 401    #[serde(serialize_with = "serialize_deref")]
 402    connection_pool: ConnectionPoolGuard<'a>,
 403}
 404
 405pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 406where
 407    S: Serializer,
 408    T: Deref<Target = U>,
 409    U: Serialize,
 410{
 411    Serialize::serialize(value.deref(), serializer)
 412}
 413
 414impl Server {
 415    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 416        let mut server = Self {
 417            id: parking_lot::Mutex::new(id),
 418            peer: Peer::new(id.0 as u32),
 419            app_state: app_state.clone(),
 420            connection_pool: Default::default(),
 421            handlers: Default::default(),
 422            teardown: watch::channel(false).0,
 423        };
 424
 425        server
 426            .add_request_handler(ping)
 427            .add_request_handler(user_handler(create_room))
 428            .add_request_handler(user_handler(join_room))
 429            .add_request_handler(user_handler(rejoin_room))
 430            .add_request_handler(user_handler(leave_room))
 431            .add_request_handler(user_handler(set_room_participant_role))
 432            .add_request_handler(user_handler(call))
 433            .add_request_handler(user_handler(cancel_call))
 434            .add_message_handler(user_message_handler(decline_call))
 435            .add_request_handler(user_handler(update_participant_location))
 436            .add_request_handler(user_handler(share_project))
 437            .add_message_handler(unshare_project)
 438            .add_request_handler(user_handler(join_project))
 439            .add_request_handler(user_handler(join_hosted_project))
 440            .add_request_handler(user_handler(rejoin_dev_server_projects))
 441            .add_request_handler(user_handler(create_dev_server_project))
 442            .add_request_handler(user_handler(update_dev_server_project))
 443            .add_request_handler(user_handler(delete_dev_server_project))
 444            .add_request_handler(user_handler(create_dev_server))
 445            .add_request_handler(user_handler(regenerate_dev_server_token))
 446            .add_request_handler(user_handler(rename_dev_server))
 447            .add_request_handler(user_handler(delete_dev_server))
 448            .add_request_handler(user_handler(list_remote_directory))
 449            .add_request_handler(dev_server_handler(share_dev_server_project))
 450            .add_request_handler(dev_server_handler(shutdown_dev_server))
 451            .add_request_handler(dev_server_handler(reconnect_dev_server))
 452            .add_message_handler(user_message_handler(leave_project))
 453            .add_request_handler(update_project)
 454            .add_request_handler(update_worktree)
 455            .add_message_handler(start_language_server)
 456            .add_message_handler(update_language_server)
 457            .add_message_handler(update_diagnostic_summary)
 458            .add_message_handler(update_worktree_settings)
 459            .add_request_handler(user_handler(
 460                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 461            ))
 462            .add_request_handler(user_handler(
 463                forward_project_request_for_owner::<proto::TaskTemplates>,
 464            ))
 465            .add_request_handler(user_handler(
 466                forward_read_only_project_request::<proto::GetHover>,
 467            ))
 468            .add_request_handler(user_handler(
 469                forward_read_only_project_request::<proto::GetDefinition>,
 470            ))
 471            .add_request_handler(user_handler(
 472                forward_read_only_project_request::<proto::GetTypeDefinition>,
 473            ))
 474            .add_request_handler(user_handler(
 475                forward_read_only_project_request::<proto::GetReferences>,
 476            ))
 477            .add_request_handler(user_handler(forward_find_search_candidates_request))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetProjectSymbols>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::OpenBufferById>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_read_only_project_request::<proto::InlayHints>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_read_only_project_request::<proto::ResolveInlayHint>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_read_only_project_request::<proto::OpenBufferByPath>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::GetCompletions>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::OpenNewBuffer>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::GetCodeActions>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ApplyCodeAction>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::PrepareRename>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::PerformRename>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::ReloadBuffers>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::FormatBuffers>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::CreateProjectEntry>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::RenameProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::CopyProjectEntry>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::OnTypeFormatting>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::SaveBuffer>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::BlameBuffer>,
 555            ))
 556            .add_request_handler(user_handler(
 557                forward_mutating_project_request::<proto::MultiLspQuery>,
 558            ))
 559            .add_request_handler(user_handler(
 560                forward_mutating_project_request::<proto::RestartLanguageServers>,
 561            ))
 562            .add_request_handler(user_handler(
 563                forward_mutating_project_request::<proto::LinkedEditingRange>,
 564            ))
 565            .add_message_handler(create_buffer_for_peer)
 566            .add_request_handler(update_buffer)
 567            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 568            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 569            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 571            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 572            .add_request_handler(get_users)
 573            .add_request_handler(user_handler(fuzzy_search_users))
 574            .add_request_handler(user_handler(request_contact))
 575            .add_request_handler(user_handler(remove_contact))
 576            .add_request_handler(user_handler(respond_to_contact_request))
 577            .add_message_handler(subscribe_to_channels)
 578            .add_request_handler(user_handler(create_channel))
 579            .add_request_handler(user_handler(delete_channel))
 580            .add_request_handler(user_handler(invite_channel_member))
 581            .add_request_handler(user_handler(remove_channel_member))
 582            .add_request_handler(user_handler(set_channel_member_role))
 583            .add_request_handler(user_handler(set_channel_visibility))
 584            .add_request_handler(user_handler(rename_channel))
 585            .add_request_handler(user_handler(join_channel_buffer))
 586            .add_request_handler(user_handler(leave_channel_buffer))
 587            .add_message_handler(user_message_handler(update_channel_buffer))
 588            .add_request_handler(user_handler(rejoin_channel_buffers))
 589            .add_request_handler(user_handler(get_channel_members))
 590            .add_request_handler(user_handler(respond_to_channel_invite))
 591            .add_request_handler(user_handler(join_channel))
 592            .add_request_handler(user_handler(join_channel_chat))
 593            .add_message_handler(user_message_handler(leave_channel_chat))
 594            .add_request_handler(user_handler(send_channel_message))
 595            .add_request_handler(user_handler(remove_channel_message))
 596            .add_request_handler(user_handler(update_channel_message))
 597            .add_request_handler(user_handler(get_channel_messages))
 598            .add_request_handler(user_handler(get_channel_messages_by_id))
 599            .add_request_handler(user_handler(get_notifications))
 600            .add_request_handler(user_handler(mark_notification_as_read))
 601            .add_request_handler(user_handler(move_channel))
 602            .add_request_handler(user_handler(follow))
 603            .add_message_handler(user_message_handler(unfollow))
 604            .add_message_handler(user_message_handler(update_followers))
 605            .add_request_handler(user_handler(get_private_user_info))
 606            .add_request_handler(user_handler(get_llm_api_token))
 607            .add_request_handler(user_handler(accept_terms_of_service))
 608            .add_message_handler(user_message_handler(acknowledge_channel_message))
 609            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 610            .add_request_handler(user_handler(get_supermaven_api_key))
 611            .add_request_handler(user_handler(
 612                forward_mutating_project_request::<proto::OpenContext>,
 613            ))
 614            .add_request_handler(user_handler(
 615                forward_mutating_project_request::<proto::CreateContext>,
 616            ))
 617            .add_request_handler(user_handler(
 618                forward_mutating_project_request::<proto::SynchronizeContexts>,
 619            ))
 620            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 621            .add_message_handler(update_context)
 622            .add_request_handler({
 623                let app_state = app_state.clone();
 624                move |request, response, session| {
 625                    let app_state = app_state.clone();
 626                    async move {
 627                        count_language_model_tokens(request, response, session, &app_state.config)
 628                            .await
 629                    }
 630                }
 631            })
 632            .add_request_handler({
 633                user_handler(move |request, response, session| {
 634                    get_cached_embeddings(request, response, session)
 635                })
 636            })
 637            .add_request_handler({
 638                let app_state = app_state.clone();
 639                user_handler(move |request, response, session| {
 640                    compute_embeddings(
 641                        request,
 642                        response,
 643                        session,
 644                        app_state.config.openai_api_key.clone(),
 645                    )
 646                })
 647            });
 648
 649        Arc::new(server)
 650    }
 651
 652    pub async fn start(&self) -> Result<()> {
 653        let server_id = *self.id.lock();
 654        let app_state = self.app_state.clone();
 655        let peer = self.peer.clone();
 656        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 657        let pool = self.connection_pool.clone();
 658        let live_kit_client = self.app_state.live_kit_client.clone();
 659
 660        let span = info_span!("start server");
 661        self.app_state.executor.spawn_detached(
 662            async move {
 663                tracing::info!("waiting for cleanup timeout");
 664                timeout.await;
 665                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 666                if let Some((room_ids, channel_ids)) = app_state
 667                    .db
 668                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 669                    .await
 670                    .trace_err()
 671                {
 672                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 673                    tracing::info!(
 674                        stale_channel_buffer_count = channel_ids.len(),
 675                        "retrieved stale channel buffers"
 676                    );
 677
 678                    for channel_id in channel_ids {
 679                        if let Some(refreshed_channel_buffer) = app_state
 680                            .db
 681                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 682                            .await
 683                            .trace_err()
 684                        {
 685                            for connection_id in refreshed_channel_buffer.connection_ids {
 686                                peer.send(
 687                                    connection_id,
 688                                    proto::UpdateChannelBufferCollaborators {
 689                                        channel_id: channel_id.to_proto(),
 690                                        collaborators: refreshed_channel_buffer
 691                                            .collaborators
 692                                            .clone(),
 693                                    },
 694                                )
 695                                .trace_err();
 696                            }
 697                        }
 698                    }
 699
 700                    for room_id in room_ids {
 701                        let mut contacts_to_update = HashSet::default();
 702                        let mut canceled_calls_to_user_ids = Vec::new();
 703                        let mut live_kit_room = String::new();
 704                        let mut delete_live_kit_room = false;
 705
 706                        if let Some(mut refreshed_room) = app_state
 707                            .db
 708                            .clear_stale_room_participants(room_id, server_id)
 709                            .await
 710                            .trace_err()
 711                        {
 712                            tracing::info!(
 713                                room_id = room_id.0,
 714                                new_participant_count = refreshed_room.room.participants.len(),
 715                                "refreshed room"
 716                            );
 717                            room_updated(&refreshed_room.room, &peer);
 718                            if let Some(channel) = refreshed_room.channel.as_ref() {
 719                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 720                            }
 721                            contacts_to_update
 722                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 723                            contacts_to_update
 724                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 725                            canceled_calls_to_user_ids =
 726                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 727                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 728                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 729                        }
 730
 731                        {
 732                            let pool = pool.lock();
 733                            for canceled_user_id in canceled_calls_to_user_ids {
 734                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 735                                    peer.send(
 736                                        connection_id,
 737                                        proto::CallCanceled {
 738                                            room_id: room_id.to_proto(),
 739                                        },
 740                                    )
 741                                    .trace_err();
 742                                }
 743                            }
 744                        }
 745
 746                        for user_id in contacts_to_update {
 747                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 748                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 749                            if let Some((busy, contacts)) = busy.zip(contacts) {
 750                                let pool = pool.lock();
 751                                let updated_contact = contact_for_user(user_id, busy, &pool);
 752                                for contact in contacts {
 753                                    if let db::Contact::Accepted {
 754                                        user_id: contact_user_id,
 755                                        ..
 756                                    } = contact
 757                                    {
 758                                        for contact_conn_id in
 759                                            pool.user_connection_ids(contact_user_id)
 760                                        {
 761                                            peer.send(
 762                                                contact_conn_id,
 763                                                proto::UpdateContacts {
 764                                                    contacts: vec![updated_contact.clone()],
 765                                                    remove_contacts: Default::default(),
 766                                                    incoming_requests: Default::default(),
 767                                                    remove_incoming_requests: Default::default(),
 768                                                    outgoing_requests: Default::default(),
 769                                                    remove_outgoing_requests: Default::default(),
 770                                                },
 771                                            )
 772                                            .trace_err();
 773                                        }
 774                                    }
 775                                }
 776                            }
 777                        }
 778
 779                        if let Some(live_kit) = live_kit_client.as_ref() {
 780                            if delete_live_kit_room {
 781                                live_kit.delete_room(live_kit_room).await.trace_err();
 782                            }
 783                        }
 784                    }
 785                }
 786
 787                app_state
 788                    .db
 789                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 790                    .await
 791                    .trace_err();
 792            }
 793            .instrument(span),
 794        );
 795        Ok(())
 796    }
 797
 798    pub fn teardown(&self) {
 799        self.peer.teardown();
 800        self.connection_pool.lock().reset();
 801        let _ = self.teardown.send(true);
 802    }
 803
 804    #[cfg(test)]
 805    pub fn reset(&self, id: ServerId) {
 806        self.teardown();
 807        *self.id.lock() = id;
 808        self.peer.reset(id.0 as u32);
 809        let _ = self.teardown.send(false);
 810    }
 811
 812    #[cfg(test)]
 813    pub fn id(&self) -> ServerId {
 814        *self.id.lock()
 815    }
 816
 817    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 818    where
 819        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 820        Fut: 'static + Send + Future<Output = Result<()>>,
 821        M: EnvelopedMessage,
 822    {
 823        let prev_handler = self.handlers.insert(
 824            TypeId::of::<M>(),
 825            Box::new(move |envelope, session| {
 826                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 827                let received_at = envelope.received_at;
 828                tracing::info!("message received");
 829                let start_time = Instant::now();
 830                let future = (handler)(*envelope, session);
 831                async move {
 832                    let result = future.await;
 833                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 834                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 835                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 836                    let payload_type = M::NAME;
 837
 838                    match result {
 839                        Err(error) => {
 840                            tracing::error!(
 841                                ?error,
 842                                total_duration_ms,
 843                                processing_duration_ms,
 844                                queue_duration_ms,
 845                                payload_type,
 846                                "error handling message"
 847                            )
 848                        }
 849                        Ok(()) => tracing::info!(
 850                            total_duration_ms,
 851                            processing_duration_ms,
 852                            queue_duration_ms,
 853                            "finished handling message"
 854                        ),
 855                    }
 856                }
 857                .boxed()
 858            }),
 859        );
 860        if prev_handler.is_some() {
 861            panic!("registered a handler for the same message twice");
 862        }
 863        self
 864    }
 865
 866    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 867    where
 868        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 869        Fut: 'static + Send + Future<Output = Result<()>>,
 870        M: EnvelopedMessage,
 871    {
 872        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 873        self
 874    }
 875
 876    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 877    where
 878        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 879        Fut: Send + Future<Output = Result<()>>,
 880        M: RequestMessage,
 881    {
 882        let handler = Arc::new(handler);
 883        self.add_handler(move |envelope, session| {
 884            let receipt = envelope.receipt();
 885            let handler = handler.clone();
 886            async move {
 887                let peer = session.peer.clone();
 888                let responded = Arc::new(AtomicBool::default());
 889                let response = Response {
 890                    peer: peer.clone(),
 891                    responded: responded.clone(),
 892                    receipt,
 893                };
 894                match (handler)(envelope.payload, response, session).await {
 895                    Ok(()) => {
 896                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 897                            Ok(())
 898                        } else {
 899                            Err(anyhow!("handler did not send a response"))?
 900                        }
 901                    }
 902                    Err(error) => {
 903                        let proto_err = match &error {
 904                            Error::Internal(err) => err.to_proto(),
 905                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 906                        };
 907                        peer.respond_with_error(receipt, proto_err)?;
 908                        Err(error)
 909                    }
 910                }
 911            }
 912        })
 913    }
 914
 915    #[allow(clippy::too_many_arguments)]
 916    pub fn handle_connection(
 917        self: &Arc<Self>,
 918        connection: Connection,
 919        address: String,
 920        principal: Principal,
 921        zed_version: ZedVersion,
 922        geoip_country_code: Option<String>,
 923        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 924        executor: Executor,
 925    ) -> impl Future<Output = ()> {
 926        let this = self.clone();
 927        let span = info_span!("handle connection", %address,
 928            connection_id=field::Empty,
 929            user_id=field::Empty,
 930            login=field::Empty,
 931            impersonator=field::Empty,
 932            dev_server_id=field::Empty,
 933            geoip_country_code=field::Empty
 934        );
 935        principal.update_span(&span);
 936        if let Some(country_code) = geoip_country_code.as_ref() {
 937            span.record("geoip_country_code", country_code);
 938        }
 939
 940        let mut teardown = self.teardown.subscribe();
 941        async move {
 942            if *teardown.borrow() {
 943                tracing::error!("server is tearing down");
 944                return
 945            }
 946            let (connection_id, handle_io, mut incoming_rx) = this
 947                .peer
 948                .add_connection(connection, {
 949                    let executor = executor.clone();
 950                    move |duration| executor.sleep(duration)
 951                });
 952            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 953
 954            tracing::info!("connection opened");
 955
 956            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 957            let http_client = match ReqwestClient::user_agent(&user_agent) {
 958                Ok(http_client) => Arc::new(http_client),
 959                Err(error) => {
 960                    tracing::error!(?error, "failed to create HTTP client");
 961                    return;
 962                }
 963            };
 964
 965            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 966                    supermaven_admin_api_key.to_string(),
 967                    http_client.clone(),
 968                )));
 969
 970            let session = Session {
 971                principal: principal.clone(),
 972                connection_id,
 973                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 974                peer: this.peer.clone(),
 975                connection_pool: this.connection_pool.clone(),
 976                app_state: this.app_state.clone(),
 977                http_client,
 978                geoip_country_code,
 979                _executor: executor.clone(),
 980                supermaven_client,
 981            };
 982
 983            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 984                tracing::error!(?error, "failed to send initial client update");
 985                return;
 986            }
 987
 988            let handle_io = handle_io.fuse();
 989            futures::pin_mut!(handle_io);
 990
 991            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 992            // This prevents deadlocks when e.g., client A performs a request to client B and
 993            // client B performs a request to client A. If both clients stop processing further
 994            // messages until their respective request completes, they won't have a chance to
 995            // respond to the other client's request and cause a deadlock.
 996            //
 997            // This arrangement ensures we will attempt to process earlier messages first, but fall
 998            // back to processing messages arrived later in the spirit of making progress.
 999            let mut foreground_message_handlers = FuturesUnordered::new();
1000            let concurrent_handlers = Arc::new(Semaphore::new(256));
1001            loop {
1002                let next_message = async {
1003                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1004                    let message = incoming_rx.next().await;
1005                    (permit, message)
1006                }.fuse();
1007                futures::pin_mut!(next_message);
1008                futures::select_biased! {
1009                    _ = teardown.changed().fuse() => return,
1010                    result = handle_io => {
1011                        if let Err(error) = result {
1012                            tracing::error!(?error, "error handling I/O");
1013                        }
1014                        break;
1015                    }
1016                    _ = foreground_message_handlers.next() => {}
1017                    next_message = next_message => {
1018                        let (permit, message) = next_message;
1019                        if let Some(message) = message {
1020                            let type_name = message.payload_type_name();
1021                            // note: we copy all the fields from the parent span so we can query them in the logs.
1022                            // (https://github.com/tokio-rs/tracing/issues/2670).
1023                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1024                                user_id=field::Empty,
1025                                login=field::Empty,
1026                                impersonator=field::Empty,
1027                                dev_server_id=field::Empty
1028                            );
1029                            principal.update_span(&span);
1030                            let span_enter = span.enter();
1031                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1032                                let is_background = message.is_background();
1033                                let handle_message = (handler)(message, session.clone());
1034                                drop(span_enter);
1035
1036                                let handle_message = async move {
1037                                    handle_message.await;
1038                                    drop(permit);
1039                                }.instrument(span);
1040                                if is_background {
1041                                    executor.spawn_detached(handle_message);
1042                                } else {
1043                                    foreground_message_handlers.push(handle_message);
1044                                }
1045                            } else {
1046                                tracing::error!("no message handler");
1047                            }
1048                        } else {
1049                            tracing::info!("connection closed");
1050                            break;
1051                        }
1052                    }
1053                }
1054            }
1055
1056            drop(foreground_message_handlers);
1057            tracing::info!("signing out");
1058            if let Err(error) = connection_lost(session, teardown, executor).await {
1059                tracing::error!(?error, "error signing out");
1060            }
1061
1062        }.instrument(span)
1063    }
1064
1065    async fn send_initial_client_update(
1066        &self,
1067        connection_id: ConnectionId,
1068        principal: &Principal,
1069        zed_version: ZedVersion,
1070        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1071        session: &Session,
1072    ) -> Result<()> {
1073        self.peer.send(
1074            connection_id,
1075            proto::Hello {
1076                peer_id: Some(connection_id.into()),
1077            },
1078        )?;
1079        tracing::info!("sent hello message");
1080        if let Some(send_connection_id) = send_connection_id.take() {
1081            let _ = send_connection_id.send(connection_id);
1082        }
1083
1084        match principal {
1085            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1086                if !user.connected_once {
1087                    self.peer.send(connection_id, proto::ShowContacts {})?;
1088                    self.app_state
1089                        .db
1090                        .set_user_connected_once(user.id, true)
1091                        .await?;
1092                }
1093
1094                update_user_plan(user.id, session).await?;
1095
1096                let (contacts, dev_server_projects) = future::try_join(
1097                    self.app_state.db.get_contacts(user.id),
1098                    self.app_state.db.dev_server_projects_update(user.id),
1099                )
1100                .await?;
1101
1102                {
1103                    let mut pool = self.connection_pool.lock();
1104                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1105                    self.peer.send(
1106                        connection_id,
1107                        build_initial_contacts_update(contacts, &pool),
1108                    )?;
1109                }
1110
1111                if should_auto_subscribe_to_channels(zed_version) {
1112                    subscribe_user_to_channels(user.id, session).await?;
1113                }
1114
1115                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1116
1117                if let Some(incoming_call) =
1118                    self.app_state.db.incoming_call_for_user(user.id).await?
1119                {
1120                    self.peer.send(connection_id, incoming_call)?;
1121                }
1122
1123                update_user_contacts(user.id, session).await?;
1124            }
1125            Principal::DevServer(dev_server) => {
1126                {
1127                    let mut pool = self.connection_pool.lock();
1128                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1129                    {
1130                        self.peer.send(
1131                            stale_connection_id,
1132                            proto::ShutdownDevServer {
1133                                reason: Some(
1134                                    "another dev server connected with the same token".to_string(),
1135                                ),
1136                            },
1137                        )?;
1138                        pool.remove_connection(stale_connection_id)?;
1139                    };
1140                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1141                }
1142
1143                let projects = self
1144                    .app_state
1145                    .db
1146                    .get_projects_for_dev_server(dev_server.id)
1147                    .await?;
1148                self.peer
1149                    .send(connection_id, proto::DevServerInstructions { projects })?;
1150
1151                let status = self
1152                    .app_state
1153                    .db
1154                    .dev_server_projects_update(dev_server.user_id)
1155                    .await?;
1156                send_dev_server_projects_update(dev_server.user_id, status, session).await;
1157            }
1158        }
1159
1160        Ok(())
1161    }
1162
1163    pub async fn invite_code_redeemed(
1164        self: &Arc<Self>,
1165        inviter_id: UserId,
1166        invitee_id: UserId,
1167    ) -> Result<()> {
1168        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1169            if let Some(code) = &user.invite_code {
1170                let pool = self.connection_pool.lock();
1171                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1172                for connection_id in pool.user_connection_ids(inviter_id) {
1173                    self.peer.send(
1174                        connection_id,
1175                        proto::UpdateContacts {
1176                            contacts: vec![invitee_contact.clone()],
1177                            ..Default::default()
1178                        },
1179                    )?;
1180                    self.peer.send(
1181                        connection_id,
1182                        proto::UpdateInviteInfo {
1183                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1184                            count: user.invite_count as u32,
1185                        },
1186                    )?;
1187                }
1188            }
1189        }
1190        Ok(())
1191    }
1192
1193    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1194        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1195            if let Some(invite_code) = &user.invite_code {
1196                let pool = self.connection_pool.lock();
1197                for connection_id in pool.user_connection_ids(user_id) {
1198                    self.peer.send(
1199                        connection_id,
1200                        proto::UpdateInviteInfo {
1201                            url: format!(
1202                                "{}{}",
1203                                self.app_state.config.invite_link_prefix, invite_code
1204                            ),
1205                            count: user.invite_count as u32,
1206                        },
1207                    )?;
1208                }
1209            }
1210        }
1211        Ok(())
1212    }
1213
1214    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1215        ServerSnapshot {
1216            connection_pool: ConnectionPoolGuard {
1217                guard: self.connection_pool.lock(),
1218                _not_send: PhantomData,
1219            },
1220            peer: &self.peer,
1221        }
1222    }
1223}
1224
1225impl<'a> Deref for ConnectionPoolGuard<'a> {
1226    type Target = ConnectionPool;
1227
1228    fn deref(&self) -> &Self::Target {
1229        &self.guard
1230    }
1231}
1232
1233impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1234    fn deref_mut(&mut self) -> &mut Self::Target {
1235        &mut self.guard
1236    }
1237}
1238
1239impl<'a> Drop for ConnectionPoolGuard<'a> {
1240    fn drop(&mut self) {
1241        #[cfg(test)]
1242        self.check_invariants();
1243    }
1244}
1245
1246fn broadcast<F>(
1247    sender_id: Option<ConnectionId>,
1248    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1249    mut f: F,
1250) where
1251    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1252{
1253    for receiver_id in receiver_ids {
1254        if Some(receiver_id) != sender_id {
1255            if let Err(error) = f(receiver_id) {
1256                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1257            }
1258        }
1259    }
1260}
1261
1262pub struct ProtocolVersion(u32);
1263
1264impl Header for ProtocolVersion {
1265    fn name() -> &'static HeaderName {
1266        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1267        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1268    }
1269
1270    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1271    where
1272        Self: Sized,
1273        I: Iterator<Item = &'i axum::http::HeaderValue>,
1274    {
1275        let version = values
1276            .next()
1277            .ok_or_else(axum::headers::Error::invalid)?
1278            .to_str()
1279            .map_err(|_| axum::headers::Error::invalid())?
1280            .parse()
1281            .map_err(|_| axum::headers::Error::invalid())?;
1282        Ok(Self(version))
1283    }
1284
1285    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1286        values.extend([self.0.to_string().parse().unwrap()]);
1287    }
1288}
1289
1290pub struct AppVersionHeader(SemanticVersion);
1291impl Header for AppVersionHeader {
1292    fn name() -> &'static HeaderName {
1293        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1294        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1295    }
1296
1297    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1298    where
1299        Self: Sized,
1300        I: Iterator<Item = &'i axum::http::HeaderValue>,
1301    {
1302        let version = values
1303            .next()
1304            .ok_or_else(axum::headers::Error::invalid)?
1305            .to_str()
1306            .map_err(|_| axum::headers::Error::invalid())?
1307            .parse()
1308            .map_err(|_| axum::headers::Error::invalid())?;
1309        Ok(Self(version))
1310    }
1311
1312    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1313        values.extend([self.0.to_string().parse().unwrap()]);
1314    }
1315}
1316
1317pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1318    Router::new()
1319        .route("/rpc", get(handle_websocket_request))
1320        .layer(
1321            ServiceBuilder::new()
1322                .layer(Extension(server.app_state.clone()))
1323                .layer(middleware::from_fn(auth::validate_header)),
1324        )
1325        .route("/metrics", get(handle_metrics))
1326        .layer(Extension(server))
1327}
1328
1329pub async fn handle_websocket_request(
1330    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1331    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1332    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1333    Extension(server): Extension<Arc<Server>>,
1334    Extension(principal): Extension<Principal>,
1335    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1336    ws: WebSocketUpgrade,
1337) -> axum::response::Response {
1338    if protocol_version != rpc::PROTOCOL_VERSION {
1339        return (
1340            StatusCode::UPGRADE_REQUIRED,
1341            "client must be upgraded".to_string(),
1342        )
1343            .into_response();
1344    }
1345
1346    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1347        return (
1348            StatusCode::UPGRADE_REQUIRED,
1349            "no version header found".to_string(),
1350        )
1351            .into_response();
1352    };
1353
1354    if !version.can_collaborate() {
1355        return (
1356            StatusCode::UPGRADE_REQUIRED,
1357            "client must be upgraded".to_string(),
1358        )
1359            .into_response();
1360    }
1361
1362    let socket_address = socket_address.to_string();
1363    ws.on_upgrade(move |socket| {
1364        let socket = socket
1365            .map_ok(to_tungstenite_message)
1366            .err_into()
1367            .with(|message| async move { to_axum_message(message) });
1368        let connection = Connection::new(Box::pin(socket));
1369        async move {
1370            server
1371                .handle_connection(
1372                    connection,
1373                    socket_address,
1374                    principal,
1375                    version,
1376                    country_code_header.map(|header| header.to_string()),
1377                    None,
1378                    Executor::Production,
1379                )
1380                .await;
1381        }
1382    })
1383}
1384
1385pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1386    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1387    let connections_metric = CONNECTIONS_METRIC
1388        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1389
1390    let connections = server
1391        .connection_pool
1392        .lock()
1393        .connections()
1394        .filter(|connection| !connection.admin)
1395        .count();
1396    connections_metric.set(connections as _);
1397
1398    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1399    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1400        register_int_gauge!(
1401            "shared_projects",
1402            "number of open projects with one or more guests"
1403        )
1404        .unwrap()
1405    });
1406
1407    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1408    shared_projects_metric.set(shared_projects as _);
1409
1410    let encoder = prometheus::TextEncoder::new();
1411    let metric_families = prometheus::gather();
1412    let encoded_metrics = encoder
1413        .encode_to_string(&metric_families)
1414        .map_err(|err| anyhow!("{}", err))?;
1415    Ok(encoded_metrics)
1416}
1417
1418#[instrument(err, skip(executor))]
1419async fn connection_lost(
1420    session: Session,
1421    mut teardown: watch::Receiver<bool>,
1422    executor: Executor,
1423) -> Result<()> {
1424    session.peer.disconnect(session.connection_id);
1425    session
1426        .connection_pool()
1427        .await
1428        .remove_connection(session.connection_id)?;
1429
1430    session
1431        .db()
1432        .await
1433        .connection_lost(session.connection_id)
1434        .await
1435        .trace_err();
1436
1437    futures::select_biased! {
1438        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1439            match &session.principal {
1440                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1441                    let session = session.for_user().unwrap();
1442
1443                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1444                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1445                    leave_channel_buffers_for_session(&session)
1446                        .await
1447                        .trace_err();
1448
1449                    if !session
1450                        .connection_pool()
1451                        .await
1452                        .is_user_online(session.user_id())
1453                    {
1454                        let db = session.db().await;
1455                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1456                            room_updated(&room, &session.peer);
1457                        }
1458                    }
1459
1460                    update_user_contacts(session.user_id(), &session).await?;
1461                },
1462            Principal::DevServer(_) => {
1463                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1464            },
1465        }
1466        },
1467        _ = teardown.changed().fuse() => {}
1468    }
1469
1470    Ok(())
1471}
1472
1473/// Acknowledges a ping from a client, used to keep the connection alive.
1474async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1475    response.send(proto::Ack {})?;
1476    Ok(())
1477}
1478
1479/// Creates a new room for calling (outside of channels)
1480async fn create_room(
1481    _request: proto::CreateRoom,
1482    response: Response<proto::CreateRoom>,
1483    session: UserSession,
1484) -> Result<()> {
1485    let live_kit_room = nanoid::nanoid!(30);
1486
1487    let live_kit_connection_info = util::maybe!(async {
1488        let live_kit = session.app_state.live_kit_client.as_ref();
1489        let live_kit = live_kit?;
1490        let user_id = session.user_id().to_string();
1491
1492        let token = live_kit
1493            .room_token(&live_kit_room, &user_id.to_string())
1494            .trace_err()?;
1495
1496        Some(proto::LiveKitConnectionInfo {
1497            server_url: live_kit.url().into(),
1498            token,
1499            can_publish: true,
1500        })
1501    })
1502    .await;
1503
1504    let room = session
1505        .db()
1506        .await
1507        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1508        .await?;
1509
1510    response.send(proto::CreateRoomResponse {
1511        room: Some(room.clone()),
1512        live_kit_connection_info,
1513    })?;
1514
1515    update_user_contacts(session.user_id(), &session).await?;
1516    Ok(())
1517}
1518
1519/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1520async fn join_room(
1521    request: proto::JoinRoom,
1522    response: Response<proto::JoinRoom>,
1523    session: UserSession,
1524) -> Result<()> {
1525    let room_id = RoomId::from_proto(request.id);
1526
1527    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1528
1529    if let Some(channel_id) = channel_id {
1530        return join_channel_internal(channel_id, Box::new(response), session).await;
1531    }
1532
1533    let joined_room = {
1534        let room = session
1535            .db()
1536            .await
1537            .join_room(room_id, session.user_id(), session.connection_id)
1538            .await?;
1539        room_updated(&room.room, &session.peer);
1540        room.into_inner()
1541    };
1542
1543    for connection_id in session
1544        .connection_pool()
1545        .await
1546        .user_connection_ids(session.user_id())
1547    {
1548        session
1549            .peer
1550            .send(
1551                connection_id,
1552                proto::CallCanceled {
1553                    room_id: room_id.to_proto(),
1554                },
1555            )
1556            .trace_err();
1557    }
1558
1559    let live_kit_connection_info =
1560        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1561            live_kit
1562                .room_token(
1563                    &joined_room.room.live_kit_room,
1564                    &session.user_id().to_string(),
1565                )
1566                .trace_err()
1567                .map(|token| proto::LiveKitConnectionInfo {
1568                    server_url: live_kit.url().into(),
1569                    token,
1570                    can_publish: true,
1571                })
1572        } else {
1573            None
1574        };
1575
1576    response.send(proto::JoinRoomResponse {
1577        room: Some(joined_room.room),
1578        channel_id: None,
1579        live_kit_connection_info,
1580    })?;
1581
1582    update_user_contacts(session.user_id(), &session).await?;
1583    Ok(())
1584}
1585
1586/// Rejoin room is used to reconnect to a room after connection errors.
1587async fn rejoin_room(
1588    request: proto::RejoinRoom,
1589    response: Response<proto::RejoinRoom>,
1590    session: UserSession,
1591) -> Result<()> {
1592    let room;
1593    let channel;
1594    {
1595        let mut rejoined_room = session
1596            .db()
1597            .await
1598            .rejoin_room(request, session.user_id(), session.connection_id)
1599            .await?;
1600
1601        response.send(proto::RejoinRoomResponse {
1602            room: Some(rejoined_room.room.clone()),
1603            reshared_projects: rejoined_room
1604                .reshared_projects
1605                .iter()
1606                .map(|project| proto::ResharedProject {
1607                    id: project.id.to_proto(),
1608                    collaborators: project
1609                        .collaborators
1610                        .iter()
1611                        .map(|collaborator| collaborator.to_proto())
1612                        .collect(),
1613                })
1614                .collect(),
1615            rejoined_projects: rejoined_room
1616                .rejoined_projects
1617                .iter()
1618                .map(|rejoined_project| rejoined_project.to_proto())
1619                .collect(),
1620        })?;
1621        room_updated(&rejoined_room.room, &session.peer);
1622
1623        for project in &rejoined_room.reshared_projects {
1624            for collaborator in &project.collaborators {
1625                session
1626                    .peer
1627                    .send(
1628                        collaborator.connection_id,
1629                        proto::UpdateProjectCollaborator {
1630                            project_id: project.id.to_proto(),
1631                            old_peer_id: Some(project.old_connection_id.into()),
1632                            new_peer_id: Some(session.connection_id.into()),
1633                        },
1634                    )
1635                    .trace_err();
1636            }
1637
1638            broadcast(
1639                Some(session.connection_id),
1640                project
1641                    .collaborators
1642                    .iter()
1643                    .map(|collaborator| collaborator.connection_id),
1644                |connection_id| {
1645                    session.peer.forward_send(
1646                        session.connection_id,
1647                        connection_id,
1648                        proto::UpdateProject {
1649                            project_id: project.id.to_proto(),
1650                            worktrees: project.worktrees.clone(),
1651                        },
1652                    )
1653                },
1654            );
1655        }
1656
1657        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1658
1659        let rejoined_room = rejoined_room.into_inner();
1660
1661        room = rejoined_room.room;
1662        channel = rejoined_room.channel;
1663    }
1664
1665    if let Some(channel) = channel {
1666        channel_updated(
1667            &channel,
1668            &room,
1669            &session.peer,
1670            &*session.connection_pool().await,
1671        );
1672    }
1673
1674    update_user_contacts(session.user_id(), &session).await?;
1675    Ok(())
1676}
1677
1678fn notify_rejoined_projects(
1679    rejoined_projects: &mut Vec<RejoinedProject>,
1680    session: &UserSession,
1681) -> Result<()> {
1682    for project in rejoined_projects.iter() {
1683        for collaborator in &project.collaborators {
1684            session
1685                .peer
1686                .send(
1687                    collaborator.connection_id,
1688                    proto::UpdateProjectCollaborator {
1689                        project_id: project.id.to_proto(),
1690                        old_peer_id: Some(project.old_connection_id.into()),
1691                        new_peer_id: Some(session.connection_id.into()),
1692                    },
1693                )
1694                .trace_err();
1695        }
1696    }
1697
1698    for project in rejoined_projects {
1699        for worktree in mem::take(&mut project.worktrees) {
1700            #[cfg(any(test, feature = "test-support"))]
1701            const MAX_CHUNK_SIZE: usize = 2;
1702            #[cfg(not(any(test, feature = "test-support")))]
1703            const MAX_CHUNK_SIZE: usize = 256;
1704
1705            // Stream this worktree's entries.
1706            let message = proto::UpdateWorktree {
1707                project_id: project.id.to_proto(),
1708                worktree_id: worktree.id,
1709                abs_path: worktree.abs_path.clone(),
1710                root_name: worktree.root_name,
1711                updated_entries: worktree.updated_entries,
1712                removed_entries: worktree.removed_entries,
1713                scan_id: worktree.scan_id,
1714                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1715                updated_repositories: worktree.updated_repositories,
1716                removed_repositories: worktree.removed_repositories,
1717            };
1718            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1719                session.peer.send(session.connection_id, update.clone())?;
1720            }
1721
1722            // Stream this worktree's diagnostics.
1723            for summary in worktree.diagnostic_summaries {
1724                session.peer.send(
1725                    session.connection_id,
1726                    proto::UpdateDiagnosticSummary {
1727                        project_id: project.id.to_proto(),
1728                        worktree_id: worktree.id,
1729                        summary: Some(summary),
1730                    },
1731                )?;
1732            }
1733
1734            for settings_file in worktree.settings_files {
1735                session.peer.send(
1736                    session.connection_id,
1737                    proto::UpdateWorktreeSettings {
1738                        project_id: project.id.to_proto(),
1739                        worktree_id: worktree.id,
1740                        path: settings_file.path,
1741                        content: Some(settings_file.content),
1742                        kind: Some(settings_file.kind.to_proto().into()),
1743                    },
1744                )?;
1745            }
1746        }
1747
1748        for language_server in &project.language_servers {
1749            session.peer.send(
1750                session.connection_id,
1751                proto::UpdateLanguageServer {
1752                    project_id: project.id.to_proto(),
1753                    language_server_id: language_server.id,
1754                    variant: Some(
1755                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1756                            proto::LspDiskBasedDiagnosticsUpdated {},
1757                        ),
1758                    ),
1759                },
1760            )?;
1761        }
1762    }
1763    Ok(())
1764}
1765
1766/// leave room disconnects from the room.
1767async fn leave_room(
1768    _: proto::LeaveRoom,
1769    response: Response<proto::LeaveRoom>,
1770    session: UserSession,
1771) -> Result<()> {
1772    leave_room_for_session(&session, session.connection_id).await?;
1773    response.send(proto::Ack {})?;
1774    Ok(())
1775}
1776
1777/// Updates the permissions of someone else in the room.
1778async fn set_room_participant_role(
1779    request: proto::SetRoomParticipantRole,
1780    response: Response<proto::SetRoomParticipantRole>,
1781    session: UserSession,
1782) -> Result<()> {
1783    let user_id = UserId::from_proto(request.user_id);
1784    let role = ChannelRole::from(request.role());
1785
1786    let (live_kit_room, can_publish) = {
1787        let room = session
1788            .db()
1789            .await
1790            .set_room_participant_role(
1791                session.user_id(),
1792                RoomId::from_proto(request.room_id),
1793                user_id,
1794                role,
1795            )
1796            .await?;
1797
1798        let live_kit_room = room.live_kit_room.clone();
1799        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1800        room_updated(&room, &session.peer);
1801        (live_kit_room, can_publish)
1802    };
1803
1804    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1805        live_kit
1806            .update_participant(
1807                live_kit_room.clone(),
1808                request.user_id.to_string(),
1809                live_kit_server::proto::ParticipantPermission {
1810                    can_subscribe: true,
1811                    can_publish,
1812                    can_publish_data: can_publish,
1813                    hidden: false,
1814                    recorder: false,
1815                },
1816            )
1817            .await
1818            .trace_err();
1819    }
1820
1821    response.send(proto::Ack {})?;
1822    Ok(())
1823}
1824
1825/// Call someone else into the current room
1826async fn call(
1827    request: proto::Call,
1828    response: Response<proto::Call>,
1829    session: UserSession,
1830) -> Result<()> {
1831    let room_id = RoomId::from_proto(request.room_id);
1832    let calling_user_id = session.user_id();
1833    let calling_connection_id = session.connection_id;
1834    let called_user_id = UserId::from_proto(request.called_user_id);
1835    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1836    if !session
1837        .db()
1838        .await
1839        .has_contact(calling_user_id, called_user_id)
1840        .await?
1841    {
1842        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1843    }
1844
1845    let incoming_call = {
1846        let (room, incoming_call) = &mut *session
1847            .db()
1848            .await
1849            .call(
1850                room_id,
1851                calling_user_id,
1852                calling_connection_id,
1853                called_user_id,
1854                initial_project_id,
1855            )
1856            .await?;
1857        room_updated(room, &session.peer);
1858        mem::take(incoming_call)
1859    };
1860    update_user_contacts(called_user_id, &session).await?;
1861
1862    let mut calls = session
1863        .connection_pool()
1864        .await
1865        .user_connection_ids(called_user_id)
1866        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1867        .collect::<FuturesUnordered<_>>();
1868
1869    while let Some(call_response) = calls.next().await {
1870        match call_response.as_ref() {
1871            Ok(_) => {
1872                response.send(proto::Ack {})?;
1873                return Ok(());
1874            }
1875            Err(_) => {
1876                call_response.trace_err();
1877            }
1878        }
1879    }
1880
1881    {
1882        let room = session
1883            .db()
1884            .await
1885            .call_failed(room_id, called_user_id)
1886            .await?;
1887        room_updated(&room, &session.peer);
1888    }
1889    update_user_contacts(called_user_id, &session).await?;
1890
1891    Err(anyhow!("failed to ring user"))?
1892}
1893
1894/// Cancel an outgoing call.
1895async fn cancel_call(
1896    request: proto::CancelCall,
1897    response: Response<proto::CancelCall>,
1898    session: UserSession,
1899) -> Result<()> {
1900    let called_user_id = UserId::from_proto(request.called_user_id);
1901    let room_id = RoomId::from_proto(request.room_id);
1902    {
1903        let room = session
1904            .db()
1905            .await
1906            .cancel_call(room_id, session.connection_id, called_user_id)
1907            .await?;
1908        room_updated(&room, &session.peer);
1909    }
1910
1911    for connection_id in session
1912        .connection_pool()
1913        .await
1914        .user_connection_ids(called_user_id)
1915    {
1916        session
1917            .peer
1918            .send(
1919                connection_id,
1920                proto::CallCanceled {
1921                    room_id: room_id.to_proto(),
1922                },
1923            )
1924            .trace_err();
1925    }
1926    response.send(proto::Ack {})?;
1927
1928    update_user_contacts(called_user_id, &session).await?;
1929    Ok(())
1930}
1931
1932/// Decline an incoming call.
1933async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1934    let room_id = RoomId::from_proto(message.room_id);
1935    {
1936        let room = session
1937            .db()
1938            .await
1939            .decline_call(Some(room_id), session.user_id())
1940            .await?
1941            .ok_or_else(|| anyhow!("failed to decline call"))?;
1942        room_updated(&room, &session.peer);
1943    }
1944
1945    for connection_id in session
1946        .connection_pool()
1947        .await
1948        .user_connection_ids(session.user_id())
1949    {
1950        session
1951            .peer
1952            .send(
1953                connection_id,
1954                proto::CallCanceled {
1955                    room_id: room_id.to_proto(),
1956                },
1957            )
1958            .trace_err();
1959    }
1960    update_user_contacts(session.user_id(), &session).await?;
1961    Ok(())
1962}
1963
1964/// Updates other participants in the room with your current location.
1965async fn update_participant_location(
1966    request: proto::UpdateParticipantLocation,
1967    response: Response<proto::UpdateParticipantLocation>,
1968    session: UserSession,
1969) -> Result<()> {
1970    let room_id = RoomId::from_proto(request.room_id);
1971    let location = request
1972        .location
1973        .ok_or_else(|| anyhow!("invalid location"))?;
1974
1975    let db = session.db().await;
1976    let room = db
1977        .update_room_participant_location(room_id, session.connection_id, location)
1978        .await?;
1979
1980    room_updated(&room, &session.peer);
1981    response.send(proto::Ack {})?;
1982    Ok(())
1983}
1984
1985/// Share a project into the room.
1986async fn share_project(
1987    request: proto::ShareProject,
1988    response: Response<proto::ShareProject>,
1989    session: UserSession,
1990) -> Result<()> {
1991    let (project_id, room) = &*session
1992        .db()
1993        .await
1994        .share_project(
1995            RoomId::from_proto(request.room_id),
1996            session.connection_id,
1997            &request.worktrees,
1998            request.is_ssh_project,
1999            request
2000                .dev_server_project_id
2001                .map(DevServerProjectId::from_proto),
2002        )
2003        .await?;
2004    response.send(proto::ShareProjectResponse {
2005        project_id: project_id.to_proto(),
2006    })?;
2007    room_updated(room, &session.peer);
2008
2009    Ok(())
2010}
2011
2012/// Unshare a project from the room.
2013async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2014    let project_id = ProjectId::from_proto(message.project_id);
2015    unshare_project_internal(
2016        project_id,
2017        session.connection_id,
2018        session.user_id(),
2019        &session,
2020    )
2021    .await
2022}
2023
2024async fn unshare_project_internal(
2025    project_id: ProjectId,
2026    connection_id: ConnectionId,
2027    user_id: Option<UserId>,
2028    session: &Session,
2029) -> Result<()> {
2030    let delete = {
2031        let room_guard = session
2032            .db()
2033            .await
2034            .unshare_project(project_id, connection_id, user_id)
2035            .await?;
2036
2037        let (delete, room, guest_connection_ids) = &*room_guard;
2038
2039        let message = proto::UnshareProject {
2040            project_id: project_id.to_proto(),
2041        };
2042
2043        broadcast(
2044            Some(connection_id),
2045            guest_connection_ids.iter().copied(),
2046            |conn_id| session.peer.send(conn_id, message.clone()),
2047        );
2048        if let Some(room) = room {
2049            room_updated(room, &session.peer);
2050        }
2051
2052        *delete
2053    };
2054
2055    if delete {
2056        let db = session.db().await;
2057        db.delete_project(project_id).await?;
2058    }
2059
2060    Ok(())
2061}
2062
2063/// DevServer makes a project available online
2064async fn share_dev_server_project(
2065    request: proto::ShareDevServerProject,
2066    response: Response<proto::ShareDevServerProject>,
2067    session: DevServerSession,
2068) -> Result<()> {
2069    let (dev_server_project, user_id, status) = session
2070        .db()
2071        .await
2072        .share_dev_server_project(
2073            DevServerProjectId::from_proto(request.dev_server_project_id),
2074            session.dev_server_id(),
2075            session.connection_id,
2076            &request.worktrees,
2077        )
2078        .await?;
2079    let Some(project_id) = dev_server_project.project_id else {
2080        return Err(anyhow!("failed to share remote project"))?;
2081    };
2082
2083    send_dev_server_projects_update(user_id, status, &session).await;
2084
2085    response.send(proto::ShareProjectResponse { project_id })?;
2086
2087    Ok(())
2088}
2089
2090/// Join someone elses shared project.
2091async fn join_project(
2092    request: proto::JoinProject,
2093    response: Response<proto::JoinProject>,
2094    session: UserSession,
2095) -> Result<()> {
2096    let project_id = ProjectId::from_proto(request.project_id);
2097
2098    tracing::info!(%project_id, "join project");
2099
2100    let db = session.db().await;
2101    let (project, replica_id) = &mut *db
2102        .join_project(project_id, session.connection_id, session.user_id())
2103        .await?;
2104    drop(db);
2105    tracing::info!(%project_id, "join remote project");
2106    join_project_internal(response, session, project, replica_id)
2107}
2108
2109trait JoinProjectInternalResponse {
2110    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2111}
2112impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2113    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2114        Response::<proto::JoinProject>::send(self, result)
2115    }
2116}
2117impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2118    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2119        Response::<proto::JoinHostedProject>::send(self, result)
2120    }
2121}
2122
2123fn join_project_internal(
2124    response: impl JoinProjectInternalResponse,
2125    session: UserSession,
2126    project: &mut Project,
2127    replica_id: &ReplicaId,
2128) -> Result<()> {
2129    let collaborators = project
2130        .collaborators
2131        .iter()
2132        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2133        .map(|collaborator| collaborator.to_proto())
2134        .collect::<Vec<_>>();
2135    let project_id = project.id;
2136    let guest_user_id = session.user_id();
2137
2138    let worktrees = project
2139        .worktrees
2140        .iter()
2141        .map(|(id, worktree)| proto::WorktreeMetadata {
2142            id: *id,
2143            root_name: worktree.root_name.clone(),
2144            visible: worktree.visible,
2145            abs_path: worktree.abs_path.clone(),
2146        })
2147        .collect::<Vec<_>>();
2148
2149    let add_project_collaborator = proto::AddProjectCollaborator {
2150        project_id: project_id.to_proto(),
2151        collaborator: Some(proto::Collaborator {
2152            peer_id: Some(session.connection_id.into()),
2153            replica_id: replica_id.0 as u32,
2154            user_id: guest_user_id.to_proto(),
2155        }),
2156    };
2157
2158    for collaborator in &collaborators {
2159        session
2160            .peer
2161            .send(
2162                collaborator.peer_id.unwrap().into(),
2163                add_project_collaborator.clone(),
2164            )
2165            .trace_err();
2166    }
2167
2168    // First, we send the metadata associated with each worktree.
2169    response.send(proto::JoinProjectResponse {
2170        project_id: project.id.0 as u64,
2171        worktrees: worktrees.clone(),
2172        replica_id: replica_id.0 as u32,
2173        collaborators: collaborators.clone(),
2174        language_servers: project.language_servers.clone(),
2175        role: project.role.into(),
2176        dev_server_project_id: project
2177            .dev_server_project_id
2178            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2179    })?;
2180
2181    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2182        #[cfg(any(test, feature = "test-support"))]
2183        const MAX_CHUNK_SIZE: usize = 2;
2184        #[cfg(not(any(test, feature = "test-support")))]
2185        const MAX_CHUNK_SIZE: usize = 256;
2186
2187        // Stream this worktree's entries.
2188        let message = proto::UpdateWorktree {
2189            project_id: project_id.to_proto(),
2190            worktree_id,
2191            abs_path: worktree.abs_path.clone(),
2192            root_name: worktree.root_name,
2193            updated_entries: worktree.entries,
2194            removed_entries: Default::default(),
2195            scan_id: worktree.scan_id,
2196            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2197            updated_repositories: worktree.repository_entries.into_values().collect(),
2198            removed_repositories: Default::default(),
2199        };
2200        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2201            session.peer.send(session.connection_id, update.clone())?;
2202        }
2203
2204        // Stream this worktree's diagnostics.
2205        for summary in worktree.diagnostic_summaries {
2206            session.peer.send(
2207                session.connection_id,
2208                proto::UpdateDiagnosticSummary {
2209                    project_id: project_id.to_proto(),
2210                    worktree_id: worktree.id,
2211                    summary: Some(summary),
2212                },
2213            )?;
2214        }
2215
2216        for settings_file in worktree.settings_files {
2217            session.peer.send(
2218                session.connection_id,
2219                proto::UpdateWorktreeSettings {
2220                    project_id: project_id.to_proto(),
2221                    worktree_id: worktree.id,
2222                    path: settings_file.path,
2223                    content: Some(settings_file.content),
2224                    kind: Some(proto::update_user_settings::Kind::Settings.into()),
2225                },
2226            )?;
2227        }
2228    }
2229
2230    for language_server in &project.language_servers {
2231        session.peer.send(
2232            session.connection_id,
2233            proto::UpdateLanguageServer {
2234                project_id: project_id.to_proto(),
2235                language_server_id: language_server.id,
2236                variant: Some(
2237                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2238                        proto::LspDiskBasedDiagnosticsUpdated {},
2239                    ),
2240                ),
2241            },
2242        )?;
2243    }
2244
2245    Ok(())
2246}
2247
2248/// Leave someone elses shared project.
2249async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2250    let sender_id = session.connection_id;
2251    let project_id = ProjectId::from_proto(request.project_id);
2252    let db = session.db().await;
2253    if db.is_hosted_project(project_id).await? {
2254        let project = db.leave_hosted_project(project_id, sender_id).await?;
2255        project_left(&project, &session);
2256        return Ok(());
2257    }
2258
2259    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2260    tracing::info!(
2261        %project_id,
2262        "leave project"
2263    );
2264
2265    project_left(project, &session);
2266    if let Some(room) = room {
2267        room_updated(room, &session.peer);
2268    }
2269
2270    Ok(())
2271}
2272
2273async fn join_hosted_project(
2274    request: proto::JoinHostedProject,
2275    response: Response<proto::JoinHostedProject>,
2276    session: UserSession,
2277) -> Result<()> {
2278    let (mut project, replica_id) = session
2279        .db()
2280        .await
2281        .join_hosted_project(
2282            ProjectId(request.project_id as i32),
2283            session.user_id(),
2284            session.connection_id,
2285        )
2286        .await?;
2287
2288    join_project_internal(response, session, &mut project, &replica_id)
2289}
2290
2291async fn list_remote_directory(
2292    request: proto::ListRemoteDirectory,
2293    response: Response<proto::ListRemoteDirectory>,
2294    session: UserSession,
2295) -> Result<()> {
2296    let dev_server_id = DevServerId(request.dev_server_id as i32);
2297    let dev_server_connection_id = session
2298        .connection_pool()
2299        .await
2300        .online_dev_server_connection_id(dev_server_id)?;
2301
2302    session
2303        .db()
2304        .await
2305        .get_dev_server_for_user(dev_server_id, session.user_id())
2306        .await?;
2307
2308    response.send(
2309        session
2310            .peer
2311            .forward_request(session.connection_id, dev_server_connection_id, request)
2312            .await?,
2313    )?;
2314    Ok(())
2315}
2316
2317async fn update_dev_server_project(
2318    request: proto::UpdateDevServerProject,
2319    response: Response<proto::UpdateDevServerProject>,
2320    session: UserSession,
2321) -> Result<()> {
2322    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2323
2324    let (dev_server_project, update) = session
2325        .db()
2326        .await
2327        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2328        .await?;
2329
2330    let projects = session
2331        .db()
2332        .await
2333        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2334        .await?;
2335
2336    let dev_server_connection_id = session
2337        .connection_pool()
2338        .await
2339        .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2340
2341    session.peer.send(
2342        dev_server_connection_id,
2343        proto::DevServerInstructions { projects },
2344    )?;
2345
2346    send_dev_server_projects_update(session.user_id(), update, &session).await;
2347
2348    response.send(proto::Ack {})
2349}
2350
2351async fn create_dev_server_project(
2352    request: proto::CreateDevServerProject,
2353    response: Response<proto::CreateDevServerProject>,
2354    session: UserSession,
2355) -> Result<()> {
2356    let dev_server_id = DevServerId(request.dev_server_id as i32);
2357    let dev_server_connection_id = session
2358        .connection_pool()
2359        .await
2360        .dev_server_connection_id(dev_server_id);
2361    let Some(dev_server_connection_id) = dev_server_connection_id else {
2362        Err(ErrorCode::DevServerOffline
2363            .message("Cannot create a remote project when the dev server is offline".to_string())
2364            .anyhow())?
2365    };
2366
2367    let path = request.path.clone();
2368    //Check that the path exists on the dev server
2369    session
2370        .peer
2371        .forward_request(
2372            session.connection_id,
2373            dev_server_connection_id,
2374            proto::ValidateDevServerProjectRequest { path: path.clone() },
2375        )
2376        .await?;
2377
2378    let (dev_server_project, update) = session
2379        .db()
2380        .await
2381        .create_dev_server_project(
2382            DevServerId(request.dev_server_id as i32),
2383            &request.path,
2384            session.user_id(),
2385        )
2386        .await?;
2387
2388    let projects = session
2389        .db()
2390        .await
2391        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2392        .await?;
2393
2394    session.peer.send(
2395        dev_server_connection_id,
2396        proto::DevServerInstructions { projects },
2397    )?;
2398
2399    send_dev_server_projects_update(session.user_id(), update, &session).await;
2400
2401    response.send(proto::CreateDevServerProjectResponse {
2402        dev_server_project: Some(dev_server_project.to_proto(None)),
2403    })?;
2404    Ok(())
2405}
2406
2407async fn create_dev_server(
2408    request: proto::CreateDevServer,
2409    response: Response<proto::CreateDevServer>,
2410    session: UserSession,
2411) -> Result<()> {
2412    let access_token = auth::random_token();
2413    let hashed_access_token = auth::hash_access_token(&access_token);
2414
2415    if request.name.is_empty() {
2416        return Err(proto::ErrorCode::Forbidden
2417            .message("Dev server name cannot be empty".to_string())
2418            .anyhow())?;
2419    }
2420
2421    let (dev_server, status) = session
2422        .db()
2423        .await
2424        .create_dev_server(
2425            &request.name,
2426            request.ssh_connection_string.as_deref(),
2427            &hashed_access_token,
2428            session.user_id(),
2429        )
2430        .await?;
2431
2432    send_dev_server_projects_update(session.user_id(), status, &session).await;
2433
2434    response.send(proto::CreateDevServerResponse {
2435        dev_server_id: dev_server.id.0 as u64,
2436        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2437        name: request.name,
2438    })?;
2439    Ok(())
2440}
2441
2442async fn regenerate_dev_server_token(
2443    request: proto::RegenerateDevServerToken,
2444    response: Response<proto::RegenerateDevServerToken>,
2445    session: UserSession,
2446) -> Result<()> {
2447    let dev_server_id = DevServerId(request.dev_server_id as i32);
2448    let access_token = auth::random_token();
2449    let hashed_access_token = auth::hash_access_token(&access_token);
2450
2451    let connection_id = session
2452        .connection_pool()
2453        .await
2454        .dev_server_connection_id(dev_server_id);
2455    if let Some(connection_id) = connection_id {
2456        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2457        session.peer.send(
2458            connection_id,
2459            proto::ShutdownDevServer {
2460                reason: Some("dev server token was regenerated".to_string()),
2461            },
2462        )?;
2463        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2464    }
2465
2466    let status = session
2467        .db()
2468        .await
2469        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2470        .await?;
2471
2472    send_dev_server_projects_update(session.user_id(), status, &session).await;
2473
2474    response.send(proto::RegenerateDevServerTokenResponse {
2475        dev_server_id: dev_server_id.to_proto(),
2476        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2477    })?;
2478    Ok(())
2479}
2480
2481async fn rename_dev_server(
2482    request: proto::RenameDevServer,
2483    response: Response<proto::RenameDevServer>,
2484    session: UserSession,
2485) -> Result<()> {
2486    if request.name.trim().is_empty() {
2487        return Err(proto::ErrorCode::Forbidden
2488            .message("Dev server name cannot be empty".to_string())
2489            .anyhow())?;
2490    }
2491
2492    let dev_server_id = DevServerId(request.dev_server_id as i32);
2493    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2494    if dev_server.user_id != session.user_id() {
2495        return Err(anyhow!(ErrorCode::Forbidden))?;
2496    }
2497
2498    let status = session
2499        .db()
2500        .await
2501        .rename_dev_server(
2502            dev_server_id,
2503            &request.name,
2504            request.ssh_connection_string.as_deref(),
2505            session.user_id(),
2506        )
2507        .await?;
2508
2509    send_dev_server_projects_update(session.user_id(), status, &session).await;
2510
2511    response.send(proto::Ack {})?;
2512    Ok(())
2513}
2514
2515async fn delete_dev_server(
2516    request: proto::DeleteDevServer,
2517    response: Response<proto::DeleteDevServer>,
2518    session: UserSession,
2519) -> Result<()> {
2520    let dev_server_id = DevServerId(request.dev_server_id as i32);
2521    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2522    if dev_server.user_id != session.user_id() {
2523        return Err(anyhow!(ErrorCode::Forbidden))?;
2524    }
2525
2526    let connection_id = session
2527        .connection_pool()
2528        .await
2529        .dev_server_connection_id(dev_server_id);
2530    if let Some(connection_id) = connection_id {
2531        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2532        session.peer.send(
2533            connection_id,
2534            proto::ShutdownDevServer {
2535                reason: Some("dev server was deleted".to_string()),
2536            },
2537        )?;
2538        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2539    }
2540
2541    let status = session
2542        .db()
2543        .await
2544        .delete_dev_server(dev_server_id, session.user_id())
2545        .await?;
2546
2547    send_dev_server_projects_update(session.user_id(), status, &session).await;
2548
2549    response.send(proto::Ack {})?;
2550    Ok(())
2551}
2552
2553async fn delete_dev_server_project(
2554    request: proto::DeleteDevServerProject,
2555    response: Response<proto::DeleteDevServerProject>,
2556    session: UserSession,
2557) -> Result<()> {
2558    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2559    let dev_server_project = session
2560        .db()
2561        .await
2562        .get_dev_server_project(dev_server_project_id)
2563        .await?;
2564
2565    let dev_server = session
2566        .db()
2567        .await
2568        .get_dev_server(dev_server_project.dev_server_id)
2569        .await?;
2570    if dev_server.user_id != session.user_id() {
2571        return Err(anyhow!(ErrorCode::Forbidden))?;
2572    }
2573
2574    let dev_server_connection_id = session
2575        .connection_pool()
2576        .await
2577        .dev_server_connection_id(dev_server.id);
2578
2579    if let Some(dev_server_connection_id) = dev_server_connection_id {
2580        let project = session
2581            .db()
2582            .await
2583            .find_dev_server_project(dev_server_project_id)
2584            .await;
2585        if let Ok(project) = project {
2586            unshare_project_internal(
2587                project.id,
2588                dev_server_connection_id,
2589                Some(session.user_id()),
2590                &session,
2591            )
2592            .await?;
2593        }
2594    }
2595
2596    let (projects, status) = session
2597        .db()
2598        .await
2599        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2600        .await?;
2601
2602    if let Some(dev_server_connection_id) = dev_server_connection_id {
2603        session.peer.send(
2604            dev_server_connection_id,
2605            proto::DevServerInstructions { projects },
2606        )?;
2607    }
2608
2609    send_dev_server_projects_update(session.user_id(), status, &session).await;
2610
2611    response.send(proto::Ack {})?;
2612    Ok(())
2613}
2614
2615async fn rejoin_dev_server_projects(
2616    request: proto::RejoinRemoteProjects,
2617    response: Response<proto::RejoinRemoteProjects>,
2618    session: UserSession,
2619) -> Result<()> {
2620    let mut rejoined_projects = {
2621        let db = session.db().await;
2622        db.rejoin_dev_server_projects(
2623            &request.rejoined_projects,
2624            session.user_id(),
2625            session.0.connection_id,
2626        )
2627        .await?
2628    };
2629    response.send(proto::RejoinRemoteProjectsResponse {
2630        rejoined_projects: rejoined_projects
2631            .iter()
2632            .map(|project| project.to_proto())
2633            .collect(),
2634    })?;
2635    notify_rejoined_projects(&mut rejoined_projects, &session)
2636}
2637
2638async fn reconnect_dev_server(
2639    request: proto::ReconnectDevServer,
2640    response: Response<proto::ReconnectDevServer>,
2641    session: DevServerSession,
2642) -> Result<()> {
2643    let reshared_projects = {
2644        let db = session.db().await;
2645        db.reshare_dev_server_projects(
2646            &request.reshared_projects,
2647            session.dev_server_id(),
2648            session.0.connection_id,
2649        )
2650        .await?
2651    };
2652
2653    for project in &reshared_projects {
2654        for collaborator in &project.collaborators {
2655            session
2656                .peer
2657                .send(
2658                    collaborator.connection_id,
2659                    proto::UpdateProjectCollaborator {
2660                        project_id: project.id.to_proto(),
2661                        old_peer_id: Some(project.old_connection_id.into()),
2662                        new_peer_id: Some(session.connection_id.into()),
2663                    },
2664                )
2665                .trace_err();
2666        }
2667
2668        broadcast(
2669            Some(session.connection_id),
2670            project
2671                .collaborators
2672                .iter()
2673                .map(|collaborator| collaborator.connection_id),
2674            |connection_id| {
2675                session.peer.forward_send(
2676                    session.connection_id,
2677                    connection_id,
2678                    proto::UpdateProject {
2679                        project_id: project.id.to_proto(),
2680                        worktrees: project.worktrees.clone(),
2681                    },
2682                )
2683            },
2684        );
2685    }
2686
2687    response.send(proto::ReconnectDevServerResponse {
2688        reshared_projects: reshared_projects
2689            .iter()
2690            .map(|project| proto::ResharedProject {
2691                id: project.id.to_proto(),
2692                collaborators: project
2693                    .collaborators
2694                    .iter()
2695                    .map(|collaborator| collaborator.to_proto())
2696                    .collect(),
2697            })
2698            .collect(),
2699    })?;
2700
2701    Ok(())
2702}
2703
2704async fn shutdown_dev_server(
2705    _: proto::ShutdownDevServer,
2706    response: Response<proto::ShutdownDevServer>,
2707    session: DevServerSession,
2708) -> Result<()> {
2709    response.send(proto::Ack {})?;
2710    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2711    remove_dev_server_connection(session.dev_server_id(), &session).await
2712}
2713
2714async fn shutdown_dev_server_internal(
2715    dev_server_id: DevServerId,
2716    connection_id: ConnectionId,
2717    session: &Session,
2718) -> Result<()> {
2719    let (dev_server_projects, dev_server) = {
2720        let db = session.db().await;
2721        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2722        let dev_server = db.get_dev_server(dev_server_id).await?;
2723        (dev_server_projects, dev_server)
2724    };
2725
2726    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2727        unshare_project_internal(
2728            ProjectId::from_proto(project_id),
2729            connection_id,
2730            None,
2731            session,
2732        )
2733        .await?;
2734    }
2735
2736    session
2737        .connection_pool()
2738        .await
2739        .set_dev_server_offline(dev_server_id);
2740
2741    let status = session
2742        .db()
2743        .await
2744        .dev_server_projects_update(dev_server.user_id)
2745        .await?;
2746    send_dev_server_projects_update(dev_server.user_id, status, session).await;
2747
2748    Ok(())
2749}
2750
2751async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2752    let dev_server_connection = session
2753        .connection_pool()
2754        .await
2755        .dev_server_connection_id(dev_server_id);
2756
2757    if let Some(dev_server_connection) = dev_server_connection {
2758        session
2759            .connection_pool()
2760            .await
2761            .remove_connection(dev_server_connection)?;
2762    }
2763    Ok(())
2764}
2765
2766/// Updates other participants with changes to the project
2767async fn update_project(
2768    request: proto::UpdateProject,
2769    response: Response<proto::UpdateProject>,
2770    session: Session,
2771) -> Result<()> {
2772    let project_id = ProjectId::from_proto(request.project_id);
2773    let (room, guest_connection_ids) = &*session
2774        .db()
2775        .await
2776        .update_project(project_id, session.connection_id, &request.worktrees)
2777        .await?;
2778    broadcast(
2779        Some(session.connection_id),
2780        guest_connection_ids.iter().copied(),
2781        |connection_id| {
2782            session
2783                .peer
2784                .forward_send(session.connection_id, connection_id, request.clone())
2785        },
2786    );
2787    if let Some(room) = room {
2788        room_updated(room, &session.peer);
2789    }
2790    response.send(proto::Ack {})?;
2791
2792    Ok(())
2793}
2794
2795/// Updates other participants with changes to the worktree
2796async fn update_worktree(
2797    request: proto::UpdateWorktree,
2798    response: Response<proto::UpdateWorktree>,
2799    session: Session,
2800) -> Result<()> {
2801    let guest_connection_ids = session
2802        .db()
2803        .await
2804        .update_worktree(&request, session.connection_id)
2805        .await?;
2806
2807    broadcast(
2808        Some(session.connection_id),
2809        guest_connection_ids.iter().copied(),
2810        |connection_id| {
2811            session
2812                .peer
2813                .forward_send(session.connection_id, connection_id, request.clone())
2814        },
2815    );
2816    response.send(proto::Ack {})?;
2817    Ok(())
2818}
2819
2820/// Updates other participants with changes to the diagnostics
2821async fn update_diagnostic_summary(
2822    message: proto::UpdateDiagnosticSummary,
2823    session: Session,
2824) -> Result<()> {
2825    let guest_connection_ids = session
2826        .db()
2827        .await
2828        .update_diagnostic_summary(&message, session.connection_id)
2829        .await?;
2830
2831    broadcast(
2832        Some(session.connection_id),
2833        guest_connection_ids.iter().copied(),
2834        |connection_id| {
2835            session
2836                .peer
2837                .forward_send(session.connection_id, connection_id, message.clone())
2838        },
2839    );
2840
2841    Ok(())
2842}
2843
2844/// Updates other participants with changes to the worktree settings
2845async fn update_worktree_settings(
2846    message: proto::UpdateWorktreeSettings,
2847    session: Session,
2848) -> Result<()> {
2849    let guest_connection_ids = session
2850        .db()
2851        .await
2852        .update_worktree_settings(&message, session.connection_id)
2853        .await?;
2854
2855    broadcast(
2856        Some(session.connection_id),
2857        guest_connection_ids.iter().copied(),
2858        |connection_id| {
2859            session
2860                .peer
2861                .forward_send(session.connection_id, connection_id, message.clone())
2862        },
2863    );
2864
2865    Ok(())
2866}
2867
2868/// Notify other participants that a  language server has started.
2869async fn start_language_server(
2870    request: proto::StartLanguageServer,
2871    session: Session,
2872) -> Result<()> {
2873    let guest_connection_ids = session
2874        .db()
2875        .await
2876        .start_language_server(&request, session.connection_id)
2877        .await?;
2878
2879    broadcast(
2880        Some(session.connection_id),
2881        guest_connection_ids.iter().copied(),
2882        |connection_id| {
2883            session
2884                .peer
2885                .forward_send(session.connection_id, connection_id, request.clone())
2886        },
2887    );
2888    Ok(())
2889}
2890
2891/// Notify other participants that a language server has changed.
2892async fn update_language_server(
2893    request: proto::UpdateLanguageServer,
2894    session: Session,
2895) -> Result<()> {
2896    let project_id = ProjectId::from_proto(request.project_id);
2897    let project_connection_ids = session
2898        .db()
2899        .await
2900        .project_connection_ids(project_id, session.connection_id, true)
2901        .await?;
2902    broadcast(
2903        Some(session.connection_id),
2904        project_connection_ids.iter().copied(),
2905        |connection_id| {
2906            session
2907                .peer
2908                .forward_send(session.connection_id, connection_id, request.clone())
2909        },
2910    );
2911    Ok(())
2912}
2913
2914/// forward a project request to the host. These requests should be read only
2915/// as guests are allowed to send them.
2916async fn forward_read_only_project_request<T>(
2917    request: T,
2918    response: Response<T>,
2919    session: UserSession,
2920) -> Result<()>
2921where
2922    T: EntityMessage + RequestMessage,
2923{
2924    let project_id = ProjectId::from_proto(request.remote_entity_id());
2925    let host_connection_id = session
2926        .db()
2927        .await
2928        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2929        .await?;
2930    let payload = session
2931        .peer
2932        .forward_request(session.connection_id, host_connection_id, request)
2933        .await?;
2934    response.send(payload)?;
2935    Ok(())
2936}
2937
2938async fn forward_find_search_candidates_request(
2939    request: proto::FindSearchCandidates,
2940    response: Response<proto::FindSearchCandidates>,
2941    session: UserSession,
2942) -> Result<()> {
2943    let project_id = ProjectId::from_proto(request.remote_entity_id());
2944    let host_connection_id = session
2945        .db()
2946        .await
2947        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2948        .await?;
2949    let payload = session
2950        .peer
2951        .forward_request(session.connection_id, host_connection_id, request)
2952        .await?;
2953    response.send(payload)?;
2954    Ok(())
2955}
2956
2957/// forward a project request to the dev server. Only allowed
2958/// if it's your dev server.
2959async fn forward_project_request_for_owner<T>(
2960    request: T,
2961    response: Response<T>,
2962    session: UserSession,
2963) -> Result<()>
2964where
2965    T: EntityMessage + RequestMessage,
2966{
2967    let project_id = ProjectId::from_proto(request.remote_entity_id());
2968
2969    let host_connection_id = session
2970        .db()
2971        .await
2972        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2973        .await?;
2974    let payload = session
2975        .peer
2976        .forward_request(session.connection_id, host_connection_id, request)
2977        .await?;
2978    response.send(payload)?;
2979    Ok(())
2980}
2981
2982/// forward a project request to the host. These requests are disallowed
2983/// for guests.
2984async fn forward_mutating_project_request<T>(
2985    request: T,
2986    response: Response<T>,
2987    session: UserSession,
2988) -> Result<()>
2989where
2990    T: EntityMessage + RequestMessage,
2991{
2992    let project_id = ProjectId::from_proto(request.remote_entity_id());
2993
2994    let host_connection_id = session
2995        .db()
2996        .await
2997        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2998        .await?;
2999    let payload = session
3000        .peer
3001        .forward_request(session.connection_id, host_connection_id, request)
3002        .await?;
3003    response.send(payload)?;
3004    Ok(())
3005}
3006
3007/// Notify other participants that a new buffer has been created
3008async fn create_buffer_for_peer(
3009    request: proto::CreateBufferForPeer,
3010    session: Session,
3011) -> Result<()> {
3012    session
3013        .db()
3014        .await
3015        .check_user_is_project_host(
3016            ProjectId::from_proto(request.project_id),
3017            session.connection_id,
3018        )
3019        .await?;
3020    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3021    session
3022        .peer
3023        .forward_send(session.connection_id, peer_id.into(), request)?;
3024    Ok(())
3025}
3026
3027/// Notify other participants that a buffer has been updated. This is
3028/// allowed for guests as long as the update is limited to selections.
3029async fn update_buffer(
3030    request: proto::UpdateBuffer,
3031    response: Response<proto::UpdateBuffer>,
3032    session: Session,
3033) -> Result<()> {
3034    let project_id = ProjectId::from_proto(request.project_id);
3035    let mut capability = Capability::ReadOnly;
3036
3037    for op in request.operations.iter() {
3038        match op.variant {
3039            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3040            Some(_) => capability = Capability::ReadWrite,
3041        }
3042    }
3043
3044    let host = {
3045        let guard = session
3046            .db()
3047            .await
3048            .connections_for_buffer_update(
3049                project_id,
3050                session.principal_id(),
3051                session.connection_id,
3052                capability,
3053            )
3054            .await?;
3055
3056        let (host, guests) = &*guard;
3057
3058        broadcast(
3059            Some(session.connection_id),
3060            guests.clone(),
3061            |connection_id| {
3062                session
3063                    .peer
3064                    .forward_send(session.connection_id, connection_id, request.clone())
3065            },
3066        );
3067
3068        *host
3069    };
3070
3071    if host != session.connection_id {
3072        session
3073            .peer
3074            .forward_request(session.connection_id, host, request.clone())
3075            .await?;
3076    }
3077
3078    response.send(proto::Ack {})?;
3079    Ok(())
3080}
3081
3082async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3083    let project_id = ProjectId::from_proto(message.project_id);
3084
3085    let operation = message.operation.as_ref().context("invalid operation")?;
3086    let capability = match operation.variant.as_ref() {
3087        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3088            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3089                match buffer_op.variant {
3090                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3091                        Capability::ReadOnly
3092                    }
3093                    _ => Capability::ReadWrite,
3094                }
3095            } else {
3096                Capability::ReadWrite
3097            }
3098        }
3099        Some(_) => Capability::ReadWrite,
3100        None => Capability::ReadOnly,
3101    };
3102
3103    let guard = session
3104        .db()
3105        .await
3106        .connections_for_buffer_update(
3107            project_id,
3108            session.principal_id(),
3109            session.connection_id,
3110            capability,
3111        )
3112        .await?;
3113
3114    let (host, guests) = &*guard;
3115
3116    broadcast(
3117        Some(session.connection_id),
3118        guests.iter().chain([host]).copied(),
3119        |connection_id| {
3120            session
3121                .peer
3122                .forward_send(session.connection_id, connection_id, message.clone())
3123        },
3124    );
3125
3126    Ok(())
3127}
3128
3129/// Notify other participants that a project has been updated.
3130async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3131    request: T,
3132    session: Session,
3133) -> Result<()> {
3134    let project_id = ProjectId::from_proto(request.remote_entity_id());
3135    let project_connection_ids = session
3136        .db()
3137        .await
3138        .project_connection_ids(project_id, session.connection_id, false)
3139        .await?;
3140
3141    broadcast(
3142        Some(session.connection_id),
3143        project_connection_ids.iter().copied(),
3144        |connection_id| {
3145            session
3146                .peer
3147                .forward_send(session.connection_id, connection_id, request.clone())
3148        },
3149    );
3150    Ok(())
3151}
3152
3153/// Start following another user in a call.
3154async fn follow(
3155    request: proto::Follow,
3156    response: Response<proto::Follow>,
3157    session: UserSession,
3158) -> Result<()> {
3159    let room_id = RoomId::from_proto(request.room_id);
3160    let project_id = request.project_id.map(ProjectId::from_proto);
3161    let leader_id = request
3162        .leader_id
3163        .ok_or_else(|| anyhow!("invalid leader id"))?
3164        .into();
3165    let follower_id = session.connection_id;
3166
3167    session
3168        .db()
3169        .await
3170        .check_room_participants(room_id, leader_id, session.connection_id)
3171        .await?;
3172
3173    let response_payload = session
3174        .peer
3175        .forward_request(session.connection_id, leader_id, request)
3176        .await?;
3177    response.send(response_payload)?;
3178
3179    if let Some(project_id) = project_id {
3180        let room = session
3181            .db()
3182            .await
3183            .follow(room_id, project_id, leader_id, follower_id)
3184            .await?;
3185        room_updated(&room, &session.peer);
3186    }
3187
3188    Ok(())
3189}
3190
3191/// Stop following another user in a call.
3192async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3193    let room_id = RoomId::from_proto(request.room_id);
3194    let project_id = request.project_id.map(ProjectId::from_proto);
3195    let leader_id = request
3196        .leader_id
3197        .ok_or_else(|| anyhow!("invalid leader id"))?
3198        .into();
3199    let follower_id = session.connection_id;
3200
3201    session
3202        .db()
3203        .await
3204        .check_room_participants(room_id, leader_id, session.connection_id)
3205        .await?;
3206
3207    session
3208        .peer
3209        .forward_send(session.connection_id, leader_id, request)?;
3210
3211    if let Some(project_id) = project_id {
3212        let room = session
3213            .db()
3214            .await
3215            .unfollow(room_id, project_id, leader_id, follower_id)
3216            .await?;
3217        room_updated(&room, &session.peer);
3218    }
3219
3220    Ok(())
3221}
3222
3223/// Notify everyone following you of your current location.
3224async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3225    let room_id = RoomId::from_proto(request.room_id);
3226    let database = session.db.lock().await;
3227
3228    let connection_ids = if let Some(project_id) = request.project_id {
3229        let project_id = ProjectId::from_proto(project_id);
3230        database
3231            .project_connection_ids(project_id, session.connection_id, true)
3232            .await?
3233    } else {
3234        database
3235            .room_connection_ids(room_id, session.connection_id)
3236            .await?
3237    };
3238
3239    // For now, don't send view update messages back to that view's current leader.
3240    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3241        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3242        _ => None,
3243    });
3244
3245    for connection_id in connection_ids.iter().cloned() {
3246        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3247            session
3248                .peer
3249                .forward_send(session.connection_id, connection_id, request.clone())?;
3250        }
3251    }
3252    Ok(())
3253}
3254
3255/// Get public data about users.
3256async fn get_users(
3257    request: proto::GetUsers,
3258    response: Response<proto::GetUsers>,
3259    session: Session,
3260) -> Result<()> {
3261    let user_ids = request
3262        .user_ids
3263        .into_iter()
3264        .map(UserId::from_proto)
3265        .collect();
3266    let users = session
3267        .db()
3268        .await
3269        .get_users_by_ids(user_ids)
3270        .await?
3271        .into_iter()
3272        .map(|user| proto::User {
3273            id: user.id.to_proto(),
3274            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3275            github_login: user.github_login,
3276        })
3277        .collect();
3278    response.send(proto::UsersResponse { users })?;
3279    Ok(())
3280}
3281
3282/// Search for users (to invite) buy Github login
3283async fn fuzzy_search_users(
3284    request: proto::FuzzySearchUsers,
3285    response: Response<proto::FuzzySearchUsers>,
3286    session: UserSession,
3287) -> Result<()> {
3288    let query = request.query;
3289    let users = match query.len() {
3290        0 => vec![],
3291        1 | 2 => session
3292            .db()
3293            .await
3294            .get_user_by_github_login(&query)
3295            .await?
3296            .into_iter()
3297            .collect(),
3298        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3299    };
3300    let users = users
3301        .into_iter()
3302        .filter(|user| user.id != session.user_id())
3303        .map(|user| proto::User {
3304            id: user.id.to_proto(),
3305            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3306            github_login: user.github_login,
3307        })
3308        .collect();
3309    response.send(proto::UsersResponse { users })?;
3310    Ok(())
3311}
3312
3313/// Send a contact request to another user.
3314async fn request_contact(
3315    request: proto::RequestContact,
3316    response: Response<proto::RequestContact>,
3317    session: UserSession,
3318) -> Result<()> {
3319    let requester_id = session.user_id();
3320    let responder_id = UserId::from_proto(request.responder_id);
3321    if requester_id == responder_id {
3322        return Err(anyhow!("cannot add yourself as a contact"))?;
3323    }
3324
3325    let notifications = session
3326        .db()
3327        .await
3328        .send_contact_request(requester_id, responder_id)
3329        .await?;
3330
3331    // Update outgoing contact requests of requester
3332    let mut update = proto::UpdateContacts::default();
3333    update.outgoing_requests.push(responder_id.to_proto());
3334    for connection_id in session
3335        .connection_pool()
3336        .await
3337        .user_connection_ids(requester_id)
3338    {
3339        session.peer.send(connection_id, update.clone())?;
3340    }
3341
3342    // Update incoming contact requests of responder
3343    let mut update = proto::UpdateContacts::default();
3344    update
3345        .incoming_requests
3346        .push(proto::IncomingContactRequest {
3347            requester_id: requester_id.to_proto(),
3348        });
3349    let connection_pool = session.connection_pool().await;
3350    for connection_id in connection_pool.user_connection_ids(responder_id) {
3351        session.peer.send(connection_id, update.clone())?;
3352    }
3353
3354    send_notifications(&connection_pool, &session.peer, notifications);
3355
3356    response.send(proto::Ack {})?;
3357    Ok(())
3358}
3359
3360/// Accept or decline a contact request
3361async fn respond_to_contact_request(
3362    request: proto::RespondToContactRequest,
3363    response: Response<proto::RespondToContactRequest>,
3364    session: UserSession,
3365) -> Result<()> {
3366    let responder_id = session.user_id();
3367    let requester_id = UserId::from_proto(request.requester_id);
3368    let db = session.db().await;
3369    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3370        db.dismiss_contact_notification(responder_id, requester_id)
3371            .await?;
3372    } else {
3373        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3374
3375        let notifications = db
3376            .respond_to_contact_request(responder_id, requester_id, accept)
3377            .await?;
3378        let requester_busy = db.is_user_busy(requester_id).await?;
3379        let responder_busy = db.is_user_busy(responder_id).await?;
3380
3381        let pool = session.connection_pool().await;
3382        // Update responder with new contact
3383        let mut update = proto::UpdateContacts::default();
3384        if accept {
3385            update
3386                .contacts
3387                .push(contact_for_user(requester_id, requester_busy, &pool));
3388        }
3389        update
3390            .remove_incoming_requests
3391            .push(requester_id.to_proto());
3392        for connection_id in pool.user_connection_ids(responder_id) {
3393            session.peer.send(connection_id, update.clone())?;
3394        }
3395
3396        // Update requester with new contact
3397        let mut update = proto::UpdateContacts::default();
3398        if accept {
3399            update
3400                .contacts
3401                .push(contact_for_user(responder_id, responder_busy, &pool));
3402        }
3403        update
3404            .remove_outgoing_requests
3405            .push(responder_id.to_proto());
3406
3407        for connection_id in pool.user_connection_ids(requester_id) {
3408            session.peer.send(connection_id, update.clone())?;
3409        }
3410
3411        send_notifications(&pool, &session.peer, notifications);
3412    }
3413
3414    response.send(proto::Ack {})?;
3415    Ok(())
3416}
3417
3418/// Remove a contact.
3419async fn remove_contact(
3420    request: proto::RemoveContact,
3421    response: Response<proto::RemoveContact>,
3422    session: UserSession,
3423) -> Result<()> {
3424    let requester_id = session.user_id();
3425    let responder_id = UserId::from_proto(request.user_id);
3426    let db = session.db().await;
3427    let (contact_accepted, deleted_notification_id) =
3428        db.remove_contact(requester_id, responder_id).await?;
3429
3430    let pool = session.connection_pool().await;
3431    // Update outgoing contact requests of requester
3432    let mut update = proto::UpdateContacts::default();
3433    if contact_accepted {
3434        update.remove_contacts.push(responder_id.to_proto());
3435    } else {
3436        update
3437            .remove_outgoing_requests
3438            .push(responder_id.to_proto());
3439    }
3440    for connection_id in pool.user_connection_ids(requester_id) {
3441        session.peer.send(connection_id, update.clone())?;
3442    }
3443
3444    // Update incoming contact requests of responder
3445    let mut update = proto::UpdateContacts::default();
3446    if contact_accepted {
3447        update.remove_contacts.push(requester_id.to_proto());
3448    } else {
3449        update
3450            .remove_incoming_requests
3451            .push(requester_id.to_proto());
3452    }
3453    for connection_id in pool.user_connection_ids(responder_id) {
3454        session.peer.send(connection_id, update.clone())?;
3455        if let Some(notification_id) = deleted_notification_id {
3456            session.peer.send(
3457                connection_id,
3458                proto::DeleteNotification {
3459                    notification_id: notification_id.to_proto(),
3460                },
3461            )?;
3462        }
3463    }
3464
3465    response.send(proto::Ack {})?;
3466    Ok(())
3467}
3468
3469fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3470    version.0.minor() < 139
3471}
3472
3473async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3474    let plan = session.current_plan(session.db().await).await?;
3475
3476    session
3477        .peer
3478        .send(
3479            session.connection_id,
3480            proto::UpdateUserPlan { plan: plan.into() },
3481        )
3482        .trace_err();
3483
3484    Ok(())
3485}
3486
3487async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3488    subscribe_user_to_channels(
3489        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3490        &session,
3491    )
3492    .await?;
3493    Ok(())
3494}
3495
3496async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3497    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3498    let mut pool = session.connection_pool().await;
3499    for membership in &channels_for_user.channel_memberships {
3500        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3501    }
3502    session.peer.send(
3503        session.connection_id,
3504        build_update_user_channels(&channels_for_user),
3505    )?;
3506    session.peer.send(
3507        session.connection_id,
3508        build_channels_update(channels_for_user),
3509    )?;
3510    Ok(())
3511}
3512
3513/// Creates a new channel.
3514async fn create_channel(
3515    request: proto::CreateChannel,
3516    response: Response<proto::CreateChannel>,
3517    session: UserSession,
3518) -> Result<()> {
3519    let db = session.db().await;
3520
3521    let parent_id = request.parent_id.map(ChannelId::from_proto);
3522    let (channel, membership) = db
3523        .create_channel(&request.name, parent_id, session.user_id())
3524        .await?;
3525
3526    let root_id = channel.root_id();
3527    let channel = Channel::from_model(channel);
3528
3529    response.send(proto::CreateChannelResponse {
3530        channel: Some(channel.to_proto()),
3531        parent_id: request.parent_id,
3532    })?;
3533
3534    let mut connection_pool = session.connection_pool().await;
3535    if let Some(membership) = membership {
3536        connection_pool.subscribe_to_channel(
3537            membership.user_id,
3538            membership.channel_id,
3539            membership.role,
3540        );
3541        let update = proto::UpdateUserChannels {
3542            channel_memberships: vec![proto::ChannelMembership {
3543                channel_id: membership.channel_id.to_proto(),
3544                role: membership.role.into(),
3545            }],
3546            ..Default::default()
3547        };
3548        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3549            session.peer.send(connection_id, update.clone())?;
3550        }
3551    }
3552
3553    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3554        if !role.can_see_channel(channel.visibility) {
3555            continue;
3556        }
3557
3558        let update = proto::UpdateChannels {
3559            channels: vec![channel.to_proto()],
3560            ..Default::default()
3561        };
3562        session.peer.send(connection_id, update.clone())?;
3563    }
3564
3565    Ok(())
3566}
3567
3568/// Delete a channel
3569async fn delete_channel(
3570    request: proto::DeleteChannel,
3571    response: Response<proto::DeleteChannel>,
3572    session: UserSession,
3573) -> Result<()> {
3574    let db = session.db().await;
3575
3576    let channel_id = request.channel_id;
3577    let (root_channel, removed_channels) = db
3578        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3579        .await?;
3580    response.send(proto::Ack {})?;
3581
3582    // Notify members of removed channels
3583    let mut update = proto::UpdateChannels::default();
3584    update
3585        .delete_channels
3586        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3587
3588    let connection_pool = session.connection_pool().await;
3589    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3590        session.peer.send(connection_id, update.clone())?;
3591    }
3592
3593    Ok(())
3594}
3595
3596/// Invite someone to join a channel.
3597async fn invite_channel_member(
3598    request: proto::InviteChannelMember,
3599    response: Response<proto::InviteChannelMember>,
3600    session: UserSession,
3601) -> Result<()> {
3602    let db = session.db().await;
3603    let channel_id = ChannelId::from_proto(request.channel_id);
3604    let invitee_id = UserId::from_proto(request.user_id);
3605    let InviteMemberResult {
3606        channel,
3607        notifications,
3608    } = db
3609        .invite_channel_member(
3610            channel_id,
3611            invitee_id,
3612            session.user_id(),
3613            request.role().into(),
3614        )
3615        .await?;
3616
3617    let update = proto::UpdateChannels {
3618        channel_invitations: vec![channel.to_proto()],
3619        ..Default::default()
3620    };
3621
3622    let connection_pool = session.connection_pool().await;
3623    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3624        session.peer.send(connection_id, update.clone())?;
3625    }
3626
3627    send_notifications(&connection_pool, &session.peer, notifications);
3628
3629    response.send(proto::Ack {})?;
3630    Ok(())
3631}
3632
3633/// remove someone from a channel
3634async fn remove_channel_member(
3635    request: proto::RemoveChannelMember,
3636    response: Response<proto::RemoveChannelMember>,
3637    session: UserSession,
3638) -> Result<()> {
3639    let db = session.db().await;
3640    let channel_id = ChannelId::from_proto(request.channel_id);
3641    let member_id = UserId::from_proto(request.user_id);
3642
3643    let RemoveChannelMemberResult {
3644        membership_update,
3645        notification_id,
3646    } = db
3647        .remove_channel_member(channel_id, member_id, session.user_id())
3648        .await?;
3649
3650    let mut connection_pool = session.connection_pool().await;
3651    notify_membership_updated(
3652        &mut connection_pool,
3653        membership_update,
3654        member_id,
3655        &session.peer,
3656    );
3657    for connection_id in connection_pool.user_connection_ids(member_id) {
3658        if let Some(notification_id) = notification_id {
3659            session
3660                .peer
3661                .send(
3662                    connection_id,
3663                    proto::DeleteNotification {
3664                        notification_id: notification_id.to_proto(),
3665                    },
3666                )
3667                .trace_err();
3668        }
3669    }
3670
3671    response.send(proto::Ack {})?;
3672    Ok(())
3673}
3674
3675/// Toggle the channel between public and private.
3676/// Care is taken to maintain the invariant that public channels only descend from public channels,
3677/// (though members-only channels can appear at any point in the hierarchy).
3678async fn set_channel_visibility(
3679    request: proto::SetChannelVisibility,
3680    response: Response<proto::SetChannelVisibility>,
3681    session: UserSession,
3682) -> Result<()> {
3683    let db = session.db().await;
3684    let channel_id = ChannelId::from_proto(request.channel_id);
3685    let visibility = request.visibility().into();
3686
3687    let channel_model = db
3688        .set_channel_visibility(channel_id, visibility, session.user_id())
3689        .await?;
3690    let root_id = channel_model.root_id();
3691    let channel = Channel::from_model(channel_model);
3692
3693    let mut connection_pool = session.connection_pool().await;
3694    for (user_id, role) in connection_pool
3695        .channel_user_ids(root_id)
3696        .collect::<Vec<_>>()
3697        .into_iter()
3698    {
3699        let update = if role.can_see_channel(channel.visibility) {
3700            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3701            proto::UpdateChannels {
3702                channels: vec![channel.to_proto()],
3703                ..Default::default()
3704            }
3705        } else {
3706            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3707            proto::UpdateChannels {
3708                delete_channels: vec![channel.id.to_proto()],
3709                ..Default::default()
3710            }
3711        };
3712
3713        for connection_id in connection_pool.user_connection_ids(user_id) {
3714            session.peer.send(connection_id, update.clone())?;
3715        }
3716    }
3717
3718    response.send(proto::Ack {})?;
3719    Ok(())
3720}
3721
3722/// Alter the role for a user in the channel.
3723async fn set_channel_member_role(
3724    request: proto::SetChannelMemberRole,
3725    response: Response<proto::SetChannelMemberRole>,
3726    session: UserSession,
3727) -> Result<()> {
3728    let db = session.db().await;
3729    let channel_id = ChannelId::from_proto(request.channel_id);
3730    let member_id = UserId::from_proto(request.user_id);
3731    let result = db
3732        .set_channel_member_role(
3733            channel_id,
3734            session.user_id(),
3735            member_id,
3736            request.role().into(),
3737        )
3738        .await?;
3739
3740    match result {
3741        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3742            let mut connection_pool = session.connection_pool().await;
3743            notify_membership_updated(
3744                &mut connection_pool,
3745                membership_update,
3746                member_id,
3747                &session.peer,
3748            )
3749        }
3750        db::SetMemberRoleResult::InviteUpdated(channel) => {
3751            let update = proto::UpdateChannels {
3752                channel_invitations: vec![channel.to_proto()],
3753                ..Default::default()
3754            };
3755
3756            for connection_id in session
3757                .connection_pool()
3758                .await
3759                .user_connection_ids(member_id)
3760            {
3761                session.peer.send(connection_id, update.clone())?;
3762            }
3763        }
3764    }
3765
3766    response.send(proto::Ack {})?;
3767    Ok(())
3768}
3769
3770/// Change the name of a channel
3771async fn rename_channel(
3772    request: proto::RenameChannel,
3773    response: Response<proto::RenameChannel>,
3774    session: UserSession,
3775) -> Result<()> {
3776    let db = session.db().await;
3777    let channel_id = ChannelId::from_proto(request.channel_id);
3778    let channel_model = db
3779        .rename_channel(channel_id, session.user_id(), &request.name)
3780        .await?;
3781    let root_id = channel_model.root_id();
3782    let channel = Channel::from_model(channel_model);
3783
3784    response.send(proto::RenameChannelResponse {
3785        channel: Some(channel.to_proto()),
3786    })?;
3787
3788    let connection_pool = session.connection_pool().await;
3789    let update = proto::UpdateChannels {
3790        channels: vec![channel.to_proto()],
3791        ..Default::default()
3792    };
3793    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3794        if role.can_see_channel(channel.visibility) {
3795            session.peer.send(connection_id, update.clone())?;
3796        }
3797    }
3798
3799    Ok(())
3800}
3801
3802/// Move a channel to a new parent.
3803async fn move_channel(
3804    request: proto::MoveChannel,
3805    response: Response<proto::MoveChannel>,
3806    session: UserSession,
3807) -> Result<()> {
3808    let channel_id = ChannelId::from_proto(request.channel_id);
3809    let to = ChannelId::from_proto(request.to);
3810
3811    let (root_id, channels) = session
3812        .db()
3813        .await
3814        .move_channel(channel_id, to, session.user_id())
3815        .await?;
3816
3817    let connection_pool = session.connection_pool().await;
3818    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3819        let channels = channels
3820            .iter()
3821            .filter_map(|channel| {
3822                if role.can_see_channel(channel.visibility) {
3823                    Some(channel.to_proto())
3824                } else {
3825                    None
3826                }
3827            })
3828            .collect::<Vec<_>>();
3829        if channels.is_empty() {
3830            continue;
3831        }
3832
3833        let update = proto::UpdateChannels {
3834            channels,
3835            ..Default::default()
3836        };
3837
3838        session.peer.send(connection_id, update.clone())?;
3839    }
3840
3841    response.send(Ack {})?;
3842    Ok(())
3843}
3844
3845/// Get the list of channel members
3846async fn get_channel_members(
3847    request: proto::GetChannelMembers,
3848    response: Response<proto::GetChannelMembers>,
3849    session: UserSession,
3850) -> Result<()> {
3851    let db = session.db().await;
3852    let channel_id = ChannelId::from_proto(request.channel_id);
3853    let limit = if request.limit == 0 {
3854        u16::MAX as u64
3855    } else {
3856        request.limit
3857    };
3858    let (members, users) = db
3859        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3860        .await?;
3861    response.send(proto::GetChannelMembersResponse { members, users })?;
3862    Ok(())
3863}
3864
3865/// Accept or decline a channel invitation.
3866async fn respond_to_channel_invite(
3867    request: proto::RespondToChannelInvite,
3868    response: Response<proto::RespondToChannelInvite>,
3869    session: UserSession,
3870) -> Result<()> {
3871    let db = session.db().await;
3872    let channel_id = ChannelId::from_proto(request.channel_id);
3873    let RespondToChannelInvite {
3874        membership_update,
3875        notifications,
3876    } = db
3877        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3878        .await?;
3879
3880    let mut connection_pool = session.connection_pool().await;
3881    if let Some(membership_update) = membership_update {
3882        notify_membership_updated(
3883            &mut connection_pool,
3884            membership_update,
3885            session.user_id(),
3886            &session.peer,
3887        );
3888    } else {
3889        let update = proto::UpdateChannels {
3890            remove_channel_invitations: vec![channel_id.to_proto()],
3891            ..Default::default()
3892        };
3893
3894        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3895            session.peer.send(connection_id, update.clone())?;
3896        }
3897    };
3898
3899    send_notifications(&connection_pool, &session.peer, notifications);
3900
3901    response.send(proto::Ack {})?;
3902
3903    Ok(())
3904}
3905
3906/// Join the channels' room
3907async fn join_channel(
3908    request: proto::JoinChannel,
3909    response: Response<proto::JoinChannel>,
3910    session: UserSession,
3911) -> Result<()> {
3912    let channel_id = ChannelId::from_proto(request.channel_id);
3913    join_channel_internal(channel_id, Box::new(response), session).await
3914}
3915
3916trait JoinChannelInternalResponse {
3917    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3918}
3919impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3920    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3921        Response::<proto::JoinChannel>::send(self, result)
3922    }
3923}
3924impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3925    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3926        Response::<proto::JoinRoom>::send(self, result)
3927    }
3928}
3929
3930async fn join_channel_internal(
3931    channel_id: ChannelId,
3932    response: Box<impl JoinChannelInternalResponse>,
3933    session: UserSession,
3934) -> Result<()> {
3935    let joined_room = {
3936        let mut db = session.db().await;
3937        // If zed quits without leaving the room, and the user re-opens zed before the
3938        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3939        // room they were in.
3940        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3941            tracing::info!(
3942                stale_connection_id = %connection,
3943                "cleaning up stale connection",
3944            );
3945            drop(db);
3946            leave_room_for_session(&session, connection).await?;
3947            db = session.db().await;
3948        }
3949
3950        let (joined_room, membership_updated, role) = db
3951            .join_channel(channel_id, session.user_id(), session.connection_id)
3952            .await?;
3953
3954        let live_kit_connection_info =
3955            session
3956                .app_state
3957                .live_kit_client
3958                .as_ref()
3959                .and_then(|live_kit| {
3960                    let (can_publish, token) = if role == ChannelRole::Guest {
3961                        (
3962                            false,
3963                            live_kit
3964                                .guest_token(
3965                                    &joined_room.room.live_kit_room,
3966                                    &session.user_id().to_string(),
3967                                )
3968                                .trace_err()?,
3969                        )
3970                    } else {
3971                        (
3972                            true,
3973                            live_kit
3974                                .room_token(
3975                                    &joined_room.room.live_kit_room,
3976                                    &session.user_id().to_string(),
3977                                )
3978                                .trace_err()?,
3979                        )
3980                    };
3981
3982                    Some(LiveKitConnectionInfo {
3983                        server_url: live_kit.url().into(),
3984                        token,
3985                        can_publish,
3986                    })
3987                });
3988
3989        response.send(proto::JoinRoomResponse {
3990            room: Some(joined_room.room.clone()),
3991            channel_id: joined_room
3992                .channel
3993                .as_ref()
3994                .map(|channel| channel.id.to_proto()),
3995            live_kit_connection_info,
3996        })?;
3997
3998        let mut connection_pool = session.connection_pool().await;
3999        if let Some(membership_updated) = membership_updated {
4000            notify_membership_updated(
4001                &mut connection_pool,
4002                membership_updated,
4003                session.user_id(),
4004                &session.peer,
4005            );
4006        }
4007
4008        room_updated(&joined_room.room, &session.peer);
4009
4010        joined_room
4011    };
4012
4013    channel_updated(
4014        &joined_room
4015            .channel
4016            .ok_or_else(|| anyhow!("channel not returned"))?,
4017        &joined_room.room,
4018        &session.peer,
4019        &*session.connection_pool().await,
4020    );
4021
4022    update_user_contacts(session.user_id(), &session).await?;
4023    Ok(())
4024}
4025
4026/// Start editing the channel notes
4027async fn join_channel_buffer(
4028    request: proto::JoinChannelBuffer,
4029    response: Response<proto::JoinChannelBuffer>,
4030    session: UserSession,
4031) -> Result<()> {
4032    let db = session.db().await;
4033    let channel_id = ChannelId::from_proto(request.channel_id);
4034
4035    let open_response = db
4036        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4037        .await?;
4038
4039    let collaborators = open_response.collaborators.clone();
4040    response.send(open_response)?;
4041
4042    let update = UpdateChannelBufferCollaborators {
4043        channel_id: channel_id.to_proto(),
4044        collaborators: collaborators.clone(),
4045    };
4046    channel_buffer_updated(
4047        session.connection_id,
4048        collaborators
4049            .iter()
4050            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4051        &update,
4052        &session.peer,
4053    );
4054
4055    Ok(())
4056}
4057
4058/// Edit the channel notes
4059async fn update_channel_buffer(
4060    request: proto::UpdateChannelBuffer,
4061    session: UserSession,
4062) -> Result<()> {
4063    let db = session.db().await;
4064    let channel_id = ChannelId::from_proto(request.channel_id);
4065
4066    let (collaborators, epoch, version) = db
4067        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4068        .await?;
4069
4070    channel_buffer_updated(
4071        session.connection_id,
4072        collaborators.clone(),
4073        &proto::UpdateChannelBuffer {
4074            channel_id: channel_id.to_proto(),
4075            operations: request.operations,
4076        },
4077        &session.peer,
4078    );
4079
4080    let pool = &*session.connection_pool().await;
4081
4082    let non_collaborators =
4083        pool.channel_connection_ids(channel_id)
4084            .filter_map(|(connection_id, _)| {
4085                if collaborators.contains(&connection_id) {
4086                    None
4087                } else {
4088                    Some(connection_id)
4089                }
4090            });
4091
4092    broadcast(None, non_collaborators, |peer_id| {
4093        session.peer.send(
4094            peer_id,
4095            proto::UpdateChannels {
4096                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4097                    channel_id: channel_id.to_proto(),
4098                    epoch: epoch as u64,
4099                    version: version.clone(),
4100                }],
4101                ..Default::default()
4102            },
4103        )
4104    });
4105
4106    Ok(())
4107}
4108
4109/// Rejoin the channel notes after a connection blip
4110async fn rejoin_channel_buffers(
4111    request: proto::RejoinChannelBuffers,
4112    response: Response<proto::RejoinChannelBuffers>,
4113    session: UserSession,
4114) -> Result<()> {
4115    let db = session.db().await;
4116    let buffers = db
4117        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4118        .await?;
4119
4120    for rejoined_buffer in &buffers {
4121        let collaborators_to_notify = rejoined_buffer
4122            .buffer
4123            .collaborators
4124            .iter()
4125            .filter_map(|c| Some(c.peer_id?.into()));
4126        channel_buffer_updated(
4127            session.connection_id,
4128            collaborators_to_notify,
4129            &proto::UpdateChannelBufferCollaborators {
4130                channel_id: rejoined_buffer.buffer.channel_id,
4131                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4132            },
4133            &session.peer,
4134        );
4135    }
4136
4137    response.send(proto::RejoinChannelBuffersResponse {
4138        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4139    })?;
4140
4141    Ok(())
4142}
4143
4144/// Stop editing the channel notes
4145async fn leave_channel_buffer(
4146    request: proto::LeaveChannelBuffer,
4147    response: Response<proto::LeaveChannelBuffer>,
4148    session: UserSession,
4149) -> Result<()> {
4150    let db = session.db().await;
4151    let channel_id = ChannelId::from_proto(request.channel_id);
4152
4153    let left_buffer = db
4154        .leave_channel_buffer(channel_id, session.connection_id)
4155        .await?;
4156
4157    response.send(Ack {})?;
4158
4159    channel_buffer_updated(
4160        session.connection_id,
4161        left_buffer.connections,
4162        &proto::UpdateChannelBufferCollaborators {
4163            channel_id: channel_id.to_proto(),
4164            collaborators: left_buffer.collaborators,
4165        },
4166        &session.peer,
4167    );
4168
4169    Ok(())
4170}
4171
4172fn channel_buffer_updated<T: EnvelopedMessage>(
4173    sender_id: ConnectionId,
4174    collaborators: impl IntoIterator<Item = ConnectionId>,
4175    message: &T,
4176    peer: &Peer,
4177) {
4178    broadcast(Some(sender_id), collaborators, |peer_id| {
4179        peer.send(peer_id, message.clone())
4180    });
4181}
4182
4183fn send_notifications(
4184    connection_pool: &ConnectionPool,
4185    peer: &Peer,
4186    notifications: db::NotificationBatch,
4187) {
4188    for (user_id, notification) in notifications {
4189        for connection_id in connection_pool.user_connection_ids(user_id) {
4190            if let Err(error) = peer.send(
4191                connection_id,
4192                proto::AddNotification {
4193                    notification: Some(notification.clone()),
4194                },
4195            ) {
4196                tracing::error!(
4197                    "failed to send notification to {:?} {}",
4198                    connection_id,
4199                    error
4200                );
4201            }
4202        }
4203    }
4204}
4205
4206/// Send a message to the channel
4207async fn send_channel_message(
4208    request: proto::SendChannelMessage,
4209    response: Response<proto::SendChannelMessage>,
4210    session: UserSession,
4211) -> Result<()> {
4212    // Validate the message body.
4213    let body = request.body.trim().to_string();
4214    if body.len() > MAX_MESSAGE_LEN {
4215        return Err(anyhow!("message is too long"))?;
4216    }
4217    if body.is_empty() {
4218        return Err(anyhow!("message can't be blank"))?;
4219    }
4220
4221    // TODO: adjust mentions if body is trimmed
4222
4223    let timestamp = OffsetDateTime::now_utc();
4224    let nonce = request
4225        .nonce
4226        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4227
4228    let channel_id = ChannelId::from_proto(request.channel_id);
4229    let CreatedChannelMessage {
4230        message_id,
4231        participant_connection_ids,
4232        notifications,
4233    } = session
4234        .db()
4235        .await
4236        .create_channel_message(
4237            channel_id,
4238            session.user_id(),
4239            &body,
4240            &request.mentions,
4241            timestamp,
4242            nonce.clone().into(),
4243            request.reply_to_message_id.map(MessageId::from_proto),
4244        )
4245        .await?;
4246
4247    let message = proto::ChannelMessage {
4248        sender_id: session.user_id().to_proto(),
4249        id: message_id.to_proto(),
4250        body,
4251        mentions: request.mentions,
4252        timestamp: timestamp.unix_timestamp() as u64,
4253        nonce: Some(nonce),
4254        reply_to_message_id: request.reply_to_message_id,
4255        edited_at: None,
4256    };
4257    broadcast(
4258        Some(session.connection_id),
4259        participant_connection_ids.clone(),
4260        |connection| {
4261            session.peer.send(
4262                connection,
4263                proto::ChannelMessageSent {
4264                    channel_id: channel_id.to_proto(),
4265                    message: Some(message.clone()),
4266                },
4267            )
4268        },
4269    );
4270    response.send(proto::SendChannelMessageResponse {
4271        message: Some(message),
4272    })?;
4273
4274    let pool = &*session.connection_pool().await;
4275    let non_participants =
4276        pool.channel_connection_ids(channel_id)
4277            .filter_map(|(connection_id, _)| {
4278                if participant_connection_ids.contains(&connection_id) {
4279                    None
4280                } else {
4281                    Some(connection_id)
4282                }
4283            });
4284    broadcast(None, non_participants, |peer_id| {
4285        session.peer.send(
4286            peer_id,
4287            proto::UpdateChannels {
4288                latest_channel_message_ids: vec![proto::ChannelMessageId {
4289                    channel_id: channel_id.to_proto(),
4290                    message_id: message_id.to_proto(),
4291                }],
4292                ..Default::default()
4293            },
4294        )
4295    });
4296    send_notifications(pool, &session.peer, notifications);
4297
4298    Ok(())
4299}
4300
4301/// Delete a channel message
4302async fn remove_channel_message(
4303    request: proto::RemoveChannelMessage,
4304    response: Response<proto::RemoveChannelMessage>,
4305    session: UserSession,
4306) -> Result<()> {
4307    let channel_id = ChannelId::from_proto(request.channel_id);
4308    let message_id = MessageId::from_proto(request.message_id);
4309    let (connection_ids, existing_notification_ids) = session
4310        .db()
4311        .await
4312        .remove_channel_message(channel_id, message_id, session.user_id())
4313        .await?;
4314
4315    broadcast(
4316        Some(session.connection_id),
4317        connection_ids,
4318        move |connection| {
4319            session.peer.send(connection, request.clone())?;
4320
4321            for notification_id in &existing_notification_ids {
4322                session.peer.send(
4323                    connection,
4324                    proto::DeleteNotification {
4325                        notification_id: (*notification_id).to_proto(),
4326                    },
4327                )?;
4328            }
4329
4330            Ok(())
4331        },
4332    );
4333    response.send(proto::Ack {})?;
4334    Ok(())
4335}
4336
4337async fn update_channel_message(
4338    request: proto::UpdateChannelMessage,
4339    response: Response<proto::UpdateChannelMessage>,
4340    session: UserSession,
4341) -> Result<()> {
4342    let channel_id = ChannelId::from_proto(request.channel_id);
4343    let message_id = MessageId::from_proto(request.message_id);
4344    let updated_at = OffsetDateTime::now_utc();
4345    let UpdatedChannelMessage {
4346        message_id,
4347        participant_connection_ids,
4348        notifications,
4349        reply_to_message_id,
4350        timestamp,
4351        deleted_mention_notification_ids,
4352        updated_mention_notifications,
4353    } = session
4354        .db()
4355        .await
4356        .update_channel_message(
4357            channel_id,
4358            message_id,
4359            session.user_id(),
4360            request.body.as_str(),
4361            &request.mentions,
4362            updated_at,
4363        )
4364        .await?;
4365
4366    let nonce = request
4367        .nonce
4368        .clone()
4369        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4370
4371    let message = proto::ChannelMessage {
4372        sender_id: session.user_id().to_proto(),
4373        id: message_id.to_proto(),
4374        body: request.body.clone(),
4375        mentions: request.mentions.clone(),
4376        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4377        nonce: Some(nonce),
4378        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4379        edited_at: Some(updated_at.unix_timestamp() as u64),
4380    };
4381
4382    response.send(proto::Ack {})?;
4383
4384    let pool = &*session.connection_pool().await;
4385    broadcast(
4386        Some(session.connection_id),
4387        participant_connection_ids,
4388        |connection| {
4389            session.peer.send(
4390                connection,
4391                proto::ChannelMessageUpdate {
4392                    channel_id: channel_id.to_proto(),
4393                    message: Some(message.clone()),
4394                },
4395            )?;
4396
4397            for notification_id in &deleted_mention_notification_ids {
4398                session.peer.send(
4399                    connection,
4400                    proto::DeleteNotification {
4401                        notification_id: (*notification_id).to_proto(),
4402                    },
4403                )?;
4404            }
4405
4406            for notification in &updated_mention_notifications {
4407                session.peer.send(
4408                    connection,
4409                    proto::UpdateNotification {
4410                        notification: Some(notification.clone()),
4411                    },
4412                )?;
4413            }
4414
4415            Ok(())
4416        },
4417    );
4418
4419    send_notifications(pool, &session.peer, notifications);
4420
4421    Ok(())
4422}
4423
4424/// Mark a channel message as read
4425async fn acknowledge_channel_message(
4426    request: proto::AckChannelMessage,
4427    session: UserSession,
4428) -> Result<()> {
4429    let channel_id = ChannelId::from_proto(request.channel_id);
4430    let message_id = MessageId::from_proto(request.message_id);
4431    let notifications = session
4432        .db()
4433        .await
4434        .observe_channel_message(channel_id, session.user_id(), message_id)
4435        .await?;
4436    send_notifications(
4437        &*session.connection_pool().await,
4438        &session.peer,
4439        notifications,
4440    );
4441    Ok(())
4442}
4443
4444/// Mark a buffer version as synced
4445async fn acknowledge_buffer_version(
4446    request: proto::AckBufferOperation,
4447    session: UserSession,
4448) -> Result<()> {
4449    let buffer_id = BufferId::from_proto(request.buffer_id);
4450    session
4451        .db()
4452        .await
4453        .observe_buffer_version(
4454            buffer_id,
4455            session.user_id(),
4456            request.epoch as i32,
4457            &request.version,
4458        )
4459        .await?;
4460    Ok(())
4461}
4462
4463async fn count_language_model_tokens(
4464    request: proto::CountLanguageModelTokens,
4465    response: Response<proto::CountLanguageModelTokens>,
4466    session: Session,
4467    config: &Config,
4468) -> Result<()> {
4469    let Some(session) = session.for_user() else {
4470        return Err(anyhow!("user not found"))?;
4471    };
4472    authorize_access_to_legacy_llm_endpoints(&session).await?;
4473
4474    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4475        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4476        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4477    };
4478
4479    session
4480        .app_state
4481        .rate_limiter
4482        .check(&*rate_limit, session.user_id())
4483        .await?;
4484
4485    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4486        Some(proto::LanguageModelProvider::Google) => {
4487            let api_key = config
4488                .google_ai_api_key
4489                .as_ref()
4490                .context("no Google AI API key configured on the server")?;
4491            google_ai::count_tokens(
4492                session.http_client.as_ref(),
4493                google_ai::API_URL,
4494                api_key,
4495                serde_json::from_str(&request.request)?,
4496                None,
4497            )
4498            .await?
4499        }
4500        _ => return Err(anyhow!("unsupported provider"))?,
4501    };
4502
4503    response.send(proto::CountLanguageModelTokensResponse {
4504        token_count: result.total_tokens as u32,
4505    })?;
4506
4507    Ok(())
4508}
4509
4510struct ZedProCountLanguageModelTokensRateLimit;
4511
4512impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4513    fn capacity(&self) -> usize {
4514        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4515            .ok()
4516            .and_then(|v| v.parse().ok())
4517            .unwrap_or(600) // Picked arbitrarily
4518    }
4519
4520    fn refill_duration(&self) -> chrono::Duration {
4521        chrono::Duration::hours(1)
4522    }
4523
4524    fn db_name(&self) -> &'static str {
4525        "zed-pro:count-language-model-tokens"
4526    }
4527}
4528
4529struct FreeCountLanguageModelTokensRateLimit;
4530
4531impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4532    fn capacity(&self) -> usize {
4533        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4534            .ok()
4535            .and_then(|v| v.parse().ok())
4536            .unwrap_or(600 / 10) // Picked arbitrarily
4537    }
4538
4539    fn refill_duration(&self) -> chrono::Duration {
4540        chrono::Duration::hours(1)
4541    }
4542
4543    fn db_name(&self) -> &'static str {
4544        "free:count-language-model-tokens"
4545    }
4546}
4547
4548struct ZedProComputeEmbeddingsRateLimit;
4549
4550impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4551    fn capacity(&self) -> usize {
4552        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4553            .ok()
4554            .and_then(|v| v.parse().ok())
4555            .unwrap_or(5000) // Picked arbitrarily
4556    }
4557
4558    fn refill_duration(&self) -> chrono::Duration {
4559        chrono::Duration::hours(1)
4560    }
4561
4562    fn db_name(&self) -> &'static str {
4563        "zed-pro:compute-embeddings"
4564    }
4565}
4566
4567struct FreeComputeEmbeddingsRateLimit;
4568
4569impl RateLimit for FreeComputeEmbeddingsRateLimit {
4570    fn capacity(&self) -> usize {
4571        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4572            .ok()
4573            .and_then(|v| v.parse().ok())
4574            .unwrap_or(5000 / 10) // Picked arbitrarily
4575    }
4576
4577    fn refill_duration(&self) -> chrono::Duration {
4578        chrono::Duration::hours(1)
4579    }
4580
4581    fn db_name(&self) -> &'static str {
4582        "free:compute-embeddings"
4583    }
4584}
4585
4586async fn compute_embeddings(
4587    request: proto::ComputeEmbeddings,
4588    response: Response<proto::ComputeEmbeddings>,
4589    session: UserSession,
4590    api_key: Option<Arc<str>>,
4591) -> Result<()> {
4592    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4593    authorize_access_to_legacy_llm_endpoints(&session).await?;
4594
4595    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4596        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4597        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4598    };
4599
4600    session
4601        .app_state
4602        .rate_limiter
4603        .check(&*rate_limit, session.user_id())
4604        .await?;
4605
4606    let embeddings = match request.model.as_str() {
4607        "openai/text-embedding-3-small" => {
4608            open_ai::embed(
4609                session.http_client.as_ref(),
4610                OPEN_AI_API_URL,
4611                &api_key,
4612                OpenAiEmbeddingModel::TextEmbedding3Small,
4613                request.texts.iter().map(|text| text.as_str()),
4614            )
4615            .await?
4616        }
4617        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4618    };
4619
4620    let embeddings = request
4621        .texts
4622        .iter()
4623        .map(|text| {
4624            let mut hasher = sha2::Sha256::new();
4625            hasher.update(text.as_bytes());
4626            let result = hasher.finalize();
4627            result.to_vec()
4628        })
4629        .zip(
4630            embeddings
4631                .data
4632                .into_iter()
4633                .map(|embedding| embedding.embedding),
4634        )
4635        .collect::<HashMap<_, _>>();
4636
4637    let db = session.db().await;
4638    db.save_embeddings(&request.model, &embeddings)
4639        .await
4640        .context("failed to save embeddings")
4641        .trace_err();
4642
4643    response.send(proto::ComputeEmbeddingsResponse {
4644        embeddings: embeddings
4645            .into_iter()
4646            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4647            .collect(),
4648    })?;
4649    Ok(())
4650}
4651
4652async fn get_cached_embeddings(
4653    request: proto::GetCachedEmbeddings,
4654    response: Response<proto::GetCachedEmbeddings>,
4655    session: UserSession,
4656) -> Result<()> {
4657    authorize_access_to_legacy_llm_endpoints(&session).await?;
4658
4659    let db = session.db().await;
4660    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4661
4662    response.send(proto::GetCachedEmbeddingsResponse {
4663        embeddings: embeddings
4664            .into_iter()
4665            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4666            .collect(),
4667    })?;
4668    Ok(())
4669}
4670
4671/// This is leftover from before the LLM service.
4672///
4673/// The endpoints protected by this check will be moved there eventually.
4674async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4675    if session.is_staff() {
4676        Ok(())
4677    } else {
4678        Err(anyhow!("permission denied"))?
4679    }
4680}
4681
4682/// Get a Supermaven API key for the user
4683async fn get_supermaven_api_key(
4684    _request: proto::GetSupermavenApiKey,
4685    response: Response<proto::GetSupermavenApiKey>,
4686    session: UserSession,
4687) -> Result<()> {
4688    let user_id: String = session.user_id().to_string();
4689    if !session.is_staff() {
4690        return Err(anyhow!("supermaven not enabled for this account"))?;
4691    }
4692
4693    let email = session
4694        .email()
4695        .ok_or_else(|| anyhow!("user must have an email"))?;
4696
4697    let supermaven_admin_api = session
4698        .supermaven_client
4699        .as_ref()
4700        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4701
4702    let result = supermaven_admin_api
4703        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4704        .await?;
4705
4706    response.send(proto::GetSupermavenApiKeyResponse {
4707        api_key: result.api_key,
4708    })?;
4709
4710    Ok(())
4711}
4712
4713/// Start receiving chat updates for a channel
4714async fn join_channel_chat(
4715    request: proto::JoinChannelChat,
4716    response: Response<proto::JoinChannelChat>,
4717    session: UserSession,
4718) -> Result<()> {
4719    let channel_id = ChannelId::from_proto(request.channel_id);
4720
4721    let db = session.db().await;
4722    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4723        .await?;
4724    let messages = db
4725        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4726        .await?;
4727    response.send(proto::JoinChannelChatResponse {
4728        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4729        messages,
4730    })?;
4731    Ok(())
4732}
4733
4734/// Stop receiving chat updates for a channel
4735async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4736    let channel_id = ChannelId::from_proto(request.channel_id);
4737    session
4738        .db()
4739        .await
4740        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4741        .await?;
4742    Ok(())
4743}
4744
4745/// Retrieve the chat history for a channel
4746async fn get_channel_messages(
4747    request: proto::GetChannelMessages,
4748    response: Response<proto::GetChannelMessages>,
4749    session: UserSession,
4750) -> Result<()> {
4751    let channel_id = ChannelId::from_proto(request.channel_id);
4752    let messages = session
4753        .db()
4754        .await
4755        .get_channel_messages(
4756            channel_id,
4757            session.user_id(),
4758            MESSAGE_COUNT_PER_PAGE,
4759            Some(MessageId::from_proto(request.before_message_id)),
4760        )
4761        .await?;
4762    response.send(proto::GetChannelMessagesResponse {
4763        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4764        messages,
4765    })?;
4766    Ok(())
4767}
4768
4769/// Retrieve specific chat messages
4770async fn get_channel_messages_by_id(
4771    request: proto::GetChannelMessagesById,
4772    response: Response<proto::GetChannelMessagesById>,
4773    session: UserSession,
4774) -> Result<()> {
4775    let message_ids = request
4776        .message_ids
4777        .iter()
4778        .map(|id| MessageId::from_proto(*id))
4779        .collect::<Vec<_>>();
4780    let messages = session
4781        .db()
4782        .await
4783        .get_channel_messages_by_id(session.user_id(), &message_ids)
4784        .await?;
4785    response.send(proto::GetChannelMessagesResponse {
4786        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4787        messages,
4788    })?;
4789    Ok(())
4790}
4791
4792/// Retrieve the current users notifications
4793async fn get_notifications(
4794    request: proto::GetNotifications,
4795    response: Response<proto::GetNotifications>,
4796    session: UserSession,
4797) -> Result<()> {
4798    let notifications = session
4799        .db()
4800        .await
4801        .get_notifications(
4802            session.user_id(),
4803            NOTIFICATION_COUNT_PER_PAGE,
4804            request.before_id.map(db::NotificationId::from_proto),
4805        )
4806        .await?;
4807    response.send(proto::GetNotificationsResponse {
4808        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4809        notifications,
4810    })?;
4811    Ok(())
4812}
4813
4814/// Mark notifications as read
4815async fn mark_notification_as_read(
4816    request: proto::MarkNotificationRead,
4817    response: Response<proto::MarkNotificationRead>,
4818    session: UserSession,
4819) -> Result<()> {
4820    let database = &session.db().await;
4821    let notifications = database
4822        .mark_notification_as_read_by_id(
4823            session.user_id(),
4824            NotificationId::from_proto(request.notification_id),
4825        )
4826        .await?;
4827    send_notifications(
4828        &*session.connection_pool().await,
4829        &session.peer,
4830        notifications,
4831    );
4832    response.send(proto::Ack {})?;
4833    Ok(())
4834}
4835
4836/// Get the current users information
4837async fn get_private_user_info(
4838    _request: proto::GetPrivateUserInfo,
4839    response: Response<proto::GetPrivateUserInfo>,
4840    session: UserSession,
4841) -> Result<()> {
4842    let db = session.db().await;
4843
4844    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4845    let user = db
4846        .get_user_by_id(session.user_id())
4847        .await?
4848        .ok_or_else(|| anyhow!("user not found"))?;
4849    let flags = db.get_user_flags(session.user_id()).await?;
4850
4851    response.send(proto::GetPrivateUserInfoResponse {
4852        metrics_id,
4853        staff: user.admin,
4854        flags,
4855        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4856    })?;
4857    Ok(())
4858}
4859
4860/// Accept the terms of service (tos) on behalf of the current user
4861async fn accept_terms_of_service(
4862    _request: proto::AcceptTermsOfService,
4863    response: Response<proto::AcceptTermsOfService>,
4864    session: UserSession,
4865) -> Result<()> {
4866    let db = session.db().await;
4867
4868    let accepted_tos_at = Utc::now();
4869    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4870        .await?;
4871
4872    response.send(proto::AcceptTermsOfServiceResponse {
4873        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4874    })?;
4875    Ok(())
4876}
4877
4878/// The minimum account age an account must have in order to use the LLM service.
4879const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4880
4881async fn get_llm_api_token(
4882    _request: proto::GetLlmToken,
4883    response: Response<proto::GetLlmToken>,
4884    session: UserSession,
4885) -> Result<()> {
4886    let db = session.db().await;
4887
4888    let flags = db.get_user_flags(session.user_id()).await?;
4889    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4890    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4891
4892    if !session.is_staff() && !has_language_models_feature_flag {
4893        Err(anyhow!("permission denied"))?
4894    }
4895
4896    let user_id = session.user_id();
4897    let user = db
4898        .get_user_by_id(user_id)
4899        .await?
4900        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4901
4902    if user.accepted_tos_at.is_none() {
4903        Err(anyhow!("terms of service not accepted"))?
4904    }
4905
4906    let mut account_created_at = user.created_at;
4907    if let Some(github_created_at) = user.github_user_created_at {
4908        account_created_at = account_created_at.min(github_created_at);
4909    }
4910    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4911        Err(anyhow!("account too young"))?
4912    }
4913    let token = LlmTokenClaims::create(
4914        user.id,
4915        user.github_login.clone(),
4916        session.is_staff(),
4917        has_llm_closed_beta_feature_flag,
4918        session.current_plan(db).await?,
4919        &session.app_state.config,
4920    )?;
4921    response.send(proto::GetLlmTokenResponse { token })?;
4922    Ok(())
4923}
4924
4925fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4926    let message = match message {
4927        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4928        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4929        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4930        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4931        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4932            code: frame.code.into(),
4933            reason: frame.reason,
4934        })),
4935        // We should never receive a frame while reading the message, according
4936        // to the `tungstenite` maintainers:
4937        //
4938        // > It cannot occur when you read messages from the WebSocket, but it
4939        // > can be used when you want to send the raw frames (e.g. you want to
4940        // > send the frames to the WebSocket without composing the full message first).
4941        // >
4942        // > — https://github.com/snapview/tungstenite-rs/issues/268
4943        TungsteniteMessage::Frame(_) => {
4944            bail!("received an unexpected frame while reading the message")
4945        }
4946    };
4947
4948    Ok(message)
4949}
4950
4951fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4952    match message {
4953        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4954        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4955        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4956        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4957        AxumMessage::Close(frame) => {
4958            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4959                code: frame.code.into(),
4960                reason: frame.reason,
4961            }))
4962        }
4963    }
4964}
4965
4966fn notify_membership_updated(
4967    connection_pool: &mut ConnectionPool,
4968    result: MembershipUpdated,
4969    user_id: UserId,
4970    peer: &Peer,
4971) {
4972    for membership in &result.new_channels.channel_memberships {
4973        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4974    }
4975    for channel_id in &result.removed_channels {
4976        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4977    }
4978
4979    let user_channels_update = proto::UpdateUserChannels {
4980        channel_memberships: result
4981            .new_channels
4982            .channel_memberships
4983            .iter()
4984            .map(|cm| proto::ChannelMembership {
4985                channel_id: cm.channel_id.to_proto(),
4986                role: cm.role.into(),
4987            })
4988            .collect(),
4989        ..Default::default()
4990    };
4991
4992    let mut update = build_channels_update(result.new_channels);
4993    update.delete_channels = result
4994        .removed_channels
4995        .into_iter()
4996        .map(|id| id.to_proto())
4997        .collect();
4998    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4999
5000    for connection_id in connection_pool.user_connection_ids(user_id) {
5001        peer.send(connection_id, user_channels_update.clone())
5002            .trace_err();
5003        peer.send(connection_id, update.clone()).trace_err();
5004    }
5005}
5006
5007fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5008    proto::UpdateUserChannels {
5009        channel_memberships: channels
5010            .channel_memberships
5011            .iter()
5012            .map(|m| proto::ChannelMembership {
5013                channel_id: m.channel_id.to_proto(),
5014                role: m.role.into(),
5015            })
5016            .collect(),
5017        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5018        observed_channel_message_id: channels.observed_channel_messages.clone(),
5019    }
5020}
5021
5022fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5023    let mut update = proto::UpdateChannels::default();
5024
5025    for channel in channels.channels {
5026        update.channels.push(channel.to_proto());
5027    }
5028
5029    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5030    update.latest_channel_message_ids = channels.latest_channel_messages;
5031
5032    for (channel_id, participants) in channels.channel_participants {
5033        update
5034            .channel_participants
5035            .push(proto::ChannelParticipants {
5036                channel_id: channel_id.to_proto(),
5037                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5038            });
5039    }
5040
5041    for channel in channels.invited_channels {
5042        update.channel_invitations.push(channel.to_proto());
5043    }
5044
5045    update.hosted_projects = channels.hosted_projects;
5046    update
5047}
5048
5049fn build_initial_contacts_update(
5050    contacts: Vec<db::Contact>,
5051    pool: &ConnectionPool,
5052) -> proto::UpdateContacts {
5053    let mut update = proto::UpdateContacts::default();
5054
5055    for contact in contacts {
5056        match contact {
5057            db::Contact::Accepted { user_id, busy } => {
5058                update.contacts.push(contact_for_user(user_id, busy, pool));
5059            }
5060            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5061            db::Contact::Incoming { user_id } => {
5062                update
5063                    .incoming_requests
5064                    .push(proto::IncomingContactRequest {
5065                        requester_id: user_id.to_proto(),
5066                    })
5067            }
5068        }
5069    }
5070
5071    update
5072}
5073
5074fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5075    proto::Contact {
5076        user_id: user_id.to_proto(),
5077        online: pool.is_user_online(user_id),
5078        busy,
5079    }
5080}
5081
5082fn room_updated(room: &proto::Room, peer: &Peer) {
5083    broadcast(
5084        None,
5085        room.participants
5086            .iter()
5087            .filter_map(|participant| Some(participant.peer_id?.into())),
5088        |peer_id| {
5089            peer.send(
5090                peer_id,
5091                proto::RoomUpdated {
5092                    room: Some(room.clone()),
5093                },
5094            )
5095        },
5096    );
5097}
5098
5099fn channel_updated(
5100    channel: &db::channel::Model,
5101    room: &proto::Room,
5102    peer: &Peer,
5103    pool: &ConnectionPool,
5104) {
5105    let participants = room
5106        .participants
5107        .iter()
5108        .map(|p| p.user_id)
5109        .collect::<Vec<_>>();
5110
5111    broadcast(
5112        None,
5113        pool.channel_connection_ids(channel.root_id())
5114            .filter_map(|(channel_id, role)| {
5115                role.can_see_channel(channel.visibility)
5116                    .then_some(channel_id)
5117            }),
5118        |peer_id| {
5119            peer.send(
5120                peer_id,
5121                proto::UpdateChannels {
5122                    channel_participants: vec![proto::ChannelParticipants {
5123                        channel_id: channel.id.to_proto(),
5124                        participant_user_ids: participants.clone(),
5125                    }],
5126                    ..Default::default()
5127                },
5128            )
5129        },
5130    );
5131}
5132
5133async fn send_dev_server_projects_update(
5134    user_id: UserId,
5135    mut status: proto::DevServerProjectsUpdate,
5136    session: &Session,
5137) {
5138    let pool = session.connection_pool().await;
5139    for dev_server in &mut status.dev_servers {
5140        dev_server.status =
5141            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5142    }
5143    let connections = pool.user_connection_ids(user_id);
5144    for connection_id in connections {
5145        session.peer.send(connection_id, status.clone()).trace_err();
5146    }
5147}
5148
5149async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5150    let db = session.db().await;
5151
5152    let contacts = db.get_contacts(user_id).await?;
5153    let busy = db.is_user_busy(user_id).await?;
5154
5155    let pool = session.connection_pool().await;
5156    let updated_contact = contact_for_user(user_id, busy, &pool);
5157    for contact in contacts {
5158        if let db::Contact::Accepted {
5159            user_id: contact_user_id,
5160            ..
5161        } = contact
5162        {
5163            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5164                session
5165                    .peer
5166                    .send(
5167                        contact_conn_id,
5168                        proto::UpdateContacts {
5169                            contacts: vec![updated_contact.clone()],
5170                            remove_contacts: Default::default(),
5171                            incoming_requests: Default::default(),
5172                            remove_incoming_requests: Default::default(),
5173                            outgoing_requests: Default::default(),
5174                            remove_outgoing_requests: Default::default(),
5175                        },
5176                    )
5177                    .trace_err();
5178            }
5179        }
5180    }
5181    Ok(())
5182}
5183
5184async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5185    log::info!("lost dev server connection, unsharing projects");
5186    let project_ids = session
5187        .db()
5188        .await
5189        .get_stale_dev_server_projects(session.connection_id)
5190        .await?;
5191
5192    for project_id in project_ids {
5193        // not unshare re-checks the connection ids match, so we get away with no transaction
5194        unshare_project_internal(project_id, session.connection_id, None, session).await?;
5195    }
5196
5197    let user_id = session.dev_server().user_id;
5198    let update = session
5199        .db()
5200        .await
5201        .dev_server_projects_update(user_id)
5202        .await?;
5203
5204    send_dev_server_projects_update(user_id, update, session).await;
5205
5206    Ok(())
5207}
5208
5209async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5210    let mut contacts_to_update = HashSet::default();
5211
5212    let room_id;
5213    let canceled_calls_to_user_ids;
5214    let live_kit_room;
5215    let delete_live_kit_room;
5216    let room;
5217    let channel;
5218
5219    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5220        contacts_to_update.insert(session.user_id());
5221
5222        for project in left_room.left_projects.values() {
5223            project_left(project, session);
5224        }
5225
5226        room_id = RoomId::from_proto(left_room.room.id);
5227        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5228        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5229        delete_live_kit_room = left_room.deleted;
5230        room = mem::take(&mut left_room.room);
5231        channel = mem::take(&mut left_room.channel);
5232
5233        room_updated(&room, &session.peer);
5234    } else {
5235        return Ok(());
5236    }
5237
5238    if let Some(channel) = channel {
5239        channel_updated(
5240            &channel,
5241            &room,
5242            &session.peer,
5243            &*session.connection_pool().await,
5244        );
5245    }
5246
5247    {
5248        let pool = session.connection_pool().await;
5249        for canceled_user_id in canceled_calls_to_user_ids {
5250            for connection_id in pool.user_connection_ids(canceled_user_id) {
5251                session
5252                    .peer
5253                    .send(
5254                        connection_id,
5255                        proto::CallCanceled {
5256                            room_id: room_id.to_proto(),
5257                        },
5258                    )
5259                    .trace_err();
5260            }
5261            contacts_to_update.insert(canceled_user_id);
5262        }
5263    }
5264
5265    for contact_user_id in contacts_to_update {
5266        update_user_contacts(contact_user_id, session).await?;
5267    }
5268
5269    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5270        live_kit
5271            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5272            .await
5273            .trace_err();
5274
5275        if delete_live_kit_room {
5276            live_kit.delete_room(live_kit_room).await.trace_err();
5277        }
5278    }
5279
5280    Ok(())
5281}
5282
5283async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5284    let left_channel_buffers = session
5285        .db()
5286        .await
5287        .leave_channel_buffers(session.connection_id)
5288        .await?;
5289
5290    for left_buffer in left_channel_buffers {
5291        channel_buffer_updated(
5292            session.connection_id,
5293            left_buffer.connections,
5294            &proto::UpdateChannelBufferCollaborators {
5295                channel_id: left_buffer.channel_id.to_proto(),
5296                collaborators: left_buffer.collaborators,
5297            },
5298            &session.peer,
5299        );
5300    }
5301
5302    Ok(())
5303}
5304
5305fn project_left(project: &db::LeftProject, session: &UserSession) {
5306    for connection_id in &project.connection_ids {
5307        if project.should_unshare {
5308            session
5309                .peer
5310                .send(
5311                    *connection_id,
5312                    proto::UnshareProject {
5313                        project_id: project.id.to_proto(),
5314                    },
5315                )
5316                .trace_err();
5317        } else {
5318            session
5319                .peer
5320                .send(
5321                    *connection_id,
5322                    proto::RemoveProjectCollaborator {
5323                        project_id: project.id.to_proto(),
5324                        peer_id: Some(session.connection_id.into()),
5325                    },
5326                )
5327                .trace_err();
5328        }
5329    }
5330}
5331
5332pub trait ResultExt {
5333    type Ok;
5334
5335    fn trace_err(self) -> Option<Self::Ok>;
5336}
5337
5338impl<T, E> ResultExt for Result<T, E>
5339where
5340    E: std::fmt::Debug,
5341{
5342    type Ok = T;
5343
5344    #[track_caller]
5345    fn trace_err(self) -> Option<T> {
5346        match self {
5347            Ok(value) => Some(value),
5348            Err(error) => {
5349                tracing::error!("{:?}", error);
5350                None
5351            }
5352        }
5353    }
5354}