rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use http_client::HttpClient;
  39use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  40use reqwest_client::ReqwestClient;
  41use sha2::Digest;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    channel::oneshot,
  46    future::{self, BoxFuture},
  47    stream::FuturesUnordered,
  48    FutureExt, SinkExt, StreamExt, TryStreamExt,
  49};
  50use prometheus::{register_int_gauge, IntGauge};
  51use rpc::{
  52    proto::{
  53        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  54        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  55    },
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57};
  58use semantic_version::SemanticVersion;
  59use serde::{Serialize, Serializer};
  60use std::{
  61    any::TypeId,
  62    future::Future,
  63    marker::PhantomData,
  64    mem,
  65    net::SocketAddr,
  66    ops::{Deref, DerefMut},
  67    rc::Rc,
  68    sync::{
  69        atomic::{AtomicBool, Ordering::SeqCst},
  70        Arc, OnceLock,
  71    },
  72    time::{Duration, Instant},
  73};
  74use time::OffsetDateTime;
  75use tokio::sync::{watch, MutexGuard, Semaphore};
  76use tower::ServiceBuilder;
  77use tracing::{
  78    field::{self},
  79    info_span, instrument, Instrument,
  80};
  81
  82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  83
  84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  86
  87const MESSAGE_COUNT_PER_PAGE: usize = 100;
  88const MAX_MESSAGE_LEN: usize = 1024;
  89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  93
  94struct Response<R> {
  95    peer: Arc<Peer>,
  96    receipt: Receipt<R>,
  97    responded: Arc<AtomicBool>,
  98}
  99
 100impl<R: RequestMessage> Response<R> {
 101    fn send(self, payload: R::Response) -> Result<()> {
 102        self.responded.store(true, SeqCst);
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108#[derive(Clone, Debug)]
 109pub enum Principal {
 110    User(User),
 111    Impersonated { user: User, admin: User },
 112    DevServer(dev_server::Model),
 113}
 114
 115impl Principal {
 116    fn update_span(&self, span: &tracing::Span) {
 117        match &self {
 118            Principal::User(user) => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121            }
 122            Principal::Impersonated { user, admin } => {
 123                span.record("user_id", user.id.0);
 124                span.record("login", &user.github_login);
 125                span.record("impersonator", &admin.github_login);
 126            }
 127            Principal::DevServer(dev_server) => {
 128                span.record("dev_server_id", dev_server.id.0);
 129            }
 130        }
 131    }
 132}
 133
 134#[derive(Clone)]
 135struct Session {
 136    principal: Principal,
 137    connection_id: ConnectionId,
 138    db: Arc<tokio::sync::Mutex<DbHandle>>,
 139    peer: Arc<Peer>,
 140    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 141    app_state: Arc<AppState>,
 142    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 143    http_client: Arc<dyn HttpClient>,
 144    /// The GeoIP country code for the user.
 145    #[allow(unused)]
 146    geoip_country_code: Option<String>,
 147    _executor: Executor,
 148}
 149
 150impl Session {
 151    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.db.lock().await;
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        guard
 158    }
 159
 160    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        let guard = self.connection_pool.lock();
 164        ConnectionPoolGuard {
 165            guard,
 166            _not_send: PhantomData,
 167        }
 168    }
 169
 170    fn for_user(self) -> Option<UserSession> {
 171        UserSession::new(self)
 172    }
 173
 174    fn for_dev_server(self) -> Option<DevServerSession> {
 175        DevServerSession::new(self)
 176    }
 177
 178    fn user_id(&self) -> Option<UserId> {
 179        match &self.principal {
 180            Principal::User(user) => Some(user.id),
 181            Principal::Impersonated { user, .. } => Some(user.id),
 182            Principal::DevServer(_) => None,
 183        }
 184    }
 185
 186    fn is_staff(&self) -> bool {
 187        match &self.principal {
 188            Principal::User(user) => user.admin,
 189            Principal::Impersonated { .. } => true,
 190            Principal::DevServer(_) => false,
 191        }
 192    }
 193
 194    pub async fn has_llm_subscription(
 195        &self,
 196        db: &MutexGuard<'_, DbHandle>,
 197    ) -> anyhow::Result<bool> {
 198        if self.is_staff() {
 199            return Ok(true);
 200        }
 201
 202        let Some(user_id) = self.user_id() else {
 203            return Ok(false);
 204        };
 205
 206        Ok(db.has_active_billing_subscription(user_id).await?)
 207    }
 208
 209    pub async fn current_plan(
 210        &self,
 211        _db: &MutexGuard<'_, DbHandle>,
 212    ) -> anyhow::Result<proto::Plan> {
 213        if self.is_staff() {
 214            Ok(proto::Plan::ZedPro)
 215        } else {
 216            Ok(proto::Plan::Free)
 217        }
 218    }
 219
 220    fn dev_server_id(&self) -> Option<DevServerId> {
 221        match &self.principal {
 222            Principal::User(_) | Principal::Impersonated { .. } => None,
 223            Principal::DevServer(dev_server) => Some(dev_server.id),
 224        }
 225    }
 226
 227    fn principal_id(&self) -> PrincipalId {
 228        match &self.principal {
 229            Principal::User(user) => PrincipalId::UserId(user.id),
 230            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 231            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 232        }
 233    }
 234}
 235
 236impl Debug for Session {
 237    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 238        let mut result = f.debug_struct("Session");
 239        match &self.principal {
 240            Principal::User(user) => {
 241                result.field("user", &user.github_login);
 242            }
 243            Principal::Impersonated { user, admin } => {
 244                result.field("user", &user.github_login);
 245                result.field("impersonator", &admin.github_login);
 246            }
 247            Principal::DevServer(dev_server) => {
 248                result.field("dev_server", &dev_server.id);
 249            }
 250        }
 251        result.field("connection_id", &self.connection_id).finish()
 252    }
 253}
 254
 255struct UserSession(Session);
 256
 257impl UserSession {
 258    pub fn new(s: Session) -> Option<Self> {
 259        s.user_id().map(|_| UserSession(s))
 260    }
 261    pub fn user_id(&self) -> UserId {
 262        self.0.user_id().unwrap()
 263    }
 264
 265    pub fn email(&self) -> Option<String> {
 266        match &self.0.principal {
 267            Principal::User(user) => user.email_address.clone(),
 268            Principal::Impersonated { user, .. } => user.email_address.clone(),
 269            Principal::DevServer(..) => None,
 270        }
 271    }
 272}
 273
 274impl Deref for UserSession {
 275    type Target = Session;
 276
 277    fn deref(&self) -> &Self::Target {
 278        &self.0
 279    }
 280}
 281impl DerefMut for UserSession {
 282    fn deref_mut(&mut self) -> &mut Self::Target {
 283        &mut self.0
 284    }
 285}
 286
 287struct DevServerSession(Session);
 288
 289impl DevServerSession {
 290    pub fn new(s: Session) -> Option<Self> {
 291        s.dev_server_id().map(|_| DevServerSession(s))
 292    }
 293    pub fn dev_server_id(&self) -> DevServerId {
 294        self.0.dev_server_id().unwrap()
 295    }
 296
 297    fn dev_server(&self) -> &dev_server::Model {
 298        match &self.0.principal {
 299            Principal::DevServer(dev_server) => dev_server,
 300            _ => unreachable!(),
 301        }
 302    }
 303}
 304
 305impl Deref for DevServerSession {
 306    type Target = Session;
 307
 308    fn deref(&self) -> &Self::Target {
 309        &self.0
 310    }
 311}
 312impl DerefMut for DevServerSession {
 313    fn deref_mut(&mut self) -> &mut Self::Target {
 314        &mut self.0
 315    }
 316}
 317
 318fn user_handler<M: RequestMessage, Fut>(
 319    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 320) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 321where
 322    Fut: Send + Future<Output = Result<()>>,
 323{
 324    let handler = Arc::new(handler);
 325    move |message, response, session| {
 326        let handler = handler.clone();
 327        Box::pin(async move {
 328            if let Some(user_session) = session.for_user() {
 329                Ok(handler(message, response, user_session).await?)
 330            } else {
 331                Err(Error::Internal(anyhow!(
 332                    "must be a user to call {}",
 333                    M::NAME
 334                )))
 335            }
 336        })
 337    }
 338}
 339
 340fn dev_server_handler<M: RequestMessage, Fut>(
 341    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 342) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 343where
 344    Fut: Send + Future<Output = Result<()>>,
 345{
 346    let handler = Arc::new(handler);
 347    move |message, response, session| {
 348        let handler = handler.clone();
 349        Box::pin(async move {
 350            if let Some(dev_server_session) = session.for_dev_server() {
 351                Ok(handler(message, response, dev_server_session).await?)
 352            } else {
 353                Err(Error::Internal(anyhow!(
 354                    "must be a dev server to call {}",
 355                    M::NAME
 356                )))
 357            }
 358        })
 359    }
 360}
 361
 362fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 363    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 364) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 365where
 366    InnertRetFut: Send + Future<Output = Result<()>>,
 367{
 368    let handler = Arc::new(handler);
 369    move |message, session| {
 370        let handler = handler.clone();
 371        Box::pin(async move {
 372            if let Some(user_session) = session.for_user() {
 373                Ok(handler(message, user_session).await?)
 374            } else {
 375                Err(Error::Internal(anyhow!(
 376                    "must be a user to call {}",
 377                    M::NAME
 378                )))
 379            }
 380        })
 381    }
 382}
 383
 384struct DbHandle(Arc<Database>);
 385
 386impl Deref for DbHandle {
 387    type Target = Database;
 388
 389    fn deref(&self) -> &Self::Target {
 390        self.0.as_ref()
 391    }
 392}
 393
 394pub struct Server {
 395    id: parking_lot::Mutex<ServerId>,
 396    peer: Arc<Peer>,
 397    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 398    app_state: Arc<AppState>,
 399    handlers: HashMap<TypeId, MessageHandler>,
 400    teardown: watch::Sender<bool>,
 401}
 402
 403pub(crate) struct ConnectionPoolGuard<'a> {
 404    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 405    _not_send: PhantomData<Rc<()>>,
 406}
 407
 408#[derive(Serialize)]
 409pub struct ServerSnapshot<'a> {
 410    peer: &'a Peer,
 411    #[serde(serialize_with = "serialize_deref")]
 412    connection_pool: ConnectionPoolGuard<'a>,
 413}
 414
 415pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 416where
 417    S: Serializer,
 418    T: Deref<Target = U>,
 419    U: Serialize,
 420{
 421    Serialize::serialize(value.deref(), serializer)
 422}
 423
 424impl Server {
 425    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 426        let mut server = Self {
 427            id: parking_lot::Mutex::new(id),
 428            peer: Peer::new(id.0 as u32),
 429            app_state: app_state.clone(),
 430            connection_pool: Default::default(),
 431            handlers: Default::default(),
 432            teardown: watch::channel(false).0,
 433        };
 434
 435        server
 436            .add_request_handler(ping)
 437            .add_request_handler(user_handler(create_room))
 438            .add_request_handler(user_handler(join_room))
 439            .add_request_handler(user_handler(rejoin_room))
 440            .add_request_handler(user_handler(leave_room))
 441            .add_request_handler(user_handler(set_room_participant_role))
 442            .add_request_handler(user_handler(call))
 443            .add_request_handler(user_handler(cancel_call))
 444            .add_message_handler(user_message_handler(decline_call))
 445            .add_request_handler(user_handler(update_participant_location))
 446            .add_request_handler(user_handler(share_project))
 447            .add_message_handler(unshare_project)
 448            .add_request_handler(user_handler(join_project))
 449            .add_request_handler(user_handler(join_hosted_project))
 450            .add_request_handler(user_handler(rejoin_dev_server_projects))
 451            .add_request_handler(user_handler(create_dev_server_project))
 452            .add_request_handler(user_handler(update_dev_server_project))
 453            .add_request_handler(user_handler(delete_dev_server_project))
 454            .add_request_handler(user_handler(create_dev_server))
 455            .add_request_handler(user_handler(regenerate_dev_server_token))
 456            .add_request_handler(user_handler(rename_dev_server))
 457            .add_request_handler(user_handler(delete_dev_server))
 458            .add_request_handler(user_handler(list_remote_directory))
 459            .add_request_handler(dev_server_handler(share_dev_server_project))
 460            .add_request_handler(dev_server_handler(shutdown_dev_server))
 461            .add_request_handler(dev_server_handler(reconnect_dev_server))
 462            .add_message_handler(user_message_handler(leave_project))
 463            .add_request_handler(update_project)
 464            .add_request_handler(update_worktree)
 465            .add_message_handler(start_language_server)
 466            .add_message_handler(update_language_server)
 467            .add_message_handler(update_diagnostic_summary)
 468            .add_message_handler(update_worktree_settings)
 469            .add_request_handler(user_handler(
 470                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetHover>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetDefinition>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::GetTypeDefinition>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetReferences>,
 483            ))
 484            .add_request_handler(user_handler(forward_find_search_candidates_request))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::GetProjectSymbols>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_read_only_project_request::<proto::OpenBufferById>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_read_only_project_request::<proto::InlayHints>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_read_only_project_request::<proto::ResolveInlayHint>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_read_only_project_request::<proto::OpenBufferByPath>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::GetCompletions>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::OpenNewBuffer>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::GetCodeActions>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::ApplyCodeAction>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::PrepareRename>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::PerformRename>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::ReloadBuffers>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::FormatBuffers>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_mutating_project_request::<proto::CreateProjectEntry>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::RenameProjectEntry>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::CopyProjectEntry>,
 547            ))
 548            .add_request_handler(user_handler(
 549                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 550            ))
 551            .add_request_handler(user_handler(
 552                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 553            ))
 554            .add_request_handler(user_handler(
 555                forward_mutating_project_request::<proto::OnTypeFormatting>,
 556            ))
 557            .add_request_handler(user_handler(
 558                forward_mutating_project_request::<proto::SaveBuffer>,
 559            ))
 560            .add_request_handler(user_handler(
 561                forward_mutating_project_request::<proto::BlameBuffer>,
 562            ))
 563            .add_request_handler(user_handler(
 564                forward_mutating_project_request::<proto::MultiLspQuery>,
 565            ))
 566            .add_request_handler(user_handler(
 567                forward_mutating_project_request::<proto::RestartLanguageServers>,
 568            ))
 569            .add_request_handler(user_handler(
 570                forward_mutating_project_request::<proto::LinkedEditingRange>,
 571            ))
 572            .add_message_handler(create_buffer_for_peer)
 573            .add_request_handler(update_buffer)
 574            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 575            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 576            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 577            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 578            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 579            .add_request_handler(get_users)
 580            .add_request_handler(user_handler(fuzzy_search_users))
 581            .add_request_handler(user_handler(request_contact))
 582            .add_request_handler(user_handler(remove_contact))
 583            .add_request_handler(user_handler(respond_to_contact_request))
 584            .add_message_handler(subscribe_to_channels)
 585            .add_request_handler(user_handler(create_channel))
 586            .add_request_handler(user_handler(delete_channel))
 587            .add_request_handler(user_handler(invite_channel_member))
 588            .add_request_handler(user_handler(remove_channel_member))
 589            .add_request_handler(user_handler(set_channel_member_role))
 590            .add_request_handler(user_handler(set_channel_visibility))
 591            .add_request_handler(user_handler(rename_channel))
 592            .add_request_handler(user_handler(join_channel_buffer))
 593            .add_request_handler(user_handler(leave_channel_buffer))
 594            .add_message_handler(user_message_handler(update_channel_buffer))
 595            .add_request_handler(user_handler(rejoin_channel_buffers))
 596            .add_request_handler(user_handler(get_channel_members))
 597            .add_request_handler(user_handler(respond_to_channel_invite))
 598            .add_request_handler(user_handler(join_channel))
 599            .add_request_handler(user_handler(join_channel_chat))
 600            .add_message_handler(user_message_handler(leave_channel_chat))
 601            .add_request_handler(user_handler(send_channel_message))
 602            .add_request_handler(user_handler(remove_channel_message))
 603            .add_request_handler(user_handler(update_channel_message))
 604            .add_request_handler(user_handler(get_channel_messages))
 605            .add_request_handler(user_handler(get_channel_messages_by_id))
 606            .add_request_handler(user_handler(get_notifications))
 607            .add_request_handler(user_handler(mark_notification_as_read))
 608            .add_request_handler(user_handler(move_channel))
 609            .add_request_handler(user_handler(follow))
 610            .add_message_handler(user_message_handler(unfollow))
 611            .add_message_handler(user_message_handler(update_followers))
 612            .add_request_handler(user_handler(get_private_user_info))
 613            .add_request_handler(user_handler(get_llm_api_token))
 614            .add_request_handler(user_handler(accept_terms_of_service))
 615            .add_message_handler(user_message_handler(acknowledge_channel_message))
 616            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 617            .add_request_handler(user_handler(get_supermaven_api_key))
 618            .add_request_handler(user_handler(
 619                forward_mutating_project_request::<proto::OpenContext>,
 620            ))
 621            .add_request_handler(user_handler(
 622                forward_mutating_project_request::<proto::CreateContext>,
 623            ))
 624            .add_request_handler(user_handler(
 625                forward_mutating_project_request::<proto::SynchronizeContexts>,
 626            ))
 627            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 628            .add_message_handler(update_context)
 629            .add_request_handler({
 630                let app_state = app_state.clone();
 631                move |request, response, session| {
 632                    let app_state = app_state.clone();
 633                    async move {
 634                        count_language_model_tokens(request, response, session, &app_state.config)
 635                            .await
 636                    }
 637                }
 638            })
 639            .add_request_handler({
 640                user_handler(move |request, response, session| {
 641                    get_cached_embeddings(request, response, session)
 642                })
 643            })
 644            .add_request_handler({
 645                let app_state = app_state.clone();
 646                user_handler(move |request, response, session| {
 647                    compute_embeddings(
 648                        request,
 649                        response,
 650                        session,
 651                        app_state.config.openai_api_key.clone(),
 652                    )
 653                })
 654            });
 655
 656        Arc::new(server)
 657    }
 658
 659    pub async fn start(&self) -> Result<()> {
 660        let server_id = *self.id.lock();
 661        let app_state = self.app_state.clone();
 662        let peer = self.peer.clone();
 663        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 664        let pool = self.connection_pool.clone();
 665        let live_kit_client = self.app_state.live_kit_client.clone();
 666
 667        let span = info_span!("start server");
 668        self.app_state.executor.spawn_detached(
 669            async move {
 670                tracing::info!("waiting for cleanup timeout");
 671                timeout.await;
 672                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 673                if let Some((room_ids, channel_ids)) = app_state
 674                    .db
 675                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 676                    .await
 677                    .trace_err()
 678                {
 679                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 680                    tracing::info!(
 681                        stale_channel_buffer_count = channel_ids.len(),
 682                        "retrieved stale channel buffers"
 683                    );
 684
 685                    for channel_id in channel_ids {
 686                        if let Some(refreshed_channel_buffer) = app_state
 687                            .db
 688                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 689                            .await
 690                            .trace_err()
 691                        {
 692                            for connection_id in refreshed_channel_buffer.connection_ids {
 693                                peer.send(
 694                                    connection_id,
 695                                    proto::UpdateChannelBufferCollaborators {
 696                                        channel_id: channel_id.to_proto(),
 697                                        collaborators: refreshed_channel_buffer
 698                                            .collaborators
 699                                            .clone(),
 700                                    },
 701                                )
 702                                .trace_err();
 703                            }
 704                        }
 705                    }
 706
 707                    for room_id in room_ids {
 708                        let mut contacts_to_update = HashSet::default();
 709                        let mut canceled_calls_to_user_ids = Vec::new();
 710                        let mut live_kit_room = String::new();
 711                        let mut delete_live_kit_room = false;
 712
 713                        if let Some(mut refreshed_room) = app_state
 714                            .db
 715                            .clear_stale_room_participants(room_id, server_id)
 716                            .await
 717                            .trace_err()
 718                        {
 719                            tracing::info!(
 720                                room_id = room_id.0,
 721                                new_participant_count = refreshed_room.room.participants.len(),
 722                                "refreshed room"
 723                            );
 724                            room_updated(&refreshed_room.room, &peer);
 725                            if let Some(channel) = refreshed_room.channel.as_ref() {
 726                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 727                            }
 728                            contacts_to_update
 729                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 730                            contacts_to_update
 731                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 732                            canceled_calls_to_user_ids =
 733                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 734                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 735                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 736                        }
 737
 738                        {
 739                            let pool = pool.lock();
 740                            for canceled_user_id in canceled_calls_to_user_ids {
 741                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 742                                    peer.send(
 743                                        connection_id,
 744                                        proto::CallCanceled {
 745                                            room_id: room_id.to_proto(),
 746                                        },
 747                                    )
 748                                    .trace_err();
 749                                }
 750                            }
 751                        }
 752
 753                        for user_id in contacts_to_update {
 754                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 755                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 756                            if let Some((busy, contacts)) = busy.zip(contacts) {
 757                                let pool = pool.lock();
 758                                let updated_contact = contact_for_user(user_id, busy, &pool);
 759                                for contact in contacts {
 760                                    if let db::Contact::Accepted {
 761                                        user_id: contact_user_id,
 762                                        ..
 763                                    } = contact
 764                                    {
 765                                        for contact_conn_id in
 766                                            pool.user_connection_ids(contact_user_id)
 767                                        {
 768                                            peer.send(
 769                                                contact_conn_id,
 770                                                proto::UpdateContacts {
 771                                                    contacts: vec![updated_contact.clone()],
 772                                                    remove_contacts: Default::default(),
 773                                                    incoming_requests: Default::default(),
 774                                                    remove_incoming_requests: Default::default(),
 775                                                    outgoing_requests: Default::default(),
 776                                                    remove_outgoing_requests: Default::default(),
 777                                                },
 778                                            )
 779                                            .trace_err();
 780                                        }
 781                                    }
 782                                }
 783                            }
 784                        }
 785
 786                        if let Some(live_kit) = live_kit_client.as_ref() {
 787                            if delete_live_kit_room {
 788                                live_kit.delete_room(live_kit_room).await.trace_err();
 789                            }
 790                        }
 791                    }
 792                }
 793
 794                app_state
 795                    .db
 796                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 797                    .await
 798                    .trace_err();
 799            }
 800            .instrument(span),
 801        );
 802        Ok(())
 803    }
 804
 805    pub fn teardown(&self) {
 806        self.peer.teardown();
 807        self.connection_pool.lock().reset();
 808        let _ = self.teardown.send(true);
 809    }
 810
 811    #[cfg(test)]
 812    pub fn reset(&self, id: ServerId) {
 813        self.teardown();
 814        *self.id.lock() = id;
 815        self.peer.reset(id.0 as u32);
 816        let _ = self.teardown.send(false);
 817    }
 818
 819    #[cfg(test)]
 820    pub fn id(&self) -> ServerId {
 821        *self.id.lock()
 822    }
 823
 824    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 825    where
 826        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 827        Fut: 'static + Send + Future<Output = Result<()>>,
 828        M: EnvelopedMessage,
 829    {
 830        let prev_handler = self.handlers.insert(
 831            TypeId::of::<M>(),
 832            Box::new(move |envelope, session| {
 833                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 834                let received_at = envelope.received_at;
 835                tracing::info!("message received");
 836                let start_time = Instant::now();
 837                let future = (handler)(*envelope, session);
 838                async move {
 839                    let result = future.await;
 840                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 841                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 842                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 843                    let payload_type = M::NAME;
 844
 845                    match result {
 846                        Err(error) => {
 847                            tracing::error!(
 848                                ?error,
 849                                total_duration_ms,
 850                                processing_duration_ms,
 851                                queue_duration_ms,
 852                                payload_type,
 853                                "error handling message"
 854                            )
 855                        }
 856                        Ok(()) => tracing::info!(
 857                            total_duration_ms,
 858                            processing_duration_ms,
 859                            queue_duration_ms,
 860                            "finished handling message"
 861                        ),
 862                    }
 863                }
 864                .boxed()
 865            }),
 866        );
 867        if prev_handler.is_some() {
 868            panic!("registered a handler for the same message twice");
 869        }
 870        self
 871    }
 872
 873    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 874    where
 875        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 876        Fut: 'static + Send + Future<Output = Result<()>>,
 877        M: EnvelopedMessage,
 878    {
 879        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 880        self
 881    }
 882
 883    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 884    where
 885        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 886        Fut: Send + Future<Output = Result<()>>,
 887        M: RequestMessage,
 888    {
 889        let handler = Arc::new(handler);
 890        self.add_handler(move |envelope, session| {
 891            let receipt = envelope.receipt();
 892            let handler = handler.clone();
 893            async move {
 894                let peer = session.peer.clone();
 895                let responded = Arc::new(AtomicBool::default());
 896                let response = Response {
 897                    peer: peer.clone(),
 898                    responded: responded.clone(),
 899                    receipt,
 900                };
 901                match (handler)(envelope.payload, response, session).await {
 902                    Ok(()) => {
 903                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 904                            Ok(())
 905                        } else {
 906                            Err(anyhow!("handler did not send a response"))?
 907                        }
 908                    }
 909                    Err(error) => {
 910                        let proto_err = match &error {
 911                            Error::Internal(err) => err.to_proto(),
 912                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 913                        };
 914                        peer.respond_with_error(receipt, proto_err)?;
 915                        Err(error)
 916                    }
 917                }
 918            }
 919        })
 920    }
 921
 922    #[allow(clippy::too_many_arguments)]
 923    pub fn handle_connection(
 924        self: &Arc<Self>,
 925        connection: Connection,
 926        address: String,
 927        principal: Principal,
 928        zed_version: ZedVersion,
 929        geoip_country_code: Option<String>,
 930        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 931        executor: Executor,
 932    ) -> impl Future<Output = ()> {
 933        let this = self.clone();
 934        let span = info_span!("handle connection", %address,
 935            connection_id=field::Empty,
 936            user_id=field::Empty,
 937            login=field::Empty,
 938            impersonator=field::Empty,
 939            dev_server_id=field::Empty,
 940            geoip_country_code=field::Empty
 941        );
 942        principal.update_span(&span);
 943        if let Some(country_code) = geoip_country_code.as_ref() {
 944            span.record("geoip_country_code", country_code);
 945        }
 946
 947        let mut teardown = self.teardown.subscribe();
 948        async move {
 949            if *teardown.borrow() {
 950                tracing::error!("server is tearing down");
 951                return
 952            }
 953            let (connection_id, handle_io, mut incoming_rx) = this
 954                .peer
 955                .add_connection(connection, {
 956                    let executor = executor.clone();
 957                    move |duration| executor.sleep(duration)
 958                });
 959            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 960
 961            tracing::info!("connection opened");
 962
 963            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 964            let http_client = match ReqwestClient::user_agent(&user_agent) {
 965                Ok(http_client) => Arc::new(http_client),
 966                Err(error) => {
 967                    tracing::error!(?error, "failed to create HTTP client");
 968                    return;
 969                }
 970            };
 971
 972            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 973                    supermaven_admin_api_key.to_string(),
 974                    http_client.clone(),
 975                )));
 976
 977            let session = Session {
 978                principal: principal.clone(),
 979                connection_id,
 980                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 981                peer: this.peer.clone(),
 982                connection_pool: this.connection_pool.clone(),
 983                app_state: this.app_state.clone(),
 984                http_client,
 985                geoip_country_code,
 986                _executor: executor.clone(),
 987                supermaven_client,
 988            };
 989
 990            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 991                tracing::error!(?error, "failed to send initial client update");
 992                return;
 993            }
 994
 995            let handle_io = handle_io.fuse();
 996            futures::pin_mut!(handle_io);
 997
 998            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 999            // This prevents deadlocks when e.g., client A performs a request to client B and
