rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use http_client::HttpClient;
  39use isahc_http_client::IsahcHttpClient;
  40use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  41use sha2::Digest;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    channel::oneshot,
  46    future::{self, BoxFuture},
  47    stream::FuturesUnordered,
  48    FutureExt, SinkExt, StreamExt, TryStreamExt,
  49};
  50use prometheus::{register_int_gauge, IntGauge};
  51use rpc::{
  52    proto::{
  53        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  54        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  55    },
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57};
  58use semantic_version::SemanticVersion;
  59use serde::{Serialize, Serializer};
  60use std::{
  61    any::TypeId,
  62    future::Future,
  63    marker::PhantomData,
  64    mem,
  65    net::SocketAddr,
  66    ops::{Deref, DerefMut},
  67    rc::Rc,
  68    sync::{
  69        atomic::{AtomicBool, Ordering::SeqCst},
  70        Arc, OnceLock,
  71    },
  72    time::{Duration, Instant},
  73};
  74use time::OffsetDateTime;
  75use tokio::sync::{watch, MutexGuard, Semaphore};
  76use tower::ServiceBuilder;
  77use tracing::{
  78    field::{self},
  79    info_span, instrument, Instrument,
  80};
  81
  82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  83
  84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  86
  87const MESSAGE_COUNT_PER_PAGE: usize = 100;
  88const MAX_MESSAGE_LEN: usize = 1024;
  89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  93
  94struct Response<R> {
  95    peer: Arc<Peer>,
  96    receipt: Receipt<R>,
  97    responded: Arc<AtomicBool>,
  98}
  99
 100impl<R: RequestMessage> Response<R> {
 101    fn send(self, payload: R::Response) -> Result<()> {
 102        self.responded.store(true, SeqCst);
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108#[derive(Clone, Debug)]
 109pub enum Principal {
 110    User(User),
 111    Impersonated { user: User, admin: User },
 112    DevServer(dev_server::Model),
 113}
 114
 115impl Principal {
 116    fn update_span(&self, span: &tracing::Span) {
 117        match &self {
 118            Principal::User(user) => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121            }
 122            Principal::Impersonated { user, admin } => {
 123                span.record("user_id", user.id.0);
 124                span.record("login", &user.github_login);
 125                span.record("impersonator", &admin.github_login);
 126            }
 127            Principal::DevServer(dev_server) => {
 128                span.record("dev_server_id", dev_server.id.0);
 129            }
 130        }
 131    }
 132}
 133
 134#[derive(Clone)]
 135struct Session {
 136    principal: Principal,
 137    connection_id: ConnectionId,
 138    db: Arc<tokio::sync::Mutex<DbHandle>>,
 139    peer: Arc<Peer>,
 140    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 141    app_state: Arc<AppState>,
 142    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 143    http_client: Arc<dyn HttpClient>,
 144    /// The GeoIP country code for the user.
 145    #[allow(unused)]
 146    geoip_country_code: Option<String>,
 147    _executor: Executor,
 148}
 149
 150impl Session {
 151    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.db.lock().await;
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        guard
 158    }
 159
 160    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        let guard = self.connection_pool.lock();
 164        ConnectionPoolGuard {
 165            guard,
 166            _not_send: PhantomData,
 167        }
 168    }
 169
 170    fn for_user(self) -> Option<UserSession> {
 171        UserSession::new(self)
 172    }
 173
 174    fn for_dev_server(self) -> Option<DevServerSession> {
 175        DevServerSession::new(self)
 176    }
 177
 178    fn user_id(&self) -> Option<UserId> {
 179        match &self.principal {
 180            Principal::User(user) => Some(user.id),
 181            Principal::Impersonated { user, .. } => Some(user.id),
 182            Principal::DevServer(_) => None,
 183        }
 184    }
 185
 186    fn is_staff(&self) -> bool {
 187        match &self.principal {
 188            Principal::User(user) => user.admin,
 189            Principal::Impersonated { .. } => true,
 190            Principal::DevServer(_) => false,
 191        }
 192    }
 193
 194    pub async fn has_llm_subscription(
 195        &self,
 196        db: &MutexGuard<'_, DbHandle>,
 197    ) -> anyhow::Result<bool> {
 198        if self.is_staff() {
 199            return Ok(true);
 200        }
 201
 202        let Some(user_id) = self.user_id() else {
 203            return Ok(false);
 204        };
 205
 206        Ok(db.has_active_billing_subscription(user_id).await?)
 207    }
 208
 209    pub async fn current_plan(
 210        &self,
 211        _db: &MutexGuard<'_, DbHandle>,
 212    ) -> anyhow::Result<proto::Plan> {
 213        if self.is_staff() {
 214            Ok(proto::Plan::ZedPro)
 215        } else {
 216            Ok(proto::Plan::Free)
 217        }
 218    }
 219
 220    fn dev_server_id(&self) -> Option<DevServerId> {
 221        match &self.principal {
 222            Principal::User(_) | Principal::Impersonated { .. } => None,
 223            Principal::DevServer(dev_server) => Some(dev_server.id),
 224        }
 225    }
 226
 227    fn principal_id(&self) -> PrincipalId {
 228        match &self.principal {
 229            Principal::User(user) => PrincipalId::UserId(user.id),
 230            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 231            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 232        }
 233    }
 234}
 235
 236impl Debug for Session {
 237    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 238        let mut result = f.debug_struct("Session");
 239        match &self.principal {
 240            Principal::User(user) => {
 241                result.field("user", &user.github_login);
 242            }
 243            Principal::Impersonated { user, admin } => {
 244                result.field("user", &user.github_login);
 245                result.field("impersonator", &admin.github_login);
 246            }
 247            Principal::DevServer(dev_server) => {
 248                result.field("dev_server", &dev_server.id);
 249            }
 250        }
 251        result.field("connection_id", &self.connection_id).finish()
 252    }
 253}
 254
 255struct UserSession(Session);
 256
 257impl UserSession {
 258    pub fn new(s: Session) -> Option<Self> {
 259        s.user_id().map(|_| UserSession(s))
 260    }
 261    pub fn user_id(&self) -> UserId {
 262        self.0.user_id().unwrap()
 263    }
 264
 265    pub fn email(&self) -> Option<String> {
 266        match &self.0.principal {
 267            Principal::User(user) => user.email_address.clone(),
 268            Principal::Impersonated { user, .. } => user.email_address.clone(),
 269            Principal::DevServer(..) => None,
 270        }
 271    }
 272}
 273
 274impl Deref for UserSession {
 275    type Target = Session;
 276
 277    fn deref(&self) -> &Self::Target {
 278        &self.0
 279    }
 280}
 281impl DerefMut for UserSession {
 282    fn deref_mut(&mut self) -> &mut Self::Target {
 283        &mut self.0
 284    }
 285}
 286
 287struct DevServerSession(Session);
 288
 289impl DevServerSession {
 290    pub fn new(s: Session) -> Option<Self> {
 291        s.dev_server_id().map(|_| DevServerSession(s))
 292    }
 293    pub fn dev_server_id(&self) -> DevServerId {
 294        self.0.dev_server_id().unwrap()
 295    }
 296
 297    fn dev_server(&self) -> &dev_server::Model {
 298        match &self.0.principal {
 299            Principal::DevServer(dev_server) => dev_server,
 300            _ => unreachable!(),
 301        }
 302    }
 303}
 304
 305impl Deref for DevServerSession {
 306    type Target = Session;
 307
 308    fn deref(&self) -> &Self::Target {
 309        &self.0
 310    }
 311}
 312impl DerefMut for DevServerSession {
 313    fn deref_mut(&mut self) -> &mut Self::Target {
 314        &mut self.0
 315    }
 316}
 317
 318fn user_handler<M: RequestMessage, Fut>(
 319    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 320) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 321where
 322    Fut: Send + Future<Output = Result<()>>,
 323{
 324    let handler = Arc::new(handler);
 325    move |message, response, session| {
 326        let handler = handler.clone();
 327        Box::pin(async move {
 328            if let Some(user_session) = session.for_user() {
 329                Ok(handler(message, response, user_session).await?)
 330            } else {
 331                Err(Error::Internal(anyhow!(
 332                    "must be a user to call {}",
 333                    M::NAME
 334                )))
 335            }
 336        })
 337    }
 338}
 339
 340fn dev_server_handler<M: RequestMessage, Fut>(
 341    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 342) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 343where
 344    Fut: Send + Future<Output = Result<()>>,
 345{
 346    let handler = Arc::new(handler);
 347    move |message, response, session| {
 348        let handler = handler.clone();
 349        Box::pin(async move {
 350            if let Some(dev_server_session) = session.for_dev_server() {
 351                Ok(handler(message, response, dev_server_session).await?)
 352            } else {
 353                Err(Error::Internal(anyhow!(
 354                    "must be a dev server to call {}",
 355                    M::NAME
 356                )))
 357            }
 358        })
 359    }
 360}
 361
 362fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 363    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 364) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 365where
 366    InnertRetFut: Send + Future<Output = Result<()>>,
 367{
 368    let handler = Arc::new(handler);
 369    move |message, session| {
 370        let handler = handler.clone();
 371        Box::pin(async move {
 372            if let Some(user_session) = session.for_user() {
 373                Ok(handler(message, user_session).await?)
 374            } else {
 375                Err(Error::Internal(anyhow!(
 376                    "must be a user to call {}",
 377                    M::NAME
 378                )))
 379            }
 380        })
 381    }
 382}
 383
 384struct DbHandle(Arc<Database>);
 385
 386impl Deref for DbHandle {
 387    type Target = Database;
 388
 389    fn deref(&self) -> &Self::Target {
 390        self.0.as_ref()
 391    }
 392}
 393
 394pub struct Server {
 395    id: parking_lot::Mutex<ServerId>,
 396    peer: Arc<Peer>,
 397    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 398    app_state: Arc<AppState>,
 399    handlers: HashMap<TypeId, MessageHandler>,
 400    teardown: watch::Sender<bool>,
 401}
 402
 403pub(crate) struct ConnectionPoolGuard<'a> {
 404    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 405    _not_send: PhantomData<Rc<()>>,
 406}
 407
 408#[derive(Serialize)]
 409pub struct ServerSnapshot<'a> {
 410    peer: &'a Peer,
 411    #[serde(serialize_with = "serialize_deref")]
 412    connection_pool: ConnectionPoolGuard<'a>,
 413}
 414
 415pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 416where
 417    S: Serializer,
 418    T: Deref<Target = U>,
 419    U: Serialize,
 420{
 421    Serialize::serialize(value.deref(), serializer)
 422}
 423
 424impl Server {
 425    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 426        let mut server = Self {
 427            id: parking_lot::Mutex::new(id),
 428            peer: Peer::new(id.0 as u32),
 429            app_state: app_state.clone(),
 430            connection_pool: Default::default(),
 431            handlers: Default::default(),
 432            teardown: watch::channel(false).0,
 433        };
 434
 435        server
 436            .add_request_handler(ping)
 437            .add_request_handler(user_handler(create_room))
 438            .add_request_handler(user_handler(join_room))
 439            .add_request_handler(user_handler(rejoin_room))
 440            .add_request_handler(user_handler(leave_room))
 441            .add_request_handler(user_handler(set_room_participant_role))
 442            .add_request_handler(user_handler(call))
 443            .add_request_handler(user_handler(cancel_call))
 444            .add_message_handler(user_message_handler(decline_call))
 445            .add_request_handler(user_handler(update_participant_location))
 446            .add_request_handler(user_handler(share_project))
 447            .add_message_handler(unshare_project)
 448            .add_request_handler(user_handler(join_project))
 449            .add_request_handler(user_handler(join_hosted_project))
 450            .add_request_handler(user_handler(rejoin_dev_server_projects))
 451            .add_request_handler(user_handler(create_dev_server_project))
 452            .add_request_handler(user_handler(update_dev_server_project))
 453            .add_request_handler(user_handler(delete_dev_server_project))
 454            .add_request_handler(user_handler(create_dev_server))
 455            .add_request_handler(user_handler(regenerate_dev_server_token))
 456            .add_request_handler(user_handler(rename_dev_server))
 457            .add_request_handler(user_handler(delete_dev_server))
 458            .add_request_handler(user_handler(list_remote_directory))
 459            .add_request_handler(dev_server_handler(share_dev_server_project))
 460            .add_request_handler(dev_server_handler(shutdown_dev_server))
 461            .add_request_handler(dev_server_handler(reconnect_dev_server))
 462            .add_message_handler(user_message_handler(leave_project))
 463            .add_request_handler(update_project)
 464            .add_request_handler(update_worktree)
 465            .add_message_handler(start_language_server)
 466            .add_message_handler(update_language_server)
 467            .add_message_handler(update_diagnostic_summary)
 468            .add_message_handler(update_worktree_settings)
 469            .add_request_handler(user_handler(
 470                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_project_request_for_owner::<proto::TaskTemplates>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetHover>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::GetDefinition>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetTypeDefinition>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::GetReferences>,
 486            ))
 487            .add_request_handler(user_handler(forward_find_search_candidates_request))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_read_only_project_request::<proto::GetProjectSymbols>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_read_only_project_request::<proto::OpenBufferById>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_read_only_project_request::<proto::InlayHints>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_read_only_project_request::<proto::ResolveInlayHint>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_read_only_project_request::<proto::OpenBufferByPath>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::GetCompletions>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::OpenNewBuffer>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::GetCodeActions>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::ApplyCodeAction>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::PrepareRename>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::PerformRename>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::ReloadBuffers>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_mutating_project_request::<proto::FormatBuffers>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::CreateProjectEntry>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::RenameProjectEntry>,
 547            ))
 548            .add_request_handler(user_handler(
 549                forward_mutating_project_request::<proto::CopyProjectEntry>,
 550            ))
 551            .add_request_handler(user_handler(
 552                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 553            ))
 554            .add_request_handler(user_handler(
 555                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 556            ))
 557            .add_request_handler(user_handler(
 558                forward_mutating_project_request::<proto::OnTypeFormatting>,
 559            ))
 560            .add_request_handler(user_handler(
 561                forward_mutating_project_request::<proto::SaveBuffer>,
 562            ))
 563            .add_request_handler(user_handler(
 564                forward_mutating_project_request::<proto::BlameBuffer>,
 565            ))
 566            .add_request_handler(user_handler(
 567                forward_mutating_project_request::<proto::MultiLspQuery>,
 568            ))
 569            .add_request_handler(user_handler(
 570                forward_mutating_project_request::<proto::RestartLanguageServers>,
 571            ))
 572            .add_request_handler(user_handler(
 573                forward_mutating_project_request::<proto::LinkedEditingRange>,
 574            ))
 575            .add_message_handler(create_buffer_for_peer)
 576            .add_request_handler(update_buffer)
 577            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 578            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 579            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 580            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 581            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 582            .add_request_handler(get_users)
 583            .add_request_handler(user_handler(fuzzy_search_users))
 584            .add_request_handler(user_handler(request_contact))
 585            .add_request_handler(user_handler(remove_contact))
 586            .add_request_handler(user_handler(respond_to_contact_request))
 587            .add_message_handler(subscribe_to_channels)
 588            .add_request_handler(user_handler(create_channel))
 589            .add_request_handler(user_handler(delete_channel))
 590            .add_request_handler(user_handler(invite_channel_member))
 591            .add_request_handler(user_handler(remove_channel_member))
 592            .add_request_handler(user_handler(set_channel_member_role))
 593            .add_request_handler(user_handler(set_channel_visibility))
 594            .add_request_handler(user_handler(rename_channel))
 595            .add_request_handler(user_handler(join_channel_buffer))
 596            .add_request_handler(user_handler(leave_channel_buffer))
 597            .add_message_handler(user_message_handler(update_channel_buffer))
 598            .add_request_handler(user_handler(rejoin_channel_buffers))
 599            .add_request_handler(user_handler(get_channel_members))
 600            .add_request_handler(user_handler(respond_to_channel_invite))
 601            .add_request_handler(user_handler(join_channel))
 602            .add_request_handler(user_handler(join_channel_chat))
 603            .add_message_handler(user_message_handler(leave_channel_chat))
 604            .add_request_handler(user_handler(send_channel_message))
 605            .add_request_handler(user_handler(remove_channel_message))
 606            .add_request_handler(user_handler(update_channel_message))
 607            .add_request_handler(user_handler(get_channel_messages))
 608            .add_request_handler(user_handler(get_channel_messages_by_id))
 609            .add_request_handler(user_handler(get_notifications))
 610            .add_request_handler(user_handler(mark_notification_as_read))
 611            .add_request_handler(user_handler(move_channel))
 612            .add_request_handler(user_handler(follow))
 613            .add_message_handler(user_message_handler(unfollow))
 614            .add_message_handler(user_message_handler(update_followers))
 615            .add_request_handler(user_handler(get_private_user_info))
 616            .add_request_handler(user_handler(get_llm_api_token))
 617            .add_request_handler(user_handler(accept_terms_of_service))
 618            .add_message_handler(user_message_handler(acknowledge_channel_message))
 619            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 620            .add_request_handler(user_handler(get_supermaven_api_key))
 621            .add_request_handler(user_handler(
 622                forward_mutating_project_request::<proto::OpenContext>,
 623            ))
 624            .add_request_handler(user_handler(
 625                forward_mutating_project_request::<proto::CreateContext>,
 626            ))
 627            .add_request_handler(user_handler(
 628                forward_mutating_project_request::<proto::SynchronizeContexts>,
 629            ))
 630            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 631            .add_message_handler(update_context)
 632            .add_request_handler({
 633                let app_state = app_state.clone();
 634                move |request, response, session| {
 635                    let app_state = app_state.clone();
 636                    async move {
 637                        count_language_model_tokens(request, response, session, &app_state.config)
 638                            .await
 639                    }
 640                }
 641            })
 642            .add_request_handler({
 643                user_handler(move |request, response, session| {
 644                    get_cached_embeddings(request, response, session)
 645                })
 646            })
 647            .add_request_handler({
 648                let app_state = app_state.clone();
 649                user_handler(move |request, response, session| {
 650                    compute_embeddings(
 651                        request,
 652                        response,
 653                        session,
 654                        app_state.config.openai_api_key.clone(),
 655                    )
 656                })
 657            });
 658
 659        Arc::new(server)
 660    }
 661
 662    pub async fn start(&self) -> Result<()> {
 663        let server_id = *self.id.lock();
 664        let app_state = self.app_state.clone();
 665        let peer = self.peer.clone();
 666        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 667        let pool = self.connection_pool.clone();
 668        let live_kit_client = self.app_state.live_kit_client.clone();
 669
 670        let span = info_span!("start server");
 671        self.app_state.executor.spawn_detached(
 672            async move {
 673                tracing::info!("waiting for cleanup timeout");
 674                timeout.await;
 675                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 676                if let Some((room_ids, channel_ids)) = app_state
 677                    .db
 678                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 679                    .await
 680                    .trace_err()
 681                {
 682                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 683                    tracing::info!(
 684                        stale_channel_buffer_count = channel_ids.len(),
 685                        "retrieved stale channel buffers"
 686                    );
 687
 688                    for channel_id in channel_ids {
 689                        if let Some(refreshed_channel_buffer) = app_state
 690                            .db
 691                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 692                            .await
 693                            .trace_err()
 694                        {
 695                            for connection_id in refreshed_channel_buffer.connection_ids {
 696                                peer.send(
 697                                    connection_id,
 698                                    proto::UpdateChannelBufferCollaborators {
 699                                        channel_id: channel_id.to_proto(),
 700                                        collaborators: refreshed_channel_buffer
 701                                            .collaborators
 702                                            .clone(),
 703                                    },
 704                                )
 705                                .trace_err();
 706                            }
 707                        }
 708                    }
 709
 710                    for room_id in room_ids {
 711                        let mut contacts_to_update = HashSet::default();
 712                        let mut canceled_calls_to_user_ids = Vec::new();
 713                        let mut live_kit_room = String::new();
 714                        let mut delete_live_kit_room = false;
 715
 716                        if let Some(mut refreshed_room) = app_state
 717                            .db
 718                            .clear_stale_room_participants(room_id, server_id)
 719                            .await
 720                            .trace_err()
 721                        {
 722                            tracing::info!(
 723                                room_id = room_id.0,
 724                                new_participant_count = refreshed_room.room.participants.len(),
 725                                "refreshed room"
 726                            );
 727                            room_updated(&refreshed_room.room, &peer);
 728                            if let Some(channel) = refreshed_room.channel.as_ref() {
 729                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 730                            }
 731                            contacts_to_update
 732                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 733                            contacts_to_update
 734                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 735                            canceled_calls_to_user_ids =
 736                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 737                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 738                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 739                        }
 740
 741                        {
 742                            let pool = pool.lock();
 743                            for canceled_user_id in canceled_calls_to_user_ids {
 744                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 745                                    peer.send(
 746                                        connection_id,
 747                                        proto::CallCanceled {
 748                                            room_id: room_id.to_proto(),
 749                                        },
 750                                    )
 751                                    .trace_err();
 752                                }
 753                            }
 754                        }
 755
 756                        for user_id in contacts_to_update {
 757                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 758                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 759                            if let Some((busy, contacts)) = busy.zip(contacts) {
 760                                let pool = pool.lock();
 761                                let updated_contact = contact_for_user(user_id, busy, &pool);
 762                                for contact in contacts {
 763                                    if let db::Contact::Accepted {
 764                                        user_id: contact_user_id,
 765                                        ..
 766                                    } = contact
 767                                    {
 768                                        for contact_conn_id in
 769                                            pool.user_connection_ids(contact_user_id)
 770                                        {
 771                                            peer.send(
 772                                                contact_conn_id,
 773                                                proto::UpdateContacts {
 774                                                    contacts: vec![updated_contact.clone()],
 775                                                    remove_contacts: Default::default(),
 776                                                    incoming_requests: Default::default(),
 777                                                    remove_incoming_requests: Default::default(),
 778                                                    outgoing_requests: Default::default(),
 779                                                    remove_outgoing_requests: Default::default(),
 780                                                },
 781                                            )
 782                                            .trace_err();
 783                                        }
 784                                    }
 785                                }
 786                            }
 787                        }
 788
 789                        if let Some(live_kit) = live_kit_client.as_ref() {
 790                            if delete_live_kit_room {
 791                                live_kit.delete_room(live_kit_room).await.trace_err();
 792                            }
 793                        }
 794                    }
 795                }
 796
 797                app_state
 798                    .db
 799                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 800                    .await
 801                    .trace_err();
 802            }
 803            .instrument(span),
 804        );
 805        Ok(())
 806    }
 807
 808    pub fn teardown(&self) {
 809        self.peer.teardown();
 810        self.connection_pool.lock().reset();
 811        let _ = self.teardown.send(true);
 812    }
 813
 814    #[cfg(test)]
 815    pub fn reset(&self, id: ServerId) {
 816        self.teardown();
 817        *self.id.lock() = id;
 818        self.peer.reset(id.0 as u32);
 819        let _ = self.teardown.send(false);
 820    }
 821
 822    #[cfg(test)]
 823    pub fn id(&self) -> ServerId {
 824        *self.id.lock()
 825    }
 826
 827    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 828    where
 829        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 830        Fut: 'static + Send + Future<Output = Result<()>>,
 831        M: EnvelopedMessage,
 832    {
 833        let prev_handler = self.handlers.insert(
 834            TypeId::of::<M>(),
 835            Box::new(move |envelope, session| {
 836                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 837                let received_at = envelope.received_at;
 838                tracing::info!("message received");
 839                let start_time = Instant::now();
 840                let future = (handler)(*envelope, session);
 841                async move {
 842                    let result = future.await;
 843                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 844                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 845                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 846                    let payload_type = M::NAME;
 847
 848                    match result {
 849                        Err(error) => {
 850                            tracing::error!(
 851                                ?error,
 852                                total_duration_ms,
 853                                processing_duration_ms,
 854                                queue_duration_ms,
 855                                payload_type,
 856                                "error handling message"
 857                            )
 858                        }
 859                        Ok(()) => tracing::info!(
 860                            total_duration_ms,
 861                            processing_duration_ms,
 862                            queue_duration_ms,
 863                            "finished handling message"
 864                        ),
 865                    }
 866                }
 867                .boxed()
 868            }),
 869        );
 870        if prev_handler.is_some() {
 871            panic!("registered a handler for the same message twice");
 872        }
 873        self
 874    }
 875
 876    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 877    where
 878        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 879        Fut: 'static + Send + Future<Output = Result<()>>,
 880        M: EnvelopedMessage,
 881    {
 882        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 883        self
 884    }
 885
 886    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 887    where
 888        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 889        Fut: Send + Future<Output = Result<()>>,
 890        M: RequestMessage,
 891    {
 892        let handler = Arc::new(handler);
 893        self.add_handler(move |envelope, session| {
 894            let receipt = envelope.receipt();
 895            let handler = handler.clone();
 896            async move {
 897                let peer = session.peer.clone();
 898                let responded = Arc::new(AtomicBool::default());
 899                let response = Response {
 900                    peer: peer.clone(),
 901                    responded: responded.clone(),
 902                    receipt,
 903                };
 904                match (handler)(envelope.payload, response, session).await {
 905                    Ok(()) => {
 906                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 907                            Ok(())
 908                        } else {
 909                            Err(anyhow!("handler did not send a response"))?
 910                        }
 911                    }
 912                    Err(error) => {
 913                        let proto_err = match &error {
 914                            Error::Internal(err) => err.to_proto(),
 915                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 916                        };
 917                        peer.respond_with_error(receipt, proto_err)?;
 918                        Err(error)
 919                    }
 920                }
 921            }
 922        })
 923    }
 924
 925    #[allow(clippy::too_many_arguments)]
 926    pub fn handle_connection(
 927        self: &Arc<Self>,
 928        connection: Connection,
 929        address: String,
 930        principal: Principal,
 931        zed_version: ZedVersion,
 932        geoip_country_code: Option<String>,
 933        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 934        executor: Executor,
 935    ) -> impl Future<Output = ()> {
 936        let this = self.clone();
 937        let span = info_span!("handle connection", %address,
 938            connection_id=field::Empty,
 939            user_id=field::Empty,
 940            login=field::Empty,
 941            impersonator=field::Empty,
 942            dev_server_id=field::Empty,
 943            geoip_country_code=field::Empty
 944        );
 945        principal.update_span(&span);
 946        if let Some(country_code) = geoip_country_code.as_ref() {
 947            span.record("geoip_country_code", country_code);
 948        }
 949
 950        let mut teardown = self.teardown.subscribe();
 951        async move {
 952            if *teardown.borrow() {
 953                tracing::error!("server is tearing down");
 954                return
 955            }
 956            let (connection_id, handle_io, mut incoming_rx) = this
 957                .peer
 958                .add_connection(connection, {
 959                    let executor = executor.clone();
 960                    move |duration| executor.sleep(duration)
 961                });
 962            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 963
 964            tracing::info!("connection opened");
 965
 966            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 967            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 968                Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
 969                Err(error) => {
 970                    tracing::error!(?error, "failed to create HTTP client");
 971                    return;
 972                }
 973            };
 974
 975            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 976                    supermaven_admin_api_key.to_string(),
 977                    http_client.clone(),
 978                )));
 979
 980            let session = Session {
 981                principal: principal.clone(),
 982                connection_id,
 983                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 984                peer: this.peer.clone(),
 985                connection_pool: this.connection_pool.clone(),
 986                app_state: this.app_state.clone(),
 987                http_client,
 988                geoip_country_code,
 989                _executor: executor.clone(),
 990                supermaven_client,
 991            };
 992
 993            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 994                tracing::error!(?error, "failed to send initial client update");
 995                return;
 996            }
 997
 998            let handle_io = handle_io.fuse();
 999            futures::pin_mut!(handle_io);
1000
1001            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1002            // This prevents deadlocks when e.g., client A performs a request to client B and
1003            // client B performs a request to client A. If both clients stop processing further
1004            // messages until their respective request completes, they won't have a chance to
1005            // respond to the other client's request and cause a deadlock.
1006            //
1007            // This arrangement ensures we will attempt to process earlier messages first, but fall
1008            // back to processing messages arrived later in the spirit of making progress.
1009            let mut foreground_message_handlers = FuturesUnordered::new();
1010            let concurrent_handlers = Arc::new(Semaphore::new(256));
1011            loop {
1012                let next_message = async {
1013                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1014                    let message = incoming_rx.next().await;
1015                    (permit, message)
1016                }.fuse();
1017                futures::pin_mut!(next_message);
1018                futures::select_biased! {
1019                    _ = teardown.changed().fuse() => return,
1020                    result = handle_io => {
1021                        if let Err(error) = result {
1022                            tracing::error!(?error, "error handling I/O");
1023                        }
1024                        break;
1025                    }
1026                    _ = foreground_message_handlers.next() => {}
1027                    next_message = next_message => {
1028                        let (permit, message) = next_message;
1029                        if let Some(message) = message {
1030                            let type_name = message.payload_type_name();
1031                            // note: we copy all the fields from the parent span so we can query them in the logs.
1032                            // (https://github.com/tokio-rs/tracing/issues/2670).
1033                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1034                                user_id=field::Empty,
1035                                login=field::Empty,
1036                                impersonator=field::Empty,
1037                                dev_server_id=field::Empty
1038                            );
1039                            principal.update_span(&span);
1040                            let span_enter = span.enter();
1041                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1042                                let is_background = message.is_background();
1043                                let handle_message = (handler)(message, session.clone());
1044                                drop(span_enter);
1045
1046                                let handle_message = async move {
1047                                    handle_message.await;
1048                                    drop(permit);
1049                                }.instrument(span);
1050                                if is_background {
1051                                    executor.spawn_detached(handle_message);
1052                                } else {
1053                                    foreground_message_handlers.push(handle_message);
1054                                }
1055                            } else {
1056                                tracing::error!("no message handler");
1057                            }
1058                        } else {
1059                            tracing::info!("connection closed");
1060                            break;
1061                        }
1062                    }
1063                }
1064            }
1065
1066            drop(foreground_message_handlers);
1067            tracing::info!("signing out");
1068            if let Err(error) = connection_lost(session, teardown, executor).await {
1069                tracing::error!(?error, "error signing out");
1070            }
1071
1072        }.instrument(span)
1073    }
1074
1075    async fn send_initial_client_update(
1076        &self,
1077        connection_id: ConnectionId,
1078        principal: &Principal,
1079        zed_version: ZedVersion,
1080        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1081        session: &Session,
1082    ) -> Result<()> {
1083        self.peer.send(
1084            connection_id,
1085            proto::Hello {
1086                peer_id: Some(connection_id.into()),
1087            },
1088        )?;
1089        tracing::info!("sent hello message");
1090        if let Some(send_connection_id) = send_connection_id.take() {
1091            let _ = send_connection_id.send(connection_id);
1092        }
1093
1094        match principal {
1095            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1096                if !user.connected_once {
1097                    self.peer.send(connection_id, proto::ShowContacts {})?;
1098                    self.app_state
1099                        .db
1100                        .set_user_connected_once(user.id, true)
1101                        .await?;
1102                }
1103
1104                update_user_plan(user.id, session).await?;
1105
1106                let (contacts, dev_server_projects) = future::try_join(
1107                    self.app_state.db.get_contacts(user.id),
1108                    self.app_state.db.dev_server_projects_update(user.id),
1109                )
1110                .await?;
1111
1112                {
1113                    let mut pool = self.connection_pool.lock();
1114                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1115                    self.peer.send(
1116                        connection_id,
1117                        build_initial_contacts_update(contacts, &pool),
1118                    )?;
1119                }
1120
1121                if should_auto_subscribe_to_channels(zed_version) {
1122                    subscribe_user_to_channels(user.id, session).await?;
1123                }
1124
1125                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1126
1127                if let Some(incoming_call) =
1128                    self.app_state.db.incoming_call_for_user(user.id).await?
1129                {
1130                    self.peer.send(connection_id, incoming_call)?;
1131                }
1132
1133                update_user_contacts(user.id, session).await?;
1134            }
1135            Principal::DevServer(dev_server) => {
1136                {
1137                    let mut pool = self.connection_pool.lock();
1138                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1139                    {
1140                        self.peer.send(
1141                            stale_connection_id,
1142                            proto::ShutdownDevServer {
1143                                reason: Some(
1144                                    "another dev server connected with the same token".to_string(),
1145                                ),
1146                            },
1147                        )?;
1148                        pool.remove_connection(stale_connection_id)?;
1149                    };
1150                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1151                }
1152
1153                let projects = self
1154                    .app_state
1155                    .db
1156                    .get_projects_for_dev_server(dev_server.id)
1157                    .await?;
1158                self.peer
1159                    .send(connection_id, proto::DevServerInstructions { projects })?;
1160
1161                let status = self
1162                    .app_state
1163                    .db
1164                    .dev_server_projects_update(dev_server.user_id)
1165                    .await?;
1166                send_dev_server_projects_update(dev_server.user_id, status, session).await;
1167            }
1168        }
1169
1170        Ok(())
1171    }
1172
1173    pub async fn invite_code_redeemed(
1174        self: &Arc<Self>,
1175        inviter_id: UserId,
1176        invitee_id: UserId,
1177    ) -> Result<()> {
1178        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1179            if let Some(code) = &user.invite_code {
1180                let pool = self.connection_pool.lock();
1181                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1182                for connection_id in pool.user_connection_ids(inviter_id) {
1183                    self.peer.send(
1184                        connection_id,
1185                        proto::UpdateContacts {
1186                            contacts: vec![invitee_contact.clone()],
1187                            ..Default::default()
1188                        },
1189                    )?;
1190                    self.peer.send(
1191                        connection_id,
1192                        proto::UpdateInviteInfo {
1193                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1194                            count: user.invite_count as u32,
1195                        },
1196                    )?;
1197                }
1198            }
1199        }
1200        Ok(())
1201    }
1202
1203    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1204        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1205            if let Some(invite_code) = &user.invite_code {
1206                let pool = self.connection_pool.lock();
1207                for connection_id in pool.user_connection_ids(user_id) {
1208                    self.peer.send(
1209                        connection_id,
1210                        proto::UpdateInviteInfo {
1211                            url: format!(
1212                                "{}{}",
1213                                self.app_state.config.invite_link_prefix, invite_code
1214                            ),
1215                            count: user.invite_count as u32,
1216                        },
1217                    )?;
1218                }
1219            }
1220        }
1221        Ok(())
1222    }
1223
1224    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1225        ServerSnapshot {
1226            connection_pool: ConnectionPoolGuard {
1227                guard: self.connection_pool.lock(),
1228                _not_send: PhantomData,
1229            },
1230            peer: &self.peer,
1231        }
1232    }
1233}
1234
1235impl<'a> Deref for ConnectionPoolGuard<'a> {
1236    type Target = ConnectionPool;
1237
1238    fn deref(&self) -> &Self::Target {
1239        &self.guard
1240    }
1241}
1242
1243impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1244    fn deref_mut(&mut self) -> &mut Self::Target {
1245        &mut self.guard
1246    }
1247}
1248
1249impl<'a> Drop for ConnectionPoolGuard<'a> {
1250    fn drop(&mut self) {
1251        #[cfg(test)]
1252        self.check_invariants();
1253    }
1254}
1255
1256fn broadcast<F>(
1257    sender_id: Option<ConnectionId>,
1258    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1259    mut f: F,
1260) where
1261    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1262{
1263    for receiver_id in receiver_ids {
1264        if Some(receiver_id) != sender_id {
1265            if let Err(error) = f(receiver_id) {
1266                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1267            }
1268        }
1269    }
1270}
1271
1272pub struct ProtocolVersion(u32);
1273
1274impl Header for ProtocolVersion {
1275    fn name() -> &'static HeaderName {
1276        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1277        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1278    }
1279
1280    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1281    where
1282        Self: Sized,
1283        I: Iterator<Item = &'i axum::http::HeaderValue>,
1284    {
1285        let version = values
1286            .next()
1287            .ok_or_else(axum::headers::Error::invalid)?
1288            .to_str()
1289            .map_err(|_| axum::headers::Error::invalid())?
1290            .parse()
1291            .map_err(|_| axum::headers::Error::invalid())?;
1292        Ok(Self(version))
1293    }
1294
1295    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1296        values.extend([self.0.to_string().parse().unwrap()]);
1297    }
1298}
1299
1300pub struct AppVersionHeader(SemanticVersion);
1301impl Header for AppVersionHeader {
1302    fn name() -> &'static HeaderName {
1303        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1304        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1305    }
1306
1307    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1308    where
1309        Self: Sized,
1310        I: Iterator<Item = &'i axum::http::HeaderValue>,
1311    {
1312        let version = values
1313            .next()
1314            .ok_or_else(axum::headers::Error::invalid)?
1315            .to_str()
1316            .map_err(|_| axum::headers::Error::invalid())?
1317            .parse()
1318            .map_err(|_| axum::headers::Error::invalid())?;
1319        Ok(Self(version))
1320    }
1321
1322    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1323        values.extend([self.0.to_string().parse().unwrap()]);
1324    }
1325}
1326
1327pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1328    Router::new()
1329        .route("/rpc", get(handle_websocket_request))
1330        .layer(
1331            ServiceBuilder::new()
1332                .layer(Extension(server.app_state.clone()))
1333                .layer(middleware::from_fn(auth::validate_header)),
1334        )
1335        .route("/metrics", get(handle_metrics))
1336        .layer(Extension(server))
1337}
1338
1339pub async fn handle_websocket_request(
1340    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1341    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1342    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1343    Extension(server): Extension<Arc<Server>>,
1344    Extension(principal): Extension<Principal>,
1345    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1346    ws: WebSocketUpgrade,
1347) -> axum::response::Response {
1348    if protocol_version != rpc::PROTOCOL_VERSION {
1349        return (
1350            StatusCode::UPGRADE_REQUIRED,
1351            "client must be upgraded".to_string(),
1352        )
1353            .into_response();
1354    }
1355
1356    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1357        return (
1358            StatusCode::UPGRADE_REQUIRED,
1359            "no version header found".to_string(),
1360        )
1361            .into_response();
1362    };
1363
1364    if !version.can_collaborate() {
1365        return (
1366            StatusCode::UPGRADE_REQUIRED,
1367            "client must be upgraded".to_string(),
1368        )
1369            .into_response();
1370    }
1371
1372    let socket_address = socket_address.to_string();
1373    ws.on_upgrade(move |socket| {
1374        let socket = socket
1375            .map_ok(to_tungstenite_message)
1376            .err_into()
1377            .with(|message| async move { to_axum_message(message) });
1378        let connection = Connection::new(Box::pin(socket));
1379        async move {
1380            server
1381                .handle_connection(
1382                    connection,
1383                    socket_address,
1384                    principal,
1385                    version,
1386                    country_code_header.map(|header| header.to_string()),
1387                    None,
1388                    Executor::Production,
1389                )
1390                .await;
1391        }
1392    })
1393}
1394
1395pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1396    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1397    let connections_metric = CONNECTIONS_METRIC
1398        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1399
1400    let connections = server
1401        .connection_pool
1402        .lock()
1403        .connections()
1404        .filter(|connection| !connection.admin)
1405        .count();
1406    connections_metric.set(connections as _);
1407
1408    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1409    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1410        register_int_gauge!(
1411            "shared_projects",
1412            "number of open projects with one or more guests"
1413        )
1414        .unwrap()
1415    });
1416
1417    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1418    shared_projects_metric.set(shared_projects as _);
1419
1420    let encoder = prometheus::TextEncoder::new();
1421    let metric_families = prometheus::gather();
1422    let encoded_metrics = encoder
1423        .encode_to_string(&metric_families)
1424        .map_err(|err| anyhow!("{}", err))?;
1425    Ok(encoded_metrics)
1426}
1427
1428#[instrument(err, skip(executor))]
1429async fn connection_lost(
1430    session: Session,
1431    mut teardown: watch::Receiver<bool>,
1432    executor: Executor,
1433) -> Result<()> {
1434    session.peer.disconnect(session.connection_id);
1435    session
1436        .connection_pool()
1437        .await
1438        .remove_connection(session.connection_id)?;
1439
1440    session
1441        .db()
1442        .await
1443        .connection_lost(session.connection_id)
1444        .await
1445        .trace_err();
1446
1447    futures::select_biased! {
1448        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1449            match &session.principal {
1450                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1451                    let session = session.for_user().unwrap();
1452
1453                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1454                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1455                    leave_channel_buffers_for_session(&session)
1456                        .await
1457                        .trace_err();
1458
1459                    if !session
1460                        .connection_pool()
1461                        .await
1462                        .is_user_online(session.user_id())
1463                    {
1464                        let db = session.db().await;
1465                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1466                            room_updated(&room, &session.peer);
1467                        }
1468                    }
1469
1470                    update_user_contacts(session.user_id(), &session).await?;
1471                },
1472            Principal::DevServer(_) => {
1473                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1474            },
1475        }
1476        },
1477        _ = teardown.changed().fuse() => {}
1478    }
1479
1480    Ok(())
1481}
1482
1483/// Acknowledges a ping from a client, used to keep the connection alive.
1484async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1485    response.send(proto::Ack {})?;
1486    Ok(())
1487}
1488
1489/// Creates a new room for calling (outside of channels)
1490async fn create_room(
1491    _request: proto::CreateRoom,
1492    response: Response<proto::CreateRoom>,
1493    session: UserSession,
1494) -> Result<()> {
1495    let live_kit_room = nanoid::nanoid!(30);
1496
1497    let live_kit_connection_info = util::maybe!(async {
1498        let live_kit = session.app_state.live_kit_client.as_ref();
1499        let live_kit = live_kit?;
1500        let user_id = session.user_id().to_string();
1501
1502        let token = live_kit
1503            .room_token(&live_kit_room, &user_id.to_string())
1504            .trace_err()?;
1505
1506        Some(proto::LiveKitConnectionInfo {
1507            server_url: live_kit.url().into(),
1508            token,
1509            can_publish: true,
1510        })
1511    })
1512    .await;
1513
1514    let room = session
1515        .db()
1516        .await
1517        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1518        .await?;
1519
1520    response.send(proto::CreateRoomResponse {
1521        room: Some(room.clone()),
1522        live_kit_connection_info,
1523    })?;
1524
1525    update_user_contacts(session.user_id(), &session).await?;
1526    Ok(())
1527}
1528
1529/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1530async fn join_room(
1531    request: proto::JoinRoom,
1532    response: Response<proto::JoinRoom>,
1533    session: UserSession,
1534) -> Result<()> {
1535    let room_id = RoomId::from_proto(request.id);
1536
1537    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1538
1539    if let Some(channel_id) = channel_id {
1540        return join_channel_internal(channel_id, Box::new(response), session).await;
1541    }
1542
1543    let joined_room = {
1544        let room = session
1545            .db()
1546            .await
1547            .join_room(room_id, session.user_id(), session.connection_id)
1548            .await?;
1549        room_updated(&room.room, &session.peer);
1550        room.into_inner()
1551    };
1552
1553    for connection_id in session
1554        .connection_pool()
1555        .await
1556        .user_connection_ids(session.user_id())
1557    {
1558        session
1559            .peer
1560            .send(
1561                connection_id,
1562                proto::CallCanceled {
1563                    room_id: room_id.to_proto(),
1564                },
1565            )
1566            .trace_err();
1567    }
1568
1569    let live_kit_connection_info =
1570        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1571            live_kit
1572                .room_token(
1573                    &joined_room.room.live_kit_room,
1574                    &session.user_id().to_string(),
1575                )
1576                .trace_err()
1577                .map(|token| proto::LiveKitConnectionInfo {
1578                    server_url: live_kit.url().into(),
1579                    token,
1580                    can_publish: true,
1581                })
1582        } else {
1583            None
1584        };
1585
1586    response.send(proto::JoinRoomResponse {
1587        room: Some(joined_room.room),
1588        channel_id: None,
1589        live_kit_connection_info,
1590    })?;
1591
1592    update_user_contacts(session.user_id(), &session).await?;
1593    Ok(())
1594}
1595
1596/// Rejoin room is used to reconnect to a room after connection errors.
1597async fn rejoin_room(
1598    request: proto::RejoinRoom,
1599    response: Response<proto::RejoinRoom>,
1600    session: UserSession,
1601) -> Result<()> {
1602    let room;
1603    let channel;
1604    {
1605        let mut rejoined_room = session
1606            .db()
1607            .await
1608            .rejoin_room(request, session.user_id(), session.connection_id)
1609            .await?;
1610
1611        response.send(proto::RejoinRoomResponse {
1612            room: Some(rejoined_room.room.clone()),
1613            reshared_projects: rejoined_room
1614                .reshared_projects
1615                .iter()
1616                .map(|project| proto::ResharedProject {
1617                    id: project.id.to_proto(),
1618                    collaborators: project
1619                        .collaborators
1620                        .iter()
1621                        .map(|collaborator| collaborator.to_proto())
1622                        .collect(),
1623                })
1624                .collect(),
1625            rejoined_projects: rejoined_room
1626                .rejoined_projects
1627                .iter()
1628                .map(|rejoined_project| rejoined_project.to_proto())
1629                .collect(),
1630        })?;
1631        room_updated(&rejoined_room.room, &session.peer);
1632
1633        for project in &rejoined_room.reshared_projects {
1634            for collaborator in &project.collaborators {
1635                session
1636                    .peer
1637                    .send(
1638                        collaborator.connection_id,
1639                        proto::UpdateProjectCollaborator {
1640                            project_id: project.id.to_proto(),
1641                            old_peer_id: Some(project.old_connection_id.into()),
1642                            new_peer_id: Some(session.connection_id.into()),
1643                        },
1644                    )
1645                    .trace_err();
1646            }
1647
1648            broadcast(
1649                Some(session.connection_id),
1650                project
1651                    .collaborators
1652                    .iter()
1653                    .map(|collaborator| collaborator.connection_id),
1654                |connection_id| {
1655                    session.peer.forward_send(
1656                        session.connection_id,
1657                        connection_id,
1658                        proto::UpdateProject {
1659                            project_id: project.id.to_proto(),
1660                            worktrees: project.worktrees.clone(),
1661                        },
1662                    )
1663                },
1664            );
1665        }
1666
1667        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1668
1669        let rejoined_room = rejoined_room.into_inner();
1670
1671        room = rejoined_room.room;
1672        channel = rejoined_room.channel;
1673    }
1674
1675    if let Some(channel) = channel {
1676        channel_updated(
1677            &channel,
1678            &room,
1679            &session.peer,
1680            &*session.connection_pool().await,
1681        );
1682    }
1683
1684    update_user_contacts(session.user_id(), &session).await?;
1685    Ok(())
1686}
1687
1688fn notify_rejoined_projects(
1689    rejoined_projects: &mut Vec<RejoinedProject>,
1690    session: &UserSession,
1691) -> Result<()> {
1692    for project in rejoined_projects.iter() {
1693        for collaborator in &project.collaborators {
1694            session
1695                .peer
1696                .send(
1697                    collaborator.connection_id,
1698                    proto::UpdateProjectCollaborator {
1699                        project_id: project.id.to_proto(),
1700                        old_peer_id: Some(project.old_connection_id.into()),
1701                        new_peer_id: Some(session.connection_id.into()),
1702                    },
1703                )
1704                .trace_err();
1705        }
1706    }
1707
1708    for project in rejoined_projects {
1709        for worktree in mem::take(&mut project.worktrees) {
1710            #[cfg(any(test, feature = "test-support"))]
1711            const MAX_CHUNK_SIZE: usize = 2;
1712            #[cfg(not(any(test, feature = "test-support")))]
1713            const MAX_CHUNK_SIZE: usize = 256;
1714
1715            // Stream this worktree's entries.
1716            let message = proto::UpdateWorktree {
1717                project_id: project.id.to_proto(),
1718                worktree_id: worktree.id,
1719                abs_path: worktree.abs_path.clone(),
1720                root_name: worktree.root_name,
1721                updated_entries: worktree.updated_entries,
1722                removed_entries: worktree.removed_entries,
1723                scan_id: worktree.scan_id,
1724                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1725                updated_repositories: worktree.updated_repositories,
1726                removed_repositories: worktree.removed_repositories,
1727            };
1728            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1729                session.peer.send(session.connection_id, update.clone())?;
1730            }
1731
1732            // Stream this worktree's diagnostics.
1733            for summary in worktree.diagnostic_summaries {
1734                session.peer.send(
1735                    session.connection_id,
1736                    proto::UpdateDiagnosticSummary {
1737                        project_id: project.id.to_proto(),
1738                        worktree_id: worktree.id,
1739                        summary: Some(summary),
1740                    },
1741                )?;
1742            }
1743
1744            for settings_file in worktree.settings_files {
1745                session.peer.send(
1746                    session.connection_id,
1747                    proto::UpdateWorktreeSettings {
1748                        project_id: project.id.to_proto(),
1749                        worktree_id: worktree.id,
1750                        path: settings_file.path,
1751                        content: Some(settings_file.content),
1752                        kind: Some(settings_file.kind.to_proto().into()),
1753                    },
1754                )?;
1755            }
1756        }
1757
1758        for language_server in &project.language_servers {
1759            session.peer.send(
1760                session.connection_id,
1761                proto::UpdateLanguageServer {
1762                    project_id: project.id.to_proto(),
1763                    language_server_id: language_server.id,
1764                    variant: Some(
1765                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1766                            proto::LspDiskBasedDiagnosticsUpdated {},
1767                        ),
1768                    ),
1769                },
1770            )?;
1771        }
1772    }
1773    Ok(())
1774}
1775
1776/// leave room disconnects from the room.
1777async fn leave_room(
1778    _: proto::LeaveRoom,
1779    response: Response<proto::LeaveRoom>,
1780    session: UserSession,
1781) -> Result<()> {
1782    leave_room_for_session(&session, session.connection_id).await?;
1783    response.send(proto::Ack {})?;
1784    Ok(())
1785}
1786
1787/// Updates the permissions of someone else in the room.
1788async fn set_room_participant_role(
1789    request: proto::SetRoomParticipantRole,
1790    response: Response<proto::SetRoomParticipantRole>,
1791    session: UserSession,
1792) -> Result<()> {
1793    let user_id = UserId::from_proto(request.user_id);
1794    let role = ChannelRole::from(request.role());
1795
1796    let (live_kit_room, can_publish) = {
1797        let room = session
1798            .db()
1799            .await
1800            .set_room_participant_role(
1801                session.user_id(),
1802                RoomId::from_proto(request.room_id),
1803                user_id,
1804                role,
1805            )
1806            .await?;
1807
1808        let live_kit_room = room.live_kit_room.clone();
1809        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1810        room_updated(&room, &session.peer);
1811        (live_kit_room, can_publish)
1812    };
1813
1814    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1815        live_kit
1816            .update_participant(
1817                live_kit_room.clone(),
1818                request.user_id.to_string(),
1819                live_kit_server::proto::ParticipantPermission {
1820                    can_subscribe: true,
1821                    can_publish,
1822                    can_publish_data: can_publish,
1823                    hidden: false,
1824                    recorder: false,
1825                },
1826            )
1827            .await
1828            .trace_err();
1829    }
1830
1831    response.send(proto::Ack {})?;
1832    Ok(())
1833}
1834
1835/// Call someone else into the current room
1836async fn call(
1837    request: proto::Call,
1838    response: Response<proto::Call>,
1839    session: UserSession,
1840) -> Result<()> {
1841    let room_id = RoomId::from_proto(request.room_id);
1842    let calling_user_id = session.user_id();
1843    let calling_connection_id = session.connection_id;
1844    let called_user_id = UserId::from_proto(request.called_user_id);
1845    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1846    if !session
1847        .db()
1848        .await
1849        .has_contact(calling_user_id, called_user_id)
1850        .await?
1851    {
1852        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1853    }
1854
1855    let incoming_call = {
1856        let (room, incoming_call) = &mut *session
1857            .db()
1858            .await
1859            .call(
1860                room_id,
1861                calling_user_id,
1862                calling_connection_id,
1863                called_user_id,
1864                initial_project_id,
1865            )
1866            .await?;
1867        room_updated(room, &session.peer);
1868        mem::take(incoming_call)
1869    };
1870    update_user_contacts(called_user_id, &session).await?;
1871
1872    let mut calls = session
1873        .connection_pool()
1874        .await
1875        .user_connection_ids(called_user_id)
1876        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1877        .collect::<FuturesUnordered<_>>();
1878
1879    while let Some(call_response) = calls.next().await {
1880        match call_response.as_ref() {
1881            Ok(_) => {
1882                response.send(proto::Ack {})?;
1883                return Ok(());
1884            }
1885            Err(_) => {
1886                call_response.trace_err();
1887            }
1888        }
1889    }
1890
1891    {
1892        let room = session
1893            .db()
1894            .await
1895            .call_failed(room_id, called_user_id)
1896            .await?;
1897        room_updated(&room, &session.peer);
1898    }
1899    update_user_contacts(called_user_id, &session).await?;
1900
1901    Err(anyhow!("failed to ring user"))?
1902}
1903
1904/// Cancel an outgoing call.
1905async fn cancel_call(
1906    request: proto::CancelCall,
1907    response: Response<proto::CancelCall>,
1908    session: UserSession,
1909) -> Result<()> {
1910    let called_user_id = UserId::from_proto(request.called_user_id);
1911    let room_id = RoomId::from_proto(request.room_id);
1912    {
1913        let room = session
1914            .db()
1915            .await
1916            .cancel_call(room_id, session.connection_id, called_user_id)
1917            .await?;
1918        room_updated(&room, &session.peer);
1919    }
1920
1921    for connection_id in session
1922        .connection_pool()
1923        .await
1924        .user_connection_ids(called_user_id)
1925    {
1926        session
1927            .peer
1928            .send(
1929                connection_id,
1930                proto::CallCanceled {
1931                    room_id: room_id.to_proto(),
1932                },
1933            )
1934            .trace_err();
1935    }
1936    response.send(proto::Ack {})?;
1937
1938    update_user_contacts(called_user_id, &session).await?;
1939    Ok(())
1940}
1941
1942/// Decline an incoming call.
1943async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1944    let room_id = RoomId::from_proto(message.room_id);
1945    {
1946        let room = session
1947            .db()
1948            .await
1949            .decline_call(Some(room_id), session.user_id())
1950            .await?
1951            .ok_or_else(|| anyhow!("failed to decline call"))?;
1952        room_updated(&room, &session.peer);
1953    }
1954
1955    for connection_id in session
1956        .connection_pool()
1957        .await
1958        .user_connection_ids(session.user_id())
1959    {
1960        session
1961            .peer
1962            .send(
1963                connection_id,
1964                proto::CallCanceled {
1965                    room_id: room_id.to_proto(),
1966                },
1967            )
1968            .trace_err();
1969    }
1970    update_user_contacts(session.user_id(), &session).await?;
1971    Ok(())
1972}
1973
1974/// Updates other participants in the room with your current location.
1975async fn update_participant_location(
1976    request: proto::UpdateParticipantLocation,
1977    response: Response<proto::UpdateParticipantLocation>,
1978    session: UserSession,
1979) -> Result<()> {
1980    let room_id = RoomId::from_proto(request.room_id);
1981    let location = request
1982        .location
1983        .ok_or_else(|| anyhow!("invalid location"))?;
1984
1985    let db = session.db().await;
1986    let room = db
1987        .update_room_participant_location(room_id, session.connection_id, location)
1988        .await?;
1989
1990    room_updated(&room, &session.peer);
1991    response.send(proto::Ack {})?;
1992    Ok(())
1993}
1994
1995/// Share a project into the room.
1996async fn share_project(
1997    request: proto::ShareProject,
1998    response: Response<proto::ShareProject>,
1999    session: UserSession,
2000) -> Result<()> {
2001    let (project_id, room) = &*session
2002        .db()
2003        .await
2004        .share_project(
2005            RoomId::from_proto(request.room_id),
2006            session.connection_id,
2007            &request.worktrees,
2008            request.is_ssh_project,
2009            request
2010                .dev_server_project_id
2011                .map(DevServerProjectId::from_proto),
2012        )
2013        .await?;
2014    response.send(proto::ShareProjectResponse {
2015        project_id: project_id.to_proto(),
2016    })?;
2017    room_updated(room, &session.peer);
2018
2019    Ok(())
2020}
2021
2022/// Unshare a project from the room.
2023async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2024    let project_id = ProjectId::from_proto(message.project_id);
2025    unshare_project_internal(
2026        project_id,
2027        session.connection_id,
2028        session.user_id(),
2029        &session,
2030    )
2031    .await
2032}
2033
2034async fn unshare_project_internal(
2035    project_id: ProjectId,
2036    connection_id: ConnectionId,
2037    user_id: Option<UserId>,
2038    session: &Session,
2039) -> Result<()> {
2040    let delete = {
2041        let room_guard = session
2042            .db()
2043            .await
2044            .unshare_project(project_id, connection_id, user_id)
2045            .await?;
2046
2047        let (delete, room, guest_connection_ids) = &*room_guard;
2048
2049        let message = proto::UnshareProject {
2050            project_id: project_id.to_proto(),
2051        };
2052
2053        broadcast(
2054            Some(connection_id),
2055            guest_connection_ids.iter().copied(),
2056            |conn_id| session.peer.send(conn_id, message.clone()),
2057        );
2058        if let Some(room) = room {
2059            room_updated(room, &session.peer);
2060        }
2061
2062        *delete
2063    };
2064
2065    if delete {
2066        let db = session.db().await;
2067        db.delete_project(project_id).await?;
2068    }
2069
2070    Ok(())
2071}
2072
2073/// DevServer makes a project available online
2074async fn share_dev_server_project(
2075    request: proto::ShareDevServerProject,
2076    response: Response<proto::ShareDevServerProject>,
2077    session: DevServerSession,
2078) -> Result<()> {
2079    let (dev_server_project, user_id, status) = session
2080        .db()
2081        .await
2082        .share_dev_server_project(
2083            DevServerProjectId::from_proto(request.dev_server_project_id),
2084            session.dev_server_id(),
2085            session.connection_id,
2086            &request.worktrees,
2087        )
2088        .await?;
2089    let Some(project_id) = dev_server_project.project_id else {
2090        return Err(anyhow!("failed to share remote project"))?;
2091    };
2092
2093    send_dev_server_projects_update(user_id, status, &session).await;
2094
2095    response.send(proto::ShareProjectResponse { project_id })?;
2096
2097    Ok(())
2098}
2099
2100/// Join someone elses shared project.
2101async fn join_project(
2102    request: proto::JoinProject,
2103    response: Response<proto::JoinProject>,
2104    session: UserSession,
2105) -> Result<()> {
2106    let project_id = ProjectId::from_proto(request.project_id);
2107
2108    tracing::info!(%project_id, "join project");
2109
2110    let db = session.db().await;
2111    let (project, replica_id) = &mut *db
2112        .join_project(project_id, session.connection_id, session.user_id())
2113        .await?;
2114    drop(db);
2115    tracing::info!(%project_id, "join remote project");
2116    join_project_internal(response, session, project, replica_id)
2117}
2118
2119trait JoinProjectInternalResponse {
2120    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2121}
2122impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2123    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2124        Response::<proto::JoinProject>::send(self, result)
2125    }
2126}
2127impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2128    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2129        Response::<proto::JoinHostedProject>::send(self, result)
2130    }
2131}
2132
2133fn join_project_internal(
2134    response: impl JoinProjectInternalResponse,
2135    session: UserSession,
2136    project: &mut Project,
2137    replica_id: &ReplicaId,
2138) -> Result<()> {
2139    let collaborators = project
2140        .collaborators
2141        .iter()
2142        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2143        .map(|collaborator| collaborator.to_proto())
2144        .collect::<Vec<_>>();
2145    let project_id = project.id;
2146    let guest_user_id = session.user_id();
2147
2148    let worktrees = project
2149        .worktrees
2150        .iter()
2151        .map(|(id, worktree)| proto::WorktreeMetadata {
2152            id: *id,
2153            root_name: worktree.root_name.clone(),
2154            visible: worktree.visible,
2155            abs_path: worktree.abs_path.clone(),
2156        })
2157        .collect::<Vec<_>>();
2158
2159    let add_project_collaborator = proto::AddProjectCollaborator {
2160        project_id: project_id.to_proto(),
2161        collaborator: Some(proto::Collaborator {
2162            peer_id: Some(session.connection_id.into()),
2163            replica_id: replica_id.0 as u32,
2164            user_id: guest_user_id.to_proto(),
2165        }),
2166    };
2167
2168    for collaborator in &collaborators {
2169        session
2170            .peer
2171            .send(
2172                collaborator.peer_id.unwrap().into(),
2173                add_project_collaborator.clone(),
2174            )
2175            .trace_err();
2176    }
2177
2178    // First, we send the metadata associated with each worktree.
2179    response.send(proto::JoinProjectResponse {
2180        project_id: project.id.0 as u64,
2181        worktrees: worktrees.clone(),
2182        replica_id: replica_id.0 as u32,
2183        collaborators: collaborators.clone(),
2184        language_servers: project.language_servers.clone(),
2185        role: project.role.into(),
2186        dev_server_project_id: project
2187            .dev_server_project_id
2188            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2189    })?;
2190
2191    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2192        #[cfg(any(test, feature = "test-support"))]
2193        const MAX_CHUNK_SIZE: usize = 2;
2194        #[cfg(not(any(test, feature = "test-support")))]
2195        const MAX_CHUNK_SIZE: usize = 256;
2196
2197        // Stream this worktree's entries.
2198        let message = proto::UpdateWorktree {
2199            project_id: project_id.to_proto(),
2200            worktree_id,
2201            abs_path: worktree.abs_path.clone(),
2202            root_name: worktree.root_name,
2203            updated_entries: worktree.entries,
2204            removed_entries: Default::default(),
2205            scan_id: worktree.scan_id,
2206            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2207            updated_repositories: worktree.repository_entries.into_values().collect(),
2208            removed_repositories: Default::default(),
2209        };
2210        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2211            session.peer.send(session.connection_id, update.clone())?;
2212        }
2213
2214        // Stream this worktree's diagnostics.
2215        for summary in worktree.diagnostic_summaries {
2216            session.peer.send(
2217                session.connection_id,
2218                proto::UpdateDiagnosticSummary {
2219                    project_id: project_id.to_proto(),
2220                    worktree_id: worktree.id,
2221                    summary: Some(summary),
2222                },
2223            )?;
2224        }
2225
2226        for settings_file in worktree.settings_files {
2227            session.peer.send(
2228                session.connection_id,
2229                proto::UpdateWorktreeSettings {
2230                    project_id: project_id.to_proto(),
2231                    worktree_id: worktree.id,
2232                    path: settings_file.path,
2233                    content: Some(settings_file.content),
2234                    kind: Some(proto::update_user_settings::Kind::Settings.into()),
2235                },
2236            )?;
2237        }
2238    }
2239
2240    for language_server in &project.language_servers {
2241        session.peer.send(
2242            session.connection_id,
2243            proto::UpdateLanguageServer {
2244                project_id: project_id.to_proto(),
2245                language_server_id: language_server.id,
2246                variant: Some(
2247                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2248                        proto::LspDiskBasedDiagnosticsUpdated {},
2249                    ),
2250                ),
2251            },
2252        )?;
2253    }
2254
2255    Ok(())
2256}
2257
2258/// Leave someone elses shared project.
2259async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2260    let sender_id = session.connection_id;
2261    let project_id = ProjectId::from_proto(request.project_id);
2262    let db = session.db().await;
2263    if db.is_hosted_project(project_id).await? {
2264        let project = db.leave_hosted_project(project_id, sender_id).await?;
2265        project_left(&project, &session);
2266        return Ok(());
2267    }
2268
2269    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2270    tracing::info!(
2271        %project_id,
2272        "leave project"
2273    );
2274
2275    project_left(project, &session);
2276    if let Some(room) = room {
2277        room_updated(room, &session.peer);
2278    }
2279
2280    Ok(())
2281}
2282
2283async fn join_hosted_project(
2284    request: proto::JoinHostedProject,
2285    response: Response<proto::JoinHostedProject>,
2286    session: UserSession,
2287) -> Result<()> {
2288    let (mut project, replica_id) = session
2289        .db()
2290        .await
2291        .join_hosted_project(
2292            ProjectId(request.project_id as i32),
2293            session.user_id(),
2294            session.connection_id,
2295        )
2296        .await?;
2297
2298    join_project_internal(response, session, &mut project, &replica_id)
2299}
2300
2301async fn list_remote_directory(
2302    request: proto::ListRemoteDirectory,
2303    response: Response<proto::ListRemoteDirectory>,
2304    session: UserSession,
2305) -> Result<()> {
2306    let dev_server_id = DevServerId(request.dev_server_id as i32);
2307    let dev_server_connection_id = session
2308        .connection_pool()
2309        .await
2310        .online_dev_server_connection_id(dev_server_id)?;
2311
2312    session
2313        .db()
2314        .await
2315        .get_dev_server_for_user(dev_server_id, session.user_id())
2316        .await?;
2317
2318    response.send(
2319        session
2320            .peer
2321            .forward_request(session.connection_id, dev_server_connection_id, request)
2322            .await?,
2323    )?;
2324    Ok(())
2325}
2326
2327async fn update_dev_server_project(
2328    request: proto::UpdateDevServerProject,
2329    response: Response<proto::UpdateDevServerProject>,
2330    session: UserSession,
2331) -> Result<()> {
2332    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2333
2334    let (dev_server_project, update) = session
2335        .db()
2336        .await
2337        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2338        .await?;
2339
2340    let projects = session
2341        .db()
2342        .await
2343        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2344        .await?;
2345
2346    let dev_server_connection_id = session
2347        .connection_pool()
2348        .await
2349        .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2350
2351    session.peer.send(
2352        dev_server_connection_id,
2353        proto::DevServerInstructions { projects },
2354    )?;
2355
2356    send_dev_server_projects_update(session.user_id(), update, &session).await;
2357
2358    response.send(proto::Ack {})
2359}
2360
2361async fn create_dev_server_project(
2362    request: proto::CreateDevServerProject,
2363    response: Response<proto::CreateDevServerProject>,
2364    session: UserSession,
2365) -> Result<()> {
2366    let dev_server_id = DevServerId(request.dev_server_id as i32);
2367    let dev_server_connection_id = session
2368        .connection_pool()
2369        .await
2370        .dev_server_connection_id(dev_server_id);
2371    let Some(dev_server_connection_id) = dev_server_connection_id else {
2372        Err(ErrorCode::DevServerOffline
2373            .message("Cannot create a remote project when the dev server is offline".to_string())
2374            .anyhow())?
2375    };
2376
2377    let path = request.path.clone();
2378    //Check that the path exists on the dev server
2379    session
2380        .peer
2381        .forward_request(
2382            session.connection_id,
2383            dev_server_connection_id,
2384            proto::ValidateDevServerProjectRequest { path: path.clone() },
2385        )
2386        .await?;
2387
2388    let (dev_server_project, update) = session
2389        .db()
2390        .await
2391        .create_dev_server_project(
2392            DevServerId(request.dev_server_id as i32),
2393            &request.path,
2394            session.user_id(),
2395        )
2396        .await?;
2397
2398    let projects = session
2399        .db()
2400        .await
2401        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2402        .await?;
2403
2404    session.peer.send(
2405        dev_server_connection_id,
2406        proto::DevServerInstructions { projects },
2407    )?;
2408
2409    send_dev_server_projects_update(session.user_id(), update, &session).await;
2410
2411    response.send(proto::CreateDevServerProjectResponse {
2412        dev_server_project: Some(dev_server_project.to_proto(None)),
2413    })?;
2414    Ok(())
2415}
2416
2417async fn create_dev_server(
2418    request: proto::CreateDevServer,
2419    response: Response<proto::CreateDevServer>,
2420    session: UserSession,
2421) -> Result<()> {
2422    let access_token = auth::random_token();
2423    let hashed_access_token = auth::hash_access_token(&access_token);
2424
2425    if request.name.is_empty() {
2426        return Err(proto::ErrorCode::Forbidden
2427            .message("Dev server name cannot be empty".to_string())
2428            .anyhow())?;
2429    }
2430
2431    let (dev_server, status) = session
2432        .db()
2433        .await
2434        .create_dev_server(
2435            &request.name,
2436            request.ssh_connection_string.as_deref(),
2437            &hashed_access_token,
2438            session.user_id(),
2439        )
2440        .await?;
2441
2442    send_dev_server_projects_update(session.user_id(), status, &session).await;
2443
2444    response.send(proto::CreateDevServerResponse {
2445        dev_server_id: dev_server.id.0 as u64,
2446        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2447        name: request.name,
2448    })?;
2449    Ok(())
2450}
2451
2452async fn regenerate_dev_server_token(
2453    request: proto::RegenerateDevServerToken,
2454    response: Response<proto::RegenerateDevServerToken>,
2455    session: UserSession,
2456) -> Result<()> {
2457    let dev_server_id = DevServerId(request.dev_server_id as i32);
2458    let access_token = auth::random_token();
2459    let hashed_access_token = auth::hash_access_token(&access_token);
2460
2461    let connection_id = session
2462        .connection_pool()
2463        .await
2464        .dev_server_connection_id(dev_server_id);
2465    if let Some(connection_id) = connection_id {
2466        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2467        session.peer.send(
2468            connection_id,
2469            proto::ShutdownDevServer {
2470                reason: Some("dev server token was regenerated".to_string()),
2471            },
2472        )?;
2473        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2474    }
2475
2476    let status = session
2477        .db()
2478        .await
2479        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2480        .await?;
2481
2482    send_dev_server_projects_update(session.user_id(), status, &session).await;
2483
2484    response.send(proto::RegenerateDevServerTokenResponse {
2485        dev_server_id: dev_server_id.to_proto(),
2486        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2487    })?;
2488    Ok(())
2489}
2490
2491async fn rename_dev_server(
2492    request: proto::RenameDevServer,
2493    response: Response<proto::RenameDevServer>,
2494    session: UserSession,
2495) -> Result<()> {
2496    if request.name.trim().is_empty() {
2497        return Err(proto::ErrorCode::Forbidden
2498            .message("Dev server name cannot be empty".to_string())
2499            .anyhow())?;
2500    }
2501
2502    let dev_server_id = DevServerId(request.dev_server_id as i32);
2503    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2504    if dev_server.user_id != session.user_id() {
2505        return Err(anyhow!(ErrorCode::Forbidden))?;
2506    }
2507
2508    let status = session
2509        .db()
2510        .await
2511        .rename_dev_server(
2512            dev_server_id,
2513            &request.name,
2514            request.ssh_connection_string.as_deref(),
2515            session.user_id(),
2516        )
2517        .await?;
2518
2519    send_dev_server_projects_update(session.user_id(), status, &session).await;
2520
2521    response.send(proto::Ack {})?;
2522    Ok(())
2523}
2524
2525async fn delete_dev_server(
2526    request: proto::DeleteDevServer,
2527    response: Response<proto::DeleteDevServer>,
2528    session: UserSession,
2529) -> Result<()> {
2530    let dev_server_id = DevServerId(request.dev_server_id as i32);
2531    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2532    if dev_server.user_id != session.user_id() {
2533        return Err(anyhow!(ErrorCode::Forbidden))?;
2534    }
2535
2536    let connection_id = session
2537        .connection_pool()
2538        .await
2539        .dev_server_connection_id(dev_server_id);
2540    if let Some(connection_id) = connection_id {
2541        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2542        session.peer.send(
2543            connection_id,
2544            proto::ShutdownDevServer {
2545                reason: Some("dev server was deleted".to_string()),
2546            },
2547        )?;
2548        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2549    }
2550
2551    let status = session
2552        .db()
2553        .await
2554        .delete_dev_server(dev_server_id, session.user_id())
2555        .await?;
2556
2557    send_dev_server_projects_update(session.user_id(), status, &session).await;
2558
2559    response.send(proto::Ack {})?;
2560    Ok(())
2561}
2562
2563async fn delete_dev_server_project(
2564    request: proto::DeleteDevServerProject,
2565    response: Response<proto::DeleteDevServerProject>,
2566    session: UserSession,
2567) -> Result<()> {
2568    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2569    let dev_server_project = session
2570        .db()
2571        .await
2572        .get_dev_server_project(dev_server_project_id)
2573        .await?;
2574
2575    let dev_server = session
2576        .db()
2577        .await
2578        .get_dev_server(dev_server_project.dev_server_id)
2579        .await?;
2580    if dev_server.user_id != session.user_id() {
2581        return Err(anyhow!(ErrorCode::Forbidden))?;
2582    }
2583
2584    let dev_server_connection_id = session
2585        .connection_pool()
2586        .await
2587        .dev_server_connection_id(dev_server.id);
2588
2589    if let Some(dev_server_connection_id) = dev_server_connection_id {
2590        let project = session
2591            .db()
2592            .await
2593            .find_dev_server_project(dev_server_project_id)
2594            .await;
2595        if let Ok(project) = project {
2596            unshare_project_internal(
2597                project.id,
2598                dev_server_connection_id,
2599                Some(session.user_id()),
2600                &session,
2601            )
2602            .await?;
2603        }
2604    }
2605
2606    let (projects, status) = session
2607        .db()
2608        .await
2609        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2610        .await?;
2611
2612    if let Some(dev_server_connection_id) = dev_server_connection_id {
2613        session.peer.send(
2614            dev_server_connection_id,
2615            proto::DevServerInstructions { projects },
2616        )?;
2617    }
2618
2619    send_dev_server_projects_update(session.user_id(), status, &session).await;
2620
2621    response.send(proto::Ack {})?;
2622    Ok(())
2623}
2624
2625async fn rejoin_dev_server_projects(
2626    request: proto::RejoinRemoteProjects,
2627    response: Response<proto::RejoinRemoteProjects>,
2628    session: UserSession,
2629) -> Result<()> {
2630    let mut rejoined_projects = {
2631        let db = session.db().await;
2632        db.rejoin_dev_server_projects(
2633            &request.rejoined_projects,
2634            session.user_id(),
2635            session.0.connection_id,
2636        )
2637        .await?
2638    };
2639    response.send(proto::RejoinRemoteProjectsResponse {
2640        rejoined_projects: rejoined_projects
2641            .iter()
2642            .map(|project| project.to_proto())
2643            .collect(),
2644    })?;
2645    notify_rejoined_projects(&mut rejoined_projects, &session)
2646}
2647
2648async fn reconnect_dev_server(
2649    request: proto::ReconnectDevServer,
2650    response: Response<proto::ReconnectDevServer>,
2651    session: DevServerSession,
2652) -> Result<()> {
2653    let reshared_projects = {
2654        let db = session.db().await;
2655        db.reshare_dev_server_projects(
2656            &request.reshared_projects,
2657            session.dev_server_id(),
2658            session.0.connection_id,
2659        )
2660        .await?
2661    };
2662
2663    for project in &reshared_projects {
2664        for collaborator in &project.collaborators {
2665            session
2666                .peer
2667                .send(
2668                    collaborator.connection_id,
2669                    proto::UpdateProjectCollaborator {
2670                        project_id: project.id.to_proto(),
2671                        old_peer_id: Some(project.old_connection_id.into()),
2672                        new_peer_id: Some(session.connection_id.into()),
2673                    },
2674                )
2675                .trace_err();
2676        }
2677
2678        broadcast(
2679            Some(session.connection_id),
2680            project
2681                .collaborators
2682                .iter()
2683                .map(|collaborator| collaborator.connection_id),
2684            |connection_id| {
2685                session.peer.forward_send(
2686                    session.connection_id,
2687                    connection_id,
2688                    proto::UpdateProject {
2689                        project_id: project.id.to_proto(),
2690                        worktrees: project.worktrees.clone(),
2691                    },
2692                )
2693            },
2694        );
2695    }
2696
2697    response.send(proto::ReconnectDevServerResponse {
2698        reshared_projects: reshared_projects
2699            .iter()
2700            .map(|project| proto::ResharedProject {
2701                id: project.id.to_proto(),
2702                collaborators: project
2703                    .collaborators
2704                    .iter()
2705                    .map(|collaborator| collaborator.to_proto())
2706                    .collect(),
2707            })
2708            .collect(),
2709    })?;
2710
2711    Ok(())
2712}
2713
2714async fn shutdown_dev_server(
2715    _: proto::ShutdownDevServer,
2716    response: Response<proto::ShutdownDevServer>,
2717    session: DevServerSession,
2718) -> Result<()> {
2719    response.send(proto::Ack {})?;
2720    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2721    remove_dev_server_connection(session.dev_server_id(), &session).await
2722}
2723
2724async fn shutdown_dev_server_internal(
2725    dev_server_id: DevServerId,
2726    connection_id: ConnectionId,
2727    session: &Session,
2728) -> Result<()> {
2729    let (dev_server_projects, dev_server) = {
2730        let db = session.db().await;
2731        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2732        let dev_server = db.get_dev_server(dev_server_id).await?;
2733        (dev_server_projects, dev_server)
2734    };
2735
2736    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2737        unshare_project_internal(
2738            ProjectId::from_proto(project_id),
2739            connection_id,
2740            None,
2741            session,
2742        )
2743        .await?;
2744    }
2745
2746    session
2747        .connection_pool()
2748        .await
2749        .set_dev_server_offline(dev_server_id);
2750
2751    let status = session
2752        .db()
2753        .await
2754        .dev_server_projects_update(dev_server.user_id)
2755        .await?;
2756    send_dev_server_projects_update(dev_server.user_id, status, session).await;
2757
2758    Ok(())
2759}
2760
2761async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2762    let dev_server_connection = session
2763        .connection_pool()
2764        .await
2765        .dev_server_connection_id(dev_server_id);
2766
2767    if let Some(dev_server_connection) = dev_server_connection {
2768        session
2769            .connection_pool()
2770            .await
2771            .remove_connection(dev_server_connection)?;
2772    }
2773    Ok(())
2774}
2775
2776/// Updates other participants with changes to the project
2777async fn update_project(
2778    request: proto::UpdateProject,
2779    response: Response<proto::UpdateProject>,
2780    session: Session,
2781) -> Result<()> {
2782    let project_id = ProjectId::from_proto(request.project_id);
2783    let (room, guest_connection_ids) = &*session
2784        .db()
2785        .await
2786        .update_project(project_id, session.connection_id, &request.worktrees)
2787        .await?;
2788    broadcast(
2789        Some(session.connection_id),
2790        guest_connection_ids.iter().copied(),
2791        |connection_id| {
2792            session
2793                .peer
2794                .forward_send(session.connection_id, connection_id, request.clone())
2795        },
2796    );
2797    if let Some(room) = room {
2798        room_updated(room, &session.peer);
2799    }
2800    response.send(proto::Ack {})?;
2801
2802    Ok(())
2803}
2804
2805/// Updates other participants with changes to the worktree
2806async fn update_worktree(
2807    request: proto::UpdateWorktree,
2808    response: Response<proto::UpdateWorktree>,
2809    session: Session,
2810) -> Result<()> {
2811    let guest_connection_ids = session
2812        .db()
2813        .await
2814        .update_worktree(&request, session.connection_id)
2815        .await?;
2816
2817    broadcast(
2818        Some(session.connection_id),
2819        guest_connection_ids.iter().copied(),
2820        |connection_id| {
2821            session
2822                .peer
2823                .forward_send(session.connection_id, connection_id, request.clone())
2824        },
2825    );
2826    response.send(proto::Ack {})?;
2827    Ok(())
2828}
2829
2830/// Updates other participants with changes to the diagnostics
2831async fn update_diagnostic_summary(
2832    message: proto::UpdateDiagnosticSummary,
2833    session: Session,
2834) -> Result<()> {
2835    let guest_connection_ids = session
2836        .db()
2837        .await
2838        .update_diagnostic_summary(&message, session.connection_id)
2839        .await?;
2840
2841    broadcast(
2842        Some(session.connection_id),
2843        guest_connection_ids.iter().copied(),
2844        |connection_id| {
2845            session
2846                .peer
2847                .forward_send(session.connection_id, connection_id, message.clone())
2848        },
2849    );
2850
2851    Ok(())
2852}
2853
2854/// Updates other participants with changes to the worktree settings
2855async fn update_worktree_settings(
2856    message: proto::UpdateWorktreeSettings,
2857    session: Session,
2858) -> Result<()> {
2859    let guest_connection_ids = session
2860        .db()
2861        .await
2862        .update_worktree_settings(&message, session.connection_id)
2863        .await?;
2864
2865    broadcast(
2866        Some(session.connection_id),
2867        guest_connection_ids.iter().copied(),
2868        |connection_id| {
2869            session
2870                .peer
2871                .forward_send(session.connection_id, connection_id, message.clone())
2872        },
2873    );
2874
2875    Ok(())
2876}
2877
2878/// Notify other participants that a  language server has started.
2879async fn start_language_server(
2880    request: proto::StartLanguageServer,
2881    session: Session,
2882) -> Result<()> {
2883    let guest_connection_ids = session
2884        .db()
2885        .await
2886        .start_language_server(&request, session.connection_id)
2887        .await?;
2888
2889    broadcast(
2890        Some(session.connection_id),
2891        guest_connection_ids.iter().copied(),
2892        |connection_id| {
2893            session
2894                .peer
2895                .forward_send(session.connection_id, connection_id, request.clone())
2896        },
2897    );
2898    Ok(())
2899}
2900
2901/// Notify other participants that a language server has changed.
2902async fn update_language_server(
2903    request: proto::UpdateLanguageServer,
2904    session: Session,
2905) -> Result<()> {
2906    let project_id = ProjectId::from_proto(request.project_id);
2907    let project_connection_ids = session
2908        .db()
2909        .await
2910        .project_connection_ids(project_id, session.connection_id, true)
2911        .await?;
2912    broadcast(
2913        Some(session.connection_id),
2914        project_connection_ids.iter().copied(),
2915        |connection_id| {
2916            session
2917                .peer
2918                .forward_send(session.connection_id, connection_id, request.clone())
2919        },
2920    );
2921    Ok(())
2922}
2923
2924/// forward a project request to the host. These requests should be read only
2925/// as guests are allowed to send them.
2926async fn forward_read_only_project_request<T>(
2927    request: T,
2928    response: Response<T>,
2929    session: UserSession,
2930) -> Result<()>
2931where
2932    T: EntityMessage + RequestMessage,
2933{
2934    let project_id = ProjectId::from_proto(request.remote_entity_id());
2935    let host_connection_id = session
2936        .db()
2937        .await
2938        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2939        .await?;
2940    let payload = session
2941        .peer
2942        .forward_request(session.connection_id, host_connection_id, request)
2943        .await?;
2944    response.send(payload)?;
2945    Ok(())
2946}
2947
2948async fn forward_find_search_candidates_request(
2949    request: proto::FindSearchCandidates,
2950    response: Response<proto::FindSearchCandidates>,
2951    session: UserSession,
2952) -> Result<()> {
2953    let project_id = ProjectId::from_proto(request.remote_entity_id());
2954    let host_connection_id = session
2955        .db()
2956        .await
2957        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2958        .await?;
2959    let payload = session
2960        .peer
2961        .forward_request(session.connection_id, host_connection_id, request)
2962        .await?;
2963    response.send(payload)?;
2964    Ok(())
2965}
2966
2967/// forward a project request to the dev server. Only allowed
2968/// if it's your dev server.
2969async fn forward_project_request_for_owner<T>(
2970    request: T,
2971    response: Response<T>,
2972    session: UserSession,
2973) -> Result<()>
2974where
2975    T: EntityMessage + RequestMessage,
2976{
2977    let project_id = ProjectId::from_proto(request.remote_entity_id());
2978
2979    let host_connection_id = session
2980        .db()
2981        .await
2982        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2983        .await?;
2984    let payload = session
2985        .peer
2986        .forward_request(session.connection_id, host_connection_id, request)
2987        .await?;
2988    response.send(payload)?;
2989    Ok(())
2990}
2991
2992/// forward a project request to the host. These requests are disallowed
2993/// for guests.
2994async fn forward_mutating_project_request<T>(
2995    request: T,
2996    response: Response<T>,
2997    session: UserSession,
2998) -> Result<()>
2999where
3000    T: EntityMessage + RequestMessage,
3001{
3002    let project_id = ProjectId::from_proto(request.remote_entity_id());
3003
3004    let host_connection_id = session
3005        .db()
3006        .await
3007        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3008        .await?;
3009    let payload = session
3010        .peer
3011        .forward_request(session.connection_id, host_connection_id, request)
3012        .await?;
3013    response.send(payload)?;
3014    Ok(())
3015}
3016
3017/// Notify other participants that a new buffer has been created
3018async fn create_buffer_for_peer(
3019    request: proto::CreateBufferForPeer,
3020    session: Session,
3021) -> Result<()> {
3022    session
3023        .db()
3024        .await
3025        .check_user_is_project_host(
3026            ProjectId::from_proto(request.project_id),
3027            session.connection_id,
3028        )
3029        .await?;
3030    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3031    session
3032        .peer
3033        .forward_send(session.connection_id, peer_id.into(), request)?;
3034    Ok(())
3035}
3036
3037/// Notify other participants that a buffer has been updated. This is
3038/// allowed for guests as long as the update is limited to selections.
3039async fn update_buffer(
3040    request: proto::UpdateBuffer,
3041    response: Response<proto::UpdateBuffer>,
3042    session: Session,
3043) -> Result<()> {
3044    let project_id = ProjectId::from_proto(request.project_id);
3045    let mut capability = Capability::ReadOnly;
3046
3047    for op in request.operations.iter() {
3048        match op.variant {
3049            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3050            Some(_) => capability = Capability::ReadWrite,
3051        }
3052    }
3053
3054    let host = {
3055        let guard = session
3056            .db()
3057            .await
3058            .connections_for_buffer_update(
3059                project_id,
3060                session.principal_id(),
3061                session.connection_id,
3062                capability,
3063            )
3064            .await?;
3065
3066        let (host, guests) = &*guard;
3067
3068        broadcast(
3069            Some(session.connection_id),
3070            guests.clone(),
3071            |connection_id| {
3072                session
3073                    .peer
3074                    .forward_send(session.connection_id, connection_id, request.clone())
3075            },
3076        );
3077
3078        *host
3079    };
3080
3081    if host != session.connection_id {
3082        session
3083            .peer
3084            .forward_request(session.connection_id, host, request.clone())
3085            .await?;
3086    }
3087
3088    response.send(proto::Ack {})?;
3089    Ok(())
3090}
3091
3092async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3093    let project_id = ProjectId::from_proto(message.project_id);
3094
3095    let operation = message.operation.as_ref().context("invalid operation")?;
3096    let capability = match operation.variant.as_ref() {
3097        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3098            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3099                match buffer_op.variant {
3100                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3101                        Capability::ReadOnly
3102                    }
3103                    _ => Capability::ReadWrite,
3104                }
3105            } else {
3106                Capability::ReadWrite
3107            }
3108        }
3109        Some(_) => Capability::ReadWrite,
3110        None => Capability::ReadOnly,
3111    };
3112
3113    let guard = session
3114        .db()
3115        .await
3116        .connections_for_buffer_update(
3117            project_id,
3118            session.principal_id(),
3119            session.connection_id,
3120            capability,
3121        )
3122        .await?;
3123
3124    let (host, guests) = &*guard;
3125
3126    broadcast(
3127        Some(session.connection_id),
3128        guests.iter().chain([host]).copied(),
3129        |connection_id| {
3130            session
3131                .peer
3132                .forward_send(session.connection_id, connection_id, message.clone())
3133        },
3134    );
3135
3136    Ok(())
3137}
3138
3139/// Notify other participants that a project has been updated.
3140async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3141    request: T,
3142    session: Session,
3143) -> Result<()> {
3144    let project_id = ProjectId::from_proto(request.remote_entity_id());
3145    let project_connection_ids = session
3146        .db()
3147        .await
3148        .project_connection_ids(project_id, session.connection_id, false)
3149        .await?;
3150
3151    broadcast(
3152        Some(session.connection_id),
3153        project_connection_ids.iter().copied(),
3154        |connection_id| {
3155            session
3156                .peer
3157                .forward_send(session.connection_id, connection_id, request.clone())
3158        },
3159    );
3160    Ok(())
3161}
3162
3163/// Start following another user in a call.
3164async fn follow(
3165    request: proto::Follow,
3166    response: Response<proto::Follow>,
3167    session: UserSession,
3168) -> Result<()> {
3169    let room_id = RoomId::from_proto(request.room_id);
3170    let project_id = request.project_id.map(ProjectId::from_proto);
3171    let leader_id = request
3172        .leader_id
3173        .ok_or_else(|| anyhow!("invalid leader id"))?
3174        .into();
3175    let follower_id = session.connection_id;
3176
3177    session
3178        .db()
3179        .await
3180        .check_room_participants(room_id, leader_id, session.connection_id)
3181        .await?;
3182
3183    let response_payload = session
3184        .peer
3185        .forward_request(session.connection_id, leader_id, request)
3186        .await?;
3187    response.send(response_payload)?;
3188
3189    if let Some(project_id) = project_id {
3190        let room = session
3191            .db()
3192            .await
3193            .follow(room_id, project_id, leader_id, follower_id)
3194            .await?;
3195        room_updated(&room, &session.peer);
3196    }
3197
3198    Ok(())
3199}
3200
3201/// Stop following another user in a call.
3202async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3203    let room_id = RoomId::from_proto(request.room_id);
3204    let project_id = request.project_id.map(ProjectId::from_proto);
3205    let leader_id = request
3206        .leader_id
3207        .ok_or_else(|| anyhow!("invalid leader id"))?
3208        .into();
3209    let follower_id = session.connection_id;
3210
3211    session
3212        .db()
3213        .await
3214        .check_room_participants(room_id, leader_id, session.connection_id)
3215        .await?;
3216
3217    session
3218        .peer
3219        .forward_send(session.connection_id, leader_id, request)?;
3220
3221    if let Some(project_id) = project_id {
3222        let room = session
3223            .db()
3224            .await
3225            .unfollow(room_id, project_id, leader_id, follower_id)
3226            .await?;
3227        room_updated(&room, &session.peer);
3228    }
3229
3230    Ok(())
3231}
3232
3233/// Notify everyone following you of your current location.
3234async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3235    let room_id = RoomId::from_proto(request.room_id);
3236    let database = session.db.lock().await;
3237
3238    let connection_ids = if let Some(project_id) = request.project_id {
3239        let project_id = ProjectId::from_proto(project_id);
3240        database
3241            .project_connection_ids(project_id, session.connection_id, true)
3242            .await?
3243    } else {
3244        database
3245            .room_connection_ids(room_id, session.connection_id)
3246            .await?
3247    };
3248
3249    // For now, don't send view update messages back to that view's current leader.
3250    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3251        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3252        _ => None,
3253    });
3254
3255    for connection_id in connection_ids.iter().cloned() {
3256        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3257            session
3258                .peer
3259                .forward_send(session.connection_id, connection_id, request.clone())?;
3260        }
3261    }
3262    Ok(())
3263}
3264
3265/// Get public data about users.
3266async fn get_users(
3267    request: proto::GetUsers,
3268    response: Response<proto::GetUsers>,
3269    session: Session,
3270) -> Result<()> {
3271    let user_ids = request
3272        .user_ids
3273        .into_iter()
3274        .map(UserId::from_proto)
3275        .collect();
3276    let users = session
3277        .db()
3278        .await
3279        .get_users_by_ids(user_ids)
3280        .await?
3281        .into_iter()
3282        .map(|user| proto::User {
3283            id: user.id.to_proto(),
3284            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3285            github_login: user.github_login,
3286        })
3287        .collect();
3288    response.send(proto::UsersResponse { users })?;
3289    Ok(())
3290}
3291
3292/// Search for users (to invite) buy Github login
3293async fn fuzzy_search_users(
3294    request: proto::FuzzySearchUsers,
3295    response: Response<proto::FuzzySearchUsers>,
3296    session: UserSession,
3297) -> Result<()> {
3298    let query = request.query;
3299    let users = match query.len() {
3300        0 => vec![],
3301        1 | 2 => session
3302            .db()
3303            .await
3304            .get_user_by_github_login(&query)
3305            .await?
3306            .into_iter()
3307            .collect(),
3308        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3309    };
3310    let users = users
3311        .into_iter()
3312        .filter(|user| user.id != session.user_id())
3313        .map(|user| proto::User {
3314            id: user.id.to_proto(),
3315            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3316            github_login: user.github_login,
3317        })
3318        .collect();
3319    response.send(proto::UsersResponse { users })?;
3320    Ok(())
3321}
3322
3323/// Send a contact request to another user.
3324async fn request_contact(
3325    request: proto::RequestContact,
3326    response: Response<proto::RequestContact>,
3327    session: UserSession,
3328) -> Result<()> {
3329    let requester_id = session.user_id();
3330    let responder_id = UserId::from_proto(request.responder_id);
3331    if requester_id == responder_id {
3332        return Err(anyhow!("cannot add yourself as a contact"))?;
3333    }
3334
3335    let notifications = session
3336        .db()
3337        .await
3338        .send_contact_request(requester_id, responder_id)
3339        .await?;
3340
3341    // Update outgoing contact requests of requester
3342    let mut update = proto::UpdateContacts::default();
3343    update.outgoing_requests.push(responder_id.to_proto());
3344    for connection_id in session
3345        .connection_pool()
3346        .await
3347        .user_connection_ids(requester_id)
3348    {
3349        session.peer.send(connection_id, update.clone())?;
3350    }
3351
3352    // Update incoming contact requests of responder
3353    let mut update = proto::UpdateContacts::default();
3354    update
3355        .incoming_requests
3356        .push(proto::IncomingContactRequest {
3357            requester_id: requester_id.to_proto(),
3358        });
3359    let connection_pool = session.connection_pool().await;
3360    for connection_id in connection_pool.user_connection_ids(responder_id) {
3361        session.peer.send(connection_id, update.clone())?;
3362    }
3363
3364    send_notifications(&connection_pool, &session.peer, notifications);
3365
3366    response.send(proto::Ack {})?;
3367    Ok(())
3368}
3369
3370/// Accept or decline a contact request
3371async fn respond_to_contact_request(
3372    request: proto::RespondToContactRequest,
3373    response: Response<proto::RespondToContactRequest>,
3374    session: UserSession,
3375) -> Result<()> {
3376    let responder_id = session.user_id();
3377    let requester_id = UserId::from_proto(request.requester_id);
3378    let db = session.db().await;
3379    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3380        db.dismiss_contact_notification(responder_id, requester_id)
3381            .await?;
3382    } else {
3383        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3384
3385        let notifications = db
3386            .respond_to_contact_request(responder_id, requester_id, accept)
3387            .await?;
3388        let requester_busy = db.is_user_busy(requester_id).await?;
3389        let responder_busy = db.is_user_busy(responder_id).await?;
3390
3391        let pool = session.connection_pool().await;
3392        // Update responder with new contact
3393        let mut update = proto::UpdateContacts::default();
3394        if accept {
3395            update
3396                .contacts
3397                .push(contact_for_user(requester_id, requester_busy, &pool));
3398        }
3399        update
3400            .remove_incoming_requests
3401            .push(requester_id.to_proto());
3402        for connection_id in pool.user_connection_ids(responder_id) {
3403            session.peer.send(connection_id, update.clone())?;
3404        }
3405
3406        // Update requester with new contact
3407        let mut update = proto::UpdateContacts::default();
3408        if accept {
3409            update
3410                .contacts
3411                .push(contact_for_user(responder_id, responder_busy, &pool));
3412        }
3413        update
3414            .remove_outgoing_requests
3415            .push(responder_id.to_proto());
3416
3417        for connection_id in pool.user_connection_ids(requester_id) {
3418            session.peer.send(connection_id, update.clone())?;
3419        }
3420
3421        send_notifications(&pool, &session.peer, notifications);
3422    }
3423
3424    response.send(proto::Ack {})?;
3425    Ok(())
3426}
3427
3428/// Remove a contact.
3429async fn remove_contact(
3430    request: proto::RemoveContact,
3431    response: Response<proto::RemoveContact>,
3432    session: UserSession,
3433) -> Result<()> {
3434    let requester_id = session.user_id();
3435    let responder_id = UserId::from_proto(request.user_id);
3436    let db = session.db().await;
3437    let (contact_accepted, deleted_notification_id) =
3438        db.remove_contact(requester_id, responder_id).await?;
3439
3440    let pool = session.connection_pool().await;
3441    // Update outgoing contact requests of requester
3442    let mut update = proto::UpdateContacts::default();
3443    if contact_accepted {
3444        update.remove_contacts.push(responder_id.to_proto());
3445    } else {
3446        update
3447            .remove_outgoing_requests
3448            .push(responder_id.to_proto());
3449    }
3450    for connection_id in pool.user_connection_ids(requester_id) {
3451        session.peer.send(connection_id, update.clone())?;
3452    }
3453
3454    // Update incoming contact requests of responder
3455    let mut update = proto::UpdateContacts::default();
3456    if contact_accepted {
3457        update.remove_contacts.push(requester_id.to_proto());
3458    } else {
3459        update
3460            .remove_incoming_requests
3461            .push(requester_id.to_proto());
3462    }
3463    for connection_id in pool.user_connection_ids(responder_id) {
3464        session.peer.send(connection_id, update.clone())?;
3465        if let Some(notification_id) = deleted_notification_id {
3466            session.peer.send(
3467                connection_id,
3468                proto::DeleteNotification {
3469                    notification_id: notification_id.to_proto(),
3470                },
3471            )?;
3472        }
3473    }
3474
3475    response.send(proto::Ack {})?;
3476    Ok(())
3477}
3478
3479fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3480    version.0.minor() < 139
3481}
3482
3483async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3484    let plan = session.current_plan(&session.db().await).await?;
3485
3486    session
3487        .peer
3488        .send(
3489            session.connection_id,
3490            proto::UpdateUserPlan { plan: plan.into() },
3491        )
3492        .trace_err();
3493
3494    Ok(())
3495}
3496
3497async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3498    subscribe_user_to_channels(
3499        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3500        &session,
3501    )
3502    .await?;
3503    Ok(())
3504}
3505
3506async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3507    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3508    let mut pool = session.connection_pool().await;
3509    for membership in &channels_for_user.channel_memberships {
3510        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3511    }
3512    session.peer.send(
3513        session.connection_id,
3514        build_update_user_channels(&channels_for_user),
3515    )?;
3516    session.peer.send(
3517        session.connection_id,
3518        build_channels_update(channels_for_user),
3519    )?;
3520    Ok(())
3521}
3522
3523/// Creates a new channel.
3524async fn create_channel(
3525    request: proto::CreateChannel,
3526    response: Response<proto::CreateChannel>,
3527    session: UserSession,
3528) -> Result<()> {
3529    let db = session.db().await;
3530
3531    let parent_id = request.parent_id.map(ChannelId::from_proto);
3532    let (channel, membership) = db
3533        .create_channel(&request.name, parent_id, session.user_id())
3534        .await?;
3535
3536    let root_id = channel.root_id();
3537    let channel = Channel::from_model(channel);
3538
3539    response.send(proto::CreateChannelResponse {
3540        channel: Some(channel.to_proto()),
3541        parent_id: request.parent_id,
3542    })?;
3543
3544    let mut connection_pool = session.connection_pool().await;
3545    if let Some(membership) = membership {
3546        connection_pool.subscribe_to_channel(
3547            membership.user_id,
3548            membership.channel_id,
3549            membership.role,
3550        );
3551        let update = proto::UpdateUserChannels {
3552            channel_memberships: vec![proto::ChannelMembership {
3553                channel_id: membership.channel_id.to_proto(),
3554                role: membership.role.into(),
3555            }],
3556            ..Default::default()
3557        };
3558        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3559            session.peer.send(connection_id, update.clone())?;
3560        }
3561    }
3562
3563    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3564        if !role.can_see_channel(channel.visibility) {
3565            continue;
3566        }
3567
3568        let update = proto::UpdateChannels {
3569            channels: vec![channel.to_proto()],
3570            ..Default::default()
3571        };
3572        session.peer.send(connection_id, update.clone())?;
3573    }
3574
3575    Ok(())
3576}
3577
3578/// Delete a channel
3579async fn delete_channel(
3580    request: proto::DeleteChannel,
3581    response: Response<proto::DeleteChannel>,
3582    session: UserSession,
3583) -> Result<()> {
3584    let db = session.db().await;
3585
3586    let channel_id = request.channel_id;
3587    let (root_channel, removed_channels) = db
3588        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3589        .await?;
3590    response.send(proto::Ack {})?;
3591
3592    // Notify members of removed channels
3593    let mut update = proto::UpdateChannels::default();
3594    update
3595        .delete_channels
3596        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3597
3598    let connection_pool = session.connection_pool().await;
3599    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3600        session.peer.send(connection_id, update.clone())?;
3601    }
3602
3603    Ok(())
3604}
3605
3606/// Invite someone to join a channel.
3607async fn invite_channel_member(
3608    request: proto::InviteChannelMember,
3609    response: Response<proto::InviteChannelMember>,
3610    session: UserSession,
3611) -> Result<()> {
3612    let db = session.db().await;
3613    let channel_id = ChannelId::from_proto(request.channel_id);
3614    let invitee_id = UserId::from_proto(request.user_id);
3615    let InviteMemberResult {
3616        channel,
3617        notifications,
3618    } = db
3619        .invite_channel_member(
3620            channel_id,
3621            invitee_id,
3622            session.user_id(),
3623            request.role().into(),
3624        )
3625        .await?;
3626
3627    let update = proto::UpdateChannels {
3628        channel_invitations: vec![channel.to_proto()],
3629        ..Default::default()
3630    };
3631
3632    let connection_pool = session.connection_pool().await;
3633    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3634        session.peer.send(connection_id, update.clone())?;
3635    }
3636
3637    send_notifications(&connection_pool, &session.peer, notifications);
3638
3639    response.send(proto::Ack {})?;
3640    Ok(())
3641}
3642
3643/// remove someone from a channel
3644async fn remove_channel_member(
3645    request: proto::RemoveChannelMember,
3646    response: Response<proto::RemoveChannelMember>,
3647    session: UserSession,
3648) -> Result<()> {
3649    let db = session.db().await;
3650    let channel_id = ChannelId::from_proto(request.channel_id);
3651    let member_id = UserId::from_proto(request.user_id);
3652
3653    let RemoveChannelMemberResult {
3654        membership_update,
3655        notification_id,
3656    } = db
3657        .remove_channel_member(channel_id, member_id, session.user_id())
3658        .await?;
3659
3660    let mut connection_pool = session.connection_pool().await;
3661    notify_membership_updated(
3662        &mut connection_pool,
3663        membership_update,
3664        member_id,
3665        &session.peer,
3666    );
3667    for connection_id in connection_pool.user_connection_ids(member_id) {
3668        if let Some(notification_id) = notification_id {
3669            session
3670                .peer
3671                .send(
3672                    connection_id,
3673                    proto::DeleteNotification {
3674                        notification_id: notification_id.to_proto(),
3675                    },
3676                )
3677                .trace_err();
3678        }
3679    }
3680
3681    response.send(proto::Ack {})?;
3682    Ok(())
3683}
3684
3685/// Toggle the channel between public and private.
3686/// Care is taken to maintain the invariant that public channels only descend from public channels,
3687/// (though members-only channels can appear at any point in the hierarchy).
3688async fn set_channel_visibility(
3689    request: proto::SetChannelVisibility,
3690    response: Response<proto::SetChannelVisibility>,
3691    session: UserSession,
3692) -> Result<()> {
3693    let db = session.db().await;
3694    let channel_id = ChannelId::from_proto(request.channel_id);
3695    let visibility = request.visibility().into();
3696
3697    let channel_model = db
3698        .set_channel_visibility(channel_id, visibility, session.user_id())
3699        .await?;
3700    let root_id = channel_model.root_id();
3701    let channel = Channel::from_model(channel_model);
3702
3703    let mut connection_pool = session.connection_pool().await;
3704    for (user_id, role) in connection_pool
3705        .channel_user_ids(root_id)
3706        .collect::<Vec<_>>()
3707        .into_iter()
3708    {
3709        let update = if role.can_see_channel(channel.visibility) {
3710            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3711            proto::UpdateChannels {
3712                channels: vec![channel.to_proto()],
3713                ..Default::default()
3714            }
3715        } else {
3716            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3717            proto::UpdateChannels {
3718                delete_channels: vec![channel.id.to_proto()],
3719                ..Default::default()
3720            }
3721        };
3722
3723        for connection_id in connection_pool.user_connection_ids(user_id) {
3724            session.peer.send(connection_id, update.clone())?;
3725        }
3726    }
3727
3728    response.send(proto::Ack {})?;
3729    Ok(())
3730}
3731
3732/// Alter the role for a user in the channel.
3733async fn set_channel_member_role(
3734    request: proto::SetChannelMemberRole,
3735    response: Response<proto::SetChannelMemberRole>,
3736    session: UserSession,
3737) -> Result<()> {
3738    let db = session.db().await;
3739    let channel_id = ChannelId::from_proto(request.channel_id);
3740    let member_id = UserId::from_proto(request.user_id);
3741    let result = db
3742        .set_channel_member_role(
3743            channel_id,
3744            session.user_id(),
3745            member_id,
3746            request.role().into(),
3747        )
3748        .await?;
3749
3750    match result {
3751        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3752            let mut connection_pool = session.connection_pool().await;
3753            notify_membership_updated(
3754                &mut connection_pool,
3755                membership_update,
3756                member_id,
3757                &session.peer,
3758            )
3759        }
3760        db::SetMemberRoleResult::InviteUpdated(channel) => {
3761            let update = proto::UpdateChannels {
3762                channel_invitations: vec![channel.to_proto()],
3763                ..Default::default()
3764            };
3765
3766            for connection_id in session
3767                .connection_pool()
3768                .await
3769                .user_connection_ids(member_id)
3770            {
3771                session.peer.send(connection_id, update.clone())?;
3772            }
3773        }
3774    }
3775
3776    response.send(proto::Ack {})?;
3777    Ok(())
3778}
3779
3780/// Change the name of a channel
3781async fn rename_channel(
3782    request: proto::RenameChannel,
3783    response: Response<proto::RenameChannel>,
3784    session: UserSession,
3785) -> Result<()> {
3786    let db = session.db().await;
3787    let channel_id = ChannelId::from_proto(request.channel_id);
3788    let channel_model = db
3789        .rename_channel(channel_id, session.user_id(), &request.name)
3790        .await?;
3791    let root_id = channel_model.root_id();
3792    let channel = Channel::from_model(channel_model);
3793
3794    response.send(proto::RenameChannelResponse {
3795        channel: Some(channel.to_proto()),
3796    })?;
3797
3798    let connection_pool = session.connection_pool().await;
3799    let update = proto::UpdateChannels {
3800        channels: vec![channel.to_proto()],
3801        ..Default::default()
3802    };
3803    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3804        if role.can_see_channel(channel.visibility) {
3805            session.peer.send(connection_id, update.clone())?;
3806        }
3807    }
3808
3809    Ok(())
3810}
3811
3812/// Move a channel to a new parent.
3813async fn move_channel(
3814    request: proto::MoveChannel,
3815    response: Response<proto::MoveChannel>,
3816    session: UserSession,
3817) -> Result<()> {
3818    let channel_id = ChannelId::from_proto(request.channel_id);
3819    let to = ChannelId::from_proto(request.to);
3820
3821    let (root_id, channels) = session
3822        .db()
3823        .await
3824        .move_channel(channel_id, to, session.user_id())
3825        .await?;
3826
3827    let connection_pool = session.connection_pool().await;
3828    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3829        let channels = channels
3830            .iter()
3831            .filter_map(|channel| {
3832                if role.can_see_channel(channel.visibility) {
3833                    Some(channel.to_proto())
3834                } else {
3835                    None
3836                }
3837            })
3838            .collect::<Vec<_>>();
3839        if channels.is_empty() {
3840            continue;
3841        }
3842
3843        let update = proto::UpdateChannels {
3844            channels,
3845            ..Default::default()
3846        };
3847
3848        session.peer.send(connection_id, update.clone())?;
3849    }
3850
3851    response.send(Ack {})?;
3852    Ok(())
3853}
3854
3855/// Get the list of channel members
3856async fn get_channel_members(
3857    request: proto::GetChannelMembers,
3858    response: Response<proto::GetChannelMembers>,
3859    session: UserSession,
3860) -> Result<()> {
3861    let db = session.db().await;
3862    let channel_id = ChannelId::from_proto(request.channel_id);
3863    let limit = if request.limit == 0 {
3864        u16::MAX as u64
3865    } else {
3866        request.limit
3867    };
3868    let (members, users) = db
3869        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3870        .await?;
3871    response.send(proto::GetChannelMembersResponse { members, users })?;
3872    Ok(())
3873}
3874
3875/// Accept or decline a channel invitation.
3876async fn respond_to_channel_invite(
3877    request: proto::RespondToChannelInvite,
3878    response: Response<proto::RespondToChannelInvite>,
3879    session: UserSession,
3880) -> Result<()> {
3881    let db = session.db().await;
3882    let channel_id = ChannelId::from_proto(request.channel_id);
3883    let RespondToChannelInvite {
3884        membership_update,
3885        notifications,
3886    } = db
3887        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3888        .await?;
3889
3890    let mut connection_pool = session.connection_pool().await;
3891    if let Some(membership_update) = membership_update {
3892        notify_membership_updated(
3893            &mut connection_pool,
3894            membership_update,
3895            session.user_id(),
3896            &session.peer,
3897        );
3898    } else {
3899        let update = proto::UpdateChannels {
3900            remove_channel_invitations: vec![channel_id.to_proto()],
3901            ..Default::default()
3902        };
3903
3904        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3905            session.peer.send(connection_id, update.clone())?;
3906        }
3907    };
3908
3909    send_notifications(&connection_pool, &session.peer, notifications);
3910
3911    response.send(proto::Ack {})?;
3912
3913    Ok(())
3914}
3915
3916/// Join the channels' room
3917async fn join_channel(
3918    request: proto::JoinChannel,
3919    response: Response<proto::JoinChannel>,
3920    session: UserSession,
3921) -> Result<()> {
3922    let channel_id = ChannelId::from_proto(request.channel_id);
3923    join_channel_internal(channel_id, Box::new(response), session).await
3924}
3925
3926trait JoinChannelInternalResponse {
3927    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3928}
3929impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3930    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3931        Response::<proto::JoinChannel>::send(self, result)
3932    }
3933}
3934impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3935    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3936        Response::<proto::JoinRoom>::send(self, result)
3937    }
3938}
3939
3940async fn join_channel_internal(
3941    channel_id: ChannelId,
3942    response: Box<impl JoinChannelInternalResponse>,
3943    session: UserSession,
3944) -> Result<()> {
3945    let joined_room = {
3946        let mut db = session.db().await;
3947        // If zed quits without leaving the room, and the user re-opens zed before the
3948        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3949        // room they were in.
3950        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3951            tracing::info!(
3952                stale_connection_id = %connection,
3953                "cleaning up stale connection",
3954            );
3955            drop(db);
3956            leave_room_for_session(&session, connection).await?;
3957            db = session.db().await;
3958        }
3959
3960        let (joined_room, membership_updated, role) = db
3961            .join_channel(channel_id, session.user_id(), session.connection_id)
3962            .await?;
3963
3964        let live_kit_connection_info =
3965            session
3966                .app_state
3967                .live_kit_client
3968                .as_ref()
3969                .and_then(|live_kit| {
3970                    let (can_publish, token) = if role == ChannelRole::Guest {
3971                        (
3972                            false,
3973                            live_kit
3974                                .guest_token(
3975                                    &joined_room.room.live_kit_room,
3976                                    &session.user_id().to_string(),
3977                                )
3978                                .trace_err()?,
3979                        )
3980                    } else {
3981                        (
3982                            true,
3983                            live_kit
3984                                .room_token(
3985                                    &joined_room.room.live_kit_room,
3986                                    &session.user_id().to_string(),
3987                                )
3988                                .trace_err()?,
3989                        )
3990                    };
3991
3992                    Some(LiveKitConnectionInfo {
3993                        server_url: live_kit.url().into(),
3994                        token,
3995                        can_publish,
3996                    })
3997                });
3998
3999        response.send(proto::JoinRoomResponse {
4000            room: Some(joined_room.room.clone()),
4001            channel_id: joined_room
4002                .channel
4003                .as_ref()
4004                .map(|channel| channel.id.to_proto()),
4005            live_kit_connection_info,
4006        })?;
4007
4008        let mut connection_pool = session.connection_pool().await;
4009        if let Some(membership_updated) = membership_updated {
4010            notify_membership_updated(
4011                &mut connection_pool,
4012                membership_updated,
4013                session.user_id(),
4014                &session.peer,
4015            );
4016        }
4017
4018        room_updated(&joined_room.room, &session.peer);
4019
4020        joined_room
4021    };
4022
4023    channel_updated(
4024        &joined_room
4025            .channel
4026            .ok_or_else(|| anyhow!("channel not returned"))?,
4027        &joined_room.room,
4028        &session.peer,
4029        &*session.connection_pool().await,
4030    );
4031
4032    update_user_contacts(session.user_id(), &session).await?;
4033    Ok(())
4034}
4035
4036/// Start editing the channel notes
4037async fn join_channel_buffer(
4038    request: proto::JoinChannelBuffer,
4039    response: Response<proto::JoinChannelBuffer>,
4040    session: UserSession,
4041) -> Result<()> {
4042    let db = session.db().await;
4043    let channel_id = ChannelId::from_proto(request.channel_id);
4044
4045    let open_response = db
4046        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4047        .await?;
4048
4049    let collaborators = open_response.collaborators.clone();
4050    response.send(open_response)?;
4051
4052    let update = UpdateChannelBufferCollaborators {
4053        channel_id: channel_id.to_proto(),
4054        collaborators: collaborators.clone(),
4055    };
4056    channel_buffer_updated(
4057        session.connection_id,
4058        collaborators
4059            .iter()
4060            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4061        &update,
4062        &session.peer,
4063    );
4064
4065    Ok(())
4066}
4067
4068/// Edit the channel notes
4069async fn update_channel_buffer(
4070    request: proto::UpdateChannelBuffer,
4071    session: UserSession,
4072) -> Result<()> {
4073    let db = session.db().await;
4074    let channel_id = ChannelId::from_proto(request.channel_id);
4075
4076    let (collaborators, epoch, version) = db
4077        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4078        .await?;
4079
4080    channel_buffer_updated(
4081        session.connection_id,
4082        collaborators.clone(),
4083        &proto::UpdateChannelBuffer {
4084            channel_id: channel_id.to_proto(),
4085            operations: request.operations,
4086        },
4087        &session.peer,
4088    );
4089
4090    let pool = &*session.connection_pool().await;
4091
4092    let non_collaborators =
4093        pool.channel_connection_ids(channel_id)
4094            .filter_map(|(connection_id, _)| {
4095                if collaborators.contains(&connection_id) {
4096                    None
4097                } else {
4098                    Some(connection_id)
4099                }
4100            });
4101
4102    broadcast(None, non_collaborators, |peer_id| {
4103        session.peer.send(
4104            peer_id,
4105            proto::UpdateChannels {
4106                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4107                    channel_id: channel_id.to_proto(),
4108                    epoch: epoch as u64,
4109                    version: version.clone(),
4110                }],
4111                ..Default::default()
4112            },
4113        )
4114    });
4115
4116    Ok(())
4117}
4118
4119/// Rejoin the channel notes after a connection blip
4120async fn rejoin_channel_buffers(
4121    request: proto::RejoinChannelBuffers,
4122    response: Response<proto::RejoinChannelBuffers>,
4123    session: UserSession,
4124) -> Result<()> {
4125    let db = session.db().await;
4126    let buffers = db
4127        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4128        .await?;
4129
4130    for rejoined_buffer in &buffers {
4131        let collaborators_to_notify = rejoined_buffer
4132            .buffer
4133            .collaborators
4134            .iter()
4135            .filter_map(|c| Some(c.peer_id?.into()));
4136        channel_buffer_updated(
4137            session.connection_id,
4138            collaborators_to_notify,
4139            &proto::UpdateChannelBufferCollaborators {
4140                channel_id: rejoined_buffer.buffer.channel_id,
4141                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4142            },
4143            &session.peer,
4144        );
4145    }
4146
4147    response.send(proto::RejoinChannelBuffersResponse {
4148        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4149    })?;
4150
4151    Ok(())
4152}
4153
4154/// Stop editing the channel notes
4155async fn leave_channel_buffer(
4156    request: proto::LeaveChannelBuffer,
4157    response: Response<proto::LeaveChannelBuffer>,
4158    session: UserSession,
4159) -> Result<()> {
4160    let db = session.db().await;
4161    let channel_id = ChannelId::from_proto(request.channel_id);
4162
4163    let left_buffer = db
4164        .leave_channel_buffer(channel_id, session.connection_id)
4165        .await?;
4166
4167    response.send(Ack {})?;
4168
4169    channel_buffer_updated(
4170        session.connection_id,
4171        left_buffer.connections,
4172        &proto::UpdateChannelBufferCollaborators {
4173            channel_id: channel_id.to_proto(),
4174            collaborators: left_buffer.collaborators,
4175        },
4176        &session.peer,
4177    );
4178
4179    Ok(())
4180}
4181
4182fn channel_buffer_updated<T: EnvelopedMessage>(
4183    sender_id: ConnectionId,
4184    collaborators: impl IntoIterator<Item = ConnectionId>,
4185    message: &T,
4186    peer: &Peer,
4187) {
4188    broadcast(Some(sender_id), collaborators, |peer_id| {
4189        peer.send(peer_id, message.clone())
4190    });
4191}
4192
4193fn send_notifications(
4194    connection_pool: &ConnectionPool,
4195    peer: &Peer,
4196    notifications: db::NotificationBatch,
4197) {
4198    for (user_id, notification) in notifications {
4199        for connection_id in connection_pool.user_connection_ids(user_id) {
4200            if let Err(error) = peer.send(
4201                connection_id,
4202                proto::AddNotification {
4203                    notification: Some(notification.clone()),
4204                },
4205            ) {
4206                tracing::error!(
4207                    "failed to send notification to {:?} {}",
4208                    connection_id,
4209                    error
4210                );
4211            }
4212        }
4213    }
4214}
4215
4216/// Send a message to the channel
4217async fn send_channel_message(
4218    request: proto::SendChannelMessage,
4219    response: Response<proto::SendChannelMessage>,
4220    session: UserSession,
4221) -> Result<()> {
4222    // Validate the message body.
4223    let body = request.body.trim().to_string();
4224    if body.len() > MAX_MESSAGE_LEN {
4225        return Err(anyhow!("message is too long"))?;
4226    }
4227    if body.is_empty() {
4228        return Err(anyhow!("message can't be blank"))?;
4229    }
4230
4231    // TODO: adjust mentions if body is trimmed
4232
4233    let timestamp = OffsetDateTime::now_utc();
4234    let nonce = request
4235        .nonce
4236        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4237
4238    let channel_id = ChannelId::from_proto(request.channel_id);
4239    let CreatedChannelMessage {
4240        message_id,
4241        participant_connection_ids,
4242        notifications,
4243    } = session
4244        .db()
4245        .await
4246        .create_channel_message(
4247            channel_id,
4248            session.user_id(),
4249            &body,
4250            &request.mentions,
4251            timestamp,
4252            nonce.clone().into(),
4253            request.reply_to_message_id.map(MessageId::from_proto),
4254        )
4255        .await?;
4256
4257    let message = proto::ChannelMessage {
4258        sender_id: session.user_id().to_proto(),
4259        id: message_id.to_proto(),
4260        body,
4261        mentions: request.mentions,
4262        timestamp: timestamp.unix_timestamp() as u64,
4263        nonce: Some(nonce),
4264        reply_to_message_id: request.reply_to_message_id,
4265        edited_at: None,
4266    };
4267    broadcast(
4268        Some(session.connection_id),
4269        participant_connection_ids.clone(),
4270        |connection| {
4271            session.peer.send(
4272                connection,
4273                proto::ChannelMessageSent {
4274                    channel_id: channel_id.to_proto(),
4275                    message: Some(message.clone()),
4276                },
4277            )
4278        },
4279    );
4280    response.send(proto::SendChannelMessageResponse {
4281        message: Some(message),
4282    })?;
4283
4284    let pool = &*session.connection_pool().await;
4285    let non_participants =
4286        pool.channel_connection_ids(channel_id)
4287            .filter_map(|(connection_id, _)| {
4288                if participant_connection_ids.contains(&connection_id) {
4289                    None
4290                } else {
4291                    Some(connection_id)
4292                }
4293            });
4294    broadcast(None, non_participants, |peer_id| {
4295        session.peer.send(
4296            peer_id,
4297            proto::UpdateChannels {
4298                latest_channel_message_ids: vec![proto::ChannelMessageId {
4299                    channel_id: channel_id.to_proto(),
4300                    message_id: message_id.to_proto(),
4301                }],
4302                ..Default::default()
4303            },
4304        )
4305    });
4306    send_notifications(pool, &session.peer, notifications);
4307
4308    Ok(())
4309}
4310
4311/// Delete a channel message
4312async fn remove_channel_message(
4313    request: proto::RemoveChannelMessage,
4314    response: Response<proto::RemoveChannelMessage>,
4315    session: UserSession,
4316) -> Result<()> {
4317    let channel_id = ChannelId::from_proto(request.channel_id);
4318    let message_id = MessageId::from_proto(request.message_id);
4319    let (connection_ids, existing_notification_ids) = session
4320        .db()
4321        .await
4322        .remove_channel_message(channel_id, message_id, session.user_id())
4323        .await?;
4324
4325    broadcast(
4326        Some(session.connection_id),
4327        connection_ids,
4328        move |connection| {
4329            session.peer.send(connection, request.clone())?;
4330
4331            for notification_id in &existing_notification_ids {
4332                session.peer.send(
4333                    connection,
4334                    proto::DeleteNotification {
4335                        notification_id: (*notification_id).to_proto(),
4336                    },
4337                )?;
4338            }
4339
4340            Ok(())
4341        },
4342    );
4343    response.send(proto::Ack {})?;
4344    Ok(())
4345}
4346
4347async fn update_channel_message(
4348    request: proto::UpdateChannelMessage,
4349    response: Response<proto::UpdateChannelMessage>,
4350    session: UserSession,
4351) -> Result<()> {
4352    let channel_id = ChannelId::from_proto(request.channel_id);
4353    let message_id = MessageId::from_proto(request.message_id);
4354    let updated_at = OffsetDateTime::now_utc();
4355    let UpdatedChannelMessage {
4356        message_id,
4357        participant_connection_ids,
4358        notifications,
4359        reply_to_message_id,
4360        timestamp,
4361        deleted_mention_notification_ids,
4362        updated_mention_notifications,
4363    } = session
4364        .db()
4365        .await
4366        .update_channel_message(
4367            channel_id,
4368            message_id,
4369            session.user_id(),
4370            request.body.as_str(),
4371            &request.mentions,
4372            updated_at,
4373        )
4374        .await?;
4375
4376    let nonce = request
4377        .nonce
4378        .clone()
4379        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4380
4381    let message = proto::ChannelMessage {
4382        sender_id: session.user_id().to_proto(),
4383        id: message_id.to_proto(),
4384        body: request.body.clone(),
4385        mentions: request.mentions.clone(),
4386        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4387        nonce: Some(nonce),
4388        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4389        edited_at: Some(updated_at.unix_timestamp() as u64),
4390    };
4391
4392    response.send(proto::Ack {})?;
4393
4394    let pool = &*session.connection_pool().await;
4395    broadcast(
4396        Some(session.connection_id),
4397        participant_connection_ids,
4398        |connection| {
4399            session.peer.send(
4400                connection,
4401                proto::ChannelMessageUpdate {
4402                    channel_id: channel_id.to_proto(),
4403                    message: Some(message.clone()),
4404                },
4405            )?;
4406
4407            for notification_id in &deleted_mention_notification_ids {
4408                session.peer.send(
4409                    connection,
4410                    proto::DeleteNotification {
4411                        notification_id: (*notification_id).to_proto(),
4412                    },
4413                )?;
4414            }
4415
4416            for notification in &updated_mention_notifications {
4417                session.peer.send(
4418                    connection,
4419                    proto::UpdateNotification {
4420                        notification: Some(notification.clone()),
4421                    },
4422                )?;
4423            }
4424
4425            Ok(())
4426        },
4427    );
4428
4429    send_notifications(pool, &session.peer, notifications);
4430
4431    Ok(())
4432}
4433
4434/// Mark a channel message as read
4435async fn acknowledge_channel_message(
4436    request: proto::AckChannelMessage,
4437    session: UserSession,
4438) -> Result<()> {
4439    let channel_id = ChannelId::from_proto(request.channel_id);
4440    let message_id = MessageId::from_proto(request.message_id);
4441    let notifications = session
4442        .db()
4443        .await
4444        .observe_channel_message(channel_id, session.user_id(), message_id)
4445        .await?;
4446    send_notifications(
4447        &*session.connection_pool().await,
4448        &session.peer,
4449        notifications,
4450    );
4451    Ok(())
4452}
4453
4454/// Mark a buffer version as synced
4455async fn acknowledge_buffer_version(
4456    request: proto::AckBufferOperation,
4457    session: UserSession,
4458) -> Result<()> {
4459    let buffer_id = BufferId::from_proto(request.buffer_id);
4460    session
4461        .db()
4462        .await
4463        .observe_buffer_version(
4464            buffer_id,
4465            session.user_id(),
4466            request.epoch as i32,
4467            &request.version,
4468        )
4469        .await?;
4470    Ok(())
4471}
4472
4473async fn count_language_model_tokens(
4474    request: proto::CountLanguageModelTokens,
4475    response: Response<proto::CountLanguageModelTokens>,
4476    session: Session,
4477    config: &Config,
4478) -> Result<()> {
4479    let Some(session) = session.for_user() else {
4480        return Err(anyhow!("user not found"))?;
4481    };
4482    authorize_access_to_legacy_llm_endpoints(&session).await?;
4483
4484    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4485        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4486        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4487    };
4488
4489    session
4490        .app_state
4491        .rate_limiter
4492        .check(&*rate_limit, session.user_id())
4493        .await?;
4494
4495    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4496        Some(proto::LanguageModelProvider::Google) => {
4497            let api_key = config
4498                .google_ai_api_key
4499                .as_ref()
4500                .context("no Google AI API key configured on the server")?;
4501            google_ai::count_tokens(
4502                session.http_client.as_ref(),
4503                google_ai::API_URL,
4504                api_key,
4505                serde_json::from_str(&request.request)?,
4506                None,
4507            )
4508            .await?
4509        }
4510        _ => return Err(anyhow!("unsupported provider"))?,
4511    };
4512
4513    response.send(proto::CountLanguageModelTokensResponse {
4514        token_count: result.total_tokens as u32,
4515    })?;
4516
4517    Ok(())
4518}
4519
4520struct ZedProCountLanguageModelTokensRateLimit;
4521
4522impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4523    fn capacity(&self) -> usize {
4524        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4525            .ok()
4526            .and_then(|v| v.parse().ok())
4527            .unwrap_or(600) // Picked arbitrarily
4528    }
4529
4530    fn refill_duration(&self) -> chrono::Duration {
4531        chrono::Duration::hours(1)
4532    }
4533
4534    fn db_name(&self) -> &'static str {
4535        "zed-pro:count-language-model-tokens"
4536    }
4537}
4538
4539struct FreeCountLanguageModelTokensRateLimit;
4540
4541impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4542    fn capacity(&self) -> usize {
4543        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4544            .ok()
4545            .and_then(|v| v.parse().ok())
4546            .unwrap_or(600 / 10) // Picked arbitrarily
4547    }
4548
4549    fn refill_duration(&self) -> chrono::Duration {
4550        chrono::Duration::hours(1)
4551    }
4552
4553    fn db_name(&self) -> &'static str {
4554        "free:count-language-model-tokens"
4555    }
4556}
4557
4558struct ZedProComputeEmbeddingsRateLimit;
4559
4560impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4561    fn capacity(&self) -> usize {
4562        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4563            .ok()
4564            .and_then(|v| v.parse().ok())
4565            .unwrap_or(5000) // Picked arbitrarily
4566    }
4567
4568    fn refill_duration(&self) -> chrono::Duration {
4569        chrono::Duration::hours(1)
4570    }
4571
4572    fn db_name(&self) -> &'static str {
4573        "zed-pro:compute-embeddings"
4574    }
4575}
4576
4577struct FreeComputeEmbeddingsRateLimit;
4578
4579impl RateLimit for FreeComputeEmbeddingsRateLimit {
4580    fn capacity(&self) -> usize {
4581        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4582            .ok()
4583            .and_then(|v| v.parse().ok())
4584            .unwrap_or(5000 / 10) // Picked arbitrarily
4585    }
4586
4587    fn refill_duration(&self) -> chrono::Duration {
4588        chrono::Duration::hours(1)
4589    }
4590
4591    fn db_name(&self) -> &'static str {
4592        "free:compute-embeddings"
4593    }
4594}
4595
4596async fn compute_embeddings(
4597    request: proto::ComputeEmbeddings,
4598    response: Response<proto::ComputeEmbeddings>,
4599    session: UserSession,
4600    api_key: Option<Arc<str>>,
4601) -> Result<()> {
4602    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4603    authorize_access_to_legacy_llm_endpoints(&session).await?;
4604
4605    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4606        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4607        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4608    };
4609
4610    session
4611        .app_state
4612        .rate_limiter
4613        .check(&*rate_limit, session.user_id())
4614        .await?;
4615
4616    let embeddings = match request.model.as_str() {
4617        "openai/text-embedding-3-small" => {
4618            open_ai::embed(
4619                session.http_client.as_ref(),
4620                OPEN_AI_API_URL,
4621                &api_key,
4622                OpenAiEmbeddingModel::TextEmbedding3Small,
4623                request.texts.iter().map(|text| text.as_str()),
4624            )
4625            .await?
4626        }
4627        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4628    };
4629
4630    let embeddings = request
4631        .texts
4632        .iter()
4633        .map(|text| {
4634            let mut hasher = sha2::Sha256::new();
4635            hasher.update(text.as_bytes());
4636            let result = hasher.finalize();
4637            result.to_vec()
4638        })
4639        .zip(
4640            embeddings
4641                .data
4642                .into_iter()
4643                .map(|embedding| embedding.embedding),
4644        )
4645        .collect::<HashMap<_, _>>();
4646
4647    let db = session.db().await;
4648    db.save_embeddings(&request.model, &embeddings)
4649        .await
4650        .context("failed to save embeddings")
4651        .trace_err();
4652
4653    response.send(proto::ComputeEmbeddingsResponse {
4654        embeddings: embeddings
4655            .into_iter()
4656            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4657            .collect(),
4658    })?;
4659    Ok(())
4660}
4661
4662async fn get_cached_embeddings(
4663    request: proto::GetCachedEmbeddings,
4664    response: Response<proto::GetCachedEmbeddings>,
4665    session: UserSession,
4666) -> Result<()> {
4667    authorize_access_to_legacy_llm_endpoints(&session).await?;
4668
4669    let db = session.db().await;
4670    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4671
4672    response.send(proto::GetCachedEmbeddingsResponse {
4673        embeddings: embeddings
4674            .into_iter()
4675            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4676            .collect(),
4677    })?;
4678    Ok(())
4679}
4680
4681/// This is leftover from before the LLM service.
4682///
4683/// The endpoints protected by this check will be moved there eventually.
4684async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4685    if session.is_staff() {
4686        Ok(())
4687    } else {
4688        Err(anyhow!("permission denied"))?
4689    }
4690}
4691
4692/// Get a Supermaven API key for the user
4693async fn get_supermaven_api_key(
4694    _request: proto::GetSupermavenApiKey,
4695    response: Response<proto::GetSupermavenApiKey>,
4696    session: UserSession,
4697) -> Result<()> {
4698    let user_id: String = session.user_id().to_string();
4699    if !session.is_staff() {
4700        return Err(anyhow!("supermaven not enabled for this account"))?;
4701    }
4702
4703    let email = session
4704        .email()
4705        .ok_or_else(|| anyhow!("user must have an email"))?;
4706
4707    let supermaven_admin_api = session
4708        .supermaven_client
4709        .as_ref()
4710        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4711
4712    let result = supermaven_admin_api
4713        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4714        .await?;
4715
4716    response.send(proto::GetSupermavenApiKeyResponse {
4717        api_key: result.api_key,
4718    })?;
4719
4720    Ok(())
4721}
4722
4723/// Start receiving chat updates for a channel
4724async fn join_channel_chat(
4725    request: proto::JoinChannelChat,
4726    response: Response<proto::JoinChannelChat>,
4727    session: UserSession,
4728) -> Result<()> {
4729    let channel_id = ChannelId::from_proto(request.channel_id);
4730
4731    let db = session.db().await;
4732    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4733        .await?;
4734    let messages = db
4735        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4736        .await?;
4737    response.send(proto::JoinChannelChatResponse {
4738        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4739        messages,
4740    })?;
4741    Ok(())
4742}
4743
4744/// Stop receiving chat updates for a channel
4745async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4746    let channel_id = ChannelId::from_proto(request.channel_id);
4747    session
4748        .db()
4749        .await
4750        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4751        .await?;
4752    Ok(())
4753}
4754
4755/// Retrieve the chat history for a channel
4756async fn get_channel_messages(
4757    request: proto::GetChannelMessages,
4758    response: Response<proto::GetChannelMessages>,
4759    session: UserSession,
4760) -> Result<()> {
4761    let channel_id = ChannelId::from_proto(request.channel_id);
4762    let messages = session
4763        .db()
4764        .await
4765        .get_channel_messages(
4766            channel_id,
4767            session.user_id(),
4768            MESSAGE_COUNT_PER_PAGE,
4769            Some(MessageId::from_proto(request.before_message_id)),
4770        )
4771        .await?;
4772    response.send(proto::GetChannelMessagesResponse {
4773        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4774        messages,
4775    })?;
4776    Ok(())
4777}
4778
4779/// Retrieve specific chat messages
4780async fn get_channel_messages_by_id(
4781    request: proto::GetChannelMessagesById,
4782    response: Response<proto::GetChannelMessagesById>,
4783    session: UserSession,
4784) -> Result<()> {
4785    let message_ids = request
4786        .message_ids
4787        .iter()
4788        .map(|id| MessageId::from_proto(*id))
4789        .collect::<Vec<_>>();
4790    let messages = session
4791        .db()
4792        .await
4793        .get_channel_messages_by_id(session.user_id(), &message_ids)
4794        .await?;
4795    response.send(proto::GetChannelMessagesResponse {
4796        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4797        messages,
4798    })?;
4799    Ok(())
4800}
4801
4802/// Retrieve the current users notifications
4803async fn get_notifications(
4804    request: proto::GetNotifications,
4805    response: Response<proto::GetNotifications>,
4806    session: UserSession,
4807) -> Result<()> {
4808    let notifications = session
4809        .db()
4810        .await
4811        .get_notifications(
4812            session.user_id(),
4813            NOTIFICATION_COUNT_PER_PAGE,
4814            request.before_id.map(db::NotificationId::from_proto),
4815        )
4816        .await?;
4817    response.send(proto::GetNotificationsResponse {
4818        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4819        notifications,
4820    })?;
4821    Ok(())
4822}
4823
4824/// Mark notifications as read
4825async fn mark_notification_as_read(
4826    request: proto::MarkNotificationRead,
4827    response: Response<proto::MarkNotificationRead>,
4828    session: UserSession,
4829) -> Result<()> {
4830    let database = &session.db().await;
4831    let notifications = database
4832        .mark_notification_as_read_by_id(
4833            session.user_id(),
4834            NotificationId::from_proto(request.notification_id),
4835        )
4836        .await?;
4837    send_notifications(
4838        &*session.connection_pool().await,
4839        &session.peer,
4840        notifications,
4841    );
4842    response.send(proto::Ack {})?;
4843    Ok(())
4844}
4845
4846/// Get the current users information
4847async fn get_private_user_info(
4848    _request: proto::GetPrivateUserInfo,
4849    response: Response<proto::GetPrivateUserInfo>,
4850    session: UserSession,
4851) -> Result<()> {
4852    let db = session.db().await;
4853
4854    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4855    let user = db
4856        .get_user_by_id(session.user_id())
4857        .await?
4858        .ok_or_else(|| anyhow!("user not found"))?;
4859    let flags = db.get_user_flags(session.user_id()).await?;
4860
4861    response.send(proto::GetPrivateUserInfoResponse {
4862        metrics_id,
4863        staff: user.admin,
4864        flags,
4865        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4866    })?;
4867    Ok(())
4868}
4869
4870/// Accept the terms of service (tos) on behalf of the current user
4871async fn accept_terms_of_service(
4872    _request: proto::AcceptTermsOfService,
4873    response: Response<proto::AcceptTermsOfService>,
4874    session: UserSession,
4875) -> Result<()> {
4876    let db = session.db().await;
4877
4878    let accepted_tos_at = Utc::now();
4879    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4880        .await?;
4881
4882    response.send(proto::AcceptTermsOfServiceResponse {
4883        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4884    })?;
4885    Ok(())
4886}
4887
4888/// The minimum account age an account must have in order to use the LLM service.
4889const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4890
4891async fn get_llm_api_token(
4892    _request: proto::GetLlmToken,
4893    response: Response<proto::GetLlmToken>,
4894    session: UserSession,
4895) -> Result<()> {
4896    let db = session.db().await;
4897
4898    let flags = db.get_user_flags(session.user_id()).await?;
4899    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4900    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4901
4902    if !session.is_staff() && !has_language_models_feature_flag {
4903        Err(anyhow!("permission denied"))?
4904    }
4905
4906    let user_id = session.user_id();
4907    let user = db
4908        .get_user_by_id(user_id)
4909        .await?
4910        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4911
4912    if user.accepted_tos_at.is_none() {
4913        Err(anyhow!("terms of service not accepted"))?
4914    }
4915
4916    let mut account_created_at = user.created_at;
4917    if let Some(github_created_at) = user.github_user_created_at {
4918        account_created_at = account_created_at.min(github_created_at);
4919    }
4920    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4921        Err(anyhow!("account too young"))?
4922    }
4923    let token = LlmTokenClaims::create(
4924        user.id,
4925        user.github_login.clone(),
4926        session.is_staff(),
4927        has_llm_closed_beta_feature_flag,
4928        session.has_llm_subscription(&db).await?,
4929        session.current_plan(&db).await?,
4930        &session.app_state.config,
4931    )?;
4932    response.send(proto::GetLlmTokenResponse { token })?;
4933    Ok(())
4934}
4935
4936fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4937    let message = match message {
4938        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4939        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4940        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4941        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4942        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4943            code: frame.code.into(),
4944            reason: frame.reason,
4945        })),
4946        // We should never receive a frame while reading the message, according
4947        // to the `tungstenite` maintainers:
4948        //
4949        // > It cannot occur when you read messages from the WebSocket, but it
4950        // > can be used when you want to send the raw frames (e.g. you want to
4951        // > send the frames to the WebSocket without composing the full message first).
4952        // >
4953        // > — https://github.com/snapview/tungstenite-rs/issues/268
4954        TungsteniteMessage::Frame(_) => {
4955            bail!("received an unexpected frame while reading the message")
4956        }
4957    };
4958
4959    Ok(message)
4960}
4961
4962fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4963    match message {
4964        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4965        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4966        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4967        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4968        AxumMessage::Close(frame) => {
4969            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4970                code: frame.code.into(),
4971                reason: frame.reason,
4972            }))
4973        }
4974    }
4975}
4976
4977fn notify_membership_updated(
4978    connection_pool: &mut ConnectionPool,
4979    result: MembershipUpdated,
4980    user_id: UserId,
4981    peer: &Peer,
4982) {
4983    for membership in &result.new_channels.channel_memberships {
4984        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4985    }
4986    for channel_id in &result.removed_channels {
4987        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4988    }
4989
4990    let user_channels_update = proto::UpdateUserChannels {
4991        channel_memberships: result
4992            .new_channels
4993            .channel_memberships
4994            .iter()
4995            .map(|cm| proto::ChannelMembership {
4996                channel_id: cm.channel_id.to_proto(),
4997                role: cm.role.into(),
4998            })
4999            .collect(),
5000        ..Default::default()
5001    };
5002
5003    let mut update = build_channels_update(result.new_channels);
5004    update.delete_channels = result
5005        .removed_channels
5006        .into_iter()
5007        .map(|id| id.to_proto())
5008        .collect();
5009    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5010
5011    for connection_id in connection_pool.user_connection_ids(user_id) {
5012        peer.send(connection_id, user_channels_update.clone())
5013            .trace_err();
5014        peer.send(connection_id, update.clone()).trace_err();
5015    }
5016}
5017
5018fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5019    proto::UpdateUserChannels {
5020        channel_memberships: channels
5021            .channel_memberships
5022            .iter()
5023            .map(|m| proto::ChannelMembership {
5024                channel_id: m.channel_id.to_proto(),
5025                role: m.role.into(),
5026            })
5027            .collect(),
5028        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5029        observed_channel_message_id: channels.observed_channel_messages.clone(),
5030    }
5031}
5032
5033fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5034    let mut update = proto::UpdateChannels::default();
5035
5036    for channel in channels.channels {
5037        update.channels.push(channel.to_proto());
5038    }
5039
5040    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5041    update.latest_channel_message_ids = channels.latest_channel_messages;
5042
5043    for (channel_id, participants) in channels.channel_participants {
5044        update
5045            .channel_participants
5046            .push(proto::ChannelParticipants {
5047                channel_id: channel_id.to_proto(),
5048                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5049            });
5050    }
5051
5052    for channel in channels.invited_channels {
5053        update.channel_invitations.push(channel.to_proto());
5054    }
5055
5056    update.hosted_projects = channels.hosted_projects;
5057    update
5058}
5059
5060fn build_initial_contacts_update(
5061    contacts: Vec<db::Contact>,
5062    pool: &ConnectionPool,
5063) -> proto::UpdateContacts {
5064    let mut update = proto::UpdateContacts::default();
5065
5066    for contact in contacts {
5067        match contact {
5068            db::Contact::Accepted { user_id, busy } => {
5069                update.contacts.push(contact_for_user(user_id, busy, pool));
5070            }
5071            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5072            db::Contact::Incoming { user_id } => {
5073                update
5074                    .incoming_requests
5075                    .push(proto::IncomingContactRequest {
5076                        requester_id: user_id.to_proto(),
5077                    })
5078            }
5079        }
5080    }
5081
5082    update
5083}
5084
5085fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5086    proto::Contact {
5087        user_id: user_id.to_proto(),
5088        online: pool.is_user_online(user_id),
5089        busy,
5090    }
5091}
5092
5093fn room_updated(room: &proto::Room, peer: &Peer) {
5094    broadcast(
5095        None,
5096        room.participants
5097            .iter()
5098            .filter_map(|participant| Some(participant.peer_id?.into())),
5099        |peer_id| {
5100            peer.send(
5101                peer_id,
5102                proto::RoomUpdated {
5103                    room: Some(room.clone()),
5104                },
5105            )
5106        },
5107    );
5108}
5109
5110fn channel_updated(
5111    channel: &db::channel::Model,
5112    room: &proto::Room,
5113    peer: &Peer,
5114    pool: &ConnectionPool,
5115) {
5116    let participants = room
5117        .participants
5118        .iter()
5119        .map(|p| p.user_id)
5120        .collect::<Vec<_>>();
5121
5122    broadcast(
5123        None,
5124        pool.channel_connection_ids(channel.root_id())
5125            .filter_map(|(channel_id, role)| {
5126                role.can_see_channel(channel.visibility)
5127                    .then_some(channel_id)
5128            }),
5129        |peer_id| {
5130            peer.send(
5131                peer_id,
5132                proto::UpdateChannels {
5133                    channel_participants: vec![proto::ChannelParticipants {
5134                        channel_id: channel.id.to_proto(),
5135                        participant_user_ids: participants.clone(),
5136                    }],
5137                    ..Default::default()
5138                },
5139            )
5140        },
5141    );
5142}
5143
5144async fn send_dev_server_projects_update(
5145    user_id: UserId,
5146    mut status: proto::DevServerProjectsUpdate,
5147    session: &Session,
5148) {
5149    let pool = session.connection_pool().await;
5150    for dev_server in &mut status.dev_servers {
5151        dev_server.status =
5152            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5153    }
5154    let connections = pool.user_connection_ids(user_id);
5155    for connection_id in connections {
5156        session.peer.send(connection_id, status.clone()).trace_err();
5157    }
5158}
5159
5160async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5161    let db = session.db().await;
5162
5163    let contacts = db.get_contacts(user_id).await?;
5164    let busy = db.is_user_busy(user_id).await?;
5165
5166    let pool = session.connection_pool().await;
5167    let updated_contact = contact_for_user(user_id, busy, &pool);
5168    for contact in contacts {
5169        if let db::Contact::Accepted {
5170            user_id: contact_user_id,
5171            ..
5172        } = contact
5173        {
5174            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5175                session
5176                    .peer
5177                    .send(
5178                        contact_conn_id,
5179                        proto::UpdateContacts {
5180                            contacts: vec![updated_contact.clone()],
5181                            remove_contacts: Default::default(),
5182                            incoming_requests: Default::default(),
5183                            remove_incoming_requests: Default::default(),
5184                            outgoing_requests: Default::default(),
5185                            remove_outgoing_requests: Default::default(),
5186                        },
5187                    )
5188                    .trace_err();
5189            }
5190        }
5191    }
5192    Ok(())
5193}
5194
5195async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5196    log::info!("lost dev server connection, unsharing projects");
5197    let project_ids = session
5198        .db()
5199        .await
5200        .get_stale_dev_server_projects(session.connection_id)
5201        .await?;
5202
5203    for project_id in project_ids {
5204        // not unshare re-checks the connection ids match, so we get away with no transaction
5205        unshare_project_internal(project_id, session.connection_id, None, session).await?;
5206    }
5207
5208    let user_id = session.dev_server().user_id;
5209    let update = session
5210        .db()
5211        .await
5212        .dev_server_projects_update(user_id)
5213        .await?;
5214
5215    send_dev_server_projects_update(user_id, update, session).await;
5216
5217    Ok(())
5218}
5219
5220async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5221    let mut contacts_to_update = HashSet::default();
5222
5223    let room_id;
5224    let canceled_calls_to_user_ids;
5225    let live_kit_room;
5226    let delete_live_kit_room;
5227    let room;
5228    let channel;
5229
5230    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5231        contacts_to_update.insert(session.user_id());
5232
5233        for project in left_room.left_projects.values() {
5234            project_left(project, session);
5235        }
5236
5237        room_id = RoomId::from_proto(left_room.room.id);
5238        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5239        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5240        delete_live_kit_room = left_room.deleted;
5241        room = mem::take(&mut left_room.room);
5242        channel = mem::take(&mut left_room.channel);
5243
5244        room_updated(&room, &session.peer);
5245    } else {
5246        return Ok(());
5247    }
5248
5249    if let Some(channel) = channel {
5250        channel_updated(
5251            &channel,
5252            &room,
5253            &session.peer,
5254            &*session.connection_pool().await,
5255        );
5256    }
5257
5258    {
5259        let pool = session.connection_pool().await;
5260        for canceled_user_id in canceled_calls_to_user_ids {
5261            for connection_id in pool.user_connection_ids(canceled_user_id) {
5262                session
5263                    .peer
5264                    .send(
5265                        connection_id,
5266                        proto::CallCanceled {
5267                            room_id: room_id.to_proto(),
5268                        },
5269                    )
5270                    .trace_err();
5271            }
5272            contacts_to_update.insert(canceled_user_id);
5273        }
5274    }
5275
5276    for contact_user_id in contacts_to_update {
5277        update_user_contacts(contact_user_id, session).await?;
5278    }
5279
5280    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5281        live_kit
5282            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5283            .await
5284            .trace_err();
5285
5286        if delete_live_kit_room {
5287            live_kit.delete_room(live_kit_room).await.trace_err();
5288        }
5289    }
5290
5291    Ok(())
5292}
5293
5294async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5295    let left_channel_buffers = session
5296        .db()
5297        .await
5298        .leave_channel_buffers(session.connection_id)
5299        .await?;
5300
5301    for left_buffer in left_channel_buffers {
5302        channel_buffer_updated(
5303            session.connection_id,
5304            left_buffer.connections,
5305            &proto::UpdateChannelBufferCollaborators {
5306                channel_id: left_buffer.channel_id.to_proto(),
5307                collaborators: left_buffer.collaborators,
5308            },
5309            &session.peer,
5310        );
5311    }
5312
5313    Ok(())
5314}
5315
5316fn project_left(project: &db::LeftProject, session: &UserSession) {
5317    for connection_id in &project.connection_ids {
5318        if project.should_unshare {
5319            session
5320                .peer
5321                .send(
5322                    *connection_id,
5323                    proto::UnshareProject {
5324                        project_id: project.id.to_proto(),
5325                    },
5326                )
5327                .trace_err();
5328        } else {
5329            session
5330                .peer
5331                .send(
5332                    *connection_id,
5333                    proto::RemoveProjectCollaborator {
5334                        project_id: project.id.to_proto(),
5335                        peer_id: Some(session.connection_id.into()),
5336                    },
5337                )
5338                .trace_err();
5339        }
5340    }
5341}
5342
5343pub trait ResultExt {
5344    type Ok;
5345
5346    fn trace_err(self) -> Option<Self::Ok>;
5347}
5348
5349impl<T, E> ResultExt for Result<T, E>
5350where
5351    E: std::fmt::Debug,
5352{
5353    type Ok = T;
5354
5355    #[track_caller]
5356    fn trace_err(self) -> Option<T> {
5357        match self {
5358            Ok(value) => Some(value),
5359            Err(error) => {
5360                tracing::error!("{:?}", error);
5361                None
5362            }
5363        }
5364    }
5365}