1000            // client B performs a request to client A. If both clients stop processing further
1001            // messages until their respective request completes, they won't have a chance to
1002            // respond to the other client's request and cause a deadlock.
1003            //
1004            // This arrangement ensures we will attempt to process earlier messages first, but fall
1005            // back to processing messages arrived later in the spirit of making progress.
1006            let mut foreground_message_handlers = FuturesUnordered::new();
1007            let concurrent_handlers = Arc::new(Semaphore::new(256));
1008            loop {
1009                let next_message = async {
1010                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1011                    let message = incoming_rx.next().await;
1012                    (permit, message)
1013                }.fuse();
1014                futures::pin_mut!(next_message);
1015                futures::select_biased! {
1016                    _ = teardown.changed().fuse() => return,
1017                    result = handle_io => {
1018                        if let Err(error) = result {
1019                            tracing::error!(?error, "error handling I/O");
1020                        }
1021                        break;
1022                    }
1023                    _ = foreground_message_handlers.next() => {}
1024                    next_message = next_message => {
1025                        let (permit, message) = next_message;
1026                        if let Some(message) = message {
1027                            let type_name = message.payload_type_name();
1028                            // note: we copy all the fields from the parent span so we can query them in the logs.
1029                            // (https://github.com/tokio-rs/tracing/issues/2670).
1030                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1031                                user_id=field::Empty,
1032                                login=field::Empty,
1033                                impersonator=field::Empty,
1034                                dev_server_id=field::Empty
1035                            );
1036                            principal.update_span(&span);
1037                            let span_enter = span.enter();
1038                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1039                                let is_background = message.is_background();
1040                                let handle_message = (handler)(message, session.clone());
1041                                drop(span_enter);
1042
1043                                let handle_message = async move {
1044                                    handle_message.await;
1045                                    drop(permit);
1046                                }.instrument(span);
1047                                if is_background {
1048                                    executor.spawn_detached(handle_message);
1049                                } else {
1050                                    foreground_message_handlers.push(handle_message);
1051                                }
1052                            } else {
1053                                tracing::error!("no message handler");
1054                            }
1055                        } else {
1056                            tracing::info!("connection closed");
1057                            break;
1058                        }
1059                    }
1060                }
1061            }
1062
1063            drop(foreground_message_handlers);
1064            tracing::info!("signing out");
1065            if let Err(error) = connection_lost(session, teardown, executor).await {
1066                tracing::error!(?error, "error signing out");
1067            }
1068
1069        }.instrument(span)
1070    }
1071
1072    async fn send_initial_client_update(
1073        &self,
1074        connection_id: ConnectionId,
1075        principal: &Principal,
1076        zed_version: ZedVersion,
1077        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1078        session: &Session,
1079    ) -> Result<()> {
1080        self.peer.send(
1081            connection_id,
1082            proto::Hello {
1083                peer_id: Some(connection_id.into()),
1084            },
1085        )?;
1086        tracing::info!("sent hello message");
1087        if let Some(send_connection_id) = send_connection_id.take() {
1088            let _ = send_connection_id.send(connection_id);
1089        }
1090
1091        match principal {
1092            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1093                if !user.connected_once {
1094                    self.peer.send(connection_id, proto::ShowContacts {})?;
1095                    self.app_state
1096                        .db
1097                        .set_user_connected_once(user.id, true)
1098                        .await?;
1099                }
1100
1101                update_user_plan(user.id, session).await?;
1102
1103                let (contacts, dev_server_projects) = future::try_join(
1104                    self.app_state.db.get_contacts(user.id),
1105                    self.app_state.db.dev_server_projects_update(user.id),
1106                )
1107                .await?;
1108
1109                {
1110                    let mut pool = self.connection_pool.lock();
1111                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1112                    self.peer.send(
1113                        connection_id,
1114                        build_initial_contacts_update(contacts, &pool),
1115                    )?;
1116                }
1117
1118                if should_auto_subscribe_to_channels(zed_version) {
1119                    subscribe_user_to_channels(user.id, session).await?;
1120                }
1121
1122                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1123
1124                if let Some(incoming_call) =
1125                    self.app_state.db.incoming_call_for_user(user.id).await?
1126                {
1127                    self.peer.send(connection_id, incoming_call)?;
1128                }
1129
1130                update_user_contacts(user.id, session).await?;
1131            }
1132            Principal::DevServer(dev_server) => {
1133                {
1134                    let mut pool = self.connection_pool.lock();
1135                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1136                    {
1137                        self.peer.send(
1138                            stale_connection_id,
1139                            proto::ShutdownDevServer {
1140                                reason: Some(
1141                                    "another dev server connected with the same token".to_string(),
1142                                ),
1143                            },
1144                        )?;
1145                        pool.remove_connection(stale_connection_id)?;
1146                    };
1147                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1148                }
1149
1150                let projects = self
1151                    .app_state
1152                    .db
1153                    .get_projects_for_dev_server(dev_server.id)
1154                    .await?;
1155                self.peer
1156                    .send(connection_id, proto::DevServerInstructions { projects })?;
1157
1158                let status = self
1159                    .app_state
1160                    .db
1161                    .dev_server_projects_update(dev_server.user_id)
1162                    .await?;
1163                send_dev_server_projects_update(dev_server.user_id, status, session).await;
1164            }
1165        }
1166
1167        Ok(())
1168    }
1169
1170    pub async fn invite_code_redeemed(
1171        self: &Arc<Self>,
1172        inviter_id: UserId,
1173        invitee_id: UserId,
1174    ) -> Result<()> {
1175        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1176            if let Some(code) = &user.invite_code {
1177                let pool = self.connection_pool.lock();
1178                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1179                for connection_id in pool.user_connection_ids(inviter_id) {
1180                    self.peer.send(
1181                        connection_id,
1182                        proto::UpdateContacts {
1183                            contacts: vec![invitee_contact.clone()],
1184                            ..Default::default()
1185                        },
1186                    )?;
1187                    self.peer.send(
1188                        connection_id,
1189                        proto::UpdateInviteInfo {
1190                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1191                            count: user.invite_count as u32,
1192                        },
1193                    )?;
1194                }
1195            }
1196        }
1197        Ok(())
1198    }
1199
1200    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1201        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1202            if let Some(invite_code) = &user.invite_code {
1203                let pool = self.connection_pool.lock();
1204                for connection_id in pool.user_connection_ids(user_id) {
1205                    self.peer.send(
1206                        connection_id,
1207                        proto::UpdateInviteInfo {
1208                            url: format!(
1209                                "{}{}",
1210                                self.app_state.config.invite_link_prefix, invite_code
1211                            ),
1212                            count: user.invite_count as u32,
1213                        },
1214                    )?;
1215                }
1216            }
1217        }
1218        Ok(())
1219    }
1220
1221    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1222        let pool = self.connection_pool.lock();
1223        for connection_id in pool.user_connection_ids(user_id) {
1224            self.peer
1225                .send(connection_id, proto::RefreshLlmToken {})
1226                .trace_err();
1227        }
1228    }
1229
1230    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1231        ServerSnapshot {
1232            connection_pool: ConnectionPoolGuard {
1233                guard: self.connection_pool.lock(),
1234                _not_send: PhantomData,
1235            },
1236            peer: &self.peer,
1237        }
1238    }
1239}
1240
1241impl<'a> Deref for ConnectionPoolGuard<'a> {
1242    type Target = ConnectionPool;
1243
1244    fn deref(&self) -> &Self::Target {
1245        &self.guard
1246    }
1247}
1248
1249impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1250    fn deref_mut(&mut self) -> &mut Self::Target {
1251        &mut self.guard
1252    }
1253}
1254
1255impl<'a> Drop for ConnectionPoolGuard<'a> {
1256    fn drop(&mut self) {
1257        #[cfg(test)]
1258        self.check_invariants();
1259    }
1260}
1261
1262fn broadcast<F>(
1263    sender_id: Option<ConnectionId>,
1264    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1265    mut f: F,
1266) where
1267    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1268{
1269    for receiver_id in receiver_ids {
1270        if Some(receiver_id) != sender_id {
1271            if let Err(error) = f(receiver_id) {
1272                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1273            }
1274        }
1275    }
1276}
1277
1278pub struct ProtocolVersion(u32);
1279
1280impl Header for ProtocolVersion {
1281    fn name() -> &'static HeaderName {
1282        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1283        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1284    }
1285
1286    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1287    where
1288        Self: Sized,
1289        I: Iterator<Item = &'i axum::http::HeaderValue>,
1290    {
1291        let version = values
1292            .next()
1293            .ok_or_else(axum::headers::Error::invalid)?
1294            .to_str()
1295            .map_err(|_| axum::headers::Error::invalid())?
1296            .parse()
1297            .map_err(|_| axum::headers::Error::invalid())?;
1298        Ok(Self(version))
1299    }
1300
1301    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1302        values.extend([self.0.to_string().parse().unwrap()]);
1303    }
1304}
1305
1306pub struct AppVersionHeader(SemanticVersion);
1307impl Header for AppVersionHeader {
1308    fn name() -> &'static HeaderName {
1309        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1310        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1311    }
1312
1313    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1314    where
1315        Self: Sized,
1316        I: Iterator<Item = &'i axum::http::HeaderValue>,
1317    {
1318        let version = values
1319            .next()
1320            .ok_or_else(axum::headers::Error::invalid)?
1321            .to_str()
1322            .map_err(|_| axum::headers::Error::invalid())?
1323            .parse()
1324            .map_err(|_| axum::headers::Error::invalid())?;
1325        Ok(Self(version))
1326    }
1327
1328    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1329        values.extend([self.0.to_string().parse().unwrap()]);
1330    }
1331}
1332
1333pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1334    Router::new()
1335        .route("/rpc", get(handle_websocket_request))
1336        .layer(
1337            ServiceBuilder::new()
1338                .layer(Extension(server.app_state.clone()))
1339                .layer(middleware::from_fn(auth::validate_header)),
1340        )
1341        .route("/metrics", get(handle_metrics))
1342        .layer(Extension(server))
1343}
1344
1345pub async fn handle_websocket_request(
1346    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1347    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1348    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1349    Extension(server): Extension<Arc<Server>>,
1350    Extension(principal): Extension<Principal>,
1351    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1352    ws: WebSocketUpgrade,
1353) -> axum::response::Response {
1354    if protocol_version != rpc::PROTOCOL_VERSION {
1355        return (
1356            StatusCode::UPGRADE_REQUIRED,
1357            "client must be upgraded".to_string(),
1358        )
1359            .into_response();
1360    }
1361
1362    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1363        return (
1364            StatusCode::UPGRADE_REQUIRED,
1365            "no version header found".to_string(),
1366        )
1367            .into_response();
1368    };
1369
1370    if !version.can_collaborate() {
1371        return (
1372            StatusCode::UPGRADE_REQUIRED,
1373            "client must be upgraded".to_string(),
1374        )
1375            .into_response();
1376    }
1377
1378    let socket_address = socket_address.to_string();
1379    ws.on_upgrade(move |socket| {
1380        let socket = socket
1381            .map_ok(to_tungstenite_message)
1382            .err_into()
1383            .with(|message| async move { to_axum_message(message) });
1384        let connection = Connection::new(Box::pin(socket));
1385        async move {
1386            server
1387                .handle_connection(
1388                    connection,
1389                    socket_address,
1390                    principal,
1391                    version,
1392                    country_code_header.map(|header| header.to_string()),
1393                    None,
1394                    Executor::Production,
1395                )
1396                .await;
1397        }
1398    })
1399}
1400
1401pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1402    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403    let connections_metric = CONNECTIONS_METRIC
1404        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1405
1406    let connections = server
1407        .connection_pool
1408        .lock()
1409        .connections()
1410        .filter(|connection| !connection.admin)
1411        .count();
1412    connections_metric.set(connections as _);
1413
1414    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1415    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1416        register_int_gauge!(
1417            "shared_projects",
1418            "number of open projects with one or more guests"
1419        )
1420        .unwrap()
1421    });
1422
1423    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1424    shared_projects_metric.set(shared_projects as _);
1425
1426    let encoder = prometheus::TextEncoder::new();
1427    let metric_families = prometheus::gather();
1428    let encoded_metrics = encoder
1429        .encode_to_string(&metric_families)
1430        .map_err(|err| anyhow!("{}", err))?;
1431    Ok(encoded_metrics)
1432}
1433
1434#[instrument(err, skip(executor))]
1435async fn connection_lost(
1436    session: Session,
1437    mut teardown: watch::Receiver<bool>,
1438    executor: Executor,
1439) -> Result<()> {
1440    session.peer.disconnect(session.connection_id);
1441    session
1442        .connection_pool()
1443        .await
1444        .remove_connection(session.connection_id)?;
1445
1446    session
1447        .db()
1448        .await
1449        .connection_lost(session.connection_id)
1450        .await
1451        .trace_err();
1452
1453    futures::select_biased! {
1454        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1455            match &session.principal {
1456                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1457                    let session = session.for_user().unwrap();
1458
1459                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1460                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1461                    leave_channel_buffers_for_session(&session)
1462                        .await
1463                        .trace_err();
1464
1465                    if !session
1466                        .connection_pool()
1467                        .await
1468                        .is_user_online(session.user_id())
1469                    {
1470                        let db = session.db().await;
1471                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1472                            room_updated(&room, &session.peer);
1473                        }
1474                    }
1475
1476                    update_user_contacts(session.user_id(), &session).await?;
1477                },
1478            Principal::DevServer(_) => {
1479                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1480            },
1481        }
1482        },
1483        _ = teardown.changed().fuse() => {}
1484    }
1485
1486    Ok(())
1487}
1488
1489/// Acknowledges a ping from a client, used to keep the connection alive.
1490async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1491    response.send(proto::Ack {})?;
1492    Ok(())
1493}
1494
1495/// Creates a new room for calling (outside of channels)
1496async fn create_room(
1497    _request: proto::CreateRoom,
1498    response: Response<proto::CreateRoom>,
1499    session: UserSession,
1500) -> Result<()> {
1501    let live_kit_room = nanoid::nanoid!(30);
1502
1503    let live_kit_connection_info = util::maybe!(async {
1504        let live_kit = session.app_state.live_kit_client.as_ref();
1505        let live_kit = live_kit?;
1506        let user_id = session.user_id().to_string();
1507
1508        let token = live_kit
1509            .room_token(&live_kit_room, &user_id.to_string())
1510            .trace_err()?;
1511
1512        Some(proto::LiveKitConnectionInfo {
1513            server_url: live_kit.url().into(),
1514            token,
1515            can_publish: true,
1516        })
1517    })
1518    .await;
1519
1520    let room = session
1521        .db()
1522        .await
1523        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1524        .await?;
1525
1526    response.send(proto::CreateRoomResponse {
1527        room: Some(room.clone()),
1528        live_kit_connection_info,
1529    })?;
1530
1531    update_user_contacts(session.user_id(), &session).await?;
1532    Ok(())
1533}
1534
1535/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1536async fn join_room(
1537    request: proto::JoinRoom,
1538    response: Response<proto::JoinRoom>,
1539    session: UserSession,
1540) -> Result<()> {
1541    let room_id = RoomId::from_proto(request.id);
1542
1543    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1544
1545    if let Some(channel_id) = channel_id {
1546        return join_channel_internal(channel_id, Box::new(response), session).await;
1547    }
1548
1549    let joined_room = {
1550        let room = session
1551            .db()
1552            .await
1553            .join_room(room_id, session.user_id(), session.connection_id)
1554            .await?;
1555        room_updated(&room.room, &session.peer);
1556        room.into_inner()
1557    };
1558
1559    for connection_id in session
1560        .connection_pool()
1561        .await
1562        .user_connection_ids(session.user_id())
1563    {
1564        session
1565            .peer
1566            .send(
1567                connection_id,
1568                proto::CallCanceled {
1569                    room_id: room_id.to_proto(),
1570                },
1571            )
1572            .trace_err();
1573    }
1574
1575    let live_kit_connection_info =
1576        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1577            live_kit
1578                .room_token(
1579                    &joined_room.room.live_kit_room,
1580                    &session.user_id().to_string(),
1581                )
1582                .trace_err()
1583                .map(|token| proto::LiveKitConnectionInfo {
1584                    server_url: live_kit.url().into(),
1585                    token,
1586                    can_publish: true,
1587                })
1588        } else {
1589            None
1590        };
1591
1592    response.send(proto::JoinRoomResponse {
1593        room: Some(joined_room.room),
1594        channel_id: None,
1595        live_kit_connection_info,
1596    })?;
1597
1598    update_user_contacts(session.user_id(), &session).await?;
1599    Ok(())
1600}
1601
1602/// Rejoin room is used to reconnect to a room after connection errors.
1603async fn rejoin_room(
1604    request: proto::RejoinRoom,
1605    response: Response<proto::RejoinRoom>,
1606    session: UserSession,
1607) -> Result<()> {
1608    let room;
1609    let channel;
1610    {
1611        let mut rejoined_room = session
1612            .db()
1613            .await
1614            .rejoin_room(request, session.user_id(), session.connection_id)
1615            .await?;
1616
1617        response.send(proto::RejoinRoomResponse {
1618            room: Some(rejoined_room.room.clone()),
1619            reshared_projects: rejoined_room
1620                .reshared_projects
1621                .iter()
1622                .map(|project| proto::ResharedProject {
1623                    id: project.id.to_proto(),
1624                    collaborators: project
1625                        .collaborators
1626                        .iter()
1627                        .map(|collaborator| collaborator.to_proto())
1628                        .collect(),
1629                })
1630                .collect(),
1631            rejoined_projects: rejoined_room
1632                .rejoined_projects
1633                .iter()
1634                .map(|rejoined_project| rejoined_project.to_proto())
1635                .collect(),
1636        })?;
1637        room_updated(&rejoined_room.room, &session.peer);
1638
1639        for project in &rejoined_room.reshared_projects {
1640            for collaborator in &project.collaborators {
1641                session
1642                    .peer
1643                    .send(
1644                        collaborator.connection_id,
1645                        proto::UpdateProjectCollaborator {
1646                            project_id: project.id.to_proto(),
1647                            old_peer_id: Some(project.old_connection_id.into()),
1648                            new_peer_id: Some(session.connection_id.into()),
1649                        },
1650                    )
1651                    .trace_err();
1652            }
1653
1654            broadcast(
1655                Some(session.connection_id),
1656                project
1657                    .collaborators
1658                    .iter()
1659                    .map(|collaborator| collaborator.connection_id),
1660                |connection_id| {
1661                    session.peer.forward_send(
1662                        session.connection_id,
1663                        connection_id,
1664                        proto::UpdateProject {
1665                            project_id: project.id.to_proto(),
1666                            worktrees: project.worktrees.clone(),
1667                        },
1668                    )
1669                },
1670            );
1671        }
1672
1673        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1674
1675        let rejoined_room = rejoined_room.into_inner();
1676
1677        room = rejoined_room.room;
1678        channel = rejoined_room.channel;
1679    }
1680
1681    if let Some(channel) = channel {
1682        channel_updated(
1683            &channel,
1684            &room,
1685            &session.peer,
1686            &*session.connection_pool().await,
1687        );
1688    }
1689
1690    update_user_contacts(session.user_id(), &session).await?;
1691    Ok(())
1692}
1693
1694fn notify_rejoined_projects(
1695    rejoined_projects: &mut Vec<RejoinedProject>,
1696    session: &UserSession,
1697) -> Result<()> {
1698    for project in rejoined_projects.iter() {
1699        for collaborator in &project.collaborators {
1700            session
1701                .peer
1702                .send(
1703                    collaborator.connection_id,
1704                    proto::UpdateProjectCollaborator {
1705                        project_id: project.id.to_proto(),
1706                        old_peer_id: Some(project.old_connection_id.into()),
1707                        new_peer_id: Some(session.connection_id.into()),
1708                    },
1709                )
1710                .trace_err();
1711        }
1712    }
1713
1714    for project in rejoined_projects {
1715        for worktree in mem::take(&mut project.worktrees) {
1716            #[cfg(any(test, feature = "test-support"))]
1717            const MAX_CHUNK_SIZE: usize = 2;
1718            #[cfg(not(any(test, feature = "test-support")))]
1719            const MAX_CHUNK_SIZE: usize = 256;
1720
1721            // Stream this worktree's entries.
1722            let message = proto::UpdateWorktree {
1723                project_id: project.id.to_proto(),
1724                worktree_id: worktree.id,
1725                abs_path: worktree.abs_path.clone(),
1726                root_name: worktree.root_name,
1727                updated_entries: worktree.updated_entries,
1728                removed_entries: worktree.removed_entries,
1729                scan_id: worktree.scan_id,
1730                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1731                updated_repositories: worktree.updated_repositories,
1732                removed_repositories: worktree.removed_repositories,
1733            };
1734            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1735                session.peer.send(session.connection_id, update.clone())?;
1736            }
1737
1738            // Stream this worktree's diagnostics.
1739            for summary in worktree.diagnostic_summaries {
1740                session.peer.send(
1741                    session.connection_id,
1742                    proto::UpdateDiagnosticSummary {
1743                        project_id: project.id.to_proto(),
1744                        worktree_id: worktree.id,
1745                        summary: Some(summary),
1746                    },
1747                )?;
1748            }
1749
1750            for settings_file in worktree.settings_files {
1751                session.peer.send(
1752                    session.connection_id,
1753                    proto::UpdateWorktreeSettings {
1754                        project_id: project.id.to_proto(),
1755                        worktree_id: worktree.id,
1756                        path: settings_file.path,
1757                        content: Some(settings_file.content),
1758                        kind: Some(settings_file.kind.to_proto().into()),
1759                    },
1760                )?;
1761            }
1762        }
1763
1764        for language_server in &project.language_servers {
1765            session.peer.send(
1766                session.connection_id,
1767                proto::UpdateLanguageServer {
1768                    project_id: project.id.to_proto(),
1769                    language_server_id: language_server.id,
1770                    variant: Some(
1771                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1772                            proto::LspDiskBasedDiagnosticsUpdated {},
1773                        ),
1774                    ),
1775                },
1776            )?;
1777        }
1778    }
1779    Ok(())
1780}
1781
1782/// leave room disconnects from the room.
1783async fn leave_room(
1784    _: proto::LeaveRoom,
1785    response: Response<proto::LeaveRoom>,
1786    session: UserSession,
1787) -> Result<()> {
1788    leave_room_for_session(&session, session.connection_id).await?;
1789    response.send(proto::Ack {})?;
1790    Ok(())
1791}
1792
1793/// Updates the permissions of someone else in the room.
1794async fn set_room_participant_role(
1795    request: proto::SetRoomParticipantRole,
1796    response: Response<proto::SetRoomParticipantRole>,
1797    session: UserSession,
1798) -> Result<()> {
1799    let user_id = UserId::from_proto(request.user_id);
1800    let role = ChannelRole::from(request.role());
1801
1802    let (live_kit_room, can_publish) = {
1803        let room = session
1804            .db()
1805            .await
1806            .set_room_participant_role(
1807                session.user_id(),
1808                RoomId::from_proto(request.room_id),
1809                user_id,
1810                role,
1811            )
1812            .await?;
1813
1814        let live_kit_room = room.live_kit_room.clone();
1815        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1816        room_updated(&room, &session.peer);
1817        (live_kit_room, can_publish)
1818    };
1819
1820    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1821        live_kit
1822            .update_participant(
1823                live_kit_room.clone(),
1824                request.user_id.to_string(),
1825                live_kit_server::proto::ParticipantPermission {
1826                    can_subscribe: true,
1827                    can_publish,
1828                    can_publish_data: can_publish,
1829                    hidden: false,
1830                    recorder: false,
1831                },
1832            )
1833            .await
1834            .trace_err();
1835    }
1836
1837    response.send(proto::Ack {})?;
1838    Ok(())
1839}
1840
1841/// Call someone else into the current room
1842async fn call(
1843    request: proto::Call,
1844    response: Response<proto::Call>,
1845    session: UserSession,
1846) -> Result<()> {
1847    let room_id = RoomId::from_proto(request.room_id);
1848    let calling_user_id = session.user_id();
1849    let calling_connection_id = session.connection_id;
1850    let called_user_id = UserId::from_proto(request.called_user_id);
1851    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1852    if !session
1853        .db()
1854        .await
1855        .has_contact(calling_user_id, called_user_id)
1856        .await?
1857    {
1858        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1859    }
1860
1861    let incoming_call = {
1862        let (room, incoming_call) = &mut *session
1863            .db()
1864            .await
1865            .call(
1866                room_id,
1867                calling_user_id,
1868                calling_connection_id,
1869                called_user_id,
1870                initial_project_id,
1871            )
1872            .await?;
1873        room_updated(room, &session.peer);
1874        mem::take(incoming_call)
1875    };
1876    update_user_contacts(called_user_id, &session).await?;
1877
1878    let mut calls = session
1879        .connection_pool()
1880        .await
1881        .user_connection_ids(called_user_id)
1882        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1883        .collect::<FuturesUnordered<_>>();
1884
1885    while let Some(call_response) = calls.next().await {
1886        match call_response.as_ref() {
1887            Ok(_) => {
1888                response.send(proto::Ack {})?;
1889                return Ok(());
1890            }
1891            Err(_) => {
1892                call_response.trace_err();
1893            }
1894        }
1895    }
1896
1897    {
1898        let room = session
1899            .db()
1900            .await
1901            .call_failed(room_id, called_user_id)
1902            .await?;
1903        room_updated(&room, &session.peer);
1904    }
1905    update_user_contacts(called_user_id, &session).await?;
1906
1907    Err(anyhow!("failed to ring user"))?
1908}
1909
1910/// Cancel an outgoing call.
1911async fn cancel_call(
1912    request: proto::CancelCall,
1913    response: Response<proto::CancelCall>,
1914    session: UserSession,
1915) -> Result<()> {
1916    let called_user_id = UserId::from_proto(request.called_user_id);
1917    let room_id = RoomId::from_proto(request.room_id);
1918    {
1919        let room = session
1920            .db()
1921            .await
1922            .cancel_call(room_id, session.connection_id, called_user_id)
1923            .await?;
1924        room_updated(&room, &session.peer);
1925    }
1926
1927    for connection_id in session
1928        .connection_pool()
1929        .await
1930        .user_connection_ids(called_user_id)
1931    {
1932        session
1933            .peer
1934            .send(
1935                connection_id,
1936                proto::CallCanceled {
1937                    room_id: room_id.to_proto(),
1938                },
1939            )
1940            .trace_err();
1941    }
1942    response.send(proto::Ack {})?;
1943
1944    update_user_contacts(called_user_id, &session).await?;
1945    Ok(())
1946}
1947
1948/// Decline an incoming call.
1949async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1950    let room_id = RoomId::from_proto(message.room_id);
1951    {
1952        let room = session
1953            .db()
1954            .await
1955            .decline_call(Some(room_id), session.user_id())
1956            .await?
1957            .ok_or_else(|| anyhow!("failed to decline call"))?;
1958        room_updated(&room, &session.peer);
1959    }
1960
1961    for connection_id in session
1962        .connection_pool()
1963        .await
1964        .user_connection_ids(session.user_id())
1965    {
1966        session
1967            .peer
1968            .send(
1969                connection_id,
1970                proto::CallCanceled {
1971                    room_id: room_id.to_proto(),
1972                },
1973            )
1974            .trace_err();
1975    }
1976    update_user_contacts(session.user_id(), &session).await?;
1977    Ok(())
1978}
1979
1980/// Updates other participants in the room with your current location.
1981async fn update_participant_location(
1982    request: proto::UpdateParticipantLocation,
1983    response: Response<proto::UpdateParticipantLocation>,
1984    session: UserSession,
1985) -> Result<()> {
1986    let room_id = RoomId::from_proto(request.room_id);
1987    let location = request
1988        .location
1989        .ok_or_else(|| anyhow!("invalid location"))?;
1990
1991    let db = session.db().await;
1992    let room = db
1993        .update_room_participant_location(room_id, session.connection_id, location)
1994        .await?;
1995
1996    room_updated(&room, &session.peer);
1997    response.send(proto::Ack {})?;
1998    Ok(())
1999}
2000
2001/// Share a project into the room.
2002async fn share_project(
2003    request: proto::ShareProject,
2004    response: Response<proto::ShareProject>,
2005    session: UserSession,
2006) -> Result<()> {
2007    let (project_id, room) = &*session
2008        .db()
2009        .await
2010        .share_project(
2011            RoomId::from_proto(request.room_id),
2012            session.connection_id,
2013            &request.worktrees,
2014            request.is_ssh_project,
2015            request
2016                .dev_server_project_id
2017                .map(DevServerProjectId::from_proto),
2018        )
2019        .await?;
2020    response.send(proto::ShareProjectResponse {
2021        project_id: project_id.to_proto(),
2022    })?;
2023    room_updated(room, &session.peer);
2024
2025    Ok(())
2026}
2027
2028/// Unshare a project from the room.
2029async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2030    let project_id = ProjectId::from_proto(message.project_id);
2031    unshare_project_internal(
2032        project_id,
2033        session.connection_id,
2034        session.user_id(),
2035        &session,
2036    )
2037    .await
2038}
2039
2040async fn unshare_project_internal(
2041    project_id: ProjectId,
2042    connection_id: ConnectionId,
2043    user_id: Option<UserId>,
2044    session: &Session,
2045) -> Result<()> {
2046    let delete = {
2047        let room_guard = session
2048            .db()
2049            .await
2050            .unshare_project(project_id, connection_id, user_id)
2051            .await?;
2052
2053        let (delete, room, guest_connection_ids) = &*room_guard;
2054
2055        let message = proto::UnshareProject {
2056            project_id: project_id.to_proto(),
2057        };
2058
2059        broadcast(
2060            Some(connection_id),
2061            guest_connection_ids.iter().copied(),
2062            |conn_id| session.peer.send(conn_id, message.clone()),
2063        );
2064        if let Some(room) = room {
2065            room_updated(room, &session.peer);
2066        }
2067
2068        *delete
2069    };
2070
2071    if delete {
2072        let db = session.db().await;
2073        db.delete_project(project_id).await?;
2074    }
2075
2076    Ok(())
2077}
2078
2079/// DevServer makes a project available online
2080async fn share_dev_server_project(
2081    request: proto::ShareDevServerProject,
2082    response: Response<proto::ShareDevServerProject>,
2083    session: DevServerSession,
2084) -> Result<()> {
2085    let (dev_server_project, user_id, status) = session
2086        .db()
2087        .await
2088        .share_dev_server_project(
2089            DevServerProjectId::from_proto(request.dev_server_project_id),
2090            session.dev_server_id(),
2091            session.connection_id,
2092            &request.worktrees,
2093        )
2094        .await?;
2095    let Some(project_id) = dev_server_project.project_id else {
2096        return Err(anyhow!("failed to share remote project"))?;
2097    };
2098
2099    send_dev_server_projects_update(user_id, status, &session).await;
2100
2101    response.send(proto::ShareProjectResponse { project_id })?;
2102
2103    Ok(())
2104}
2105
2106/// Join someone elses shared project.
2107async fn join_project(
2108    request: proto::JoinProject,
2109    response: Response<proto::JoinProject>,
2110    session: UserSession,
2111) -> Result<()> {
2112    let project_id = ProjectId::from_proto(request.project_id);
2113
2114    tracing::info!(%project_id, "join project");
2115
2116    let db = session.db().await;
2117    let (project, replica_id) = &mut *db
2118        .join_project(project_id, session.connection_id, session.user_id())
2119        .await?;
2120    drop(db);
2121    tracing::info!(%project_id, "join remote project");
2122    join_project_internal(response, session, project, replica_id)
2123}
2124
2125trait JoinProjectInternalResponse {
2126    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2127}
2128impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2129    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2130        Response::<proto::JoinProject>::send(self, result)
2131    }
2132}
2133impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2134    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2135        Response::<proto::JoinHostedProject>::send(self, result)
2136    }
2137}
2138
2139fn join_project_internal(
2140    response: impl JoinProjectInternalResponse,
2141    session: UserSession,
2142    project: &mut Project,
2143    replica_id: &ReplicaId,
2144) -> Result<()> {
2145    let collaborators = project
2146        .collaborators
2147        .iter()
2148        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2149        .map(|collaborator| collaborator.to_proto())
2150        .collect::<Vec<_>>();
2151    let project_id = project.id;
2152    let guest_user_id = session.user_id();
2153
2154    let worktrees = project
2155        .worktrees
2156        .iter()
2157        .map(|(id, worktree)| proto::WorktreeMetadata {
2158            id: *id,
2159            root_name: worktree.root_name.clone(),
2160            visible: worktree.visible,
2161            abs_path: worktree.abs_path.clone(),
2162        })
2163        .collect::<Vec<_>>();
2164
2165    let add_project_collaborator = proto::AddProjectCollaborator {
2166        project_id: project_id.to_proto(),
2167        collaborator: Some(proto::Collaborator {
2168            peer_id: Some(session.connection_id.into()),
2169            replica_id: replica_id.0 as u32,
2170            user_id: guest_user_id.to_proto(),
2171        }),
2172    };
2173
2174    for collaborator in &collaborators {
2175        session
2176            .peer
2177            .send(
2178                collaborator.peer_id.unwrap().into(),
2179                add_project_collaborator.clone(),
2180            )
2181            .trace_err();
2182    }
2183
2184    // First, we send the metadata associated with each worktree.
2185    response.send(proto::JoinProjectResponse {
2186        project_id: project.id.0 as u64,
2187        worktrees: worktrees.clone(),
2188        replica_id: replica_id.0 as u32,
2189        collaborators: collaborators.clone(),
2190        language_servers: project.language_servers.clone(),
2191        role: project.role.into(),
2192        dev_server_project_id: project
2193            .dev_server_project_id
2194            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2195    })?;
2196
2197    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2198        #[cfg(any(test, feature = "test-support"))]
2199        const MAX_CHUNK_SIZE: usize = 2;
2200        #[cfg(not(any(test, feature = "test-support")))]
2201        const MAX_CHUNK_SIZE: usize = 256;
2202
2203        // Stream this worktree's entries.
2204        let message = proto::UpdateWorktree {
2205            project_id: project_id.to_proto(),
2206            worktree_id,
2207            abs_path: worktree.abs_path.clone(),
2208            root_name: worktree.root_name,
2209            updated_entries: worktree.entries,
2210            removed_entries: Default::default(),
2211            scan_id: worktree.scan_id,
2212            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2213            updated_repositories: worktree.repository_entries.into_values().collect(),
2214            removed_repositories: Default::default(),
2215        };
2216        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2217            session.peer.send(session.connection_id, update.clone())?;
2218        }
2219
2220        // Stream this worktree's diagnostics.
2221        for summary in worktree.diagnostic_summaries {
2222            session.peer.send(
2223                session.connection_id,
2224                proto::UpdateDiagnosticSummary {
2225                    project_id: project_id.to_proto(),
2226                    worktree_id: worktree.id,
2227                    summary: Some(summary),
2228                },
2229            )?;
2230        }
2231
2232        for settings_file in worktree.settings_files {
2233            session.peer.send(
2234                session.connection_id,
2235                proto::UpdateWorktreeSettings {
2236                    project_id: project_id.to_proto(),
2237                    worktree_id: worktree.id,
2238                    path: settings_file.path,
2239                    content: Some(settings_file.content),
2240                    kind: Some(proto::update_user_settings::Kind::Settings.into()),
2241                },
2242            )?;
2243        }
2244    }
2245
2246    for language_server in &project.language_servers {
2247        session.peer.send(
2248            session.connection_id,
2249            proto::UpdateLanguageServer {
2250                project_id: project_id.to_proto(),
2251                language_server_id: language_server.id,
2252                variant: Some(
2253                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2254                        proto::LspDiskBasedDiagnosticsUpdated {},
2255                    ),
2256                ),
2257            },
2258        )?;
2259    }
2260
2261    Ok(())
2262}
2263
2264/// Leave someone elses shared project.
2265async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2266    let sender_id = session.connection_id;
2267    let project_id = ProjectId::from_proto(request.project_id);
2268    let db = session.db().await;
2269    if db.is_hosted_project(project_id).await? {
2270        let project = db.leave_hosted_project(project_id, sender_id).await?;
2271        project_left(&project, &session);
2272        return Ok(());
2273    }
2274
2275    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2276    tracing::info!(
2277        %project_id,
2278        "leave project"
2279    );
2280
2281    project_left(project, &session);
2282    if let Some(room) = room {
2283        room_updated(room, &session.peer);
2284    }
2285
2286    Ok(())
2287}
2288
2289async fn join_hosted_project(
2290    request: proto::JoinHostedProject,
2291    response: Response<proto::JoinHostedProject>,
2292    session: UserSession,
2293) -> Result<()> {
2294    let (mut project, replica_id) = session
2295        .db()
2296        .await
2297        .join_hosted_project(
2298            ProjectId(request.project_id as i32),
2299            session.user_id(),
2300            session.connection_id,
2301        )
2302        .await?;
2303
2304    join_project_internal(response, session, &mut project, &replica_id)
2305}
2306
2307async fn list_remote_directory(
2308    request: proto::ListRemoteDirectory,
2309    response: Response<proto::ListRemoteDirectory>,
2310    session: UserSession,
2311) -> Result<()> {
2312    let dev_server_id = DevServerId(request.dev_server_id as i32);
2313    let dev_server_connection_id = session
2314        .connection_pool()
2315        .await
2316        .online_dev_server_connection_id(dev_server_id)?;
2317
2318    session
2319        .db()
2320        .await
2321        .get_dev_server_for_user(dev_server_id, session.user_id())
2322        .await?;
2323
2324    response.send(
2325        session
2326            .peer
2327            .forward_request(session.connection_id, dev_server_connection_id, request)
2328            .await?,
2329    )?;
2330    Ok(())
2331}
2332
2333async fn update_dev_server_project(
2334    request: proto::UpdateDevServerProject,
2335    response: Response<proto::UpdateDevServerProject>,
2336    session: UserSession,
2337) -> Result<()> {
2338    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2339
2340    let (dev_server_project, update) = session
2341        .db()
2342        .await
2343        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2344        .await?;
2345
2346    let projects = session
2347        .db()
2348        .await
2349        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2350        .await?;
2351
2352    let dev_server_connection_id = session
2353        .connection_pool()
2354        .await
2355        .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2356
2357    session.peer.send(
2358        dev_server_connection_id,
2359        proto::DevServerInstructions { projects },
2360    )?;
2361
2362    send_dev_server_projects_update(session.user_id(), update, &session).await;
2363
2364    response.send(proto::Ack {})
2365}
2366
2367async fn create_dev_server_project(
2368    request: proto::CreateDevServerProject,
2369    response: Response<proto::CreateDevServerProject>,
2370    session: UserSession,
2371) -> Result<()> {
2372    let dev_server_id = DevServerId(request.dev_server_id as i32);
2373    let dev_server_connection_id = session
2374        .connection_pool()
2375        .await
2376        .dev_server_connection_id(dev_server_id);
2377    let Some(dev_server_connection_id) = dev_server_connection_id else {
2378        Err(ErrorCode::DevServerOffline
2379            .message("Cannot create a remote project when the dev server is offline".to_string())
2380            .anyhow())?
2381    };
2382
2383    let path = request.path.clone();
2384    //Check that the path exists on the dev server
2385    session
2386        .peer
2387        .forward_request(
2388            session.connection_id,
2389            dev_server_connection_id,
2390            proto::ValidateDevServerProjectRequest { path: path.clone() },
2391        )
2392        .await?;
2393
2394    let (dev_server_project, update) = session
2395        .db()
2396        .await
2397        .create_dev_server_project(
2398            DevServerId(request.dev_server_id as i32),
2399            &request.path,
2400            session.user_id(),
2401        )
2402        .await?;
2403
2404    let projects = session
2405        .db()
2406        .await
2407        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2408        .await?;
2409
2410    session.peer.send(
2411        dev_server_connection_id,
2412        proto::DevServerInstructions { projects },
2413    )?;
2414
2415    send_dev_server_projects_update(session.user_id(), update, &session).await;
2416
2417    response.send(proto::CreateDevServerProjectResponse {
2418        dev_server_project: Some(dev_server_project.to_proto(None)),
2419    })?;
2420    Ok(())
2421}
2422
2423async fn create_dev_server(
2424    request: proto::CreateDevServer,
2425    response: Response<proto::CreateDevServer>,
2426    session: UserSession,
2427) -> Result<()> {
2428    let access_token = auth::random_token();
2429    let hashed_access_token = auth::hash_access_token(&access_token);
2430
2431    if request.name.is_empty() {
2432        return Err(proto::ErrorCode::Forbidden
2433            .message("Dev server name cannot be empty".to_string())
2434            .anyhow())?;
2435    }
2436
2437    let (dev_server, status) = session
2438        .db()
2439        .await
2440        .create_dev_server(
2441            &request.name,
2442            request.ssh_connection_string.as_deref(),
2443            &hashed_access_token,
2444            session.user_id(),
2445        )
2446        .await?;
2447
2448    send_dev_server_projects_update(session.user_id(), status, &session).await;
2449
2450    response.send(proto::CreateDevServerResponse {
2451        dev_server_id: dev_server.id.0 as u64,
2452        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2453        name: request.name,
2454    })?;
2455    Ok(())
2456}
2457
2458async fn regenerate_dev_server_token(
2459    request: proto::RegenerateDevServerToken,
2460    response: Response<proto::RegenerateDevServerToken>,
2461    session: UserSession,
2462) -> Result<()> {
2463    let dev_server_id = DevServerId(request.dev_server_id as i32);
2464    let access_token = auth::random_token();
2465    let hashed_access_token = auth::hash_access_token(&access_token);
2466
2467    let connection_id = session
2468        .connection_pool()
2469        .await
2470        .dev_server_connection_id(dev_server_id);
2471    if let Some(connection_id) = connection_id {
2472        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2473        session.peer.send(
2474            connection_id,
2475            proto::ShutdownDevServer {
2476                reason: Some("dev server token was regenerated".to_string()),
2477            },
2478        )?;
2479        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2480    }
2481
2482    let status = session
2483        .db()
2484        .await
2485        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2486        .await?;
2487
2488    send_dev_server_projects_update(session.user_id(), status, &session).await;
2489
2490    response.send(proto::RegenerateDevServerTokenResponse {
2491        dev_server_id: dev_server_id.to_proto(),
2492        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2493    })?;
2494    Ok(())
2495}
2496
2497async fn rename_dev_server(
2498    request: proto::RenameDevServer,
2499    response: Response<proto::RenameDevServer>,
2500    session: UserSession,
2501) -> Result<()> {
2502    if request.name.trim().is_empty() {
2503        return Err(proto::ErrorCode::Forbidden
2504            .message("Dev server name cannot be empty".to_string())
2505            .anyhow())?;
2506    }
2507
2508    let dev_server_id = DevServerId(request.dev_server_id as i32);
2509    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2510    if dev_server.user_id != session.user_id() {
2511        return Err(anyhow!(ErrorCode::Forbidden))?;
2512    }
2513
2514    let status = session
2515        .db()
2516        .await
2517        .rename_dev_server(
2518            dev_server_id,
2519            &request.name,
2520            request.ssh_connection_string.as_deref(),
2521            session.user_id(),
2522        )
2523        .await?;
2524
2525    send_dev_server_projects_update(session.user_id(), status, &session).await;
2526
2527    response.send(proto::Ack {})?;
2528    Ok(())
2529}
2530
2531async fn delete_dev_server(
2532    request: proto::DeleteDevServer,
2533    response: Response<proto::DeleteDevServer>,
2534    session: UserSession,
2535) -> Result<()> {
2536    let dev_server_id = DevServerId(request.dev_server_id as i32);
2537    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2538    if dev_server.user_id != session.user_id() {
2539        return Err(anyhow!(ErrorCode::Forbidden))?;
2540    }
2541
2542    let connection_id = session
2543        .connection_pool()
2544        .await
2545        .dev_server_connection_id(dev_server_id);
2546    if let Some(connection_id) = connection_id {
2547        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2548        session.peer.send(
2549            connection_id,
2550            proto::ShutdownDevServer {
2551                reason: Some("dev server was deleted".to_string()),
2552            },
2553        )?;
2554        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2555    }
2556
2557    let status = session
2558        .db()
2559        .await
2560        .delete_dev_server(dev_server_id, session.user_id())
2561        .await?;
2562
2563    send_dev_server_projects_update(session.user_id(), status, &session).await;
2564
2565    response.send(proto::Ack {})?;
2566    Ok(())
2567}
2568
2569async fn delete_dev_server_project(
2570    request: proto::DeleteDevServerProject,
2571    response: Response<proto::DeleteDevServerProject>,
2572    session: UserSession,
2573) -> Result<()> {
2574    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2575    let dev_server_project = session
2576        .db()
2577        .await
2578        .get_dev_server_project(dev_server_project_id)
2579        .await?;
2580
2581    let dev_server = session
2582        .db()
2583        .await
2584        .get_dev_server(dev_server_project.dev_server_id)
2585        .await?;
2586    if dev_server.user_id != session.user_id() {
2587        return Err(anyhow!(ErrorCode::Forbidden))?;
2588    }
2589
2590    let dev_server_connection_id = session
2591        .connection_pool()
2592        .await
2593        .dev_server_connection_id(dev_server.id);
2594
2595    if let Some(dev_server_connection_id) = dev_server_connection_id {
2596        let project = session
2597            .db()
2598            .await
2599            .find_dev_server_project(dev_server_project_id)
2600            .await;
2601        if let Ok(project) = project {
2602            unshare_project_internal(
2603                project.id,
2604                dev_server_connection_id,
2605                Some(session.user_id()),
2606                &session,
2607            )
2608            .await?;
2609        }
2610    }
2611
2612    let (projects, status) = session
2613        .db()
2614        .await
2615        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2616        .await?;
2617
2618    if let Some(dev_server_connection_id) = dev_server_connection_id {
2619        session.peer.send(
2620            dev_server_connection_id,
2621            proto::DevServerInstructions { projects },
2622        )?;
2623    }
2624
2625    send_dev_server_projects_update(session.user_id(), status, &session).await;
2626
2627    response.send(proto::Ack {})?;
2628    Ok(())
2629}
2630
2631async fn rejoin_dev_server_projects(
2632    request: proto::RejoinRemoteProjects,
2633    response: Response<proto::RejoinRemoteProjects>,
2634    session: UserSession,
2635) -> Result<()> {
2636    let mut rejoined_projects = {
2637        let db = session.db().await;
2638        db.rejoin_dev_server_projects(
2639            &request.rejoined_projects,
2640            session.user_id(),
2641            session.0.connection_id,
2642        )
2643        .await?
2644    };
2645    response.send(proto::RejoinRemoteProjectsResponse {
2646        rejoined_projects: rejoined_projects
2647            .iter()
2648            .map(|project| project.to_proto())
2649            .collect(),
2650    })?;
2651    notify_rejoined_projects(&mut rejoined_projects, &session)
2652}
2653
2654async fn reconnect_dev_server(
2655    request: proto::ReconnectDevServer,
2656    response: Response<proto::ReconnectDevServer>,
2657    session: DevServerSession,
2658) -> Result<()> {
2659    let reshared_projects = {
2660        let db = session.db().await;
2661        db.reshare_dev_server_projects(
2662            &request.reshared_projects,
2663            session.dev_server_id(),
2664            session.0.connection_id,
2665        )
2666        .await?
2667    };
2668
2669    for project in &reshared_projects {
2670        for collaborator in &project.collaborators {
2671            session
2672                .peer
2673                .send(
2674                    collaborator.connection_id,
2675                    proto::UpdateProjectCollaborator {
2676                        project_id: project.id.to_proto(),
2677                        old_peer_id: Some(project.old_connection_id.into()),
2678                        new_peer_id: Some(session.connection_id.into()),
2679                    },
2680                )
2681                .trace_err();
2682        }
2683
2684        broadcast(
2685            Some(session.connection_id),
2686            project
2687                .collaborators
2688                .iter()
2689                .map(|collaborator| collaborator.connection_id),
2690            |connection_id| {
2691                session.peer.forward_send(
2692                    session.connection_id,
2693                    connection_id,
2694                    proto::UpdateProject {
2695                        project_id: project.id.to_proto(),
2696                        worktrees: project.worktrees.clone(),
2697                    },
2698                )
2699            },
2700        );
2701    }
2702
2703    response.send(proto::ReconnectDevServerResponse {
2704        reshared_projects: reshared_projects
2705            .iter()
2706            .map(|project| proto::ResharedProject {
2707                id: project.id.to_proto(),
2708                collaborators: project
2709                    .collaborators
2710                    .iter()
2711                    .map(|collaborator| collaborator.to_proto())
2712                    .collect(),
2713            })
2714            .collect(),
2715    })?;
2716
2717    Ok(())
2718}
2719
2720async fn shutdown_dev_server(
2721    _: proto::ShutdownDevServer,
2722    response: Response<proto::ShutdownDevServer>,
2723    session: DevServerSession,
2724) -> Result<()> {
2725    response.send(proto::Ack {})?;
2726    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2727    remove_dev_server_connection(session.dev_server_id(), &session).await
2728}
2729
2730async fn shutdown_dev_server_internal(
2731    dev_server_id: DevServerId,
2732    connection_id: ConnectionId,
2733    session: &Session,
2734) -> Result<()> {
2735    let (dev_server_projects, dev_server) = {
2736        let db = session.db().await;
2737        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2738        let dev_server = db.get_dev_server(dev_server_id).await?;
2739        (dev_server_projects, dev_server)
2740    };
2741
2742    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2743        unshare_project_internal(
2744            ProjectId::from_proto(project_id),
2745            connection_id,
2746            None,
2747            session,
2748        )
2749        .await?;
2750    }
2751
2752    session
2753        .connection_pool()
2754        .await
2755        .set_dev_server_offline(dev_server_id);
2756
2757    let status = session
2758        .db()
2759        .await
2760        .dev_server_projects_update(dev_server.user_id)
2761        .await?;
2762    send_dev_server_projects_update(dev_server.user_id, status, session).await;
2763
2764    Ok(())
2765}
2766
2767async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2768    let dev_server_connection = session
2769        .connection_pool()
2770        .await
2771        .dev_server_connection_id(dev_server_id);
2772
2773    if let Some(dev_server_connection) = dev_server_connection {
2774        session
2775            .connection_pool()
2776            .await
2777            .remove_connection(dev_server_connection)?;
2778    }
2779    Ok(())
2780}
2781
2782/// Updates other participants with changes to the project
2783async fn update_project(
2784    request: proto::UpdateProject,
2785    response: Response<proto::UpdateProject>,
2786    session: Session,
2787) -> Result<()> {
2788    let project_id = ProjectId::from_proto(request.project_id);
2789    let (room, guest_connection_ids) = &*session
2790        .db()
2791        .await
2792        .update_project(project_id, session.connection_id, &request.worktrees)
2793        .await?;
2794    broadcast(
2795        Some(session.connection_id),
2796        guest_connection_ids.iter().copied(),
2797        |connection_id| {
2798            session
2799                .peer
2800                .forward_send(session.connection_id, connection_id, request.clone())
2801        },
2802    );
2803    if let Some(room) = room {
2804        room_updated(room, &session.peer);
2805    }
2806    response.send(proto::Ack {})?;
2807
2808    Ok(())
2809}
2810
2811/// Updates other participants with changes to the worktree
2812async fn update_worktree(
2813    request: proto::UpdateWorktree,
2814    response: Response<proto::UpdateWorktree>,
2815    session: Session,
2816) -> Result<()> {
2817    let guest_connection_ids = session
2818        .db()
2819        .await
2820        .update_worktree(&request, session.connection_id)
2821        .await?;
2822
2823    broadcast(
2824        Some(session.connection_id),
2825        guest_connection_ids.iter().copied(),
2826        |connection_id| {
2827            session
2828                .peer
2829                .forward_send(session.connection_id, connection_id, request.clone())
2830        },
2831    );
2832    response.send(proto::Ack {})?;
2833    Ok(())
2834}
2835
2836/// Updates other participants with changes to the diagnostics
2837async fn update_diagnostic_summary(
2838    message: proto::UpdateDiagnosticSummary,
2839    session: Session,
2840) -> Result<()> {
2841    let guest_connection_ids = session
2842        .db()
2843        .await
2844        .update_diagnostic_summary(&message, session.connection_id)
2845        .await?;
2846
2847    broadcast(
2848        Some(session.connection_id),
2849        guest_connection_ids.iter().copied(),
2850        |connection_id| {
2851            session
2852                .peer
2853                .forward_send(session.connection_id, connection_id, message.clone())
2854        },
2855    );
2856
2857    Ok(())
2858}
2859
2860/// Updates other participants with changes to the worktree settings
2861async fn update_worktree_settings(
2862    message: proto::UpdateWorktreeSettings,
2863    session: Session,
2864) -> Result<()> {
2865    let guest_connection_ids = session
2866        .db()
2867        .await
2868        .update_worktree_settings(&message, session.connection_id)
2869        .await?;
2870
2871    broadcast(
2872        Some(session.connection_id),
2873        guest_connection_ids.iter().copied(),
2874        |connection_id| {
2875            session
2876                .peer
2877                .forward_send(session.connection_id, connection_id, message.clone())
2878        },
2879    );
2880
2881    Ok(())
2882}
2883
2884/// Notify other participants that a  language server has started.
2885async fn start_language_server(
2886    request: proto::StartLanguageServer,
2887    session: Session,
2888) -> Result<()> {
2889    let guest_connection_ids = session
2890        .db()
2891        .await
2892        .start_language_server(&request, session.connection_id)
2893        .await?;
2894
2895    broadcast(
2896        Some(session.connection_id),
2897        guest_connection_ids.iter().copied(),
2898        |connection_id| {
2899            session
2900                .peer
2901                .forward_send(session.connection_id, connection_id, request.clone())
2902        },
2903    );
2904    Ok(())
2905}
2906
2907/// Notify other participants that a language server has changed.
2908async fn update_language_server(
2909    request: proto::UpdateLanguageServer,
2910    session: Session,
2911) -> Result<()> {
2912    let project_id = ProjectId::from_proto(request.project_id);
2913    let project_connection_ids = session
2914        .db()
2915        .await
2916        .project_connection_ids(project_id, session.connection_id, true)
2917        .await?;
2918    broadcast(
2919        Some(session.connection_id),
2920        project_connection_ids.iter().copied(),
2921        |connection_id| {
2922            session
2923                .peer
2924                .forward_send(session.connection_id, connection_id, request.clone())
2925        },
2926    );
2927    Ok(())
2928}
2929
2930/// forward a project request to the host. These requests should be read only
2931/// as guests are allowed to send them.
2932async fn forward_read_only_project_request<T>(
2933    request: T,
2934    response: Response<T>,
2935    session: UserSession,
2936) -> Result<()>
2937where
2938    T: EntityMessage + RequestMessage,
2939{
2940    let project_id = ProjectId::from_proto(request.remote_entity_id());
2941    let host_connection_id = session
2942        .db()
2943        .await
2944        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2945        .await?;
2946    let payload = session
2947        .peer
2948        .forward_request(session.connection_id, host_connection_id, request)
2949        .await?;
2950    response.send(payload)?;
2951    Ok(())
2952}
2953
2954async fn forward_find_search_candidates_request(
2955    request: proto::FindSearchCandidates,
2956    response: Response<proto::FindSearchCandidates>,
2957    session: UserSession,
2958) -> Result<()> {
2959    let project_id = ProjectId::from_proto(request.remote_entity_id());
2960    let host_connection_id = session
2961        .db()
2962        .await
2963        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2964        .await?;
2965    let payload = session
2966        .peer
2967        .forward_request(session.connection_id, host_connection_id, request)
2968        .await?;
2969    response.send(payload)?;
2970    Ok(())
2971}
2972
2973/// forward a project request to the dev server. Only allowed
2974/// if it's your dev server.
2975async fn forward_project_request_for_owner<T>(
2976    request: T,
2977    response: Response<T>,
2978    session: UserSession,
2979) -> Result<()>
2980where
2981    T: EntityMessage + RequestMessage,
2982{
2983    let project_id = ProjectId::from_proto(request.remote_entity_id());
2984
2985    let host_connection_id = session
2986        .db()
2987        .await
2988        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2989        .await?;
2990    let payload = session
2991        .peer
2992        .forward_request(session.connection_id, host_connection_id, request)
2993        .await?;
2994    response.send(payload)?;
2995    Ok(())
2996}
2997
2998/// forward a project request to the host. These requests are disallowed
2999/// for guests.
3000async fn forward_mutating_project_request<T>(
3001    request: T,
3002    response: Response<T>,
3003    session: UserSession,
3004) -> Result<()>
3005where
3006    T: EntityMessage + RequestMessage,
3007{
3008    let project_id = ProjectId::from_proto(request.remote_entity_id());
3009
3010    let host_connection_id = session
3011        .db()
3012        .await
3013        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3014        .await?;
3015    let payload = session
3016        .peer
3017        .forward_request(session.connection_id, host_connection_id, request)
3018        .await?;
3019    response.send(payload)?;
3020    Ok(())
3021}
3022
3023/// Notify other participants that a new buffer has been created
3024async fn create_buffer_for_peer(
3025    request: proto::CreateBufferForPeer,
3026    session: Session,
3027) -> Result<()> {
3028    session
3029        .db()
3030        .await
3031        .check_user_is_project_host(
3032            ProjectId::from_proto(request.project_id),
3033            session.connection_id,
3034        )
3035        .await?;
3036    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3037    session
3038        .peer
3039        .forward_send(session.connection_id, peer_id.into(), request)?;
3040    Ok(())
3041}
3042
3043/// Notify other participants that a buffer has been updated. This is
3044/// allowed for guests as long as the update is limited to selections.
3045async fn update_buffer(
3046    request: proto::UpdateBuffer,
3047    response: Response<proto::UpdateBuffer>,
3048    session: Session,
3049) -> Result<()> {
3050    let project_id = ProjectId::from_proto(request.project_id);
3051    let mut capability = Capability::ReadOnly;
3052
3053    for op in request.operations.iter() {
3054        match op.variant {
3055            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3056            Some(_) => capability = Capability::ReadWrite,
3057        }
3058    }
3059
3060    let host = {
3061        let guard = session
3062            .db()
3063            .await
3064            .connections_for_buffer_update(
3065                project_id,
3066                session.principal_id(),
3067                session.connection_id,
3068                capability,
3069            )
3070            .await?;
3071
3072        let (host, guests) = &*guard;
3073
3074        broadcast(
3075            Some(session.connection_id),
3076            guests.clone(),
3077            |connection_id| {
3078                session
3079                    .peer
3080                    .forward_send(session.connection_id, connection_id, request.clone())
3081            },
3082        );
3083
3084        *host
3085    };
3086
3087    if host != session.connection_id {
3088        session
3089            .peer
3090            .forward_request(session.connection_id, host, request.clone())
3091            .await?;
3092    }
3093
3094    response.send(proto::Ack {})?;
3095    Ok(())
3096}
3097
3098async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3099    let project_id = ProjectId::from_proto(message.project_id);
3100
3101    let operation = message.operation.as_ref().context("invalid operation")?;
3102    let capability = match operation.variant.as_ref() {
3103        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3104            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3105                match buffer_op.variant {
3106                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3107                        Capability::ReadOnly
3108                    }
3109                    _ => Capability::ReadWrite,
3110                }
3111            } else {
3112                Capability::ReadWrite
3113            }
3114        }
3115        Some(_) => Capability::ReadWrite,
3116        None => Capability::ReadOnly,
3117    };
3118
3119    let guard = session
3120        .db()
3121        .await
3122        .connections_for_buffer_update(
3123            project_id,
3124            session.principal_id(),
3125            session.connection_id,
3126            capability,
3127        )
3128        .await?;
3129
3130    let (host, guests) = &*guard;
3131
3132    broadcast(
3133        Some(session.connection_id),
3134        guests.iter().chain([host]).copied(),
3135        |connection_id| {
3136            session
3137                .peer
3138                .forward_send(session.connection_id, connection_id, message.clone())
3139        },
3140    );
3141
3142    Ok(())
3143}
3144
3145/// Notify other participants that a project has been updated.
3146async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3147    request: T,
3148    session: Session,
3149) -> Result<()> {
3150    let project_id = ProjectId::from_proto(request.remote_entity_id());
3151    let project_connection_ids = session
3152        .db()
3153        .await
3154        .project_connection_ids(project_id, session.connection_id, false)
3155        .await?;
3156
3157    broadcast(
3158        Some(session.connection_id),
3159        project_connection_ids.iter().copied(),
3160        |connection_id| {
3161            session
3162                .peer
3163                .forward_send(session.connection_id, connection_id, request.clone())
3164        },
3165    );
3166    Ok(())
3167}
3168
3169/// Start following another user in a call.
3170async fn follow(
3171    request: proto::Follow,
3172    response: Response<proto::Follow>,
3173    session: UserSession,
3174) -> Result<()> {
3175    let room_id = RoomId::from_proto(request.room_id);
3176    let project_id = request.project_id.map(ProjectId::from_proto);
3177    let leader_id = request
3178        .leader_id
3179        .ok_or_else(|| anyhow!("invalid leader id"))?
3180        .into();
3181    let follower_id = session.connection_id;
3182
3183    session
3184        .db()
3185        .await
3186        .check_room_participants(room_id, leader_id, session.connection_id)
3187        .await?;
3188
3189    let response_payload = session
3190        .peer
3191        .forward_request(session.connection_id, leader_id, request)
3192        .await?;
3193    response.send(response_payload)?;
3194
3195    if let Some(project_id) = project_id {
3196        let room = session
3197            .db()
3198            .await
3199            .follow(room_id, project_id, leader_id, follower_id)
3200            .await?;
3201        room_updated(&room, &session.peer);
3202    }
3203
3204    Ok(())
3205}
3206
3207/// Stop following another user in a call.
3208async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3209    let room_id = RoomId::from_proto(request.room_id);
3210    let project_id = request.project_id.map(ProjectId::from_proto);
3211    let leader_id = request
3212        .leader_id
3213        .ok_or_else(|| anyhow!("invalid leader id"))?
3214        .into();
3215    let follower_id = session.connection_id;
3216
3217    session
3218        .db()
3219        .await
3220        .check_room_participants(room_id, leader_id, session.connection_id)
3221        .await?;
3222
3223    session
3224        .peer
3225        .forward_send(session.connection_id, leader_id, request)?;
3226
3227    if let Some(project_id) = project_id {
3228        let room = session
3229            .db()
3230            .await
3231            .unfollow(room_id, project_id, leader_id, follower_id)
3232            .await?;
3233        room_updated(&room, &session.peer);
3234    }
3235
3236    Ok(())
3237}
3238
3239/// Notify everyone following you of your current location.
3240async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3241    let room_id = RoomId::from_proto(request.room_id);
3242    let database = session.db.lock().await;
3243
3244    let connection_ids = if let Some(project_id) = request.project_id {
3245        let project_id = ProjectId::from_proto(project_id);
3246        database
3247            .project_connection_ids(project_id, session.connection_id, true)
3248            .await?
3249    } else {
3250        database
3251            .room_connection_ids(room_id, session.connection_id)
3252            .await?
3253    };
3254
3255    // For now, don't send view update messages back to that view's current leader.
3256    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3257        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3258        _ => None,
3259    });
3260
3261    for connection_id in connection_ids.iter().cloned() {
3262        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3263            session
3264                .peer
3265                .forward_send(session.connection_id, connection_id, request.clone())?;
3266        }
3267    }
3268    Ok(())
3269}
3270
3271/// Get public data about users.
3272async fn get_users(
3273    request: proto::GetUsers,
3274    response: Response<proto::GetUsers>,
3275    session: Session,
3276) -> Result<()> {
3277    let user_ids = request
3278        .user_ids
3279        .into_iter()
3280        .map(UserId::from_proto)
3281        .collect();
3282    let users = session
3283        .db()
3284        .await
3285        .get_users_by_ids(user_ids)
3286        .await?
3287        .into_iter()
3288        .map(|user| proto::User {
3289            id: user.id.to_proto(),
3290            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3291            github_login: user.github_login,
3292        })
3293        .collect();
3294    response.send(proto::UsersResponse { users })?;
3295    Ok(())
3296}
3297
3298/// Search for users (to invite) buy Github login
3299async fn fuzzy_search_users(
3300    request: proto::FuzzySearchUsers,
3301    response: Response<proto::FuzzySearchUsers>,
3302    session: UserSession,
3303) -> Result<()> {
3304    let query = request.query;
3305    let users = match query.len() {
3306        0 => vec![],
3307        1 | 2 => session
3308            .db()
3309            .await
3310            .get_user_by_github_login(&query)
3311            .await?
3312            .into_iter()
3313            .collect(),
3314        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3315    };
3316    let users = users
3317        .into_iter()
3318        .filter(|user| user.id != session.user_id())
3319        .map(|user| proto::User {
3320            id: user.id.to_proto(),
3321            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3322            github_login: user.github_login,
3323        })
3324        .collect();
3325    response.send(proto::UsersResponse { users })?;
3326    Ok(())
3327}
3328
3329/// Send a contact request to another user.
3330async fn request_contact(
3331    request: proto::RequestContact,
3332    response: Response<proto::RequestContact>,
3333    session: UserSession,
3334) -> Result<()> {
3335    let requester_id = session.user_id();
3336    let responder_id = UserId::from_proto(request.responder_id);
3337    if requester_id == responder_id {
3338        return Err(anyhow!("cannot add yourself as a contact"))?;
3339    }
3340
3341    let notifications = session
3342        .db()
3343        .await
3344        .send_contact_request(requester_id, responder_id)
3345        .await?;
3346
3347    // Update outgoing contact requests of requester
3348    let mut update = proto::UpdateContacts::default();
3349    update.outgoing_requests.push(responder_id.to_proto());
3350    for connection_id in session
3351        .connection_pool()
3352        .await
3353        .user_connection_ids(requester_id)
3354    {
3355        session.peer.send(connection_id, update.clone())?;
3356    }
3357
3358    // Update incoming contact requests of responder
3359    let mut update = proto::UpdateContacts::default();
3360    update
3361        .incoming_requests
3362        .push(proto::IncomingContactRequest {
3363            requester_id: requester_id.to_proto(),
3364        });
3365    let connection_pool = session.connection_pool().await;
3366    for connection_id in connection_pool.user_connection_ids(responder_id) {
3367        session.peer.send(connection_id, update.clone())?;
3368    }
3369
3370    send_notifications(&connection_pool, &session.peer, notifications);
3371
3372    response.send(proto::Ack {})?;
3373    Ok(())
3374}
3375
3376/// Accept or decline a contact request
3377async fn respond_to_contact_request(
3378    request: proto::RespondToContactRequest,
3379    response: Response<proto::RespondToContactRequest>,
3380    session: UserSession,
3381) -> Result<()> {
3382    let responder_id = session.user_id();
3383    let requester_id = UserId::from_proto(request.requester_id);
3384    let db = session.db().await;
3385    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3386        db.dismiss_contact_notification(responder_id, requester_id)
3387            .await?;
3388    } else {
3389        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3390
3391        let notifications = db
3392            .respond_to_contact_request(responder_id, requester_id, accept)
3393            .await?;
3394        let requester_busy = db.is_user_busy(requester_id).await?;
3395        let responder_busy = db.is_user_busy(responder_id).await?;
3396
3397        let pool = session.connection_pool().await;
3398        // Update responder with new contact
3399        let mut update = proto::UpdateContacts::default();
3400        if accept {
3401            update
3402                .contacts
3403                .push(contact_for_user(requester_id, requester_busy, &pool));
3404        }
3405        update
3406            .remove_incoming_requests
3407            .push(requester_id.to_proto());
3408        for connection_id in pool.user_connection_ids(responder_id) {
3409            session.peer.send(connection_id, update.clone())?;
3410        }
3411
3412        // Update requester with new contact
3413        let mut update = proto::UpdateContacts::default();
3414        if accept {
3415            update
3416                .contacts
3417                .push(contact_for_user(responder_id, responder_busy, &pool));
3418        }
3419        update
3420            .remove_outgoing_requests
3421            .push(responder_id.to_proto());
3422
3423        for connection_id in pool.user_connection_ids(requester_id) {
3424            session.peer.send(connection_id, update.clone())?;
3425        }
3426
3427        send_notifications(&pool, &session.peer, notifications);
3428    }
3429
3430    response.send(proto::Ack {})?;
3431    Ok(())
3432}
3433
3434/// Remove a contact.
3435async fn remove_contact(
3436    request: proto::RemoveContact,
3437    response: Response<proto::RemoveContact>,
3438    session: UserSession,
3439) -> Result<()> {
3440    let requester_id = session.user_id();
3441    let responder_id = UserId::from_proto(request.user_id);
3442    let db = session.db().await;
3443    let (contact_accepted, deleted_notification_id) =
3444        db.remove_contact(requester_id, responder_id).await?;
3445
3446    let pool = session.connection_pool().await;
3447    // Update outgoing contact requests of requester
3448    let mut update = proto::UpdateContacts::default();
3449    if contact_accepted {
3450        update.remove_contacts.push(responder_id.to_proto());
3451    } else {
3452        update
3453            .remove_outgoing_requests
3454            .push(responder_id.to_proto());
3455    }
3456    for connection_id in pool.user_connection_ids(requester_id) {
3457        session.peer.send(connection_id, update.clone())?;
3458    }
3459
3460    // Update incoming contact requests of responder
3461    let mut update = proto::UpdateContacts::default();
3462    if contact_accepted {
3463        update.remove_contacts.push(requester_id.to_proto());
3464    } else {
3465        update
3466            .remove_incoming_requests
3467            .push(requester_id.to_proto());
3468    }
3469    for connection_id in pool.user_connection_ids(responder_id) {
3470        session.peer.send(connection_id, update.clone())?;
3471        if let Some(notification_id) = deleted_notification_id {
3472            session.peer.send(
3473                connection_id,
3474                proto::DeleteNotification {
3475                    notification_id: notification_id.to_proto(),
3476                },
3477            )?;
3478        }
3479    }
3480
3481    response.send(proto::Ack {})?;
3482    Ok(())
3483}
3484
3485fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3486    version.0.minor() < 139
3487}
3488
3489async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3490    let plan = session.current_plan(&session.db().await).await?;
3491
3492    session
3493        .peer
3494        .send(
3495            session.connection_id,
3496            proto::UpdateUserPlan { plan: plan.into() },
3497        )
3498        .trace_err();
3499
3500    Ok(())
3501}
3502
3503async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3504    subscribe_user_to_channels(
3505        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3506        &session,
3507    )
3508    .await?;
3509    Ok(())
3510}
3511
3512async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3513    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3514    let mut pool = session.connection_pool().await;
3515    for membership in &channels_for_user.channel_memberships {
3516        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3517    }
3518    session.peer.send(
3519        session.connection_id,
3520        build_update_user_channels(&channels_for_user),
3521    )?;
3522    session.peer.send(
3523        session.connection_id,
3524        build_channels_update(channels_for_user),
3525    )?;
3526    Ok(())
3527}
3528
3529/// Creates a new channel.
3530async fn create_channel(
3531    request: proto::CreateChannel,
3532    response: Response<proto::CreateChannel>,
3533    session: UserSession,
3534) -> Result<()> {
3535    let db = session.db().await;
3536
3537    let parent_id = request.parent_id.map(ChannelId::from_proto);
3538    let (channel, membership) = db
3539        .create_channel(&request.name, parent_id, session.user_id())
3540        .await?;
3541
3542    let root_id = channel.root_id();
3543    let channel = Channel::from_model(channel);
3544
3545    response.send(proto::CreateChannelResponse {
3546        channel: Some(channel.to_proto()),
3547        parent_id: request.parent_id,
3548    })?;
3549
3550    let mut connection_pool = session.connection_pool().await;
3551    if let Some(membership) = membership {
3552        connection_pool.subscribe_to_channel(
3553            membership.user_id,
3554            membership.channel_id,
3555            membership.role,
3556        );
3557        let update = proto::UpdateUserChannels {
3558            channel_memberships: vec![proto::ChannelMembership {
3559                channel_id: membership.channel_id.to_proto(),
3560                role: membership.role.into(),
3561            }],
3562            ..Default::default()
3563        };
3564        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3565            session.peer.send(connection_id, update.clone())?;
3566        }
3567    }
3568
3569    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3570        if !role.can_see_channel(channel.visibility) {
3571            continue;
3572        }
3573
3574        let update = proto::UpdateChannels {
3575            channels: vec![channel.to_proto()],
3576            ..Default::default()
3577        };
3578        session.peer.send(connection_id, update.clone())?;
3579    }
3580
3581    Ok(())
3582}
3583
3584/// Delete a channel
3585async fn delete_channel(
3586    request: proto::DeleteChannel,
3587    response: Response<proto::DeleteChannel>,
3588    session: UserSession,
3589) -> Result<()> {
3590    let db = session.db().await;
3591
3592    let channel_id = request.channel_id;
3593    let (root_channel, removed_channels) = db
3594        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3595        .await?;
3596    response.send(proto::Ack {})?;
3597
3598    // Notify members of removed channels
3599    let mut update = proto::UpdateChannels::default();
3600    update
3601        .delete_channels
3602        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3603
3604    let connection_pool = session.connection_pool().await;
3605    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3606        session.peer.send(connection_id, update.clone())?;
3607    }
3608
3609    Ok(())
3610}
3611
3612/// Invite someone to join a channel.
3613async fn invite_channel_member(
3614    request: proto::InviteChannelMember,
3615    response: Response<proto::InviteChannelMember>,
3616    session: UserSession,
3617) -> Result<()> {
3618    let db = session.db().await;
3619    let channel_id = ChannelId::from_proto(request.channel_id);
3620    let invitee_id = UserId::from_proto(request.user_id);
3621    let InviteMemberResult {
3622        channel,
3623        notifications,
3624    } = db
3625        .invite_channel_member(
3626            channel_id,
3627            invitee_id,
3628            session.user_id(),
3629            request.role().into(),
3630        )
3631        .await?;
3632
3633    let update = proto::UpdateChannels {
3634        channel_invitations: vec![channel.to_proto()],
3635        ..Default::default()
3636    };
3637
3638    let connection_pool = session.connection_pool().await;
3639    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3640        session.peer.send(connection_id, update.clone())?;
3641    }
3642
3643    send_notifications(&connection_pool, &session.peer, notifications);
3644
3645    response.send(proto::Ack {})?;
3646    Ok(())
3647}
3648
3649/// remove someone from a channel
3650async fn remove_channel_member(
3651    request: proto::RemoveChannelMember,
3652    response: Response<proto::RemoveChannelMember>,
3653    session: UserSession,
3654) -> Result<()> {
3655    let db = session.db().await;
3656    let channel_id = ChannelId::from_proto(request.channel_id);
3657    let member_id = UserId::from_proto(request.user_id);
3658
3659    let RemoveChannelMemberResult {
3660        membership_update,
3661        notification_id,
3662    } = db
3663        .remove_channel_member(channel_id, member_id, session.user_id())
3664        .await?;
3665
3666    let mut connection_pool = session.connection_pool().await;
3667    notify_membership_updated(
3668        &mut connection_pool,
3669        membership_update,
3670        member_id,
3671        &session.peer,
3672    );
3673    for connection_id in connection_pool.user_connection_ids(member_id) {
3674        if let Some(notification_id) = notification_id {
3675            session
3676                .peer
3677                .send(
3678                    connection_id,
3679                    proto::DeleteNotification {
3680                        notification_id: notification_id.to_proto(),
3681                    },
3682                )
3683                .trace_err();
3684        }
3685    }
3686
3687    response.send(proto::Ack {})?;
3688    Ok(())
3689}
3690
3691/// Toggle the channel between public and private.
3692/// Care is taken to maintain the invariant that public channels only descend from public channels,
3693/// (though members-only channels can appear at any point in the hierarchy).
3694async fn set_channel_visibility(
3695    request: proto::SetChannelVisibility,
3696    response: Response<proto::SetChannelVisibility>,
3697    session: UserSession,
3698) -> Result<()> {
3699    let db = session.db().await;
3700    let channel_id = ChannelId::from_proto(request.channel_id);
3701    let visibility = request.visibility().into();
3702
3703    let channel_model = db
3704        .set_channel_visibility(channel_id, visibility, session.user_id())
3705        .await?;
3706    let root_id = channel_model.root_id();
3707    let channel = Channel::from_model(channel_model);
3708
3709    let mut connection_pool = session.connection_pool().await;
3710    for (user_id, role) in connection_pool
3711        .channel_user_ids(root_id)
3712        .collect::<Vec<_>>()
3713        .into_iter()
3714    {
3715        let update = if role.can_see_channel(channel.visibility) {
3716            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3717            proto::UpdateChannels {
3718                channels: vec![channel.to_proto()],
3719                ..Default::default()
3720            }
3721        } else {
3722            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3723            proto::UpdateChannels {
3724                delete_channels: vec![channel.id.to_proto()],
3725                ..Default::default()
3726            }
3727        };
3728
3729        for connection_id in connection_pool.user_connection_ids(user_id) {
3730            session.peer.send(connection_id, update.clone())?;
3731        }
3732    }
3733
3734    response.send(proto::Ack {})?;
3735    Ok(())
3736}
3737
3738/// Alter the role for a user in the channel.
3739async fn set_channel_member_role(
3740    request: proto::SetChannelMemberRole,
3741    response: Response<proto::SetChannelMemberRole>,
3742    session: UserSession,
3743) -> Result<()> {
3744    let db = session.db().await;
3745    let channel_id = ChannelId::from_proto(request.channel_id);
3746    let member_id = UserId::from_proto(request.user_id);
3747    let result = db
3748        .set_channel_member_role(
3749            channel_id,
3750            session.user_id(),
3751            member_id,
3752            request.role().into(),
3753        )
3754        .await?;
3755
3756    match result {
3757        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3758            let mut connection_pool = session.connection_pool().await;
3759            notify_membership_updated(
3760                &mut connection_pool,
3761                membership_update,
3762                member_id,
3763                &session.peer,
3764            )
3765        }
3766        db::SetMemberRoleResult::InviteUpdated(channel) => {
3767            let update = proto::UpdateChannels {
3768                channel_invitations: vec![channel.to_proto()],
3769                ..Default::default()
3770            };
3771
3772            for connection_id in session
3773                .connection_pool()
3774                .await
3775                .user_connection_ids(member_id)
3776            {
3777                session.peer.send(connection_id, update.clone())?;
3778            }
3779        }
3780    }
3781
3782    response.send(proto::Ack {})?;
3783    Ok(())
3784}
3785
3786/// Change the name of a channel
3787async fn rename_channel(
3788    request: proto::RenameChannel,
3789    response: Response<proto::RenameChannel>,
3790    session: UserSession,
3791) -> Result<()> {
3792    let db = session.db().await;
3793    let channel_id = ChannelId::from_proto(request.channel_id);
3794    let channel_model = db
3795        .rename_channel(channel_id, session.user_id(), &request.name)
3796        .await?;
3797    let root_id = channel_model.root_id();
3798    let channel = Channel::from_model(channel_model);
3799
3800    response.send(proto::RenameChannelResponse {
3801        channel: Some(channel.to_proto()),
3802    })?;
3803
3804    let connection_pool = session.connection_pool().await;
3805    let update = proto::UpdateChannels {
3806        channels: vec![channel.to_proto()],
3807        ..Default::default()
3808    };
3809    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3810        if role.can_see_channel(channel.visibility) {
3811            session.peer.send(connection_id, update.clone())?;
3812        }
3813    }
3814
3815    Ok(())
3816}
3817
3818/// Move a channel to a new parent.
3819async fn move_channel(
3820    request: proto::MoveChannel,
3821    response: Response<proto::MoveChannel>,
3822    session: UserSession,
3823) -> Result<()> {
3824    let channel_id = ChannelId::from_proto(request.channel_id);
3825    let to = ChannelId::from_proto(request.to);
3826
3827    let (root_id, channels) = session
3828        .db()
3829        .await
3830        .move_channel(channel_id, to, session.user_id())
3831        .await?;
3832
3833    let connection_pool = session.connection_pool().await;
3834    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3835        let channels = channels
3836            .iter()
3837            .filter_map(|channel| {
3838                if role.can_see_channel(channel.visibility) {
3839                    Some(channel.to_proto())
3840                } else {
3841                    None
3842                }
3843            })
3844            .collect::<Vec<_>>();
3845        if channels.is_empty() {
3846            continue;
3847        }
3848
3849        let update = proto::UpdateChannels {
3850            channels,
3851            ..Default::default()
3852        };
3853
3854        session.peer.send(connection_id, update.clone())?;
3855    }
3856
3857    response.send(Ack {})?;
3858    Ok(())
3859}
3860
3861/// Get the list of channel members
3862async fn get_channel_members(
3863    request: proto::GetChannelMembers,
3864    response: Response<proto::GetChannelMembers>,
3865    session: UserSession,
3866) -> Result<()> {
3867    let db = session.db().await;
3868    let channel_id = ChannelId::from_proto(request.channel_id);
3869    let limit = if request.limit == 0 {
3870        u16::MAX as u64
3871    } else {
3872        request.limit
3873    };
3874    let (members, users) = db
3875        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3876        .await?;
3877    response.send(proto::GetChannelMembersResponse { members, users })?;
3878    Ok(())
3879}
3880
3881/// Accept or decline a channel invitation.
3882async fn respond_to_channel_invite(
3883    request: proto::RespondToChannelInvite,
3884    response: Response<proto::RespondToChannelInvite>,
3885    session: UserSession,
3886) -> Result<()> {
3887    let db = session.db().await;
3888    let channel_id = ChannelId::from_proto(request.channel_id);
3889    let RespondToChannelInvite {
3890        membership_update,
3891        notifications,
3892    } = db
3893        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3894        .await?;
3895
3896    let mut connection_pool = session.connection_pool().await;
3897    if let Some(membership_update) = membership_update {
3898        notify_membership_updated(
3899            &mut connection_pool,
3900            membership_update,
3901            session.user_id(),
3902            &session.peer,
3903        );
3904    } else {
3905        let update = proto::UpdateChannels {
3906            remove_channel_invitations: vec![channel_id.to_proto()],
3907            ..Default::default()
3908        };
3909
3910        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3911            session.peer.send(connection_id, update.clone())?;
3912        }
3913    };
3914
3915    send_notifications(&connection_pool, &session.peer, notifications);
3916
3917    response.send(proto::Ack {})?;
3918
3919    Ok(())
3920}
3921
3922/// Join the channels' room
3923async fn join_channel(
3924    request: proto::JoinChannel,
3925    response: Response<proto::JoinChannel>,
3926    session: UserSession,
3927) -> Result<()> {
3928    let channel_id = ChannelId::from_proto(request.channel_id);
3929    join_channel_internal(channel_id, Box::new(response), session).await
3930}
3931
3932trait JoinChannelInternalResponse {
3933    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3934}
3935impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3936    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3937        Response::<proto::JoinChannel>::send(self, result)
3938    }
3939}
3940impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3941    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3942        Response::<proto::JoinRoom>::send(self, result)
3943    }
3944}
3945
3946async fn join_channel_internal(
3947    channel_id: ChannelId,
3948    response: Box<impl JoinChannelInternalResponse>,
3949    session: UserSession,
3950) -> Result<()> {
3951    let joined_room = {
3952        let mut db = session.db().await;
3953        // If zed quits without leaving the room, and the user re-opens zed before the
3954        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3955        // room they were in.
3956        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3957            tracing::info!(
3958                stale_connection_id = %connection,
3959                "cleaning up stale connection",
3960            );
3961            drop(db);
3962            leave_room_for_session(&session, connection).await?;
3963            db = session.db().await;
3964        }
3965
3966        let (joined_room, membership_updated, role) = db
3967            .join_channel(channel_id, session.user_id(), session.connection_id)
3968            .await?;
3969
3970        let live_kit_connection_info =
3971            session
3972                .app_state
3973                .live_kit_client
3974                .as_ref()
3975                .and_then(|live_kit| {
3976                    let (can_publish, token) = if role == ChannelRole::Guest {
3977                        (
3978                            false,
3979                            live_kit
3980                                .guest_token(
3981                                    &joined_room.room.live_kit_room,
3982                                    &session.user_id().to_string(),
3983                                )
3984                                .trace_err()?,
3985                        )
3986                    } else {
3987                        (
3988                            true,
3989                            live_kit
3990                                .room_token(
3991                                    &joined_room.room.live_kit_room,
3992                                    &session.user_id().to_string(),
3993                                )
3994                                .trace_err()?,
3995                        )
3996                    };
3997
3998                    Some(LiveKitConnectionInfo {
3999                        server_url: live_kit.url().into(),
4000                        token,
4001                        can_publish,
4002                    })
4003                });
4004
4005        response.send(proto::JoinRoomResponse {
4006            room: Some(joined_room.room.clone()),
4007            channel_id: joined_room
4008                .channel
4009                .as_ref()
4010                .map(|channel| channel.id.to_proto()),
4011            live_kit_connection_info,
4012        })?;
4013
4014        let mut connection_pool = session.connection_pool().await;
4015        if let Some(membership_updated) = membership_updated {
4016            notify_membership_updated(
4017                &mut connection_pool,
4018                membership_updated,
4019                session.user_id(),
4020                &session.peer,
4021            );
4022        }
4023
4024        room_updated(&joined_room.room, &session.peer);
4025
4026        joined_room
4027    };
4028
4029    channel_updated(
4030        &joined_room
4031            .channel
4032            .ok_or_else(|| anyhow!("channel not returned"))?,
4033        &joined_room.room,
4034        &session.peer,
4035        &*session.connection_pool().await,
4036    );
4037
4038    update_user_contacts(session.user_id(), &session).await?;
4039    Ok(())
4040}
4041
4042/// Start editing the channel notes
4043async fn join_channel_buffer(
4044    request: proto::JoinChannelBuffer,
4045    response: Response<proto::JoinChannelBuffer>,
4046    session: UserSession,
4047) -> Result<()> {
4048    let db = session.db().await;
4049    let channel_id = ChannelId::from_proto(request.channel_id);
4050
4051    let open_response = db
4052        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4053        .await?;
4054
4055    let collaborators = open_response.collaborators.clone();
4056    response.send(open_response)?;
4057
4058    let update = UpdateChannelBufferCollaborators {
4059        channel_id: channel_id.to_proto(),
4060        collaborators: collaborators.clone(),
4061    };
4062    channel_buffer_updated(
4063        session.connection_id,
4064        collaborators
4065            .iter()
4066            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4067        &update,
4068        &session.peer,
4069    );
4070
4071    Ok(())
4072}
4073
4074/// Edit the channel notes
4075async fn update_channel_buffer(
4076    request: proto::UpdateChannelBuffer,
4077    session: UserSession,
4078) -> Result<()> {
4079    let db = session.db().await;
4080    let channel_id = ChannelId::from_proto(request.channel_id);
4081
4082    let (collaborators, epoch, version) = db
4083        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4084        .await?;
4085
4086    channel_buffer_updated(
4087        session.connection_id,
4088        collaborators.clone(),
4089        &proto::UpdateChannelBuffer {
4090            channel_id: channel_id.to_proto(),
4091            operations: request.operations,
4092        },
4093        &session.peer,
4094    );
4095
4096    let pool = &*session.connection_pool().await;
4097
4098    let non_collaborators =
4099        pool.channel_connection_ids(channel_id)
4100            .filter_map(|(connection_id, _)| {
4101                if collaborators.contains(&connection_id) {
4102                    None
4103                } else {
4104                    Some(connection_id)
4105                }
4106            });
4107
4108    broadcast(None, non_collaborators, |peer_id| {
4109        session.peer.send(
4110            peer_id,
4111            proto::UpdateChannels {
4112                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4113                    channel_id: channel_id.to_proto(),
4114                    epoch: epoch as u64,
4115                    version: version.clone(),
4116                }],
4117                ..Default::default()
4118            },
4119        )
4120    });
4121
4122    Ok(())
4123}
4124
4125/// Rejoin the channel notes after a connection blip
4126async fn rejoin_channel_buffers(
4127    request: proto::RejoinChannelBuffers,
4128    response: Response<proto::RejoinChannelBuffers>,
4129    session: UserSession,
4130) -> Result<()> {
4131    let db = session.db().await;
4132    let buffers = db
4133        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4134        .await?;
4135
4136    for rejoined_buffer in &buffers {
4137        let collaborators_to_notify = rejoined_buffer
4138            .buffer
4139            .collaborators
4140            .iter()
4141            .filter_map(|c| Some(c.peer_id?.into()));
4142        channel_buffer_updated(
4143            session.connection_id,
4144            collaborators_to_notify,
4145            &proto::UpdateChannelBufferCollaborators {
4146                channel_id: rejoined_buffer.buffer.channel_id,
4147                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4148            },
4149            &session.peer,
4150        );
4151    }
4152
4153    response.send(proto::RejoinChannelBuffersResponse {
4154        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4155    })?;
4156
4157    Ok(())
4158}
4159
4160/// Stop editing the channel notes
4161async fn leave_channel_buffer(
4162    request: proto::LeaveChannelBuffer,
4163    response: Response<proto::LeaveChannelBuffer>,
4164    session: UserSession,
4165) -> Result<()> {
4166    let db = session.db().await;
4167    let channel_id = ChannelId::from_proto(request.channel_id);
4168
4169    let left_buffer = db
4170        .leave_channel_buffer(channel_id, session.connection_id)
4171        .await?;
4172
4173    response.send(Ack {})?;
4174
4175    channel_buffer_updated(
4176        session.connection_id,
4177        left_buffer.connections,
4178        &proto::UpdateChannelBufferCollaborators {
4179            channel_id: channel_id.to_proto(),
4180            collaborators: left_buffer.collaborators,
4181        },
4182        &session.peer,
4183    );
4184
4185    Ok(())
4186}
4187
4188fn channel_buffer_updated<T: EnvelopedMessage>(
4189    sender_id: ConnectionId,
4190    collaborators: impl IntoIterator<Item = ConnectionId>,
4191    message: &T,
4192    peer: &Peer,
4193) {
4194    broadcast(Some(sender_id), collaborators, |peer_id| {
4195        peer.send(peer_id, message.clone())
4196    });
4197}
4198
4199fn send_notifications(
4200    connection_pool: &ConnectionPool,
4201    peer: &Peer,
4202    notifications: db::NotificationBatch,
4203) {
4204    for (user_id, notification) in notifications {
4205        for connection_id in connection_pool.user_connection_ids(user_id) {
4206            if let Err(error) = peer.send(
4207                connection_id,
4208                proto::AddNotification {
4209                    notification: Some(notification.clone()),
4210                },
4211            ) {
4212                tracing::error!(
4213                    "failed to send notification to {:?} {}",
4214                    connection_id,
4215                    error
4216                );
4217            }
4218        }
4219    }
4220}
4221
4222/// Send a message to the channel
4223async fn send_channel_message(
4224    request: proto::SendChannelMessage,
4225    response: Response<proto::SendChannelMessage>,
4226    session: UserSession,
4227) -> Result<()> {
4228    // Validate the message body.
4229    let body = request.body.trim().to_string();
4230    if body.len() > MAX_MESSAGE_LEN {
4231        return Err(anyhow!("message is too long"))?;
4232    }
4233    if body.is_empty() {
4234        return Err(anyhow!("message can't be blank"))?;
4235    }
4236
4237    // TODO: adjust mentions if body is trimmed
4238
4239    let timestamp = OffsetDateTime::now_utc();
4240    let nonce = request
4241        .nonce
4242        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4243
4244    let channel_id = ChannelId::from_proto(request.channel_id);
4245    let CreatedChannelMessage {
4246        message_id,
4247        participant_connection_ids,
4248        notifications,
4249    } = session
4250        .db()
4251        .await
4252        .create_channel_message(
4253            channel_id,
4254            session.user_id(),
4255            &body,
4256            &request.mentions,
4257            timestamp,
4258            nonce.clone().into(),
4259            request.reply_to_message_id.map(MessageId::from_proto),
4260        )
4261        .await?;
4262
4263    let message = proto::ChannelMessage {
4264        sender_id: session.user_id().to_proto(),
4265        id: message_id.to_proto(),
4266        body,
4267        mentions: request.mentions,
4268        timestamp: timestamp.unix_timestamp() as u64,
4269        nonce: Some(nonce),
4270        reply_to_message_id: request.reply_to_message_id,
4271        edited_at: None,
4272    };
4273    broadcast(
4274        Some(session.connection_id),
4275        participant_connection_ids.clone(),
4276        |connection| {
4277            session.peer.send(
4278                connection,
4279                proto::ChannelMessageSent {
4280                    channel_id: channel_id.to_proto(),
4281                    message: Some(message.clone()),
4282                },
4283            )
4284        },
4285    );
4286    response.send(proto::SendChannelMessageResponse {
4287        message: Some(message),
4288    })?;
4289
4290    let pool = &*session.connection_pool().await;
4291    let non_participants =
4292        pool.channel_connection_ids(channel_id)
4293            .filter_map(|(connection_id, _)| {
4294                if participant_connection_ids.contains(&connection_id) {
4295                    None
4296                } else {
4297                    Some(connection_id)
4298                }
4299            });
4300    broadcast(None, non_participants, |peer_id| {
4301        session.peer.send(
4302            peer_id,
4303            proto::UpdateChannels {
4304                latest_channel_message_ids: vec![proto::ChannelMessageId {
4305                    channel_id: channel_id.to_proto(),
4306                    message_id: message_id.to_proto(),
4307                }],
4308                ..Default::default()
4309            },
4310        )
4311    });
4312    send_notifications(pool, &session.peer, notifications);
4313
4314    Ok(())
4315}
4316
4317/// Delete a channel message
4318async fn remove_channel_message(
4319    request: proto::RemoveChannelMessage,
4320    response: Response<proto::RemoveChannelMessage>,
4321    session: UserSession,
4322) -> Result<()> {
4323    let channel_id = ChannelId::from_proto(request.channel_id);
4324    let message_id = MessageId::from_proto(request.message_id);
4325    let (connection_ids, existing_notification_ids) = session
4326        .db()
4327        .await
4328        .remove_channel_message(channel_id, message_id, session.user_id())
4329        .await?;
4330
4331    broadcast(
4332        Some(session.connection_id),
4333        connection_ids,
4334        move |connection| {
4335            session.peer.send(connection, request.clone())?;
4336
4337            for notification_id in &existing_notification_ids {
4338                session.peer.send(
4339                    connection,
4340                    proto::DeleteNotification {
4341                        notification_id: (*notification_id).to_proto(),
4342                    },
4343                )?;
4344            }
4345
4346            Ok(())
4347        },
4348    );
4349    response.send(proto::Ack {})?;
4350    Ok(())
4351}
4352
4353async fn update_channel_message(
4354    request: proto::UpdateChannelMessage,
4355    response: Response<proto::UpdateChannelMessage>,
4356    session: UserSession,
4357) -> Result<()> {
4358    let channel_id = ChannelId::from_proto(request.channel_id);
4359    let message_id = MessageId::from_proto(request.message_id);
4360    let updated_at = OffsetDateTime::now_utc();
4361    let UpdatedChannelMessage {
4362        message_id,
4363        participant_connection_ids,
4364        notifications,
4365        reply_to_message_id,
4366        timestamp,
4367        deleted_mention_notification_ids,
4368        updated_mention_notifications,
4369    } = session
4370        .db()
4371        .await
4372        .update_channel_message(
4373            channel_id,
4374            message_id,
4375            session.user_id(),
4376            request.body.as_str(),
4377            &request.mentions,
4378            updated_at,
4379        )
4380        .await?;
4381
4382    let nonce = request
4383        .nonce
4384        .clone()
4385        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4386
4387    let message = proto::ChannelMessage {
4388        sender_id: session.user_id().to_proto(),
4389        id: message_id.to_proto(),
4390        body: request.body.clone(),
4391        mentions: request.mentions.clone(),
4392        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4393        nonce: Some(nonce),
4394        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4395        edited_at: Some(updated_at.unix_timestamp() as u64),
4396    };
4397
4398    response.send(proto::Ack {})?;
4399
4400    let pool = &*session.connection_pool().await;
4401    broadcast(
4402        Some(session.connection_id),
4403        participant_connection_ids,
4404        |connection| {
4405            session.peer.send(
4406                connection,
4407                proto::ChannelMessageUpdate {
4408                    channel_id: channel_id.to_proto(),
4409                    message: Some(message.clone()),
4410                },
4411            )?;
4412
4413            for notification_id in &deleted_mention_notification_ids {
4414                session.peer.send(
4415                    connection,
4416                    proto::DeleteNotification {
4417                        notification_id: (*notification_id).to_proto(),
4418                    },
4419                )?;
4420            }
4421
4422            for notification in &updated_mention_notifications {
4423                session.peer.send(
4424                    connection,
4425                    proto::UpdateNotification {
4426                        notification: Some(notification.clone()),
4427                    },
4428                )?;
4429            }
4430
4431            Ok(())
4432        },
4433    );
4434
4435    send_notifications(pool, &session.peer, notifications);
4436
4437    Ok(())
4438}
4439
4440/// Mark a channel message as read
4441async fn acknowledge_channel_message(
4442    request: proto::AckChannelMessage,
4443    session: UserSession,
4444) -> Result<()> {
4445    let channel_id = ChannelId::from_proto(request.channel_id);
4446    let message_id = MessageId::from_proto(request.message_id);
4447    let notifications = session
4448        .db()
4449        .await
4450        .observe_channel_message(channel_id, session.user_id(), message_id)
4451        .await?;
4452    send_notifications(
4453        &*session.connection_pool().await,
4454        &session.peer,
4455        notifications,
4456    );
4457    Ok(())
4458}
4459
4460/// Mark a buffer version as synced
4461async fn acknowledge_buffer_version(
4462    request: proto::AckBufferOperation,
4463    session: UserSession,
4464) -> Result<()> {
4465    let buffer_id = BufferId::from_proto(request.buffer_id);
4466    session
4467        .db()
4468        .await
4469        .observe_buffer_version(
4470            buffer_id,
4471            session.user_id(),
4472            request.epoch as i32,
4473            &request.version,
4474        )
4475        .await?;
4476    Ok(())
4477}
4478
4479async fn count_language_model_tokens(
4480    request: proto::CountLanguageModelTokens,
4481    response: Response<proto::CountLanguageModelTokens>,
4482    session: Session,
4483    config: &Config,
4484) -> Result<()> {
4485    let Some(session) = session.for_user() else {
4486        return Err(anyhow!("user not found"))?;
4487    };
4488    authorize_access_to_legacy_llm_endpoints(&session).await?;
4489
4490    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4491        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4492        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4493    };
4494
4495    session
4496        .app_state
4497        .rate_limiter
4498        .check(&*rate_limit, session.user_id())
4499        .await?;
4500
4501    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4502        Some(proto::LanguageModelProvider::Google) => {
4503            let api_key = config
4504                .google_ai_api_key
4505                .as_ref()
4506                .context("no Google AI API key configured on the server")?;
4507            google_ai::count_tokens(
4508                session.http_client.as_ref(),
4509                google_ai::API_URL,
4510                api_key,
4511                serde_json::from_str(&request.request)?,
4512                None,
4513            )
4514            .await?
4515        }
4516        _ => return Err(anyhow!("unsupported provider"))?,
4517    };
4518
4519    response.send(proto::CountLanguageModelTokensResponse {
4520        token_count: result.total_tokens as u32,
4521    })?;
4522
4523    Ok(())
4524}
4525
4526struct ZedProCountLanguageModelTokensRateLimit;
4527
4528impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4529    fn capacity(&self) -> usize {
4530        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4531            .ok()
4532            .and_then(|v| v.parse().ok())
4533            .unwrap_or(600) // Picked arbitrarily
4534    }
4535
4536    fn refill_duration(&self) -> chrono::Duration {
4537        chrono::Duration::hours(1)
4538    }
4539
4540    fn db_name(&self) -> &'static str {
4541        "zed-pro:count-language-model-tokens"
4542    }
4543}
4544
4545struct FreeCountLanguageModelTokensRateLimit;
4546
4547impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4548    fn capacity(&self) -> usize {
4549        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4550            .ok()
4551            .and_then(|v| v.parse().ok())
4552            .unwrap_or(600 / 10) // Picked arbitrarily
4553    }
4554
4555    fn refill_duration(&self) -> chrono::Duration {
4556        chrono::Duration::hours(1)
4557    }
4558
4559    fn db_name(&self) -> &'static str {
4560        "free:count-language-model-tokens"
4561    }
4562}
4563
4564struct ZedProComputeEmbeddingsRateLimit;
4565
4566impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4567    fn capacity(&self) -> usize {
4568        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4569            .ok()
4570            .and_then(|v| v.parse().ok())
4571            .unwrap_or(5000) // Picked arbitrarily
4572    }
4573
4574    fn refill_duration(&self) -> chrono::Duration {
4575        chrono::Duration::hours(1)
4576    }
4577
4578    fn db_name(&self) -> &'static str {
4579        "zed-pro:compute-embeddings"
4580    }
4581}
4582
4583struct FreeComputeEmbeddingsRateLimit;
4584
4585impl RateLimit for FreeComputeEmbeddingsRateLimit {
4586    fn capacity(&self) -> usize {
4587        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4588            .ok()
4589            .and_then(|v| v.parse().ok())
4590            .unwrap_or(5000 / 10) // Picked arbitrarily
4591    }
4592
4593    fn refill_duration(&self) -> chrono::Duration {
4594        chrono::Duration::hours(1)
4595    }
4596
4597    fn db_name(&self) -> &'static str {
4598        "free:compute-embeddings"
4599    }
4600}
4601
4602async fn compute_embeddings(
4603    request: proto::ComputeEmbeddings,
4604    response: Response<proto::ComputeEmbeddings>,
4605    session: UserSession,
4606    api_key: Option<Arc<str>>,
4607) -> Result<()> {
4608    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4609    authorize_access_to_legacy_llm_endpoints(&session).await?;
4610
4611    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4612        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4613        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4614    };
4615
4616    session
4617        .app_state
4618        .rate_limiter
4619        .check(&*rate_limit, session.user_id())
4620        .await?;
4621
4622    let embeddings = match request.model.as_str() {
4623        "openai/text-embedding-3-small" => {
4624            open_ai::embed(
4625                session.http_client.as_ref(),
4626                OPEN_AI_API_URL,
4627                &api_key,
4628                OpenAiEmbeddingModel::TextEmbedding3Small,
4629                request.texts.iter().map(|text| text.as_str()),
4630            )
4631            .await?
4632        }
4633        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4634    };
4635
4636    let embeddings = request
4637        .texts
4638        .iter()
4639        .map(|text| {
4640            let mut hasher = sha2::Sha256::new();
4641            hasher.update(text.as_bytes());
4642            let result = hasher.finalize();
4643            result.to_vec()
4644        })
4645        .zip(
4646            embeddings
4647                .data
4648                .into_iter()
4649                .map(|embedding| embedding.embedding),
4650        )
4651        .collect::<HashMap<_, _>>();
4652
4653    let db = session.db().await;
4654    db.save_embeddings(&request.model, &embeddings)
4655        .await
4656        .context("failed to save embeddings")
4657        .trace_err();
4658
4659    response.send(proto::ComputeEmbeddingsResponse {
4660        embeddings: embeddings
4661            .into_iter()
4662            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4663            .collect(),
4664    })?;
4665    Ok(())
4666}
4667
4668async fn get_cached_embeddings(
4669    request: proto::GetCachedEmbeddings,
4670    response: Response<proto::GetCachedEmbeddings>,
4671    session: UserSession,
4672) -> Result<()> {
4673    authorize_access_to_legacy_llm_endpoints(&session).await?;
4674
4675    let db = session.db().await;
4676    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4677
4678    response.send(proto::GetCachedEmbeddingsResponse {
4679        embeddings: embeddings
4680            .into_iter()
4681            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4682            .collect(),
4683    })?;
4684    Ok(())
4685}
4686
4687/// This is leftover from before the LLM service.
4688///
4689/// The endpoints protected by this check will be moved there eventually.
4690async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4691    if session.is_staff() {
4692        Ok(())
4693    } else {
4694        Err(anyhow!("permission denied"))?
4695    }
4696}
4697
4698/// Get a Supermaven API key for the user
4699async fn get_supermaven_api_key(
4700    _request: proto::GetSupermavenApiKey,
4701    response: Response<proto::GetSupermavenApiKey>,
4702    session: UserSession,
4703) -> Result<()> {
4704    let user_id: String = session.user_id().to_string();
4705    if !session.is_staff() {
4706        return Err(anyhow!("supermaven not enabled for this account"))?;
4707    }
4708
4709    let email = session
4710        .email()
4711        .ok_or_else(|| anyhow!("user must have an email"))?;
4712
4713    let supermaven_admin_api = session
4714        .supermaven_client
4715        .as_ref()
4716        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4717
4718    let result = supermaven_admin_api
4719        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4720        .await?;
4721
4722    response.send(proto::GetSupermavenApiKeyResponse {
4723        api_key: result.api_key,
4724    })?;
4725
4726    Ok(())
4727}
4728
4729/// Start receiving chat updates for a channel
4730async fn join_channel_chat(
4731    request: proto::JoinChannelChat,
4732    response: Response<proto::JoinChannelChat>,
4733    session: UserSession,
4734) -> Result<()> {
4735    let channel_id = ChannelId::from_proto(request.channel_id);
4736
4737    let db = session.db().await;
4738    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4739        .await?;
4740    let messages = db
4741        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4742        .await?;
4743    response.send(proto::JoinChannelChatResponse {
4744        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4745        messages,
4746    })?;
4747    Ok(())
4748}
4749
4750/// Stop receiving chat updates for a channel
4751async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4752    let channel_id = ChannelId::from_proto(request.channel_id);
4753    session
4754        .db()
4755        .await
4756        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4757        .await?;
4758    Ok(())
4759}
4760
4761/// Retrieve the chat history for a channel
4762async fn get_channel_messages(
4763    request: proto::GetChannelMessages,
4764    response: Response<proto::GetChannelMessages>,
4765    session: UserSession,
4766) -> Result<()> {
4767    let channel_id = ChannelId::from_proto(request.channel_id);
4768    let messages = session
4769        .db()
4770        .await
4771        .get_channel_messages(
4772            channel_id,
4773            session.user_id(),
4774            MESSAGE_COUNT_PER_PAGE,
4775            Some(MessageId::from_proto(request.before_message_id)),
4776        )
4777        .await?;
4778    response.send(proto::GetChannelMessagesResponse {
4779        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4780        messages,
4781    })?;
4782    Ok(())
4783}
4784
4785/// Retrieve specific chat messages
4786async fn get_channel_messages_by_id(
4787    request: proto::GetChannelMessagesById,
4788    response: Response<proto::GetChannelMessagesById>,
4789    session: UserSession,
4790) -> Result<()> {
4791    let message_ids = request
4792        .message_ids
4793        .iter()
4794        .map(|id| MessageId::from_proto(*id))
4795        .collect::<Vec<_>>();
4796    let messages = session
4797        .db()
4798        .await
4799        .get_channel_messages_by_id(session.user_id(), &message_ids)
4800        .await?;
4801    response.send(proto::GetChannelMessagesResponse {
4802        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4803        messages,
4804    })?;
4805    Ok(())
4806}
4807
4808/// Retrieve the current users notifications
4809async fn get_notifications(
4810    request: proto::GetNotifications,
4811    response: Response<proto::GetNotifications>,
4812    session: UserSession,
4813) -> Result<()> {
4814    let notifications = session
4815        .db()
4816        .await
4817        .get_notifications(
4818            session.user_id(),
4819            NOTIFICATION_COUNT_PER_PAGE,
4820            request.before_id.map(db::NotificationId::from_proto),
4821        )
4822        .await?;
4823    response.send(proto::GetNotificationsResponse {
4824        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4825        notifications,
4826    })?;
4827    Ok(())
4828}
4829
4830/// Mark notifications as read
4831async fn mark_notification_as_read(
4832    request: proto::MarkNotificationRead,
4833    response: Response<proto::MarkNotificationRead>,
4834    session: UserSession,
4835) -> Result<()> {
4836    let database = &session.db().await;
4837    let notifications = database
4838        .mark_notification_as_read_by_id(
4839            session.user_id(),
4840            NotificationId::from_proto(request.notification_id),
4841        )
4842        .await?;
4843    send_notifications(
4844        &*session.connection_pool().await,
4845        &session.peer,
4846        notifications,
4847    );
4848    response.send(proto::Ack {})?;
4849    Ok(())
4850}
4851
4852/// Get the current users information
4853async fn get_private_user_info(
4854    _request: proto::GetPrivateUserInfo,
4855    response: Response<proto::GetPrivateUserInfo>,
4856    session: UserSession,
4857) -> Result<()> {
4858    let db = session.db().await;
4859
4860    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4861    let user = db
4862        .get_user_by_id(session.user_id())
4863        .await?
4864        .ok_or_else(|| anyhow!("user not found"))?;
4865    let flags = db.get_user_flags(session.user_id()).await?;
4866
4867    response.send(proto::GetPrivateUserInfoResponse {
4868        metrics_id,
4869        staff: user.admin,
4870        flags,
4871        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4872    })?;
4873    Ok(())
4874}
4875
4876/// Accept the terms of service (tos) on behalf of the current user
4877async fn accept_terms_of_service(
4878    _request: proto::AcceptTermsOfService,
4879    response: Response<proto::AcceptTermsOfService>,
4880    session: UserSession,
4881) -> Result<()> {
4882    let db = session.db().await;
4883
4884    let accepted_tos_at = Utc::now();
4885    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4886        .await?;
4887
4888    response.send(proto::AcceptTermsOfServiceResponse {
4889        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4890    })?;
4891    Ok(())
4892}
4893
4894/// The minimum account age an account must have in order to use the LLM service.
4895const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4896
4897async fn get_llm_api_token(
4898    _request: proto::GetLlmToken,
4899    response: Response<proto::GetLlmToken>,
4900    session: UserSession,
4901) -> Result<()> {
4902    let db = session.db().await;
4903
4904    let flags = db.get_user_flags(session.user_id()).await?;
4905    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4906    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4907
4908    if !session.is_staff() && !has_language_models_feature_flag {
4909        Err(anyhow!("permission denied"))?
4910    }
4911
4912    let user_id = session.user_id();
4913    let user = db
4914        .get_user_by_id(user_id)
4915        .await?
4916        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4917
4918    if user.accepted_tos_at.is_none() {
4919        Err(anyhow!("terms of service not accepted"))?
4920    }
4921
4922    let mut account_created_at = user.created_at;
4923    if let Some(github_created_at) = user.github_user_created_at {
4924        account_created_at = account_created_at.min(github_created_at);
4925    }
4926    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4927        Err(anyhow!("account too young"))?
4928    }
4929
4930    let billing_preferences = db.get_billing_preferences(user.id).await?;
4931
4932    let token = LlmTokenClaims::create(
4933        user.id,
4934        user.github_login.clone(),
4935        session.is_staff(),
4936        billing_preferences,
4937        has_llm_closed_beta_feature_flag,
4938        session.has_llm_subscription(&db).await?,
4939        session.current_plan(&db).await?,
4940        &session.app_state.config,
4941    )?;
4942    response.send(proto::GetLlmTokenResponse { token })?;
4943    Ok(())
4944}
4945
4946fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4947    let message = match message {
4948        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4949        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4950        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4951        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4952        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4953            code: frame.code.into(),
4954            reason: frame.reason,
4955        })),
4956        // We should never receive a frame while reading the message, according
4957        // to the `tungstenite` maintainers:
4958        //
4959        // > It cannot occur when you read messages from the WebSocket, but it
4960        // > can be used when you want to send the raw frames (e.g. you want to
4961        // > send the frames to the WebSocket without composing the full message first).
4962        // >
4963        // > — https://github.com/snapview/tungstenite-rs/issues/268
4964        TungsteniteMessage::Frame(_) => {
4965            bail!("received an unexpected frame while reading the message")
4966        }
4967    };
4968
4969    Ok(message)
4970}
4971
4972fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4973    match message {
4974        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4975        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4976        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4977        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4978        AxumMessage::Close(frame) => {
4979            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4980                code: frame.code.into(),
4981                reason: frame.reason,
4982            }))
4983        }
4984    }
4985}
4986
4987fn notify_membership_updated(
4988    connection_pool: &mut ConnectionPool,
4989    result: MembershipUpdated,
4990    user_id: UserId,
4991    peer: &Peer,
4992) {
4993    for membership in &result.new_channels.channel_memberships {
4994        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4995    }
4996    for channel_id in &result.removed_channels {
4997        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4998    }
4999
5000    let user_channels_update = proto::UpdateUserChannels {
5001        channel_memberships: result
5002            .new_channels
5003            .channel_memberships
5004            .iter()
5005            .map(|cm| proto::ChannelMembership {
5006                channel_id: cm.channel_id.to_proto(),
5007                role: cm.role.into(),
5008            })
5009            .collect(),
5010        ..Default::default()
5011    };
5012
5013    let mut update = build_channels_update(result.new_channels);
5014    update.delete_channels = result
5015        .removed_channels
5016        .into_iter()
5017        .map(|id| id.to_proto())
5018        .collect();
5019    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5020
5021    for connection_id in connection_pool.user_connection_ids(user_id) {
5022        peer.send(connection_id, user_channels_update.clone())
5023            .trace_err();
5024        peer.send(connection_id, update.clone()).trace_err();
5025    }
5026}
5027
5028fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5029    proto::UpdateUserChannels {
5030        channel_memberships: channels
5031            .channel_memberships
5032            .iter()
5033            .map(|m| proto::ChannelMembership {
5034                channel_id: m.channel_id.to_proto(),
5035                role: m.role.into(),
5036            })
5037            .collect(),
5038        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5039        observed_channel_message_id: channels.observed_channel_messages.clone(),
5040    }
5041}
5042
5043fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5044    let mut update = proto::UpdateChannels::default();
5045
5046    for channel in channels.channels {
5047        update.channels.push(channel.to_proto());
5048    }
5049
5050    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5051    update.latest_channel_message_ids = channels.latest_channel_messages;
5052
5053    for (channel_id, participants) in channels.channel_participants {
5054        update
5055            .channel_participants
5056            .push(proto::ChannelParticipants {
5057                channel_id: channel_id.to_proto(),
5058                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5059            });
5060    }
5061
5062    for channel in channels.invited_channels {
5063        update.channel_invitations.push(channel.to_proto());
5064    }
5065
5066    update.hosted_projects = channels.hosted_projects;
5067    update
5068}
5069
5070fn build_initial_contacts_update(
5071    contacts: Vec<db::Contact>,
5072    pool: &ConnectionPool,
5073) -> proto::UpdateContacts {
5074    let mut update = proto::UpdateContacts::default();
5075
5076    for contact in contacts {
5077        match contact {
5078            db::Contact::Accepted { user_id, busy } => {
5079                update.contacts.push(contact_for_user(user_id, busy, pool));
5080            }
5081            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5082            db::Contact::Incoming { user_id } => {
5083                update
5084                    .incoming_requests
5085                    .push(proto::IncomingContactRequest {
5086                        requester_id: user_id.to_proto(),
5087                    })
5088            }
5089        }
5090    }
5091
5092    update
5093}
5094
5095fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5096    proto::Contact {
5097        user_id: user_id.to_proto(),
5098        online: pool.is_user_online(user_id),
5099        busy,
5100    }
5101}
5102
5103fn room_updated(room: &proto::Room, peer: &Peer) {
5104    broadcast(
5105        None,
5106        room.participants
5107            .iter()
5108            .filter_map(|participant| Some(participant.peer_id?.into())),
5109        |peer_id| {
5110            peer.send(
5111                peer_id,
5112                proto::RoomUpdated {
5113                    room: Some(room.clone()),
5114                },
5115            )
5116        },
5117    );
5118}
5119
5120fn channel_updated(
5121    channel: &db::channel::Model,
5122    room: &proto::Room,
5123    peer: &Peer,
5124    pool: &ConnectionPool,
5125) {
5126    let participants = room
5127        .participants
5128        .iter()
5129        .map(|p| p.user_id)
5130        .collect::<Vec<_>>();
5131
5132    broadcast(
5133        None,
5134        pool.channel_connection_ids(channel.root_id())
5135            .filter_map(|(channel_id, role)| {
5136                role.can_see_channel(channel.visibility)
5137                    .then_some(channel_id)
5138            }),
5139        |peer_id| {
5140            peer.send(
5141                peer_id,
5142                proto::UpdateChannels {
5143                    channel_participants: vec![proto::ChannelParticipants {
5144                        channel_id: channel.id.to_proto(),
5145                        participant_user_ids: participants.clone(),
5146                    }],
5147                    ..Default::default()
5148                },
5149            )
5150        },
5151    );
5152}
5153
5154async fn send_dev_server_projects_update(
5155    user_id: UserId,
5156    mut status: proto::DevServerProjectsUpdate,
5157    session: &Session,
5158) {
5159    let pool = session.connection_pool().await;
5160    for dev_server in &mut status.dev_servers {
5161        dev_server.status =
5162            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5163    }
5164    let connections = pool.user_connection_ids(user_id);
5165    for connection_id in connections {
5166        session.peer.send(connection_id, status.clone()).trace_err();
5167    }
5168}
5169
5170async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5171    let db = session.db().await;
5172
5173    let contacts = db.get_contacts(user_id).await?;
5174    let busy = db.is_user_busy(user_id).await?;
5175
5176    let pool = session.connection_pool().await;
5177    let updated_contact = contact_for_user(user_id, busy, &pool);
5178    for contact in contacts {
5179        if let db::Contact::Accepted {
5180            user_id: contact_user_id,
5181            ..
5182        } = contact
5183        {
5184            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5185                session
5186                    .peer
5187                    .send(
5188                        contact_conn_id,
5189                        proto::UpdateContacts {
5190                            contacts: vec![updated_contact.clone()],
5191                            remove_contacts: Default::default(),
5192                            incoming_requests: Default::default(),
5193                            remove_incoming_requests: Default::default(),
5194                            outgoing_requests: Default::default(),
5195                            remove_outgoing_requests: Default::default(),
5196                        },
5197                    )
5198                    .trace_err();
5199            }
5200        }
5201    }
5202    Ok(())
5203}
5204
5205async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5206    log::info!("lost dev server connection, unsharing projects");
5207    let project_ids = session
5208        .db()
5209        .await
5210        .get_stale_dev_server_projects(session.connection_id)
5211        .await?;
5212
5213    for project_id in project_ids {
5214        // not unshare re-checks the connection ids match, so we get away with no transaction
5215        unshare_project_internal(project_id, session.connection_id, None, session).await?;
5216    }
5217
5218    let user_id = session.dev_server().user_id;
5219    let update = session
5220        .db()
5221        .await
5222        .dev_server_projects_update(user_id)
5223        .await?;
5224
5225    send_dev_server_projects_update(user_id, update, session).await;
5226
5227    Ok(())
5228}
5229
5230async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5231    let mut contacts_to_update = HashSet::default();
5232
5233    let room_id;
5234    let canceled_calls_to_user_ids;
5235    let live_kit_room;
5236    let delete_live_kit_room;
5237    let room;
5238    let channel;
5239
5240    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5241        contacts_to_update.insert(session.user_id());
5242
5243        for project in left_room.left_projects.values() {
5244            project_left(project, session);
5245        }
5246
5247        room_id = RoomId::from_proto(left_room.room.id);
5248        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5249        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5250        delete_live_kit_room = left_room.deleted;
5251        room = mem::take(&mut left_room.room);
5252        channel = mem::take(&mut left_room.channel);
5253
5254        room_updated(&room, &session.peer);
5255    } else {
5256        return Ok(());
5257    }
5258
5259    if let Some(channel) = channel {
5260        channel_updated(
5261            &channel,
5262            &room,
5263            &session.peer,
5264            &*session.connection_pool().await,
5265        );
5266    }
5267
5268    {
5269        let pool = session.connection_pool().await;
5270        for canceled_user_id in canceled_calls_to_user_ids {
5271            for connection_id in pool.user_connection_ids(canceled_user_id) {
5272                session
5273                    .peer
5274                    .send(
5275                        connection_id,
5276                        proto::CallCanceled {
5277                            room_id: room_id.to_proto(),
5278                        },
5279                    )
5280                    .trace_err();
5281            }
5282            contacts_to_update.insert(canceled_user_id);
5283        }
5284    }
5285
5286    for contact_user_id in contacts_to_update {
5287        update_user_contacts(contact_user_id, session).await?;
5288    }
5289
5290    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5291        live_kit
5292            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5293            .await
5294            .trace_err();
5295
5296        if delete_live_kit_room {
5297            live_kit.delete_room(live_kit_room).await.trace_err();
5298        }
5299    }
5300
5301    Ok(())
5302}
5303
5304async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5305    let left_channel_buffers = session
5306        .db()
5307        .await
5308        .leave_channel_buffers(session.connection_id)
5309        .await?;
5310
5311    for left_buffer in left_channel_buffers {
5312        channel_buffer_updated(
5313            session.connection_id,
5314            left_buffer.connections,
5315            &proto::UpdateChannelBufferCollaborators {
5316                channel_id: left_buffer.channel_id.to_proto(),
5317                collaborators: left_buffer.collaborators,
5318            },
5319            &session.peer,
5320        );
5321    }
5322
5323    Ok(())
5324}
5325
5326fn project_left(project: &db::LeftProject, session: &UserSession) {
5327    for connection_id in &project.connection_ids {
5328        if project.should_unshare {
5329            session
5330                .peer
5331                .send(
5332                    *connection_id,
5333                    proto::UnshareProject {
5334                        project_id: project.id.to_proto(),
5335                    },
5336                )
5337                .trace_err();
5338        } else {
5339            session
5340                .peer
5341                .send(
5342                    *connection_id,
5343                    proto::RemoveProjectCollaborator {
5344                        project_id: project.id.to_proto(),
5345                        peer_id: Some(session.connection_id.into()),
5346                    },
5347                )
5348                .trace_err();
5349        }
5350    }
5351}
5352
5353pub trait ResultExt {
5354    type Ok;
5355
5356    fn trace_err(self) -> Option<Self::Ok>;
5357}
5358
5359impl<T, E> ResultExt for Result<T, E>
5360where
5361    E: std::fmt::Debug,
5362{
5363    type Ok = T;
5364
5365    #[track_caller]
5366    fn trace_err(self) -> Option<T> {
5367        match self {
5368            Ok(value) => Some(value),
5369            Err(error) => {
5370                tracing::error!("{:?}", error);
5371                None
5372            }
5373        }
5374    }
5375}