rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use http_client::HttpClient;
  39use isahc_http_client::IsahcHttpClient;
  40use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  41use sha2::Digest;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    channel::oneshot,
  46    future::{self, BoxFuture},
  47    stream::FuturesUnordered,
  48    FutureExt, SinkExt, StreamExt, TryStreamExt,
  49};
  50use prometheus::{register_int_gauge, IntGauge};
  51use rpc::{
  52    proto::{
  53        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  54        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  55    },
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57};
  58use semantic_version::SemanticVersion;
  59use serde::{Serialize, Serializer};
  60use std::{
  61    any::TypeId,
  62    future::Future,
  63    marker::PhantomData,
  64    mem,
  65    net::SocketAddr,
  66    ops::{Deref, DerefMut},
  67    rc::Rc,
  68    sync::{
  69        atomic::{AtomicBool, Ordering::SeqCst},
  70        Arc, OnceLock,
  71    },
  72    time::{Duration, Instant},
  73};
  74use time::OffsetDateTime;
  75use tokio::sync::{watch, MutexGuard, Semaphore};
  76use tower::ServiceBuilder;
  77use tracing::{
  78    field::{self},
  79    info_span, instrument, Instrument,
  80};
  81
  82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  83
  84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  86
  87const MESSAGE_COUNT_PER_PAGE: usize = 100;
  88const MAX_MESSAGE_LEN: usize = 1024;
  89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  93
  94struct Response<R> {
  95    peer: Arc<Peer>,
  96    receipt: Receipt<R>,
  97    responded: Arc<AtomicBool>,
  98}
  99
 100impl<R: RequestMessage> Response<R> {
 101    fn send(self, payload: R::Response) -> Result<()> {
 102        self.responded.store(true, SeqCst);
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108#[derive(Clone, Debug)]
 109pub enum Principal {
 110    User(User),
 111    Impersonated { user: User, admin: User },
 112    DevServer(dev_server::Model),
 113}
 114
 115impl Principal {
 116    fn update_span(&self, span: &tracing::Span) {
 117        match &self {
 118            Principal::User(user) => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121            }
 122            Principal::Impersonated { user, admin } => {
 123                span.record("user_id", user.id.0);
 124                span.record("login", &user.github_login);
 125                span.record("impersonator", &admin.github_login);
 126            }
 127            Principal::DevServer(dev_server) => {
 128                span.record("dev_server_id", dev_server.id.0);
 129            }
 130        }
 131    }
 132}
 133
 134#[derive(Clone)]
 135struct Session {
 136    principal: Principal,
 137    connection_id: ConnectionId,
 138    db: Arc<tokio::sync::Mutex<DbHandle>>,
 139    peer: Arc<Peer>,
 140    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 141    app_state: Arc<AppState>,
 142    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 143    http_client: Arc<dyn HttpClient>,
 144    /// The GeoIP country code for the user.
 145    #[allow(unused)]
 146    geoip_country_code: Option<String>,
 147    _executor: Executor,
 148}
 149
 150impl Session {
 151    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.db.lock().await;
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        guard
 158    }
 159
 160    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        let guard = self.connection_pool.lock();
 164        ConnectionPoolGuard {
 165            guard,
 166            _not_send: PhantomData,
 167        }
 168    }
 169
 170    fn for_user(self) -> Option<UserSession> {
 171        UserSession::new(self)
 172    }
 173
 174    fn for_dev_server(self) -> Option<DevServerSession> {
 175        DevServerSession::new(self)
 176    }
 177
 178    fn user_id(&self) -> Option<UserId> {
 179        match &self.principal {
 180            Principal::User(user) => Some(user.id),
 181            Principal::Impersonated { user, .. } => Some(user.id),
 182            Principal::DevServer(_) => None,
 183        }
 184    }
 185
 186    fn is_staff(&self) -> bool {
 187        match &self.principal {
 188            Principal::User(user) => user.admin,
 189            Principal::Impersonated { .. } => true,
 190            Principal::DevServer(_) => false,
 191        }
 192    }
 193
 194    pub async fn has_llm_subscription(
 195        &self,
 196        db: &MutexGuard<'_, DbHandle>,
 197    ) -> anyhow::Result<bool> {
 198        if self.is_staff() {
 199            return Ok(true);
 200        }
 201
 202        let Some(user_id) = self.user_id() else {
 203            return Ok(false);
 204        };
 205
 206        Ok(db.has_active_billing_subscription(user_id).await?)
 207    }
 208
 209    pub async fn current_plan(
 210        &self,
 211        _db: &MutexGuard<'_, DbHandle>,
 212    ) -> anyhow::Result<proto::Plan> {
 213        if self.is_staff() {
 214            Ok(proto::Plan::ZedPro)
 215        } else {
 216            Ok(proto::Plan::Free)
 217        }
 218    }
 219
 220    fn dev_server_id(&self) -> Option<DevServerId> {
 221        match &self.principal {
 222            Principal::User(_) | Principal::Impersonated { .. } => None,
 223            Principal::DevServer(dev_server) => Some(dev_server.id),
 224        }
 225    }
 226
 227    fn principal_id(&self) -> PrincipalId {
 228        match &self.principal {
 229            Principal::User(user) => PrincipalId::UserId(user.id),
 230            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 231            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 232        }
 233    }
 234}
 235
 236impl Debug for Session {
 237    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 238        let mut result = f.debug_struct("Session");
 239        match &self.principal {
 240            Principal::User(user) => {
 241                result.field("user", &user.github_login);
 242            }
 243            Principal::Impersonated { user, admin } => {
 244                result.field("user", &user.github_login);
 245                result.field("impersonator", &admin.github_login);
 246            }
 247            Principal::DevServer(dev_server) => {
 248                result.field("dev_server", &dev_server.id);
 249            }
 250        }
 251        result.field("connection_id", &self.connection_id).finish()
 252    }
 253}
 254
 255struct UserSession(Session);
 256
 257impl UserSession {
 258    pub fn new(s: Session) -> Option<Self> {
 259        s.user_id().map(|_| UserSession(s))
 260    }
 261    pub fn user_id(&self) -> UserId {
 262        self.0.user_id().unwrap()
 263    }
 264
 265    pub fn email(&self) -> Option<String> {
 266        match &self.0.principal {
 267            Principal::User(user) => user.email_address.clone(),
 268            Principal::Impersonated { user, .. } => user.email_address.clone(),
 269            Principal::DevServer(..) => None,
 270        }
 271    }
 272}
 273
 274impl Deref for UserSession {
 275    type Target = Session;
 276
 277    fn deref(&self) -> &Self::Target {
 278        &self.0
 279    }
 280}
 281impl DerefMut for UserSession {
 282    fn deref_mut(&mut self) -> &mut Self::Target {
 283        &mut self.0
 284    }
 285}
 286
 287struct DevServerSession(Session);
 288
 289impl DevServerSession {
 290    pub fn new(s: Session) -> Option<Self> {
 291        s.dev_server_id().map(|_| DevServerSession(s))
 292    }
 293    pub fn dev_server_id(&self) -> DevServerId {
 294        self.0.dev_server_id().unwrap()
 295    }
 296
 297    fn dev_server(&self) -> &dev_server::Model {
 298        match &self.0.principal {
 299            Principal::DevServer(dev_server) => dev_server,
 300            _ => unreachable!(),
 301        }
 302    }
 303}
 304
 305impl Deref for DevServerSession {
 306    type Target = Session;
 307
 308    fn deref(&self) -> &Self::Target {
 309        &self.0
 310    }
 311}
 312impl DerefMut for DevServerSession {
 313    fn deref_mut(&mut self) -> &mut Self::Target {
 314        &mut self.0
 315    }
 316}
 317
 318fn user_handler<M: RequestMessage, Fut>(
 319    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 320) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 321where
 322    Fut: Send + Future<Output = Result<()>>,
 323{
 324    let handler = Arc::new(handler);
 325    move |message, response, session| {
 326        let handler = handler.clone();
 327        Box::pin(async move {
 328            if let Some(user_session) = session.for_user() {
 329                Ok(handler(message, response, user_session).await?)
 330            } else {
 331                Err(Error::Internal(anyhow!(
 332                    "must be a user to call {}",
 333                    M::NAME
 334                )))
 335            }
 336        })
 337    }
 338}
 339
 340fn dev_server_handler<M: RequestMessage, Fut>(
 341    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 342) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 343where
 344    Fut: Send + Future<Output = Result<()>>,
 345{
 346    let handler = Arc::new(handler);
 347    move |message, response, session| {
 348        let handler = handler.clone();
 349        Box::pin(async move {
 350            if let Some(dev_server_session) = session.for_dev_server() {
 351                Ok(handler(message, response, dev_server_session).await?)
 352            } else {
 353                Err(Error::Internal(anyhow!(
 354                    "must be a dev server to call {}",
 355                    M::NAME
 356                )))
 357            }
 358        })
 359    }
 360}
 361
 362fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 363    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 364) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 365where
 366    InnertRetFut: Send + Future<Output = Result<()>>,
 367{
 368    let handler = Arc::new(handler);
 369    move |message, session| {
 370        let handler = handler.clone();
 371        Box::pin(async move {
 372            if let Some(user_session) = session.for_user() {
 373                Ok(handler(message, user_session).await?)
 374            } else {
 375                Err(Error::Internal(anyhow!(
 376                    "must be a user to call {}",
 377                    M::NAME
 378                )))
 379            }
 380        })
 381    }
 382}
 383
 384struct DbHandle(Arc<Database>);
 385
 386impl Deref for DbHandle {
 387    type Target = Database;
 388
 389    fn deref(&self) -> &Self::Target {
 390        self.0.as_ref()
 391    }
 392}
 393
 394pub struct Server {
 395    id: parking_lot::Mutex<ServerId>,
 396    peer: Arc<Peer>,
 397    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 398    app_state: Arc<AppState>,
 399    handlers: HashMap<TypeId, MessageHandler>,
 400    teardown: watch::Sender<bool>,
 401}
 402
 403pub(crate) struct ConnectionPoolGuard<'a> {
 404    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 405    _not_send: PhantomData<Rc<()>>,
 406}
 407
 408#[derive(Serialize)]
 409pub struct ServerSnapshot<'a> {
 410    peer: &'a Peer,
 411    #[serde(serialize_with = "serialize_deref")]
 412    connection_pool: ConnectionPoolGuard<'a>,
 413}
 414
 415pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 416where
 417    S: Serializer,
 418    T: Deref<Target = U>,
 419    U: Serialize,
 420{
 421    Serialize::serialize(value.deref(), serializer)
 422}
 423
 424impl Server {
 425    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 426        let mut server = Self {
 427            id: parking_lot::Mutex::new(id),
 428            peer: Peer::new(id.0 as u32),
 429            app_state: app_state.clone(),
 430            connection_pool: Default::default(),
 431            handlers: Default::default(),
 432            teardown: watch::channel(false).0,
 433        };
 434
 435        server
 436            .add_request_handler(ping)
 437            .add_request_handler(user_handler(create_room))
 438            .add_request_handler(user_handler(join_room))
 439            .add_request_handler(user_handler(rejoin_room))
 440            .add_request_handler(user_handler(leave_room))
 441            .add_request_handler(user_handler(set_room_participant_role))
 442            .add_request_handler(user_handler(call))
 443            .add_request_handler(user_handler(cancel_call))
 444            .add_message_handler(user_message_handler(decline_call))
 445            .add_request_handler(user_handler(update_participant_location))
 446            .add_request_handler(user_handler(share_project))
 447            .add_message_handler(unshare_project)
 448            .add_request_handler(user_handler(join_project))
 449            .add_request_handler(user_handler(join_hosted_project))
 450            .add_request_handler(user_handler(rejoin_dev_server_projects))
 451            .add_request_handler(user_handler(create_dev_server_project))
 452            .add_request_handler(user_handler(update_dev_server_project))
 453            .add_request_handler(user_handler(delete_dev_server_project))
 454            .add_request_handler(user_handler(create_dev_server))
 455            .add_request_handler(user_handler(regenerate_dev_server_token))
 456            .add_request_handler(user_handler(rename_dev_server))
 457            .add_request_handler(user_handler(delete_dev_server))
 458            .add_request_handler(user_handler(list_remote_directory))
 459            .add_request_handler(dev_server_handler(share_dev_server_project))
 460            .add_request_handler(dev_server_handler(shutdown_dev_server))
 461            .add_request_handler(dev_server_handler(reconnect_dev_server))
 462            .add_message_handler(user_message_handler(leave_project))
 463            .add_request_handler(update_project)
 464            .add_request_handler(update_worktree)
 465            .add_message_handler(start_language_server)
 466            .add_message_handler(update_language_server)
 467            .add_message_handler(update_diagnostic_summary)
 468            .add_message_handler(update_worktree_settings)
 469            .add_request_handler(user_handler(
 470                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetHover>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetDefinition>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::GetTypeDefinition>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetReferences>,
 483            ))
 484            .add_request_handler(user_handler(forward_find_search_candidates_request))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::GetProjectSymbols>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_read_only_project_request::<proto::OpenBufferById>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_read_only_project_request::<proto::InlayHints>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_read_only_project_request::<proto::ResolveInlayHint>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_read_only_project_request::<proto::OpenBufferByPath>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::GetCompletions>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::OpenNewBuffer>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::GetCodeActions>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::ApplyCodeAction>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::PrepareRename>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::PerformRename>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::ReloadBuffers>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::FormatBuffers>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_mutating_project_request::<proto::CreateProjectEntry>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::RenameProjectEntry>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::CopyProjectEntry>,
 547            ))
 548            .add_request_handler(user_handler(
 549                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 550            ))
 551            .add_request_handler(user_handler(
 552                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 553            ))
 554            .add_request_handler(user_handler(
 555                forward_mutating_project_request::<proto::OnTypeFormatting>,
 556            ))
 557            .add_request_handler(user_handler(
 558                forward_mutating_project_request::<proto::SaveBuffer>,
 559            ))
 560            .add_request_handler(user_handler(
 561                forward_mutating_project_request::<proto::BlameBuffer>,
 562            ))
 563            .add_request_handler(user_handler(
 564                forward_mutating_project_request::<proto::MultiLspQuery>,
 565            ))
 566            .add_request_handler(user_handler(
 567                forward_mutating_project_request::<proto::RestartLanguageServers>,
 568            ))
 569            .add_request_handler(user_handler(
 570                forward_mutating_project_request::<proto::LinkedEditingRange>,
 571            ))
 572            .add_message_handler(create_buffer_for_peer)
 573            .add_request_handler(update_buffer)
 574            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 575            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 576            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 577            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 578            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 579            .add_request_handler(get_users)
 580            .add_request_handler(user_handler(fuzzy_search_users))
 581            .add_request_handler(user_handler(request_contact))
 582            .add_request_handler(user_handler(remove_contact))
 583            .add_request_handler(user_handler(respond_to_contact_request))
 584            .add_message_handler(subscribe_to_channels)
 585            .add_request_handler(user_handler(create_channel))
 586            .add_request_handler(user_handler(delete_channel))
 587            .add_request_handler(user_handler(invite_channel_member))
 588            .add_request_handler(user_handler(remove_channel_member))
 589            .add_request_handler(user_handler(set_channel_member_role))
 590            .add_request_handler(user_handler(set_channel_visibility))
 591            .add_request_handler(user_handler(rename_channel))
 592            .add_request_handler(user_handler(join_channel_buffer))
 593            .add_request_handler(user_handler(leave_channel_buffer))
 594            .add_message_handler(user_message_handler(update_channel_buffer))
 595            .add_request_handler(user_handler(rejoin_channel_buffers))
 596            .add_request_handler(user_handler(get_channel_members))
 597            .add_request_handler(user_handler(respond_to_channel_invite))
 598            .add_request_handler(user_handler(join_channel))
 599            .add_request_handler(user_handler(join_channel_chat))
 600            .add_message_handler(user_message_handler(leave_channel_chat))
 601            .add_request_handler(user_handler(send_channel_message))
 602            .add_request_handler(user_handler(remove_channel_message))
 603            .add_request_handler(user_handler(update_channel_message))
 604            .add_request_handler(user_handler(get_channel_messages))
 605            .add_request_handler(user_handler(get_channel_messages_by_id))
 606            .add_request_handler(user_handler(get_notifications))
 607            .add_request_handler(user_handler(mark_notification_as_read))
 608            .add_request_handler(user_handler(move_channel))
 609            .add_request_handler(user_handler(follow))
 610            .add_message_handler(user_message_handler(unfollow))
 611            .add_message_handler(user_message_handler(update_followers))
 612            .add_request_handler(user_handler(get_private_user_info))
 613            .add_request_handler(user_handler(get_llm_api_token))
 614            .add_request_handler(user_handler(accept_terms_of_service))
 615            .add_message_handler(user_message_handler(acknowledge_channel_message))
 616            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 617            .add_request_handler(user_handler(get_supermaven_api_key))
 618            .add_request_handler(user_handler(
 619                forward_mutating_project_request::<proto::OpenContext>,
 620            ))
 621            .add_request_handler(user_handler(
 622                forward_mutating_project_request::<proto::CreateContext>,
 623            ))
 624            .add_request_handler(user_handler(
 625                forward_mutating_project_request::<proto::SynchronizeContexts>,
 626            ))
 627            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 628            .add_message_handler(update_context)
 629            .add_request_handler({
 630                let app_state = app_state.clone();
 631                move |request, response, session| {
 632                    let app_state = app_state.clone();
 633                    async move {
 634                        count_language_model_tokens(request, response, session, &app_state.config)
 635                            .await
 636                    }
 637                }
 638            })
 639            .add_request_handler({
 640                user_handler(move |request, response, session| {
 641                    get_cached_embeddings(request, response, session)
 642                })
 643            })
 644            .add_request_handler({
 645                let app_state = app_state.clone();
 646                user_handler(move |request, response, session| {
 647                    compute_embeddings(
 648                        request,
 649                        response,
 650                        session,
 651                        app_state.config.openai_api_key.clone(),
 652                    )
 653                })
 654            });
 655
 656        Arc::new(server)
 657    }
 658
 659    pub async fn start(&self) -> Result<()> {
 660        let server_id = *self.id.lock();
 661        let app_state = self.app_state.clone();
 662        let peer = self.peer.clone();
 663        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 664        let pool = self.connection_pool.clone();
 665        let live_kit_client = self.app_state.live_kit_client.clone();
 666
 667        let span = info_span!("start server");
 668        self.app_state.executor.spawn_detached(
 669            async move {
 670                tracing::info!("waiting for cleanup timeout");
 671                timeout.await;
 672                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 673                if let Some((room_ids, channel_ids)) = app_state
 674                    .db
 675                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 676                    .await
 677                    .trace_err()
 678                {
 679                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 680                    tracing::info!(
 681                        stale_channel_buffer_count = channel_ids.len(),
 682                        "retrieved stale channel buffers"
 683                    );
 684
 685                    for channel_id in channel_ids {
 686                        if let Some(refreshed_channel_buffer) = app_state
 687                            .db
 688                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 689                            .await
 690                            .trace_err()
 691                        {
 692                            for connection_id in refreshed_channel_buffer.connection_ids {
 693                                peer.send(
 694                                    connection_id,
 695                                    proto::UpdateChannelBufferCollaborators {
 696                                        channel_id: channel_id.to_proto(),
 697                                        collaborators: refreshed_channel_buffer
 698                                            .collaborators
 699                                            .clone(),
 700                                    },
 701                                )
 702                                .trace_err();
 703                            }
 704                        }
 705                    }
 706
 707                    for room_id in room_ids {
 708                        let mut contacts_to_update = HashSet::default();
 709                        let mut canceled_calls_to_user_ids = Vec::new();
 710                        let mut live_kit_room = String::new();
 711                        let mut delete_live_kit_room = false;
 712
 713                        if let Some(mut refreshed_room) = app_state
 714                            .db
 715                            .clear_stale_room_participants(room_id, server_id)
 716                            .await
 717                            .trace_err()
 718                        {
 719                            tracing::info!(
 720                                room_id = room_id.0,
 721                                new_participant_count = refreshed_room.room.participants.len(),
 722                                "refreshed room"
 723                            );
 724                            room_updated(&refreshed_room.room, &peer);
 725                            if let Some(channel) = refreshed_room.channel.as_ref() {
 726                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 727                            }
 728                            contacts_to_update
 729                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 730                            contacts_to_update
 731                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 732                            canceled_calls_to_user_ids =
 733                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 734                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 735                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 736                        }
 737
 738                        {
 739                            let pool = pool.lock();
 740                            for canceled_user_id in canceled_calls_to_user_ids {
 741                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 742                                    peer.send(
 743                                        connection_id,
 744                                        proto::CallCanceled {
 745                                            room_id: room_id.to_proto(),
 746                                        },
 747                                    )
 748                                    .trace_err();
 749                                }
 750                            }
 751                        }
 752
 753                        for user_id in contacts_to_update {
 754                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 755                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 756                            if let Some((busy, contacts)) = busy.zip(contacts) {
 757                                let pool = pool.lock();
 758                                let updated_contact = contact_for_user(user_id, busy, &pool);
 759                                for contact in contacts {
 760                                    if let db::Contact::Accepted {
 761                                        user_id: contact_user_id,
 762                                        ..
 763                                    } = contact
 764                                    {
 765                                        for contact_conn_id in
 766                                            pool.user_connection_ids(contact_user_id)
 767                                        {
 768                                            peer.send(
 769                                                contact_conn_id,
 770                                                proto::UpdateContacts {
 771                                                    contacts: vec![updated_contact.clone()],
 772                                                    remove_contacts: Default::default(),
 773                                                    incoming_requests: Default::default(),
 774                                                    remove_incoming_requests: Default::default(),
 775                                                    outgoing_requests: Default::default(),
 776                                                    remove_outgoing_requests: Default::default(),
 777                                                },
 778                                            )
 779                                            .trace_err();
 780                                        }
 781                                    }
 782                                }
 783                            }
 784                        }
 785
 786                        if let Some(live_kit) = live_kit_client.as_ref() {
 787                            if delete_live_kit_room {
 788                                live_kit.delete_room(live_kit_room).await.trace_err();
 789                            }
 790                        }
 791                    }
 792                }
 793
 794                app_state
 795                    .db
 796                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 797                    .await
 798                    .trace_err();
 799            }
 800            .instrument(span),
 801        );
 802        Ok(())
 803    }
 804
 805    pub fn teardown(&self) {
 806        self.peer.teardown();
 807        self.connection_pool.lock().reset();
 808        let _ = self.teardown.send(true);
 809    }
 810
 811    #[cfg(test)]
 812    pub fn reset(&self, id: ServerId) {
 813        self.teardown();
 814        *self.id.lock() = id;
 815        self.peer.reset(id.0 as u32);
 816        let _ = self.teardown.send(false);
 817    }
 818
 819    #[cfg(test)]
 820    pub fn id(&self) -> ServerId {
 821        *self.id.lock()
 822    }
 823
 824    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 825    where
 826        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 827        Fut: 'static + Send + Future<Output = Result<()>>,
 828        M: EnvelopedMessage,
 829    {
 830        let prev_handler = self.handlers.insert(
 831            TypeId::of::<M>(),
 832            Box::new(move |envelope, session| {
 833                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 834                let received_at = envelope.received_at;
 835                tracing::info!("message received");
 836                let start_time = Instant::now();
 837                let future = (handler)(*envelope, session);
 838                async move {
 839                    let result = future.await;
 840                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 841                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 842                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 843                    let payload_type = M::NAME;
 844
 845                    match result {
 846                        Err(error) => {
 847                            tracing::error!(
 848                                ?error,
 849                                total_duration_ms,
 850                                processing_duration_ms,
 851                                queue_duration_ms,
 852                                payload_type,
 853                                "error handling message"
 854                            )
 855                        }
 856                        Ok(()) => tracing::info!(
 857                            total_duration_ms,
 858                            processing_duration_ms,
 859                            queue_duration_ms,
 860                            "finished handling message"
 861                        ),
 862                    }
 863                }
 864                .boxed()
 865            }),
 866        );
 867        if prev_handler.is_some() {
 868            panic!("registered a handler for the same message twice");
 869        }
 870        self
 871    }
 872
 873    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 874    where
 875        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 876        Fut: 'static + Send + Future<Output = Result<()>>,
 877        M: EnvelopedMessage,
 878    {
 879        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 880        self
 881    }
 882
 883    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 884    where
 885        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 886        Fut: Send + Future<Output = Result<()>>,
 887        M: RequestMessage,
 888    {
 889        let handler = Arc::new(handler);
 890        self.add_handler(move |envelope, session| {
 891            let receipt = envelope.receipt();
 892            let handler = handler.clone();
 893            async move {
 894                let peer = session.peer.clone();
 895                let responded = Arc::new(AtomicBool::default());
 896                let response = Response {
 897                    peer: peer.clone(),
 898                    responded: responded.clone(),
 899                    receipt,
 900                };
 901                match (handler)(envelope.payload, response, session).await {
 902                    Ok(()) => {
 903                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 904                            Ok(())
 905                        } else {
 906                            Err(anyhow!("handler did not send a response"))?
 907                        }
 908                    }
 909                    Err(error) => {
 910                        let proto_err = match &error {
 911                            Error::Internal(err) => err.to_proto(),
 912                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 913                        };
 914                        peer.respond_with_error(receipt, proto_err)?;
 915                        Err(error)
 916                    }
 917                }
 918            }
 919        })
 920    }
 921
 922    #[allow(clippy::too_many_arguments)]
 923    pub fn handle_connection(
 924        self: &Arc<Self>,
 925        connection: Connection,
 926        address: String,
 927        principal: Principal,
 928        zed_version: ZedVersion,
 929        geoip_country_code: Option<String>,
 930        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 931        executor: Executor,
 932    ) -> impl Future<Output = ()> {
 933        let this = self.clone();
 934        let span = info_span!("handle connection", %address,
 935            connection_id=field::Empty,
 936            user_id=field::Empty,
 937            login=field::Empty,
 938            impersonator=field::Empty,
 939            dev_server_id=field::Empty,
 940            geoip_country_code=field::Empty
 941        );
 942        principal.update_span(&span);
 943        if let Some(country_code) = geoip_country_code.as_ref() {
 944            span.record("geoip_country_code", country_code);
 945        }
 946
 947        let mut teardown = self.teardown.subscribe();
 948        async move {
 949            if *teardown.borrow() {
 950                tracing::error!("server is tearing down");
 951                return
 952            }
 953            let (connection_id, handle_io, mut incoming_rx) = this
 954                .peer
 955                .add_connection(connection, {
 956                    let executor = executor.clone();
 957                    move |duration| executor.sleep(duration)
 958                });
 959            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 960
 961            tracing::info!("connection opened");
 962
 963            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 964            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 965                Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
 966                Err(error) => {
 967                    tracing::error!(?error, "failed to create HTTP client");
 968                    return;
 969                }
 970            };
 971
 972            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 973                    supermaven_admin_api_key.to_string(),
 974                    http_client.clone(),
 975                )));
 976
 977            let session = Session {
 978                principal: principal.clone(),
 979                connection_id,
 980                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 981                peer: this.peer.clone(),
 982                connection_pool: this.connection_pool.clone(),
 983                app_state: this.app_state.clone(),
 984                http_client,
 985                geoip_country_code,
 986                _executor: executor.clone(),
 987                supermaven_client,
 988            };
 989
 990            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 991                tracing::error!(?error, "failed to send initial client update");
 992                return;
 993            }
 994
 995            let handle_io = handle_io.fuse();
 996            futures::pin_mut!(handle_io);
 997
 998            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 999            // This prevents deadlocks when e.g., client A performs a request to client B and
1000            // client B performs a request to client A. If both clients stop processing further
1001            // messages until their respective request completes, they won't have a chance to
1002            // respond to the other client's request and cause a deadlock.
1003            //
1004            // This arrangement ensures we will attempt to process earlier messages first, but fall
1005            // back to processing messages arrived later in the spirit of making progress.
1006            let mut foreground_message_handlers = FuturesUnordered::new();
1007            let concurrent_handlers = Arc::new(Semaphore::new(256));
1008            loop {
1009                let next_message = async {
1010                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1011                    let message = incoming_rx.next().await;
1012                    (permit, message)
1013                }.fuse();
1014                futures::pin_mut!(next_message);
1015                futures::select_biased! {
1016                    _ = teardown.changed().fuse() => return,
1017                    result = handle_io => {
1018                        if let Err(error) = result {
1019                            tracing::error!(?error, "error handling I/O");
1020                        }
1021                        break;
1022                    }
1023                    _ = foreground_message_handlers.next() => {}
1024                    next_message = next_message => {
1025                        let (permit, message) = next_message;
1026                        if let Some(message) = message {
1027                            let type_name = message.payload_type_name();
1028                            // note: we copy all the fields from the parent span so we can query them in the logs.
1029                            // (https://github.com/tokio-rs/tracing/issues/2670).
1030                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1031                                user_id=field::Empty,
1032                                login=field::Empty,
1033                                impersonator=field::Empty,
1034                                dev_server_id=field::Empty
1035                            );
1036                            principal.update_span(&span);
1037                            let span_enter = span.enter();
1038                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1039                                let is_background = message.is_background();
1040                                let handle_message = (handler)(message, session.clone());
1041                                drop(span_enter);
1042
1043                                let handle_message = async move {
1044                                    handle_message.await;
1045                                    drop(permit);
1046                                }.instrument(span);
1047                                if is_background {
1048                                    executor.spawn_detached(handle_message);
1049                                } else {
1050                                    foreground_message_handlers.push(handle_message);
1051                                }
1052                            } else {
1053                                tracing::error!("no message handler");
1054                            }
1055                        } else {
1056                            tracing::info!("connection closed");
1057                            break;
1058                        }
1059                    }
1060                }
1061            }
1062
1063            drop(foreground_message_handlers);
1064            tracing::info!("signing out");
1065            if let Err(error) = connection_lost(session, teardown, executor).await {
1066                tracing::error!(?error, "error signing out");
1067            }
1068
1069        }.instrument(span)
1070    }
1071
1072    async fn send_initial_client_update(
1073        &self,
1074        connection_id: ConnectionId,
1075        principal: &Principal,
1076        zed_version: ZedVersion,
1077        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1078        session: &Session,
1079    ) -> Result<()> {
1080        self.peer.send(
1081            connection_id,
1082            proto::Hello {
1083                peer_id: Some(connection_id.into()),
1084            },
1085        )?;
1086        tracing::info!("sent hello message");
1087        if let Some(send_connection_id) = send_connection_id.take() {
1088            let _ = send_connection_id.send(connection_id);
1089        }
1090
1091        match principal {
1092            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1093                if !user.connected_once {
1094                    self.peer.send(connection_id, proto::ShowContacts {})?;
1095                    self.app_state
1096                        .db
1097                        .set_user_connected_once(user.id, true)
1098                        .await?;
1099                }
1100
1101                update_user_plan(user.id, session).await?;
1102
1103                let (contacts, dev_server_projects) = future::try_join(
1104                    self.app_state.db.get_contacts(user.id),
1105                    self.app_state.db.dev_server_projects_update(user.id),
1106                )
1107                .await?;
1108
1109                {
1110                    let mut pool = self.connection_pool.lock();
1111                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1112                    self.peer.send(
1113                        connection_id,
1114                        build_initial_contacts_update(contacts, &pool),
1115                    )?;
1116                }
1117
1118                if should_auto_subscribe_to_channels(zed_version) {
1119                    subscribe_user_to_channels(user.id, session).await?;
1120                }
1121
1122                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1123
1124                if let Some(incoming_call) =
1125                    self.app_state.db.incoming_call_for_user(user.id).await?
1126                {
1127                    self.peer.send(connection_id, incoming_call)?;
1128                }
1129
1130                update_user_contacts(user.id, session).await?;
1131            }
1132            Principal::DevServer(dev_server) => {
1133                {
1134                    let mut pool = self.connection_pool.lock();
1135                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1136                    {
1137                        self.peer.send(
1138                            stale_connection_id,
1139                            proto::ShutdownDevServer {
1140                                reason: Some(
1141                                    "another dev server connected with the same token".to_string(),
1142                                ),
1143                            },
1144                        )?;
1145                        pool.remove_connection(stale_connection_id)?;
1146                    };
1147                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1148                }
1149
1150                let projects = self
1151                    .app_state
1152                    .db
1153                    .get_projects_for_dev_server(dev_server.id)
1154                    .await?;
1155                self.peer
1156                    .send(connection_id, proto::DevServerInstructions { projects })?;
1157
1158                let status = self
1159                    .app_state
1160                    .db
1161                    .dev_server_projects_update(dev_server.user_id)
1162                    .await?;
1163                send_dev_server_projects_update(dev_server.user_id, status, session).await;
1164            }
1165        }
1166
1167        Ok(())
1168    }
1169
1170    pub async fn invite_code_redeemed(
1171        self: &Arc<Self>,
1172        inviter_id: UserId,
1173        invitee_id: UserId,
1174    ) -> Result<()> {
1175        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1176            if let Some(code) = &user.invite_code {
1177                let pool = self.connection_pool.lock();
1178                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1179                for connection_id in pool.user_connection_ids(inviter_id) {
1180                    self.peer.send(
1181                        connection_id,
1182                        proto::UpdateContacts {
1183                            contacts: vec![invitee_contact.clone()],
1184                            ..Default::default()
1185                        },
1186                    )?;
1187                    self.peer.send(
1188                        connection_id,
1189                        proto::UpdateInviteInfo {
1190                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1191                            count: user.invite_count as u32,
1192                        },
1193                    )?;
1194                }
1195            }
1196        }
1197        Ok(())
1198    }
1199
1200    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1201        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1202            if let Some(invite_code) = &user.invite_code {
1203                let pool = self.connection_pool.lock();
1204                for connection_id in pool.user_connection_ids(user_id) {
1205                    self.peer.send(
1206                        connection_id,
1207                        proto::UpdateInviteInfo {
1208                            url: format!(
1209                                "{}{}",
1210                                self.app_state.config.invite_link_prefix, invite_code
1211                            ),
1212                            count: user.invite_count as u32,
1213                        },
1214                    )?;
1215                }
1216            }
1217        }
1218        Ok(())
1219    }
1220
1221    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1222        ServerSnapshot {
1223            connection_pool: ConnectionPoolGuard {
1224                guard: self.connection_pool.lock(),
1225                _not_send: PhantomData,
1226            },
1227            peer: &self.peer,
1228        }
1229    }
1230}
1231
1232impl<'a> Deref for ConnectionPoolGuard<'a> {
1233    type Target = ConnectionPool;
1234
1235    fn deref(&self) -> &Self::Target {
1236        &self.guard
1237    }
1238}
1239
1240impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1241    fn deref_mut(&mut self) -> &mut Self::Target {
1242        &mut self.guard
1243    }
1244}
1245
1246impl<'a> Drop for ConnectionPoolGuard<'a> {
1247    fn drop(&mut self) {
1248        #[cfg(test)]
1249        self.check_invariants();
1250    }
1251}
1252
1253fn broadcast<F>(
1254    sender_id: Option<ConnectionId>,
1255    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1256    mut f: F,
1257) where
1258    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1259{
1260    for receiver_id in receiver_ids {
1261        if Some(receiver_id) != sender_id {
1262            if let Err(error) = f(receiver_id) {
1263                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1264            }
1265        }
1266    }
1267}
1268
1269pub struct ProtocolVersion(u32);
1270
1271impl Header for ProtocolVersion {
1272    fn name() -> &'static HeaderName {
1273        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1274        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1275    }
1276
1277    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1278    where
1279        Self: Sized,
1280        I: Iterator<Item = &'i axum::http::HeaderValue>,
1281    {
1282        let version = values
1283            .next()
1284            .ok_or_else(axum::headers::Error::invalid)?
1285            .to_str()
1286            .map_err(|_| axum::headers::Error::invalid())?
1287            .parse()
1288            .map_err(|_| axum::headers::Error::invalid())?;
1289        Ok(Self(version))
1290    }
1291
1292    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1293        values.extend([self.0.to_string().parse().unwrap()]);
1294    }
1295}
1296
1297pub struct AppVersionHeader(SemanticVersion);
1298impl Header for AppVersionHeader {
1299    fn name() -> &'static HeaderName {
1300        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1301        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1302    }
1303
1304    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1305    where
1306        Self: Sized,
1307        I: Iterator<Item = &'i axum::http::HeaderValue>,
1308    {
1309        let version = values
1310            .next()
1311            .ok_or_else(axum::headers::Error::invalid)?
1312            .to_str()
1313            .map_err(|_| axum::headers::Error::invalid())?
1314            .parse()
1315            .map_err(|_| axum::headers::Error::invalid())?;
1316        Ok(Self(version))
1317    }
1318
1319    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1320        values.extend([self.0.to_string().parse().unwrap()]);
1321    }
1322}
1323
1324pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1325    Router::new()
1326        .route("/rpc", get(handle_websocket_request))
1327        .layer(
1328            ServiceBuilder::new()
1329                .layer(Extension(server.app_state.clone()))
1330                .layer(middleware::from_fn(auth::validate_header)),
1331        )
1332        .route("/metrics", get(handle_metrics))
1333        .layer(Extension(server))
1334}
1335
1336pub async fn handle_websocket_request(
1337    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1338    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1339    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1340    Extension(server): Extension<Arc<Server>>,
1341    Extension(principal): Extension<Principal>,
1342    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1343    ws: WebSocketUpgrade,
1344) -> axum::response::Response {
1345    if protocol_version != rpc::PROTOCOL_VERSION {
1346        return (
1347            StatusCode::UPGRADE_REQUIRED,
1348            "client must be upgraded".to_string(),
1349        )
1350            .into_response();
1351    }
1352
1353    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1354        return (
1355            StatusCode::UPGRADE_REQUIRED,
1356            "no version header found".to_string(),
1357        )
1358            .into_response();
1359    };
1360
1361    if !version.can_collaborate() {
1362        return (
1363            StatusCode::UPGRADE_REQUIRED,
1364            "client must be upgraded".to_string(),
1365        )
1366            .into_response();
1367    }
1368
1369    let socket_address = socket_address.to_string();
1370    ws.on_upgrade(move |socket| {
1371        let socket = socket
1372            .map_ok(to_tungstenite_message)
1373            .err_into()
1374            .with(|message| async move { to_axum_message(message) });
1375        let connection = Connection::new(Box::pin(socket));
1376        async move {
1377            server
1378                .handle_connection(
1379                    connection,
1380                    socket_address,
1381                    principal,
1382                    version,
1383                    country_code_header.map(|header| header.to_string()),
1384                    None,
1385                    Executor::Production,
1386                )
1387                .await;
1388        }
1389    })
1390}
1391
1392pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1393    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1394    let connections_metric = CONNECTIONS_METRIC
1395        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1396
1397    let connections = server
1398        .connection_pool
1399        .lock()
1400        .connections()
1401        .filter(|connection| !connection.admin)
1402        .count();
1403    connections_metric.set(connections as _);
1404
1405    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1406    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1407        register_int_gauge!(
1408            "shared_projects",
1409            "number of open projects with one or more guests"
1410        )
1411        .unwrap()
1412    });
1413
1414    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1415    shared_projects_metric.set(shared_projects as _);
1416
1417    let encoder = prometheus::TextEncoder::new();
1418    let metric_families = prometheus::gather();
1419    let encoded_metrics = encoder
1420        .encode_to_string(&metric_families)
1421        .map_err(|err| anyhow!("{}", err))?;
1422    Ok(encoded_metrics)
1423}
1424
1425#[instrument(err, skip(executor))]
1426async fn connection_lost(
1427    session: Session,
1428    mut teardown: watch::Receiver<bool>,
1429    executor: Executor,
1430) -> Result<()> {
1431    session.peer.disconnect(session.connection_id);
1432    session
1433        .connection_pool()
1434        .await
1435        .remove_connection(session.connection_id)?;
1436
1437    session
1438        .db()
1439        .await
1440        .connection_lost(session.connection_id)
1441        .await
1442        .trace_err();
1443
1444    futures::select_biased! {
1445        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1446            match &session.principal {
1447                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1448                    let session = session.for_user().unwrap();
1449
1450                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1451                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1452                    leave_channel_buffers_for_session(&session)
1453                        .await
1454                        .trace_err();
1455
1456                    if !session
1457                        .connection_pool()
1458                        .await
1459                        .is_user_online(session.user_id())
1460                    {
1461                        let db = session.db().await;
1462                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1463                            room_updated(&room, &session.peer);
1464                        }
1465                    }
1466
1467                    update_user_contacts(session.user_id(), &session).await?;
1468                },
1469            Principal::DevServer(_) => {
1470                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1471            },
1472        }
1473        },
1474        _ = teardown.changed().fuse() => {}
1475    }
1476
1477    Ok(())
1478}
1479
1480/// Acknowledges a ping from a client, used to keep the connection alive.
1481async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1482    response.send(proto::Ack {})?;
1483    Ok(())
1484}
1485
1486/// Creates a new room for calling (outside of channels)
1487async fn create_room(
1488    _request: proto::CreateRoom,
1489    response: Response<proto::CreateRoom>,
1490    session: UserSession,
1491) -> Result<()> {
1492    let live_kit_room = nanoid::nanoid!(30);
1493
1494    let live_kit_connection_info = util::maybe!(async {
1495        let live_kit = session.app_state.live_kit_client.as_ref();
1496        let live_kit = live_kit?;
1497        let user_id = session.user_id().to_string();
1498
1499        let token = live_kit
1500            .room_token(&live_kit_room, &user_id.to_string())
1501            .trace_err()?;
1502
1503        Some(proto::LiveKitConnectionInfo {
1504            server_url: live_kit.url().into(),
1505            token,
1506            can_publish: true,
1507        })
1508    })
1509    .await;
1510
1511    let room = session
1512        .db()
1513        .await
1514        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1515        .await?;
1516
1517    response.send(proto::CreateRoomResponse {
1518        room: Some(room.clone()),
1519        live_kit_connection_info,
1520    })?;
1521
1522    update_user_contacts(session.user_id(), &session).await?;
1523    Ok(())
1524}
1525
1526/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1527async fn join_room(
1528    request: proto::JoinRoom,
1529    response: Response<proto::JoinRoom>,
1530    session: UserSession,
1531) -> Result<()> {
1532    let room_id = RoomId::from_proto(request.id);
1533
1534    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1535
1536    if let Some(channel_id) = channel_id {
1537        return join_channel_internal(channel_id, Box::new(response), session).await;
1538    }
1539
1540    let joined_room = {
1541        let room = session
1542            .db()
1543            .await
1544            .join_room(room_id, session.user_id(), session.connection_id)
1545            .await?;
1546        room_updated(&room.room, &session.peer);
1547        room.into_inner()
1548    };
1549
1550    for connection_id in session
1551        .connection_pool()
1552        .await
1553        .user_connection_ids(session.user_id())
1554    {
1555        session
1556            .peer
1557            .send(
1558                connection_id,
1559                proto::CallCanceled {
1560                    room_id: room_id.to_proto(),
1561                },
1562            )
1563            .trace_err();
1564    }
1565
1566    let live_kit_connection_info =
1567        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1568            live_kit
1569                .room_token(
1570                    &joined_room.room.live_kit_room,
1571                    &session.user_id().to_string(),
1572                )
1573                .trace_err()
1574                .map(|token| proto::LiveKitConnectionInfo {
1575                    server_url: live_kit.url().into(),
1576                    token,
1577                    can_publish: true,
1578                })
1579        } else {
1580            None
1581        };
1582
1583    response.send(proto::JoinRoomResponse {
1584        room: Some(joined_room.room),
1585        channel_id: None,
1586        live_kit_connection_info,
1587    })?;
1588
1589    update_user_contacts(session.user_id(), &session).await?;
1590    Ok(())
1591}
1592
1593/// Rejoin room is used to reconnect to a room after connection errors.
1594async fn rejoin_room(
1595    request: proto::RejoinRoom,
1596    response: Response<proto::RejoinRoom>,
1597    session: UserSession,
1598) -> Result<()> {
1599    let room;
1600    let channel;
1601    {
1602        let mut rejoined_room = session
1603            .db()
1604            .await
1605            .rejoin_room(request, session.user_id(), session.connection_id)
1606            .await?;
1607
1608        response.send(proto::RejoinRoomResponse {
1609            room: Some(rejoined_room.room.clone()),
1610            reshared_projects: rejoined_room
1611                .reshared_projects
1612                .iter()
1613                .map(|project| proto::ResharedProject {
1614                    id: project.id.to_proto(),
1615                    collaborators: project
1616                        .collaborators
1617                        .iter()
1618                        .map(|collaborator| collaborator.to_proto())
1619                        .collect(),
1620                })
1621                .collect(),
1622            rejoined_projects: rejoined_room
1623                .rejoined_projects
1624                .iter()
1625                .map(|rejoined_project| rejoined_project.to_proto())
1626                .collect(),
1627        })?;
1628        room_updated(&rejoined_room.room, &session.peer);
1629
1630        for project in &rejoined_room.reshared_projects {
1631            for collaborator in &project.collaborators {
1632                session
1633                    .peer
1634                    .send(
1635                        collaborator.connection_id,
1636                        proto::UpdateProjectCollaborator {
1637                            project_id: project.id.to_proto(),
1638                            old_peer_id: Some(project.old_connection_id.into()),
1639                            new_peer_id: Some(session.connection_id.into()),
1640                        },
1641                    )
1642                    .trace_err();
1643            }
1644
1645            broadcast(
1646                Some(session.connection_id),
1647                project
1648                    .collaborators
1649                    .iter()
1650                    .map(|collaborator| collaborator.connection_id),
1651                |connection_id| {
1652                    session.peer.forward_send(
1653                        session.connection_id,
1654                        connection_id,
1655                        proto::UpdateProject {
1656                            project_id: project.id.to_proto(),
1657                            worktrees: project.worktrees.clone(),
1658                        },
1659                    )
1660                },
1661            );
1662        }
1663
1664        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1665
1666        let rejoined_room = rejoined_room.into_inner();
1667
1668        room = rejoined_room.room;
1669        channel = rejoined_room.channel;
1670    }
1671
1672    if let Some(channel) = channel {
1673        channel_updated(
1674            &channel,
1675            &room,
1676            &session.peer,
1677            &*session.connection_pool().await,
1678        );
1679    }
1680
1681    update_user_contacts(session.user_id(), &session).await?;
1682    Ok(())
1683}
1684
1685fn notify_rejoined_projects(
1686    rejoined_projects: &mut Vec<RejoinedProject>,
1687    session: &UserSession,
1688) -> Result<()> {
1689    for project in rejoined_projects.iter() {
1690        for collaborator in &project.collaborators {
1691            session
1692                .peer
1693                .send(
1694                    collaborator.connection_id,
1695                    proto::UpdateProjectCollaborator {
1696                        project_id: project.id.to_proto(),
1697                        old_peer_id: Some(project.old_connection_id.into()),
1698                        new_peer_id: Some(session.connection_id.into()),
1699                    },
1700                )
1701                .trace_err();
1702        }
1703    }
1704
1705    for project in rejoined_projects {
1706        for worktree in mem::take(&mut project.worktrees) {
1707            #[cfg(any(test, feature = "test-support"))]
1708            const MAX_CHUNK_SIZE: usize = 2;
1709            #[cfg(not(any(test, feature = "test-support")))]
1710            const MAX_CHUNK_SIZE: usize = 256;
1711
1712            // Stream this worktree's entries.
1713            let message = proto::UpdateWorktree {
1714                project_id: project.id.to_proto(),
1715                worktree_id: worktree.id,
1716                abs_path: worktree.abs_path.clone(),
1717                root_name: worktree.root_name,
1718                updated_entries: worktree.updated_entries,
1719                removed_entries: worktree.removed_entries,
1720                scan_id: worktree.scan_id,
1721                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1722                updated_repositories: worktree.updated_repositories,
1723                removed_repositories: worktree.removed_repositories,
1724            };
1725            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1726                session.peer.send(session.connection_id, update.clone())?;
1727            }
1728
1729            // Stream this worktree's diagnostics.
1730            for summary in worktree.diagnostic_summaries {
1731                session.peer.send(
1732                    session.connection_id,
1733                    proto::UpdateDiagnosticSummary {
1734                        project_id: project.id.to_proto(),
1735                        worktree_id: worktree.id,
1736                        summary: Some(summary),
1737                    },
1738                )?;
1739            }
1740
1741            for settings_file in worktree.settings_files {
1742                session.peer.send(
1743                    session.connection_id,
1744                    proto::UpdateWorktreeSettings {
1745                        project_id: project.id.to_proto(),
1746                        worktree_id: worktree.id,
1747                        path: settings_file.path,
1748                        content: Some(settings_file.content),
1749                        kind: Some(settings_file.kind.to_proto().into()),
1750                    },
1751                )?;
1752            }
1753        }
1754
1755        for language_server in &project.language_servers {
1756            session.peer.send(
1757                session.connection_id,
1758                proto::UpdateLanguageServer {
1759                    project_id: project.id.to_proto(),
1760                    language_server_id: language_server.id,
1761                    variant: Some(
1762                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1763                            proto::LspDiskBasedDiagnosticsUpdated {},
1764                        ),
1765                    ),
1766                },
1767            )?;
1768        }
1769    }
1770    Ok(())
1771}
1772
1773/// leave room disconnects from the room.
1774async fn leave_room(
1775    _: proto::LeaveRoom,
1776    response: Response<proto::LeaveRoom>,
1777    session: UserSession,
1778) -> Result<()> {
1779    leave_room_for_session(&session, session.connection_id).await?;
1780    response.send(proto::Ack {})?;
1781    Ok(())
1782}
1783
1784/// Updates the permissions of someone else in the room.
1785async fn set_room_participant_role(
1786    request: proto::SetRoomParticipantRole,
1787    response: Response<proto::SetRoomParticipantRole>,
1788    session: UserSession,
1789) -> Result<()> {
1790    let user_id = UserId::from_proto(request.user_id);
1791    let role = ChannelRole::from(request.role());
1792
1793    let (live_kit_room, can_publish) = {
1794        let room = session
1795            .db()
1796            .await
1797            .set_room_participant_role(
1798                session.user_id(),
1799                RoomId::from_proto(request.room_id),
1800                user_id,
1801                role,
1802            )
1803            .await?;
1804
1805        let live_kit_room = room.live_kit_room.clone();
1806        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1807        room_updated(&room, &session.peer);
1808        (live_kit_room, can_publish)
1809    };
1810
1811    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1812        live_kit
1813            .update_participant(
1814                live_kit_room.clone(),
1815                request.user_id.to_string(),
1816                live_kit_server::proto::ParticipantPermission {
1817                    can_subscribe: true,
1818                    can_publish,
1819                    can_publish_data: can_publish,
1820                    hidden: false,
1821                    recorder: false,
1822                },
1823            )
1824            .await
1825            .trace_err();
1826    }
1827
1828    response.send(proto::Ack {})?;
1829    Ok(())
1830}
1831
1832/// Call someone else into the current room
1833async fn call(
1834    request: proto::Call,
1835    response: Response<proto::Call>,
1836    session: UserSession,
1837) -> Result<()> {
1838    let room_id = RoomId::from_proto(request.room_id);
1839    let calling_user_id = session.user_id();
1840    let calling_connection_id = session.connection_id;
1841    let called_user_id = UserId::from_proto(request.called_user_id);
1842    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1843    if !session
1844        .db()
1845        .await
1846        .has_contact(calling_user_id, called_user_id)
1847        .await?
1848    {
1849        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1850    }
1851
1852    let incoming_call = {
1853        let (room, incoming_call) = &mut *session
1854            .db()
1855            .await
1856            .call(
1857                room_id,
1858                calling_user_id,
1859                calling_connection_id,
1860                called_user_id,
1861                initial_project_id,
1862            )
1863            .await?;
1864        room_updated(room, &session.peer);
1865        mem::take(incoming_call)
1866    };
1867    update_user_contacts(called_user_id, &session).await?;
1868
1869    let mut calls = session
1870        .connection_pool()
1871        .await
1872        .user_connection_ids(called_user_id)
1873        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1874        .collect::<FuturesUnordered<_>>();
1875
1876    while let Some(call_response) = calls.next().await {
1877        match call_response.as_ref() {
1878            Ok(_) => {
1879                response.send(proto::Ack {})?;
1880                return Ok(());
1881            }
1882            Err(_) => {
1883                call_response.trace_err();
1884            }
1885        }
1886    }
1887
1888    {
1889        let room = session
1890            .db()
1891            .await
1892            .call_failed(room_id, called_user_id)
1893            .await?;
1894        room_updated(&room, &session.peer);
1895    }
1896    update_user_contacts(called_user_id, &session).await?;
1897
1898    Err(anyhow!("failed to ring user"))?
1899}
1900
1901/// Cancel an outgoing call.
1902async fn cancel_call(
1903    request: proto::CancelCall,
1904    response: Response<proto::CancelCall>,
1905    session: UserSession,
1906) -> Result<()> {
1907    let called_user_id = UserId::from_proto(request.called_user_id);
1908    let room_id = RoomId::from_proto(request.room_id);
1909    {
1910        let room = session
1911            .db()
1912            .await
1913            .cancel_call(room_id, session.connection_id, called_user_id)
1914            .await?;
1915        room_updated(&room, &session.peer);
1916    }
1917
1918    for connection_id in session
1919        .connection_pool()
1920        .await
1921        .user_connection_ids(called_user_id)
1922    {
1923        session
1924            .peer
1925            .send(
1926                connection_id,
1927                proto::CallCanceled {
1928                    room_id: room_id.to_proto(),
1929                },
1930            )
1931            .trace_err();
1932    }
1933    response.send(proto::Ack {})?;
1934
1935    update_user_contacts(called_user_id, &session).await?;
1936    Ok(())
1937}
1938
1939/// Decline an incoming call.
1940async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1941    let room_id = RoomId::from_proto(message.room_id);
1942    {
1943        let room = session
1944            .db()
1945            .await
1946            .decline_call(Some(room_id), session.user_id())
1947            .await?
1948            .ok_or_else(|| anyhow!("failed to decline call"))?;
1949        room_updated(&room, &session.peer);
1950    }
1951
1952    for connection_id in session
1953        .connection_pool()
1954        .await
1955        .user_connection_ids(session.user_id())
1956    {
1957        session
1958            .peer
1959            .send(
1960                connection_id,
1961                proto::CallCanceled {
1962                    room_id: room_id.to_proto(),
1963                },
1964            )
1965            .trace_err();
1966    }
1967    update_user_contacts(session.user_id(), &session).await?;
1968    Ok(())
1969}
1970
1971/// Updates other participants in the room with your current location.
1972async fn update_participant_location(
1973    request: proto::UpdateParticipantLocation,
1974    response: Response<proto::UpdateParticipantLocation>,
1975    session: UserSession,
1976) -> Result<()> {
1977    let room_id = RoomId::from_proto(request.room_id);
1978    let location = request
1979        .location
1980        .ok_or_else(|| anyhow!("invalid location"))?;
1981
1982    let db = session.db().await;
1983    let room = db
1984        .update_room_participant_location(room_id, session.connection_id, location)
1985        .await?;
1986
1987    room_updated(&room, &session.peer);
1988    response.send(proto::Ack {})?;
1989    Ok(())
1990}
1991
1992/// Share a project into the room.
1993async fn share_project(
1994    request: proto::ShareProject,
1995    response: Response<proto::ShareProject>,
1996    session: UserSession,
1997) -> Result<()> {
1998    let (project_id, room) = &*session
1999        .db()
2000        .await
2001        .share_project(
2002            RoomId::from_proto(request.room_id),
2003            session.connection_id,
2004            &request.worktrees,
2005            request.is_ssh_project,
2006            request
2007                .dev_server_project_id
2008                .map(DevServerProjectId::from_proto),
2009        )
2010        .await?;
2011    response.send(proto::ShareProjectResponse {
2012        project_id: project_id.to_proto(),
2013    })?;
2014    room_updated(room, &session.peer);
2015
2016    Ok(())
2017}
2018
2019/// Unshare a project from the room.
2020async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2021    let project_id = ProjectId::from_proto(message.project_id);
2022    unshare_project_internal(
2023        project_id,
2024        session.connection_id,
2025        session.user_id(),
2026        &session,
2027    )
2028    .await
2029}
2030
2031async fn unshare_project_internal(
2032    project_id: ProjectId,
2033    connection_id: ConnectionId,
2034    user_id: Option<UserId>,
2035    session: &Session,
2036) -> Result<()> {
2037    let delete = {
2038        let room_guard = session
2039            .db()
2040            .await
2041            .unshare_project(project_id, connection_id, user_id)
2042            .await?;
2043
2044        let (delete, room, guest_connection_ids) = &*room_guard;
2045
2046        let message = proto::UnshareProject {
2047            project_id: project_id.to_proto(),
2048        };
2049
2050        broadcast(
2051            Some(connection_id),
2052            guest_connection_ids.iter().copied(),
2053            |conn_id| session.peer.send(conn_id, message.clone()),
2054        );
2055        if let Some(room) = room {
2056            room_updated(room, &session.peer);
2057        }
2058
2059        *delete
2060    };
2061
2062    if delete {
2063        let db = session.db().await;
2064        db.delete_project(project_id).await?;
2065    }
2066
2067    Ok(())
2068}
2069
2070/// DevServer makes a project available online
2071async fn share_dev_server_project(
2072    request: proto::ShareDevServerProject,
2073    response: Response<proto::ShareDevServerProject>,
2074    session: DevServerSession,
2075) -> Result<()> {
2076    let (dev_server_project, user_id, status) = session
2077        .db()
2078        .await
2079        .share_dev_server_project(
2080            DevServerProjectId::from_proto(request.dev_server_project_id),
2081            session.dev_server_id(),
2082            session.connection_id,
2083            &request.worktrees,
2084        )
2085        .await?;
2086    let Some(project_id) = dev_server_project.project_id else {
2087        return Err(anyhow!("failed to share remote project"))?;
2088    };
2089
2090    send_dev_server_projects_update(user_id, status, &session).await;
2091
2092    response.send(proto::ShareProjectResponse { project_id })?;
2093
2094    Ok(())
2095}
2096
2097/// Join someone elses shared project.
2098async fn join_project(
2099    request: proto::JoinProject,
2100    response: Response<proto::JoinProject>,
2101    session: UserSession,
2102) -> Result<()> {
2103    let project_id = ProjectId::from_proto(request.project_id);
2104
2105    tracing::info!(%project_id, "join project");
2106
2107    let db = session.db().await;
2108    let (project, replica_id) = &mut *db
2109        .join_project(project_id, session.connection_id, session.user_id())
2110        .await?;
2111    drop(db);
2112    tracing::info!(%project_id, "join remote project");
2113    join_project_internal(response, session, project, replica_id)
2114}
2115
2116trait JoinProjectInternalResponse {
2117    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2118}
2119impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2120    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2121        Response::<proto::JoinProject>::send(self, result)
2122    }
2123}
2124impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2125    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2126        Response::<proto::JoinHostedProject>::send(self, result)
2127    }
2128}
2129
2130fn join_project_internal(
2131    response: impl JoinProjectInternalResponse,
2132    session: UserSession,
2133    project: &mut Project,
2134    replica_id: &ReplicaId,
2135) -> Result<()> {
2136    let collaborators = project
2137        .collaborators
2138        .iter()
2139        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2140        .map(|collaborator| collaborator.to_proto())
2141        .collect::<Vec<_>>();
2142    let project_id = project.id;
2143    let guest_user_id = session.user_id();
2144
2145    let worktrees = project
2146        .worktrees
2147        .iter()
2148        .map(|(id, worktree)| proto::WorktreeMetadata {
2149            id: *id,
2150            root_name: worktree.root_name.clone(),
2151            visible: worktree.visible,
2152            abs_path: worktree.abs_path.clone(),
2153        })
2154        .collect::<Vec<_>>();
2155
2156    let add_project_collaborator = proto::AddProjectCollaborator {
2157        project_id: project_id.to_proto(),
2158        collaborator: Some(proto::Collaborator {
2159            peer_id: Some(session.connection_id.into()),
2160            replica_id: replica_id.0 as u32,
2161            user_id: guest_user_id.to_proto(),
2162        }),
2163    };
2164
2165    for collaborator in &collaborators {
2166        session
2167            .peer
2168            .send(
2169                collaborator.peer_id.unwrap().into(),
2170                add_project_collaborator.clone(),
2171            )
2172            .trace_err();
2173    }
2174
2175    // First, we send the metadata associated with each worktree.
2176    response.send(proto::JoinProjectResponse {
2177        project_id: project.id.0 as u64,
2178        worktrees: worktrees.clone(),
2179        replica_id: replica_id.0 as u32,
2180        collaborators: collaborators.clone(),
2181        language_servers: project.language_servers.clone(),
2182        role: project.role.into(),
2183        dev_server_project_id: project
2184            .dev_server_project_id
2185            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2186    })?;
2187
2188    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2189        #[cfg(any(test, feature = "test-support"))]
2190        const MAX_CHUNK_SIZE: usize = 2;
2191        #[cfg(not(any(test, feature = "test-support")))]
2192        const MAX_CHUNK_SIZE: usize = 256;
2193
2194        // Stream this worktree's entries.
2195        let message = proto::UpdateWorktree {
2196            project_id: project_id.to_proto(),
2197            worktree_id,
2198            abs_path: worktree.abs_path.clone(),
2199            root_name: worktree.root_name,
2200            updated_entries: worktree.entries,
2201            removed_entries: Default::default(),
2202            scan_id: worktree.scan_id,
2203            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2204            updated_repositories: worktree.repository_entries.into_values().collect(),
2205            removed_repositories: Default::default(),
2206        };
2207        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2208            session.peer.send(session.connection_id, update.clone())?;
2209        }
2210
2211        // Stream this worktree's diagnostics.
2212        for summary in worktree.diagnostic_summaries {
2213            session.peer.send(
2214                session.connection_id,
2215                proto::UpdateDiagnosticSummary {
2216                    project_id: project_id.to_proto(),
2217                    worktree_id: worktree.id,
2218                    summary: Some(summary),
2219                },
2220            )?;
2221        }
2222
2223        for settings_file in worktree.settings_files {
2224            session.peer.send(
2225                session.connection_id,
2226                proto::UpdateWorktreeSettings {
2227                    project_id: project_id.to_proto(),
2228                    worktree_id: worktree.id,
2229                    path: settings_file.path,
2230                    content: Some(settings_file.content),
2231                    kind: Some(proto::update_user_settings::Kind::Settings.into()),
2232                },
2233            )?;
2234        }
2235    }
2236
2237    for language_server in &project.language_servers {
2238        session.peer.send(
2239            session.connection_id,
2240            proto::UpdateLanguageServer {
2241                project_id: project_id.to_proto(),
2242                language_server_id: language_server.id,
2243                variant: Some(
2244                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2245                        proto::LspDiskBasedDiagnosticsUpdated {},
2246                    ),
2247                ),
2248            },
2249        )?;
2250    }
2251
2252    Ok(())
2253}
2254
2255/// Leave someone elses shared project.
2256async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2257    let sender_id = session.connection_id;
2258    let project_id = ProjectId::from_proto(request.project_id);
2259    let db = session.db().await;
2260    if db.is_hosted_project(project_id).await? {
2261        let project = db.leave_hosted_project(project_id, sender_id).await?;
2262        project_left(&project, &session);
2263        return Ok(());
2264    }
2265
2266    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2267    tracing::info!(
2268        %project_id,
2269        "leave project"
2270    );
2271
2272    project_left(project, &session);
2273    if let Some(room) = room {
2274        room_updated(room, &session.peer);
2275    }
2276
2277    Ok(())
2278}
2279
2280async fn join_hosted_project(
2281    request: proto::JoinHostedProject,
2282    response: Response<proto::JoinHostedProject>,
2283    session: UserSession,
2284) -> Result<()> {
2285    let (mut project, replica_id) = session
2286        .db()
2287        .await
2288        .join_hosted_project(
2289            ProjectId(request.project_id as i32),
2290            session.user_id(),
2291            session.connection_id,
2292        )
2293        .await?;
2294
2295    join_project_internal(response, session, &mut project, &replica_id)
2296}
2297
2298async fn list_remote_directory(
2299    request: proto::ListRemoteDirectory,
2300    response: Response<proto::ListRemoteDirectory>,
2301    session: UserSession,
2302) -> Result<()> {
2303    let dev_server_id = DevServerId(request.dev_server_id as i32);
2304    let dev_server_connection_id = session
2305        .connection_pool()
2306        .await
2307        .online_dev_server_connection_id(dev_server_id)?;
2308
2309    session
2310        .db()
2311        .await
2312        .get_dev_server_for_user(dev_server_id, session.user_id())
2313        .await?;
2314
2315    response.send(
2316        session
2317            .peer
2318            .forward_request(session.connection_id, dev_server_connection_id, request)
2319            .await?,
2320    )?;
2321    Ok(())
2322}
2323
2324async fn update_dev_server_project(
2325    request: proto::UpdateDevServerProject,
2326    response: Response<proto::UpdateDevServerProject>,
2327    session: UserSession,
2328) -> Result<()> {
2329    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2330
2331    let (dev_server_project, update) = session
2332        .db()
2333        .await
2334        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2335        .await?;
2336
2337    let projects = session
2338        .db()
2339        .await
2340        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2341        .await?;
2342
2343    let dev_server_connection_id = session
2344        .connection_pool()
2345        .await
2346        .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2347
2348    session.peer.send(
2349        dev_server_connection_id,
2350        proto::DevServerInstructions { projects },
2351    )?;
2352
2353    send_dev_server_projects_update(session.user_id(), update, &session).await;
2354
2355    response.send(proto::Ack {})
2356}
2357
2358async fn create_dev_server_project(
2359    request: proto::CreateDevServerProject,
2360    response: Response<proto::CreateDevServerProject>,
2361    session: UserSession,
2362) -> Result<()> {
2363    let dev_server_id = DevServerId(request.dev_server_id as i32);
2364    let dev_server_connection_id = session
2365        .connection_pool()
2366        .await
2367        .dev_server_connection_id(dev_server_id);
2368    let Some(dev_server_connection_id) = dev_server_connection_id else {
2369        Err(ErrorCode::DevServerOffline
2370            .message("Cannot create a remote project when the dev server is offline".to_string())
2371            .anyhow())?
2372    };
2373
2374    let path = request.path.clone();
2375    //Check that the path exists on the dev server
2376    session
2377        .peer
2378        .forward_request(
2379            session.connection_id,
2380            dev_server_connection_id,
2381            proto::ValidateDevServerProjectRequest { path: path.clone() },
2382        )
2383        .await?;
2384
2385    let (dev_server_project, update) = session
2386        .db()
2387        .await
2388        .create_dev_server_project(
2389            DevServerId(request.dev_server_id as i32),
2390            &request.path,
2391            session.user_id(),
2392        )
2393        .await?;
2394
2395    let projects = session
2396        .db()
2397        .await
2398        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2399        .await?;
2400
2401    session.peer.send(
2402        dev_server_connection_id,
2403        proto::DevServerInstructions { projects },
2404    )?;
2405
2406    send_dev_server_projects_update(session.user_id(), update, &session).await;
2407
2408    response.send(proto::CreateDevServerProjectResponse {
2409        dev_server_project: Some(dev_server_project.to_proto(None)),
2410    })?;
2411    Ok(())
2412}
2413
2414async fn create_dev_server(
2415    request: proto::CreateDevServer,
2416    response: Response<proto::CreateDevServer>,
2417    session: UserSession,
2418) -> Result<()> {
2419    let access_token = auth::random_token();
2420    let hashed_access_token = auth::hash_access_token(&access_token);
2421
2422    if request.name.is_empty() {
2423        return Err(proto::ErrorCode::Forbidden
2424            .message("Dev server name cannot be empty".to_string())
2425            .anyhow())?;
2426    }
2427
2428    let (dev_server, status) = session
2429        .db()
2430        .await
2431        .create_dev_server(
2432            &request.name,
2433            request.ssh_connection_string.as_deref(),
2434            &hashed_access_token,
2435            session.user_id(),
2436        )
2437        .await?;
2438
2439    send_dev_server_projects_update(session.user_id(), status, &session).await;
2440
2441    response.send(proto::CreateDevServerResponse {
2442        dev_server_id: dev_server.id.0 as u64,
2443        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2444        name: request.name,
2445    })?;
2446    Ok(())
2447}
2448
2449async fn regenerate_dev_server_token(
2450    request: proto::RegenerateDevServerToken,
2451    response: Response<proto::RegenerateDevServerToken>,
2452    session: UserSession,
2453) -> Result<()> {
2454    let dev_server_id = DevServerId(request.dev_server_id as i32);
2455    let access_token = auth::random_token();
2456    let hashed_access_token = auth::hash_access_token(&access_token);
2457
2458    let connection_id = session
2459        .connection_pool()
2460        .await
2461        .dev_server_connection_id(dev_server_id);
2462    if let Some(connection_id) = connection_id {
2463        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2464        session.peer.send(
2465            connection_id,
2466            proto::ShutdownDevServer {
2467                reason: Some("dev server token was regenerated".to_string()),
2468            },
2469        )?;
2470        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2471    }
2472
2473    let status = session
2474        .db()
2475        .await
2476        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2477        .await?;
2478
2479    send_dev_server_projects_update(session.user_id(), status, &session).await;
2480
2481    response.send(proto::RegenerateDevServerTokenResponse {
2482        dev_server_id: dev_server_id.to_proto(),
2483        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2484    })?;
2485    Ok(())
2486}
2487
2488async fn rename_dev_server(
2489    request: proto::RenameDevServer,
2490    response: Response<proto::RenameDevServer>,
2491    session: UserSession,
2492) -> Result<()> {
2493    if request.name.trim().is_empty() {
2494        return Err(proto::ErrorCode::Forbidden
2495            .message("Dev server name cannot be empty".to_string())
2496            .anyhow())?;
2497    }
2498
2499    let dev_server_id = DevServerId(request.dev_server_id as i32);
2500    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2501    if dev_server.user_id != session.user_id() {
2502        return Err(anyhow!(ErrorCode::Forbidden))?;
2503    }
2504
2505    let status = session
2506        .db()
2507        .await
2508        .rename_dev_server(
2509            dev_server_id,
2510            &request.name,
2511            request.ssh_connection_string.as_deref(),
2512            session.user_id(),
2513        )
2514        .await?;
2515
2516    send_dev_server_projects_update(session.user_id(), status, &session).await;
2517
2518    response.send(proto::Ack {})?;
2519    Ok(())
2520}
2521
2522async fn delete_dev_server(
2523    request: proto::DeleteDevServer,
2524    response: Response<proto::DeleteDevServer>,
2525    session: UserSession,
2526) -> Result<()> {
2527    let dev_server_id = DevServerId(request.dev_server_id as i32);
2528    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2529    if dev_server.user_id != session.user_id() {
2530        return Err(anyhow!(ErrorCode::Forbidden))?;
2531    }
2532
2533    let connection_id = session
2534        .connection_pool()
2535        .await
2536        .dev_server_connection_id(dev_server_id);
2537    if let Some(connection_id) = connection_id {
2538        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2539        session.peer.send(
2540            connection_id,
2541            proto::ShutdownDevServer {
2542                reason: Some("dev server was deleted".to_string()),
2543            },
2544        )?;
2545        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2546    }
2547
2548    let status = session
2549        .db()
2550        .await
2551        .delete_dev_server(dev_server_id, session.user_id())
2552        .await?;
2553
2554    send_dev_server_projects_update(session.user_id(), status, &session).await;
2555
2556    response.send(proto::Ack {})?;
2557    Ok(())
2558}
2559
2560async fn delete_dev_server_project(
2561    request: proto::DeleteDevServerProject,
2562    response: Response<proto::DeleteDevServerProject>,
2563    session: UserSession,
2564) -> Result<()> {
2565    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2566    let dev_server_project = session
2567        .db()
2568        .await
2569        .get_dev_server_project(dev_server_project_id)
2570        .await?;
2571
2572    let dev_server = session
2573        .db()
2574        .await
2575        .get_dev_server(dev_server_project.dev_server_id)
2576        .await?;
2577    if dev_server.user_id != session.user_id() {
2578        return Err(anyhow!(ErrorCode::Forbidden))?;
2579    }
2580
2581    let dev_server_connection_id = session
2582        .connection_pool()
2583        .await
2584        .dev_server_connection_id(dev_server.id);
2585
2586    if let Some(dev_server_connection_id) = dev_server_connection_id {
2587        let project = session
2588            .db()
2589            .await
2590            .find_dev_server_project(dev_server_project_id)
2591            .await;
2592        if let Ok(project) = project {
2593            unshare_project_internal(
2594                project.id,
2595                dev_server_connection_id,
2596                Some(session.user_id()),
2597                &session,
2598            )
2599            .await?;
2600        }
2601    }
2602
2603    let (projects, status) = session
2604        .db()
2605        .await
2606        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2607        .await?;
2608
2609    if let Some(dev_server_connection_id) = dev_server_connection_id {
2610        session.peer.send(
2611            dev_server_connection_id,
2612            proto::DevServerInstructions { projects },
2613        )?;
2614    }
2615
2616    send_dev_server_projects_update(session.user_id(), status, &session).await;
2617
2618    response.send(proto::Ack {})?;
2619    Ok(())
2620}
2621
2622async fn rejoin_dev_server_projects(
2623    request: proto::RejoinRemoteProjects,
2624    response: Response<proto::RejoinRemoteProjects>,
2625    session: UserSession,
2626) -> Result<()> {
2627    let mut rejoined_projects = {
2628        let db = session.db().await;
2629        db.rejoin_dev_server_projects(
2630            &request.rejoined_projects,
2631            session.user_id(),
2632            session.0.connection_id,
2633        )
2634        .await?
2635    };
2636    response.send(proto::RejoinRemoteProjectsResponse {
2637        rejoined_projects: rejoined_projects
2638            .iter()
2639            .map(|project| project.to_proto())
2640            .collect(),
2641    })?;
2642    notify_rejoined_projects(&mut rejoined_projects, &session)
2643}
2644
2645async fn reconnect_dev_server(
2646    request: proto::ReconnectDevServer,
2647    response: Response<proto::ReconnectDevServer>,
2648    session: DevServerSession,
2649) -> Result<()> {
2650    let reshared_projects = {
2651        let db = session.db().await;
2652        db.reshare_dev_server_projects(
2653            &request.reshared_projects,
2654            session.dev_server_id(),
2655            session.0.connection_id,
2656        )
2657        .await?
2658    };
2659
2660    for project in &reshared_projects {
2661        for collaborator in &project.collaborators {
2662            session
2663                .peer
2664                .send(
2665                    collaborator.connection_id,
2666                    proto::UpdateProjectCollaborator {
2667                        project_id: project.id.to_proto(),
2668                        old_peer_id: Some(project.old_connection_id.into()),
2669                        new_peer_id: Some(session.connection_id.into()),
2670                    },
2671                )
2672                .trace_err();
2673        }
2674
2675        broadcast(
2676            Some(session.connection_id),
2677            project
2678                .collaborators
2679                .iter()
2680                .map(|collaborator| collaborator.connection_id),
2681            |connection_id| {
2682                session.peer.forward_send(
2683                    session.connection_id,
2684                    connection_id,
2685                    proto::UpdateProject {
2686                        project_id: project.id.to_proto(),
2687                        worktrees: project.worktrees.clone(),
2688                    },
2689                )
2690            },
2691        );
2692    }
2693
2694    response.send(proto::ReconnectDevServerResponse {
2695        reshared_projects: reshared_projects
2696            .iter()
2697            .map(|project| proto::ResharedProject {
2698                id: project.id.to_proto(),
2699                collaborators: project
2700                    .collaborators
2701                    .iter()
2702                    .map(|collaborator| collaborator.to_proto())
2703                    .collect(),
2704            })
2705            .collect(),
2706    })?;
2707
2708    Ok(())
2709}
2710
2711async fn shutdown_dev_server(
2712    _: proto::ShutdownDevServer,
2713    response: Response<proto::ShutdownDevServer>,
2714    session: DevServerSession,
2715) -> Result<()> {
2716    response.send(proto::Ack {})?;
2717    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2718    remove_dev_server_connection(session.dev_server_id(), &session).await
2719}
2720
2721async fn shutdown_dev_server_internal(
2722    dev_server_id: DevServerId,
2723    connection_id: ConnectionId,
2724    session: &Session,
2725) -> Result<()> {
2726    let (dev_server_projects, dev_server) = {
2727        let db = session.db().await;
2728        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2729        let dev_server = db.get_dev_server(dev_server_id).await?;
2730        (dev_server_projects, dev_server)
2731    };
2732
2733    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2734        unshare_project_internal(
2735            ProjectId::from_proto(project_id),
2736            connection_id,
2737            None,
2738            session,
2739        )
2740        .await?;
2741    }
2742
2743    session
2744        .connection_pool()
2745        .await
2746        .set_dev_server_offline(dev_server_id);
2747
2748    let status = session
2749        .db()
2750        .await
2751        .dev_server_projects_update(dev_server.user_id)
2752        .await?;
2753    send_dev_server_projects_update(dev_server.user_id, status, session).await;
2754
2755    Ok(())
2756}
2757
2758async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2759    let dev_server_connection = session
2760        .connection_pool()
2761        .await
2762        .dev_server_connection_id(dev_server_id);
2763
2764    if let Some(dev_server_connection) = dev_server_connection {
2765        session
2766            .connection_pool()
2767            .await
2768            .remove_connection(dev_server_connection)?;
2769    }
2770    Ok(())
2771}
2772
2773/// Updates other participants with changes to the project
2774async fn update_project(
2775    request: proto::UpdateProject,
2776    response: Response<proto::UpdateProject>,
2777    session: Session,
2778) -> Result<()> {
2779    let project_id = ProjectId::from_proto(request.project_id);
2780    let (room, guest_connection_ids) = &*session
2781        .db()
2782        .await
2783        .update_project(project_id, session.connection_id, &request.worktrees)
2784        .await?;
2785    broadcast(
2786        Some(session.connection_id),
2787        guest_connection_ids.iter().copied(),
2788        |connection_id| {
2789            session
2790                .peer
2791                .forward_send(session.connection_id, connection_id, request.clone())
2792        },
2793    );
2794    if let Some(room) = room {
2795        room_updated(room, &session.peer);
2796    }
2797    response.send(proto::Ack {})?;
2798
2799    Ok(())
2800}
2801
2802/// Updates other participants with changes to the worktree
2803async fn update_worktree(
2804    request: proto::UpdateWorktree,
2805    response: Response<proto::UpdateWorktree>,
2806    session: Session,
2807) -> Result<()> {
2808    let guest_connection_ids = session
2809        .db()
2810        .await
2811        .update_worktree(&request, session.connection_id)
2812        .await?;
2813
2814    broadcast(
2815        Some(session.connection_id),
2816        guest_connection_ids.iter().copied(),
2817        |connection_id| {
2818            session
2819                .peer
2820                .forward_send(session.connection_id, connection_id, request.clone())
2821        },
2822    );
2823    response.send(proto::Ack {})?;
2824    Ok(())
2825}
2826
2827/// Updates other participants with changes to the diagnostics
2828async fn update_diagnostic_summary(
2829    message: proto::UpdateDiagnosticSummary,
2830    session: Session,
2831) -> Result<()> {
2832    let guest_connection_ids = session
2833        .db()
2834        .await
2835        .update_diagnostic_summary(&message, session.connection_id)
2836        .await?;
2837
2838    broadcast(
2839        Some(session.connection_id),
2840        guest_connection_ids.iter().copied(),
2841        |connection_id| {
2842            session
2843                .peer
2844                .forward_send(session.connection_id, connection_id, message.clone())
2845        },
2846    );
2847
2848    Ok(())
2849}
2850
2851/// Updates other participants with changes to the worktree settings
2852async fn update_worktree_settings(
2853    message: proto::UpdateWorktreeSettings,
2854    session: Session,
2855) -> Result<()> {
2856    let guest_connection_ids = session
2857        .db()
2858        .await
2859        .update_worktree_settings(&message, session.connection_id)
2860        .await?;
2861
2862    broadcast(
2863        Some(session.connection_id),
2864        guest_connection_ids.iter().copied(),
2865        |connection_id| {
2866            session
2867                .peer
2868                .forward_send(session.connection_id, connection_id, message.clone())
2869        },
2870    );
2871
2872    Ok(())
2873}
2874
2875/// Notify other participants that a  language server has started.
2876async fn start_language_server(
2877    request: proto::StartLanguageServer,
2878    session: Session,
2879) -> Result<()> {
2880    let guest_connection_ids = session
2881        .db()
2882        .await
2883        .start_language_server(&request, session.connection_id)
2884        .await?;
2885
2886    broadcast(
2887        Some(session.connection_id),
2888        guest_connection_ids.iter().copied(),
2889        |connection_id| {
2890            session
2891                .peer
2892                .forward_send(session.connection_id, connection_id, request.clone())
2893        },
2894    );
2895    Ok(())
2896}
2897
2898/// Notify other participants that a language server has changed.
2899async fn update_language_server(
2900    request: proto::UpdateLanguageServer,
2901    session: Session,
2902) -> Result<()> {
2903    let project_id = ProjectId::from_proto(request.project_id);
2904    let project_connection_ids = session
2905        .db()
2906        .await
2907        .project_connection_ids(project_id, session.connection_id, true)
2908        .await?;
2909    broadcast(
2910        Some(session.connection_id),
2911        project_connection_ids.iter().copied(),
2912        |connection_id| {
2913            session
2914                .peer
2915                .forward_send(session.connection_id, connection_id, request.clone())
2916        },
2917    );
2918    Ok(())
2919}
2920
2921/// forward a project request to the host. These requests should be read only
2922/// as guests are allowed to send them.
2923async fn forward_read_only_project_request<T>(
2924    request: T,
2925    response: Response<T>,
2926    session: UserSession,
2927) -> Result<()>
2928where
2929    T: EntityMessage + RequestMessage,
2930{
2931    let project_id = ProjectId::from_proto(request.remote_entity_id());
2932    let host_connection_id = session
2933        .db()
2934        .await
2935        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2936        .await?;
2937    let payload = session
2938        .peer
2939        .forward_request(session.connection_id, host_connection_id, request)
2940        .await?;
2941    response.send(payload)?;
2942    Ok(())
2943}
2944
2945async fn forward_find_search_candidates_request(
2946    request: proto::FindSearchCandidates,
2947    response: Response<proto::FindSearchCandidates>,
2948    session: UserSession,
2949) -> Result<()> {
2950    let project_id = ProjectId::from_proto(request.remote_entity_id());
2951    let host_connection_id = session
2952        .db()
2953        .await
2954        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2955        .await?;
2956    let payload = session
2957        .peer
2958        .forward_request(session.connection_id, host_connection_id, request)
2959        .await?;
2960    response.send(payload)?;
2961    Ok(())
2962}
2963
2964/// forward a project request to the dev server. Only allowed
2965/// if it's your dev server.
2966async fn forward_project_request_for_owner<T>(
2967    request: T,
2968    response: Response<T>,
2969    session: UserSession,
2970) -> Result<()>
2971where
2972    T: EntityMessage + RequestMessage,
2973{
2974    let project_id = ProjectId::from_proto(request.remote_entity_id());
2975
2976    let host_connection_id = session
2977        .db()
2978        .await
2979        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2980        .await?;
2981    let payload = session
2982        .peer
2983        .forward_request(session.connection_id, host_connection_id, request)
2984        .await?;
2985    response.send(payload)?;
2986    Ok(())
2987}
2988
2989/// forward a project request to the host. These requests are disallowed
2990/// for guests.
2991async fn forward_mutating_project_request<T>(
2992    request: T,
2993    response: Response<T>,
2994    session: UserSession,
2995) -> Result<()>
2996where
2997    T: EntityMessage + RequestMessage,
2998{
2999    let project_id = ProjectId::from_proto(request.remote_entity_id());
3000
3001    let host_connection_id = session
3002        .db()
3003        .await
3004        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3005        .await?;
3006    let payload = session
3007        .peer
3008        .forward_request(session.connection_id, host_connection_id, request)
3009        .await?;
3010    response.send(payload)?;
3011    Ok(())
3012}
3013
3014/// Notify other participants that a new buffer has been created
3015async fn create_buffer_for_peer(
3016    request: proto::CreateBufferForPeer,
3017    session: Session,
3018) -> Result<()> {
3019    session
3020        .db()
3021        .await
3022        .check_user_is_project_host(
3023            ProjectId::from_proto(request.project_id),
3024            session.connection_id,
3025        )
3026        .await?;
3027    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3028    session
3029        .peer
3030        .forward_send(session.connection_id, peer_id.into(), request)?;
3031    Ok(())
3032}
3033
3034/// Notify other participants that a buffer has been updated. This is
3035/// allowed for guests as long as the update is limited to selections.
3036async fn update_buffer(
3037    request: proto::UpdateBuffer,
3038    response: Response<proto::UpdateBuffer>,
3039    session: Session,
3040) -> Result<()> {
3041    let project_id = ProjectId::from_proto(request.project_id);
3042    let mut capability = Capability::ReadOnly;
3043
3044    for op in request.operations.iter() {
3045        match op.variant {
3046            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3047            Some(_) => capability = Capability::ReadWrite,
3048        }
3049    }
3050
3051    let host = {
3052        let guard = session
3053            .db()
3054            .await
3055            .connections_for_buffer_update(
3056                project_id,
3057                session.principal_id(),
3058                session.connection_id,
3059                capability,
3060            )
3061            .await?;
3062
3063        let (host, guests) = &*guard;
3064
3065        broadcast(
3066            Some(session.connection_id),
3067            guests.clone(),
3068            |connection_id| {
3069                session
3070                    .peer
3071                    .forward_send(session.connection_id, connection_id, request.clone())
3072            },
3073        );
3074
3075        *host
3076    };
3077
3078    if host != session.connection_id {
3079        session
3080            .peer
3081            .forward_request(session.connection_id, host, request.clone())
3082            .await?;
3083    }
3084
3085    response.send(proto::Ack {})?;
3086    Ok(())
3087}
3088
3089async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3090    let project_id = ProjectId::from_proto(message.project_id);
3091
3092    let operation = message.operation.as_ref().context("invalid operation")?;
3093    let capability = match operation.variant.as_ref() {
3094        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3095            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3096                match buffer_op.variant {
3097                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3098                        Capability::ReadOnly
3099                    }
3100                    _ => Capability::ReadWrite,
3101                }
3102            } else {
3103                Capability::ReadWrite
3104            }
3105        }
3106        Some(_) => Capability::ReadWrite,
3107        None => Capability::ReadOnly,
3108    };
3109
3110    let guard = session
3111        .db()
3112        .await
3113        .connections_for_buffer_update(
3114            project_id,
3115            session.principal_id(),
3116            session.connection_id,
3117            capability,
3118        )
3119        .await?;
3120
3121    let (host, guests) = &*guard;
3122
3123    broadcast(
3124        Some(session.connection_id),
3125        guests.iter().chain([host]).copied(),
3126        |connection_id| {
3127            session
3128                .peer
3129                .forward_send(session.connection_id, connection_id, message.clone())
3130        },
3131    );
3132
3133    Ok(())
3134}
3135
3136/// Notify other participants that a project has been updated.
3137async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3138    request: T,
3139    session: Session,
3140) -> Result<()> {
3141    let project_id = ProjectId::from_proto(request.remote_entity_id());
3142    let project_connection_ids = session
3143        .db()
3144        .await
3145        .project_connection_ids(project_id, session.connection_id, false)
3146        .await?;
3147
3148    broadcast(
3149        Some(session.connection_id),
3150        project_connection_ids.iter().copied(),
3151        |connection_id| {
3152            session
3153                .peer
3154                .forward_send(session.connection_id, connection_id, request.clone())
3155        },
3156    );
3157    Ok(())
3158}
3159
3160/// Start following another user in a call.
3161async fn follow(
3162    request: proto::Follow,
3163    response: Response<proto::Follow>,
3164    session: UserSession,
3165) -> Result<()> {
3166    let room_id = RoomId::from_proto(request.room_id);
3167    let project_id = request.project_id.map(ProjectId::from_proto);
3168    let leader_id = request
3169        .leader_id
3170        .ok_or_else(|| anyhow!("invalid leader id"))?
3171        .into();
3172    let follower_id = session.connection_id;
3173
3174    session
3175        .db()
3176        .await
3177        .check_room_participants(room_id, leader_id, session.connection_id)
3178        .await?;
3179
3180    let response_payload = session
3181        .peer
3182        .forward_request(session.connection_id, leader_id, request)
3183        .await?;
3184    response.send(response_payload)?;
3185
3186    if let Some(project_id) = project_id {
3187        let room = session
3188            .db()
3189            .await
3190            .follow(room_id, project_id, leader_id, follower_id)
3191            .await?;
3192        room_updated(&room, &session.peer);
3193    }
3194
3195    Ok(())
3196}
3197
3198/// Stop following another user in a call.
3199async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3200    let room_id = RoomId::from_proto(request.room_id);
3201    let project_id = request.project_id.map(ProjectId::from_proto);
3202    let leader_id = request
3203        .leader_id
3204        .ok_or_else(|| anyhow!("invalid leader id"))?
3205        .into();
3206    let follower_id = session.connection_id;
3207
3208    session
3209        .db()
3210        .await
3211        .check_room_participants(room_id, leader_id, session.connection_id)
3212        .await?;
3213
3214    session
3215        .peer
3216        .forward_send(session.connection_id, leader_id, request)?;
3217
3218    if let Some(project_id) = project_id {
3219        let room = session
3220            .db()
3221            .await
3222            .unfollow(room_id, project_id, leader_id, follower_id)
3223            .await?;
3224        room_updated(&room, &session.peer);
3225    }
3226
3227    Ok(())
3228}
3229
3230/// Notify everyone following you of your current location.
3231async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3232    let room_id = RoomId::from_proto(request.room_id);
3233    let database = session.db.lock().await;
3234
3235    let connection_ids = if let Some(project_id) = request.project_id {
3236        let project_id = ProjectId::from_proto(project_id);
3237        database
3238            .project_connection_ids(project_id, session.connection_id, true)
3239            .await?
3240    } else {
3241        database
3242            .room_connection_ids(room_id, session.connection_id)
3243            .await?
3244    };
3245
3246    // For now, don't send view update messages back to that view's current leader.
3247    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3248        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3249        _ => None,
3250    });
3251
3252    for connection_id in connection_ids.iter().cloned() {
3253        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3254            session
3255                .peer
3256                .forward_send(session.connection_id, connection_id, request.clone())?;
3257        }
3258    }
3259    Ok(())
3260}
3261
3262/// Get public data about users.
3263async fn get_users(
3264    request: proto::GetUsers,
3265    response: Response<proto::GetUsers>,
3266    session: Session,
3267) -> Result<()> {
3268    let user_ids = request
3269        .user_ids
3270        .into_iter()
3271        .map(UserId::from_proto)
3272        .collect();
3273    let users = session
3274        .db()
3275        .await
3276        .get_users_by_ids(user_ids)
3277        .await?
3278        .into_iter()
3279        .map(|user| proto::User {
3280            id: user.id.to_proto(),
3281            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3282            github_login: user.github_login,
3283        })
3284        .collect();
3285    response.send(proto::UsersResponse { users })?;
3286    Ok(())
3287}
3288
3289/// Search for users (to invite) buy Github login
3290async fn fuzzy_search_users(
3291    request: proto::FuzzySearchUsers,
3292    response: Response<proto::FuzzySearchUsers>,
3293    session: UserSession,
3294) -> Result<()> {
3295    let query = request.query;
3296    let users = match query.len() {
3297        0 => vec![],
3298        1 | 2 => session
3299            .db()
3300            .await
3301            .get_user_by_github_login(&query)
3302            .await?
3303            .into_iter()
3304            .collect(),
3305        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3306    };
3307    let users = users
3308        .into_iter()
3309        .filter(|user| user.id != session.user_id())
3310        .map(|user| proto::User {
3311            id: user.id.to_proto(),
3312            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3313            github_login: user.github_login,
3314        })
3315        .collect();
3316    response.send(proto::UsersResponse { users })?;
3317    Ok(())
3318}
3319
3320/// Send a contact request to another user.
3321async fn request_contact(
3322    request: proto::RequestContact,
3323    response: Response<proto::RequestContact>,
3324    session: UserSession,
3325) -> Result<()> {
3326    let requester_id = session.user_id();
3327    let responder_id = UserId::from_proto(request.responder_id);
3328    if requester_id == responder_id {
3329        return Err(anyhow!("cannot add yourself as a contact"))?;
3330    }
3331
3332    let notifications = session
3333        .db()
3334        .await
3335        .send_contact_request(requester_id, responder_id)
3336        .await?;
3337
3338    // Update outgoing contact requests of requester
3339    let mut update = proto::UpdateContacts::default();
3340    update.outgoing_requests.push(responder_id.to_proto());
3341    for connection_id in session
3342        .connection_pool()
3343        .await
3344        .user_connection_ids(requester_id)
3345    {
3346        session.peer.send(connection_id, update.clone())?;
3347    }
3348
3349    // Update incoming contact requests of responder
3350    let mut update = proto::UpdateContacts::default();
3351    update
3352        .incoming_requests
3353        .push(proto::IncomingContactRequest {
3354            requester_id: requester_id.to_proto(),
3355        });
3356    let connection_pool = session.connection_pool().await;
3357    for connection_id in connection_pool.user_connection_ids(responder_id) {
3358        session.peer.send(connection_id, update.clone())?;
3359    }
3360
3361    send_notifications(&connection_pool, &session.peer, notifications);
3362
3363    response.send(proto::Ack {})?;
3364    Ok(())
3365}
3366
3367/// Accept or decline a contact request
3368async fn respond_to_contact_request(
3369    request: proto::RespondToContactRequest,
3370    response: Response<proto::RespondToContactRequest>,
3371    session: UserSession,
3372) -> Result<()> {
3373    let responder_id = session.user_id();
3374    let requester_id = UserId::from_proto(request.requester_id);
3375    let db = session.db().await;
3376    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3377        db.dismiss_contact_notification(responder_id, requester_id)
3378            .await?;
3379    } else {
3380        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3381
3382        let notifications = db
3383            .respond_to_contact_request(responder_id, requester_id, accept)
3384            .await?;
3385        let requester_busy = db.is_user_busy(requester_id).await?;
3386        let responder_busy = db.is_user_busy(responder_id).await?;
3387
3388        let pool = session.connection_pool().await;
3389        // Update responder with new contact
3390        let mut update = proto::UpdateContacts::default();
3391        if accept {
3392            update
3393                .contacts
3394                .push(contact_for_user(requester_id, requester_busy, &pool));
3395        }
3396        update
3397            .remove_incoming_requests
3398            .push(requester_id.to_proto());
3399        for connection_id in pool.user_connection_ids(responder_id) {
3400            session.peer.send(connection_id, update.clone())?;
3401        }
3402
3403        // Update requester with new contact
3404        let mut update = proto::UpdateContacts::default();
3405        if accept {
3406            update
3407                .contacts
3408                .push(contact_for_user(responder_id, responder_busy, &pool));
3409        }
3410        update
3411            .remove_outgoing_requests
3412            .push(responder_id.to_proto());
3413
3414        for connection_id in pool.user_connection_ids(requester_id) {
3415            session.peer.send(connection_id, update.clone())?;
3416        }
3417
3418        send_notifications(&pool, &session.peer, notifications);
3419    }
3420
3421    response.send(proto::Ack {})?;
3422    Ok(())
3423}
3424
3425/// Remove a contact.
3426async fn remove_contact(
3427    request: proto::RemoveContact,
3428    response: Response<proto::RemoveContact>,
3429    session: UserSession,
3430) -> Result<()> {
3431    let requester_id = session.user_id();
3432    let responder_id = UserId::from_proto(request.user_id);
3433    let db = session.db().await;
3434    let (contact_accepted, deleted_notification_id) =
3435        db.remove_contact(requester_id, responder_id).await?;
3436
3437    let pool = session.connection_pool().await;
3438    // Update outgoing contact requests of requester
3439    let mut update = proto::UpdateContacts::default();
3440    if contact_accepted {
3441        update.remove_contacts.push(responder_id.to_proto());
3442    } else {
3443        update
3444            .remove_outgoing_requests
3445            .push(responder_id.to_proto());
3446    }
3447    for connection_id in pool.user_connection_ids(requester_id) {
3448        session.peer.send(connection_id, update.clone())?;
3449    }
3450
3451    // Update incoming contact requests of responder
3452    let mut update = proto::UpdateContacts::default();
3453    if contact_accepted {
3454        update.remove_contacts.push(requester_id.to_proto());
3455    } else {
3456        update
3457            .remove_incoming_requests
3458            .push(requester_id.to_proto());
3459    }
3460    for connection_id in pool.user_connection_ids(responder_id) {
3461        session.peer.send(connection_id, update.clone())?;
3462        if let Some(notification_id) = deleted_notification_id {
3463            session.peer.send(
3464                connection_id,
3465                proto::DeleteNotification {
3466                    notification_id: notification_id.to_proto(),
3467                },
3468            )?;
3469        }
3470    }
3471
3472    response.send(proto::Ack {})?;
3473    Ok(())
3474}
3475
3476fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3477    version.0.minor() < 139
3478}
3479
3480async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3481    let plan = session.current_plan(&session.db().await).await?;
3482
3483    session
3484        .peer
3485        .send(
3486            session.connection_id,
3487            proto::UpdateUserPlan { plan: plan.into() },
3488        )
3489        .trace_err();
3490
3491    Ok(())
3492}
3493
3494async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3495    subscribe_user_to_channels(
3496        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3497        &session,
3498    )
3499    .await?;
3500    Ok(())
3501}
3502
3503async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3504    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3505    let mut pool = session.connection_pool().await;
3506    for membership in &channels_for_user.channel_memberships {
3507        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3508    }
3509    session.peer.send(
3510        session.connection_id,
3511        build_update_user_channels(&channels_for_user),
3512    )?;
3513    session.peer.send(
3514        session.connection_id,
3515        build_channels_update(channels_for_user),
3516    )?;
3517    Ok(())
3518}
3519
3520/// Creates a new channel.
3521async fn create_channel(
3522    request: proto::CreateChannel,
3523    response: Response<proto::CreateChannel>,
3524    session: UserSession,
3525) -> Result<()> {
3526    let db = session.db().await;
3527
3528    let parent_id = request.parent_id.map(ChannelId::from_proto);
3529    let (channel, membership) = db
3530        .create_channel(&request.name, parent_id, session.user_id())
3531        .await?;
3532
3533    let root_id = channel.root_id();
3534    let channel = Channel::from_model(channel);
3535
3536    response.send(proto::CreateChannelResponse {
3537        channel: Some(channel.to_proto()),
3538        parent_id: request.parent_id,
3539    })?;
3540
3541    let mut connection_pool = session.connection_pool().await;
3542    if let Some(membership) = membership {
3543        connection_pool.subscribe_to_channel(
3544            membership.user_id,
3545            membership.channel_id,
3546            membership.role,
3547        );
3548        let update = proto::UpdateUserChannels {
3549            channel_memberships: vec![proto::ChannelMembership {
3550                channel_id: membership.channel_id.to_proto(),
3551                role: membership.role.into(),
3552            }],
3553            ..Default::default()
3554        };
3555        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3556            session.peer.send(connection_id, update.clone())?;
3557        }
3558    }
3559
3560    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3561        if !role.can_see_channel(channel.visibility) {
3562            continue;
3563        }
3564
3565        let update = proto::UpdateChannels {
3566            channels: vec![channel.to_proto()],
3567            ..Default::default()
3568        };
3569        session.peer.send(connection_id, update.clone())?;
3570    }
3571
3572    Ok(())
3573}
3574
3575/// Delete a channel
3576async fn delete_channel(
3577    request: proto::DeleteChannel,
3578    response: Response<proto::DeleteChannel>,
3579    session: UserSession,
3580) -> Result<()> {
3581    let db = session.db().await;
3582
3583    let channel_id = request.channel_id;
3584    let (root_channel, removed_channels) = db
3585        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3586        .await?;
3587    response.send(proto::Ack {})?;
3588
3589    // Notify members of removed channels
3590    let mut update = proto::UpdateChannels::default();
3591    update
3592        .delete_channels
3593        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3594
3595    let connection_pool = session.connection_pool().await;
3596    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3597        session.peer.send(connection_id, update.clone())?;
3598    }
3599
3600    Ok(())
3601}
3602
3603/// Invite someone to join a channel.
3604async fn invite_channel_member(
3605    request: proto::InviteChannelMember,
3606    response: Response<proto::InviteChannelMember>,
3607    session: UserSession,
3608) -> Result<()> {
3609    let db = session.db().await;
3610    let channel_id = ChannelId::from_proto(request.channel_id);
3611    let invitee_id = UserId::from_proto(request.user_id);
3612    let InviteMemberResult {
3613        channel,
3614        notifications,
3615    } = db
3616        .invite_channel_member(
3617            channel_id,
3618            invitee_id,
3619            session.user_id(),
3620            request.role().into(),
3621        )
3622        .await?;
3623
3624    let update = proto::UpdateChannels {
3625        channel_invitations: vec![channel.to_proto()],
3626        ..Default::default()
3627    };
3628
3629    let connection_pool = session.connection_pool().await;
3630    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3631        session.peer.send(connection_id, update.clone())?;
3632    }
3633
3634    send_notifications(&connection_pool, &session.peer, notifications);
3635
3636    response.send(proto::Ack {})?;
3637    Ok(())
3638}
3639
3640/// remove someone from a channel
3641async fn remove_channel_member(
3642    request: proto::RemoveChannelMember,
3643    response: Response<proto::RemoveChannelMember>,
3644    session: UserSession,
3645) -> Result<()> {
3646    let db = session.db().await;
3647    let channel_id = ChannelId::from_proto(request.channel_id);
3648    let member_id = UserId::from_proto(request.user_id);
3649
3650    let RemoveChannelMemberResult {
3651        membership_update,
3652        notification_id,
3653    } = db
3654        .remove_channel_member(channel_id, member_id, session.user_id())
3655        .await?;
3656
3657    let mut connection_pool = session.connection_pool().await;
3658    notify_membership_updated(
3659        &mut connection_pool,
3660        membership_update,
3661        member_id,
3662        &session.peer,
3663    );
3664    for connection_id in connection_pool.user_connection_ids(member_id) {
3665        if let Some(notification_id) = notification_id {
3666            session
3667                .peer
3668                .send(
3669                    connection_id,
3670                    proto::DeleteNotification {
3671                        notification_id: notification_id.to_proto(),
3672                    },
3673                )
3674                .trace_err();
3675        }
3676    }
3677
3678    response.send(proto::Ack {})?;
3679    Ok(())
3680}
3681
3682/// Toggle the channel between public and private.
3683/// Care is taken to maintain the invariant that public channels only descend from public channels,
3684/// (though members-only channels can appear at any point in the hierarchy).
3685async fn set_channel_visibility(
3686    request: proto::SetChannelVisibility,
3687    response: Response<proto::SetChannelVisibility>,
3688    session: UserSession,
3689) -> Result<()> {
3690    let db = session.db().await;
3691    let channel_id = ChannelId::from_proto(request.channel_id);
3692    let visibility = request.visibility().into();
3693
3694    let channel_model = db
3695        .set_channel_visibility(channel_id, visibility, session.user_id())
3696        .await?;
3697    let root_id = channel_model.root_id();
3698    let channel = Channel::from_model(channel_model);
3699
3700    let mut connection_pool = session.connection_pool().await;
3701    for (user_id, role) in connection_pool
3702        .channel_user_ids(root_id)
3703        .collect::<Vec<_>>()
3704        .into_iter()
3705    {
3706        let update = if role.can_see_channel(channel.visibility) {
3707            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3708            proto::UpdateChannels {
3709                channels: vec![channel.to_proto()],
3710                ..Default::default()
3711            }
3712        } else {
3713            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3714            proto::UpdateChannels {
3715                delete_channels: vec![channel.id.to_proto()],
3716                ..Default::default()
3717            }
3718        };
3719
3720        for connection_id in connection_pool.user_connection_ids(user_id) {
3721            session.peer.send(connection_id, update.clone())?;
3722        }
3723    }
3724
3725    response.send(proto::Ack {})?;
3726    Ok(())
3727}
3728
3729/// Alter the role for a user in the channel.
3730async fn set_channel_member_role(
3731    request: proto::SetChannelMemberRole,
3732    response: Response<proto::SetChannelMemberRole>,
3733    session: UserSession,
3734) -> Result<()> {
3735    let db = session.db().await;
3736    let channel_id = ChannelId::from_proto(request.channel_id);
3737    let member_id = UserId::from_proto(request.user_id);
3738    let result = db
3739        .set_channel_member_role(
3740            channel_id,
3741            session.user_id(),
3742            member_id,
3743            request.role().into(),
3744        )
3745        .await?;
3746
3747    match result {
3748        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3749            let mut connection_pool = session.connection_pool().await;
3750            notify_membership_updated(
3751                &mut connection_pool,
3752                membership_update,
3753                member_id,
3754                &session.peer,
3755            )
3756        }
3757        db::SetMemberRoleResult::InviteUpdated(channel) => {
3758            let update = proto::UpdateChannels {
3759                channel_invitations: vec![channel.to_proto()],
3760                ..Default::default()
3761            };
3762
3763            for connection_id in session
3764                .connection_pool()
3765                .await
3766                .user_connection_ids(member_id)
3767            {
3768                session.peer.send(connection_id, update.clone())?;
3769            }
3770        }
3771    }
3772
3773    response.send(proto::Ack {})?;
3774    Ok(())
3775}
3776
3777/// Change the name of a channel
3778async fn rename_channel(
3779    request: proto::RenameChannel,
3780    response: Response<proto::RenameChannel>,
3781    session: UserSession,
3782) -> Result<()> {
3783    let db = session.db().await;
3784    let channel_id = ChannelId::from_proto(request.channel_id);
3785    let channel_model = db
3786        .rename_channel(channel_id, session.user_id(), &request.name)
3787        .await?;
3788    let root_id = channel_model.root_id();
3789    let channel = Channel::from_model(channel_model);
3790
3791    response.send(proto::RenameChannelResponse {
3792        channel: Some(channel.to_proto()),
3793    })?;
3794
3795    let connection_pool = session.connection_pool().await;
3796    let update = proto::UpdateChannels {
3797        channels: vec![channel.to_proto()],
3798        ..Default::default()
3799    };
3800    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3801        if role.can_see_channel(channel.visibility) {
3802            session.peer.send(connection_id, update.clone())?;
3803        }
3804    }
3805
3806    Ok(())
3807}
3808
3809/// Move a channel to a new parent.
3810async fn move_channel(
3811    request: proto::MoveChannel,
3812    response: Response<proto::MoveChannel>,
3813    session: UserSession,
3814) -> Result<()> {
3815    let channel_id = ChannelId::from_proto(request.channel_id);
3816    let to = ChannelId::from_proto(request.to);
3817
3818    let (root_id, channels) = session
3819        .db()
3820        .await
3821        .move_channel(channel_id, to, session.user_id())
3822        .await?;
3823
3824    let connection_pool = session.connection_pool().await;
3825    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3826        let channels = channels
3827            .iter()
3828            .filter_map(|channel| {
3829                if role.can_see_channel(channel.visibility) {
3830                    Some(channel.to_proto())
3831                } else {
3832                    None
3833                }
3834            })
3835            .collect::<Vec<_>>();
3836        if channels.is_empty() {
3837            continue;
3838        }
3839
3840        let update = proto::UpdateChannels {
3841            channels,
3842            ..Default::default()
3843        };
3844
3845        session.peer.send(connection_id, update.clone())?;
3846    }
3847
3848    response.send(Ack {})?;
3849    Ok(())
3850}
3851
3852/// Get the list of channel members
3853async fn get_channel_members(
3854    request: proto::GetChannelMembers,
3855    response: Response<proto::GetChannelMembers>,
3856    session: UserSession,
3857) -> Result<()> {
3858    let db = session.db().await;
3859    let channel_id = ChannelId::from_proto(request.channel_id);
3860    let limit = if request.limit == 0 {
3861        u16::MAX as u64
3862    } else {
3863        request.limit
3864    };
3865    let (members, users) = db
3866        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3867        .await?;
3868    response.send(proto::GetChannelMembersResponse { members, users })?;
3869    Ok(())
3870}
3871
3872/// Accept or decline a channel invitation.
3873async fn respond_to_channel_invite(
3874    request: proto::RespondToChannelInvite,
3875    response: Response<proto::RespondToChannelInvite>,
3876    session: UserSession,
3877) -> Result<()> {
3878    let db = session.db().await;
3879    let channel_id = ChannelId::from_proto(request.channel_id);
3880    let RespondToChannelInvite {
3881        membership_update,
3882        notifications,
3883    } = db
3884        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3885        .await?;
3886
3887    let mut connection_pool = session.connection_pool().await;
3888    if let Some(membership_update) = membership_update {
3889        notify_membership_updated(
3890            &mut connection_pool,
3891            membership_update,
3892            session.user_id(),
3893            &session.peer,
3894        );
3895    } else {
3896        let update = proto::UpdateChannels {
3897            remove_channel_invitations: vec![channel_id.to_proto()],
3898            ..Default::default()
3899        };
3900
3901        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3902            session.peer.send(connection_id, update.clone())?;
3903        }
3904    };
3905
3906    send_notifications(&connection_pool, &session.peer, notifications);
3907
3908    response.send(proto::Ack {})?;
3909
3910    Ok(())
3911}
3912
3913/// Join the channels' room
3914async fn join_channel(
3915    request: proto::JoinChannel,
3916    response: Response<proto::JoinChannel>,
3917    session: UserSession,
3918) -> Result<()> {
3919    let channel_id = ChannelId::from_proto(request.channel_id);
3920    join_channel_internal(channel_id, Box::new(response), session).await
3921}
3922
3923trait JoinChannelInternalResponse {
3924    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3925}
3926impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3927    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3928        Response::<proto::JoinChannel>::send(self, result)
3929    }
3930}
3931impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3932    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3933        Response::<proto::JoinRoom>::send(self, result)
3934    }
3935}
3936
3937async fn join_channel_internal(
3938    channel_id: ChannelId,
3939    response: Box<impl JoinChannelInternalResponse>,
3940    session: UserSession,
3941) -> Result<()> {
3942    let joined_room = {
3943        let mut db = session.db().await;
3944        // If zed quits without leaving the room, and the user re-opens zed before the
3945        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3946        // room they were in.
3947        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3948            tracing::info!(
3949                stale_connection_id = %connection,
3950                "cleaning up stale connection",
3951            );
3952            drop(db);
3953            leave_room_for_session(&session, connection).await?;
3954            db = session.db().await;
3955        }
3956
3957        let (joined_room, membership_updated, role) = db
3958            .join_channel(channel_id, session.user_id(), session.connection_id)
3959            .await?;
3960
3961        let live_kit_connection_info =
3962            session
3963                .app_state
3964                .live_kit_client
3965                .as_ref()
3966                .and_then(|live_kit| {
3967                    let (can_publish, token) = if role == ChannelRole::Guest {
3968                        (
3969                            false,
3970                            live_kit
3971                                .guest_token(
3972                                    &joined_room.room.live_kit_room,
3973                                    &session.user_id().to_string(),
3974                                )
3975                                .trace_err()?,
3976                        )
3977                    } else {
3978                        (
3979                            true,
3980                            live_kit
3981                                .room_token(
3982                                    &joined_room.room.live_kit_room,
3983                                    &session.user_id().to_string(),
3984                                )
3985                                .trace_err()?,
3986                        )
3987                    };
3988
3989                    Some(LiveKitConnectionInfo {
3990                        server_url: live_kit.url().into(),
3991                        token,
3992                        can_publish,
3993                    })
3994                });
3995
3996        response.send(proto::JoinRoomResponse {
3997            room: Some(joined_room.room.clone()),
3998            channel_id: joined_room
3999                .channel
4000                .as_ref()
4001                .map(|channel| channel.id.to_proto()),
4002            live_kit_connection_info,
4003        })?;
4004
4005        let mut connection_pool = session.connection_pool().await;
4006        if let Some(membership_updated) = membership_updated {
4007            notify_membership_updated(
4008                &mut connection_pool,
4009                membership_updated,
4010                session.user_id(),
4011                &session.peer,
4012            );
4013        }
4014
4015        room_updated(&joined_room.room, &session.peer);
4016
4017        joined_room
4018    };
4019
4020    channel_updated(
4021        &joined_room
4022            .channel
4023            .ok_or_else(|| anyhow!("channel not returned"))?,
4024        &joined_room.room,
4025        &session.peer,
4026        &*session.connection_pool().await,
4027    );
4028
4029    update_user_contacts(session.user_id(), &session).await?;
4030    Ok(())
4031}
4032
4033/// Start editing the channel notes
4034async fn join_channel_buffer(
4035    request: proto::JoinChannelBuffer,
4036    response: Response<proto::JoinChannelBuffer>,
4037    session: UserSession,
4038) -> Result<()> {
4039    let db = session.db().await;
4040    let channel_id = ChannelId::from_proto(request.channel_id);
4041
4042    let open_response = db
4043        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4044        .await?;
4045
4046    let collaborators = open_response.collaborators.clone();
4047    response.send(open_response)?;
4048
4049    let update = UpdateChannelBufferCollaborators {
4050        channel_id: channel_id.to_proto(),
4051        collaborators: collaborators.clone(),
4052    };
4053    channel_buffer_updated(
4054        session.connection_id,
4055        collaborators
4056            .iter()
4057            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4058        &update,
4059        &session.peer,
4060    );
4061
4062    Ok(())
4063}
4064
4065/// Edit the channel notes
4066async fn update_channel_buffer(
4067    request: proto::UpdateChannelBuffer,
4068    session: UserSession,
4069) -> Result<()> {
4070    let db = session.db().await;
4071    let channel_id = ChannelId::from_proto(request.channel_id);
4072
4073    let (collaborators, epoch, version) = db
4074        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4075        .await?;
4076
4077    channel_buffer_updated(
4078        session.connection_id,
4079        collaborators.clone(),
4080        &proto::UpdateChannelBuffer {
4081            channel_id: channel_id.to_proto(),
4082            operations: request.operations,
4083        },
4084        &session.peer,
4085    );
4086
4087    let pool = &*session.connection_pool().await;
4088
4089    let non_collaborators =
4090        pool.channel_connection_ids(channel_id)
4091            .filter_map(|(connection_id, _)| {
4092                if collaborators.contains(&connection_id) {
4093                    None
4094                } else {
4095                    Some(connection_id)
4096                }
4097            });
4098
4099    broadcast(None, non_collaborators, |peer_id| {
4100        session.peer.send(
4101            peer_id,
4102            proto::UpdateChannels {
4103                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4104                    channel_id: channel_id.to_proto(),
4105                    epoch: epoch as u64,
4106                    version: version.clone(),
4107                }],
4108                ..Default::default()
4109            },
4110        )
4111    });
4112
4113    Ok(())
4114}
4115
4116/// Rejoin the channel notes after a connection blip
4117async fn rejoin_channel_buffers(
4118    request: proto::RejoinChannelBuffers,
4119    response: Response<proto::RejoinChannelBuffers>,
4120    session: UserSession,
4121) -> Result<()> {
4122    let db = session.db().await;
4123    let buffers = db
4124        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4125        .await?;
4126
4127    for rejoined_buffer in &buffers {
4128        let collaborators_to_notify = rejoined_buffer
4129            .buffer
4130            .collaborators
4131            .iter()
4132            .filter_map(|c| Some(c.peer_id?.into()));
4133        channel_buffer_updated(
4134            session.connection_id,
4135            collaborators_to_notify,
4136            &proto::UpdateChannelBufferCollaborators {
4137                channel_id: rejoined_buffer.buffer.channel_id,
4138                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4139            },
4140            &session.peer,
4141        );
4142    }
4143
4144    response.send(proto::RejoinChannelBuffersResponse {
4145        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4146    })?;
4147
4148    Ok(())
4149}
4150
4151/// Stop editing the channel notes
4152async fn leave_channel_buffer(
4153    request: proto::LeaveChannelBuffer,
4154    response: Response<proto::LeaveChannelBuffer>,
4155    session: UserSession,
4156) -> Result<()> {
4157    let db = session.db().await;
4158    let channel_id = ChannelId::from_proto(request.channel_id);
4159
4160    let left_buffer = db
4161        .leave_channel_buffer(channel_id, session.connection_id)
4162        .await?;
4163
4164    response.send(Ack {})?;
4165
4166    channel_buffer_updated(
4167        session.connection_id,
4168        left_buffer.connections,
4169        &proto::UpdateChannelBufferCollaborators {
4170            channel_id: channel_id.to_proto(),
4171            collaborators: left_buffer.collaborators,
4172        },
4173        &session.peer,
4174    );
4175
4176    Ok(())
4177}
4178
4179fn channel_buffer_updated<T: EnvelopedMessage>(
4180    sender_id: ConnectionId,
4181    collaborators: impl IntoIterator<Item = ConnectionId>,
4182    message: &T,
4183    peer: &Peer,
4184) {
4185    broadcast(Some(sender_id), collaborators, |peer_id| {
4186        peer.send(peer_id, message.clone())
4187    });
4188}
4189
4190fn send_notifications(
4191    connection_pool: &ConnectionPool,
4192    peer: &Peer,
4193    notifications: db::NotificationBatch,
4194) {
4195    for (user_id, notification) in notifications {
4196        for connection_id in connection_pool.user_connection_ids(user_id) {
4197            if let Err(error) = peer.send(
4198                connection_id,
4199                proto::AddNotification {
4200                    notification: Some(notification.clone()),
4201                },
4202            ) {
4203                tracing::error!(
4204                    "failed to send notification to {:?} {}",
4205                    connection_id,
4206                    error
4207                );
4208            }
4209        }
4210    }
4211}
4212
4213/// Send a message to the channel
4214async fn send_channel_message(
4215    request: proto::SendChannelMessage,
4216    response: Response<proto::SendChannelMessage>,
4217    session: UserSession,
4218) -> Result<()> {
4219    // Validate the message body.
4220    let body = request.body.trim().to_string();
4221    if body.len() > MAX_MESSAGE_LEN {
4222        return Err(anyhow!("message is too long"))?;
4223    }
4224    if body.is_empty() {
4225        return Err(anyhow!("message can't be blank"))?;
4226    }
4227
4228    // TODO: adjust mentions if body is trimmed
4229
4230    let timestamp = OffsetDateTime::now_utc();
4231    let nonce = request
4232        .nonce
4233        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4234
4235    let channel_id = ChannelId::from_proto(request.channel_id);
4236    let CreatedChannelMessage {
4237        message_id,
4238        participant_connection_ids,
4239        notifications,
4240    } = session
4241        .db()
4242        .await
4243        .create_channel_message(
4244            channel_id,
4245            session.user_id(),
4246            &body,
4247            &request.mentions,
4248            timestamp,
4249            nonce.clone().into(),
4250            request.reply_to_message_id.map(MessageId::from_proto),
4251        )
4252        .await?;
4253
4254    let message = proto::ChannelMessage {
4255        sender_id: session.user_id().to_proto(),
4256        id: message_id.to_proto(),
4257        body,
4258        mentions: request.mentions,
4259        timestamp: timestamp.unix_timestamp() as u64,
4260        nonce: Some(nonce),
4261        reply_to_message_id: request.reply_to_message_id,
4262        edited_at: None,
4263    };
4264    broadcast(
4265        Some(session.connection_id),
4266        participant_connection_ids.clone(),
4267        |connection| {
4268            session.peer.send(
4269                connection,
4270                proto::ChannelMessageSent {
4271                    channel_id: channel_id.to_proto(),
4272                    message: Some(message.clone()),
4273                },
4274            )
4275        },
4276    );
4277    response.send(proto::SendChannelMessageResponse {
4278        message: Some(message),
4279    })?;
4280
4281    let pool = &*session.connection_pool().await;
4282    let non_participants =
4283        pool.channel_connection_ids(channel_id)
4284            .filter_map(|(connection_id, _)| {
4285                if participant_connection_ids.contains(&connection_id) {
4286                    None
4287                } else {
4288                    Some(connection_id)
4289                }
4290            });
4291    broadcast(None, non_participants, |peer_id| {
4292        session.peer.send(
4293            peer_id,
4294            proto::UpdateChannels {
4295                latest_channel_message_ids: vec![proto::ChannelMessageId {
4296                    channel_id: channel_id.to_proto(),
4297                    message_id: message_id.to_proto(),
4298                }],
4299                ..Default::default()
4300            },
4301        )
4302    });
4303    send_notifications(pool, &session.peer, notifications);
4304
4305    Ok(())
4306}
4307
4308/// Delete a channel message
4309async fn remove_channel_message(
4310    request: proto::RemoveChannelMessage,
4311    response: Response<proto::RemoveChannelMessage>,
4312    session: UserSession,
4313) -> Result<()> {
4314    let channel_id = ChannelId::from_proto(request.channel_id);
4315    let message_id = MessageId::from_proto(request.message_id);
4316    let (connection_ids, existing_notification_ids) = session
4317        .db()
4318        .await
4319        .remove_channel_message(channel_id, message_id, session.user_id())
4320        .await?;
4321
4322    broadcast(
4323        Some(session.connection_id),
4324        connection_ids,
4325        move |connection| {
4326            session.peer.send(connection, request.clone())?;
4327
4328            for notification_id in &existing_notification_ids {
4329                session.peer.send(
4330                    connection,
4331                    proto::DeleteNotification {
4332                        notification_id: (*notification_id).to_proto(),
4333                    },
4334                )?;
4335            }
4336
4337            Ok(())
4338        },
4339    );
4340    response.send(proto::Ack {})?;
4341    Ok(())
4342}
4343
4344async fn update_channel_message(
4345    request: proto::UpdateChannelMessage,
4346    response: Response<proto::UpdateChannelMessage>,
4347    session: UserSession,
4348) -> Result<()> {
4349    let channel_id = ChannelId::from_proto(request.channel_id);
4350    let message_id = MessageId::from_proto(request.message_id);
4351    let updated_at = OffsetDateTime::now_utc();
4352    let UpdatedChannelMessage {
4353        message_id,
4354        participant_connection_ids,
4355        notifications,
4356        reply_to_message_id,
4357        timestamp,
4358        deleted_mention_notification_ids,
4359        updated_mention_notifications,
4360    } = session
4361        .db()
4362        .await
4363        .update_channel_message(
4364            channel_id,
4365            message_id,
4366            session.user_id(),
4367            request.body.as_str(),
4368            &request.mentions,
4369            updated_at,
4370        )
4371        .await?;
4372
4373    let nonce = request
4374        .nonce
4375        .clone()
4376        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4377
4378    let message = proto::ChannelMessage {
4379        sender_id: session.user_id().to_proto(),
4380        id: message_id.to_proto(),
4381        body: request.body.clone(),
4382        mentions: request.mentions.clone(),
4383        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4384        nonce: Some(nonce),
4385        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4386        edited_at: Some(updated_at.unix_timestamp() as u64),
4387    };
4388
4389    response.send(proto::Ack {})?;
4390
4391    let pool = &*session.connection_pool().await;
4392    broadcast(
4393        Some(session.connection_id),
4394        participant_connection_ids,
4395        |connection| {
4396            session.peer.send(
4397                connection,
4398                proto::ChannelMessageUpdate {
4399                    channel_id: channel_id.to_proto(),
4400                    message: Some(message.clone()),
4401                },
4402            )?;
4403
4404            for notification_id in &deleted_mention_notification_ids {
4405                session.peer.send(
4406                    connection,
4407                    proto::DeleteNotification {
4408                        notification_id: (*notification_id).to_proto(),
4409                    },
4410                )?;
4411            }
4412
4413            for notification in &updated_mention_notifications {
4414                session.peer.send(
4415                    connection,
4416                    proto::UpdateNotification {
4417                        notification: Some(notification.clone()),
4418                    },
4419                )?;
4420            }
4421
4422            Ok(())
4423        },
4424    );
4425
4426    send_notifications(pool, &session.peer, notifications);
4427
4428    Ok(())
4429}
4430
4431/// Mark a channel message as read
4432async fn acknowledge_channel_message(
4433    request: proto::AckChannelMessage,
4434    session: UserSession,
4435) -> Result<()> {
4436    let channel_id = ChannelId::from_proto(request.channel_id);
4437    let message_id = MessageId::from_proto(request.message_id);
4438    let notifications = session
4439        .db()
4440        .await
4441        .observe_channel_message(channel_id, session.user_id(), message_id)
4442        .await?;
4443    send_notifications(
4444        &*session.connection_pool().await,
4445        &session.peer,
4446        notifications,
4447    );
4448    Ok(())
4449}
4450
4451/// Mark a buffer version as synced
4452async fn acknowledge_buffer_version(
4453    request: proto::AckBufferOperation,
4454    session: UserSession,
4455) -> Result<()> {
4456    let buffer_id = BufferId::from_proto(request.buffer_id);
4457    session
4458        .db()
4459        .await
4460        .observe_buffer_version(
4461            buffer_id,
4462            session.user_id(),
4463            request.epoch as i32,
4464            &request.version,
4465        )
4466        .await?;
4467    Ok(())
4468}
4469
4470async fn count_language_model_tokens(
4471    request: proto::CountLanguageModelTokens,
4472    response: Response<proto::CountLanguageModelTokens>,
4473    session: Session,
4474    config: &Config,
4475) -> Result<()> {
4476    let Some(session) = session.for_user() else {
4477        return Err(anyhow!("user not found"))?;
4478    };
4479    authorize_access_to_legacy_llm_endpoints(&session).await?;
4480
4481    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4482        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4483        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4484    };
4485
4486    session
4487        .app_state
4488        .rate_limiter
4489        .check(&*rate_limit, session.user_id())
4490        .await?;
4491
4492    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4493        Some(proto::LanguageModelProvider::Google) => {
4494            let api_key = config
4495                .google_ai_api_key
4496                .as_ref()
4497                .context("no Google AI API key configured on the server")?;
4498            google_ai::count_tokens(
4499                session.http_client.as_ref(),
4500                google_ai::API_URL,
4501                api_key,
4502                serde_json::from_str(&request.request)?,
4503                None,
4504            )
4505            .await?
4506        }
4507        _ => return Err(anyhow!("unsupported provider"))?,
4508    };
4509
4510    response.send(proto::CountLanguageModelTokensResponse {
4511        token_count: result.total_tokens as u32,
4512    })?;
4513
4514    Ok(())
4515}
4516
4517struct ZedProCountLanguageModelTokensRateLimit;
4518
4519impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4520    fn capacity(&self) -> usize {
4521        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4522            .ok()
4523            .and_then(|v| v.parse().ok())
4524            .unwrap_or(600) // Picked arbitrarily
4525    }
4526
4527    fn refill_duration(&self) -> chrono::Duration {
4528        chrono::Duration::hours(1)
4529    }
4530
4531    fn db_name(&self) -> &'static str {
4532        "zed-pro:count-language-model-tokens"
4533    }
4534}
4535
4536struct FreeCountLanguageModelTokensRateLimit;
4537
4538impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4539    fn capacity(&self) -> usize {
4540        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4541            .ok()
4542            .and_then(|v| v.parse().ok())
4543            .unwrap_or(600 / 10) // Picked arbitrarily
4544    }
4545
4546    fn refill_duration(&self) -> chrono::Duration {
4547        chrono::Duration::hours(1)
4548    }
4549
4550    fn db_name(&self) -> &'static str {
4551        "free:count-language-model-tokens"
4552    }
4553}
4554
4555struct ZedProComputeEmbeddingsRateLimit;
4556
4557impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4558    fn capacity(&self) -> usize {
4559        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4560            .ok()
4561            .and_then(|v| v.parse().ok())
4562            .unwrap_or(5000) // Picked arbitrarily
4563    }
4564
4565    fn refill_duration(&self) -> chrono::Duration {
4566        chrono::Duration::hours(1)
4567    }
4568
4569    fn db_name(&self) -> &'static str {
4570        "zed-pro:compute-embeddings"
4571    }
4572}
4573
4574struct FreeComputeEmbeddingsRateLimit;
4575
4576impl RateLimit for FreeComputeEmbeddingsRateLimit {
4577    fn capacity(&self) -> usize {
4578        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4579            .ok()
4580            .and_then(|v| v.parse().ok())
4581            .unwrap_or(5000 / 10) // Picked arbitrarily
4582    }
4583
4584    fn refill_duration(&self) -> chrono::Duration {
4585        chrono::Duration::hours(1)
4586    }
4587
4588    fn db_name(&self) -> &'static str {
4589        "free:compute-embeddings"
4590    }
4591}
4592
4593async fn compute_embeddings(
4594    request: proto::ComputeEmbeddings,
4595    response: Response<proto::ComputeEmbeddings>,
4596    session: UserSession,
4597    api_key: Option<Arc<str>>,
4598) -> Result<()> {
4599    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4600    authorize_access_to_legacy_llm_endpoints(&session).await?;
4601
4602    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4603        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4604        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4605    };
4606
4607    session
4608        .app_state
4609        .rate_limiter
4610        .check(&*rate_limit, session.user_id())
4611        .await?;
4612
4613    let embeddings = match request.model.as_str() {
4614        "openai/text-embedding-3-small" => {
4615            open_ai::embed(
4616                session.http_client.as_ref(),
4617                OPEN_AI_API_URL,
4618                &api_key,
4619                OpenAiEmbeddingModel::TextEmbedding3Small,
4620                request.texts.iter().map(|text| text.as_str()),
4621            )
4622            .await?
4623        }
4624        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4625    };
4626
4627    let embeddings = request
4628        .texts
4629        .iter()
4630        .map(|text| {
4631            let mut hasher = sha2::Sha256::new();
4632            hasher.update(text.as_bytes());
4633            let result = hasher.finalize();
4634            result.to_vec()
4635        })
4636        .zip(
4637            embeddings
4638                .data
4639                .into_iter()
4640                .map(|embedding| embedding.embedding),
4641        )
4642        .collect::<HashMap<_, _>>();
4643
4644    let db = session.db().await;
4645    db.save_embeddings(&request.model, &embeddings)
4646        .await
4647        .context("failed to save embeddings")
4648        .trace_err();
4649
4650    response.send(proto::ComputeEmbeddingsResponse {
4651        embeddings: embeddings
4652            .into_iter()
4653            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4654            .collect(),
4655    })?;
4656    Ok(())
4657}
4658
4659async fn get_cached_embeddings(
4660    request: proto::GetCachedEmbeddings,
4661    response: Response<proto::GetCachedEmbeddings>,
4662    session: UserSession,
4663) -> Result<()> {
4664    authorize_access_to_legacy_llm_endpoints(&session).await?;
4665
4666    let db = session.db().await;
4667    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4668
4669    response.send(proto::GetCachedEmbeddingsResponse {
4670        embeddings: embeddings
4671            .into_iter()
4672            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4673            .collect(),
4674    })?;
4675    Ok(())
4676}
4677
4678/// This is leftover from before the LLM service.
4679///
4680/// The endpoints protected by this check will be moved there eventually.
4681async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4682    if session.is_staff() {
4683        Ok(())
4684    } else {
4685        Err(anyhow!("permission denied"))?
4686    }
4687}
4688
4689/// Get a Supermaven API key for the user
4690async fn get_supermaven_api_key(
4691    _request: proto::GetSupermavenApiKey,
4692    response: Response<proto::GetSupermavenApiKey>,
4693    session: UserSession,
4694) -> Result<()> {
4695    let user_id: String = session.user_id().to_string();
4696    if !session.is_staff() {
4697        return Err(anyhow!("supermaven not enabled for this account"))?;
4698    }
4699
4700    let email = session
4701        .email()
4702        .ok_or_else(|| anyhow!("user must have an email"))?;
4703
4704    let supermaven_admin_api = session
4705        .supermaven_client
4706        .as_ref()
4707        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4708
4709    let result = supermaven_admin_api
4710        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4711        .await?;
4712
4713    response.send(proto::GetSupermavenApiKeyResponse {
4714        api_key: result.api_key,
4715    })?;
4716
4717    Ok(())
4718}
4719
4720/// Start receiving chat updates for a channel
4721async fn join_channel_chat(
4722    request: proto::JoinChannelChat,
4723    response: Response<proto::JoinChannelChat>,
4724    session: UserSession,
4725) -> Result<()> {
4726    let channel_id = ChannelId::from_proto(request.channel_id);
4727
4728    let db = session.db().await;
4729    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4730        .await?;
4731    let messages = db
4732        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4733        .await?;
4734    response.send(proto::JoinChannelChatResponse {
4735        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4736        messages,
4737    })?;
4738    Ok(())
4739}
4740
4741/// Stop receiving chat updates for a channel
4742async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4743    let channel_id = ChannelId::from_proto(request.channel_id);
4744    session
4745        .db()
4746        .await
4747        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4748        .await?;
4749    Ok(())
4750}
4751
4752/// Retrieve the chat history for a channel
4753async fn get_channel_messages(
4754    request: proto::GetChannelMessages,
4755    response: Response<proto::GetChannelMessages>,
4756    session: UserSession,
4757) -> Result<()> {
4758    let channel_id = ChannelId::from_proto(request.channel_id);
4759    let messages = session
4760        .db()
4761        .await
4762        .get_channel_messages(
4763            channel_id,
4764            session.user_id(),
4765            MESSAGE_COUNT_PER_PAGE,
4766            Some(MessageId::from_proto(request.before_message_id)),
4767        )
4768        .await?;
4769    response.send(proto::GetChannelMessagesResponse {
4770        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4771        messages,
4772    })?;
4773    Ok(())
4774}
4775
4776/// Retrieve specific chat messages
4777async fn get_channel_messages_by_id(
4778    request: proto::GetChannelMessagesById,
4779    response: Response<proto::GetChannelMessagesById>,
4780    session: UserSession,
4781) -> Result<()> {
4782    let message_ids = request
4783        .message_ids
4784        .iter()
4785        .map(|id| MessageId::from_proto(*id))
4786        .collect::<Vec<_>>();
4787    let messages = session
4788        .db()
4789        .await
4790        .get_channel_messages_by_id(session.user_id(), &message_ids)
4791        .await?;
4792    response.send(proto::GetChannelMessagesResponse {
4793        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4794        messages,
4795    })?;
4796    Ok(())
4797}
4798
4799/// Retrieve the current users notifications
4800async fn get_notifications(
4801    request: proto::GetNotifications,
4802    response: Response<proto::GetNotifications>,
4803    session: UserSession,
4804) -> Result<()> {
4805    let notifications = session
4806        .db()
4807        .await
4808        .get_notifications(
4809            session.user_id(),
4810            NOTIFICATION_COUNT_PER_PAGE,
4811            request.before_id.map(db::NotificationId::from_proto),
4812        )
4813        .await?;
4814    response.send(proto::GetNotificationsResponse {
4815        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4816        notifications,
4817    })?;
4818    Ok(())
4819}
4820
4821/// Mark notifications as read
4822async fn mark_notification_as_read(
4823    request: proto::MarkNotificationRead,
4824    response: Response<proto::MarkNotificationRead>,
4825    session: UserSession,
4826) -> Result<()> {
4827    let database = &session.db().await;
4828    let notifications = database
4829        .mark_notification_as_read_by_id(
4830            session.user_id(),
4831            NotificationId::from_proto(request.notification_id),
4832        )
4833        .await?;
4834    send_notifications(
4835        &*session.connection_pool().await,
4836        &session.peer,
4837        notifications,
4838    );
4839    response.send(proto::Ack {})?;
4840    Ok(())
4841}
4842
4843/// Get the current users information
4844async fn get_private_user_info(
4845    _request: proto::GetPrivateUserInfo,
4846    response: Response<proto::GetPrivateUserInfo>,
4847    session: UserSession,
4848) -> Result<()> {
4849    let db = session.db().await;
4850
4851    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4852    let user = db
4853        .get_user_by_id(session.user_id())
4854        .await?
4855        .ok_or_else(|| anyhow!("user not found"))?;
4856    let flags = db.get_user_flags(session.user_id()).await?;
4857
4858    response.send(proto::GetPrivateUserInfoResponse {
4859        metrics_id,
4860        staff: user.admin,
4861        flags,
4862        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4863    })?;
4864    Ok(())
4865}
4866
4867/// Accept the terms of service (tos) on behalf of the current user
4868async fn accept_terms_of_service(
4869    _request: proto::AcceptTermsOfService,
4870    response: Response<proto::AcceptTermsOfService>,
4871    session: UserSession,
4872) -> Result<()> {
4873    let db = session.db().await;
4874
4875    let accepted_tos_at = Utc::now();
4876    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4877        .await?;
4878
4879    response.send(proto::AcceptTermsOfServiceResponse {
4880        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4881    })?;
4882    Ok(())
4883}
4884
4885/// The minimum account age an account must have in order to use the LLM service.
4886const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4887
4888async fn get_llm_api_token(
4889    _request: proto::GetLlmToken,
4890    response: Response<proto::GetLlmToken>,
4891    session: UserSession,
4892) -> Result<()> {
4893    let db = session.db().await;
4894
4895    let flags = db.get_user_flags(session.user_id()).await?;
4896    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4897    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4898
4899    if !session.is_staff() && !has_language_models_feature_flag {
4900        Err(anyhow!("permission denied"))?
4901    }
4902
4903    let user_id = session.user_id();
4904    let user = db
4905        .get_user_by_id(user_id)
4906        .await?
4907        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4908
4909    if user.accepted_tos_at.is_none() {
4910        Err(anyhow!("terms of service not accepted"))?
4911    }
4912
4913    let mut account_created_at = user.created_at;
4914    if let Some(github_created_at) = user.github_user_created_at {
4915        account_created_at = account_created_at.min(github_created_at);
4916    }
4917    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4918        Err(anyhow!("account too young"))?
4919    }
4920
4921    let billing_preferences = db.get_billing_preferences(user.id).await?;
4922
4923    let token = LlmTokenClaims::create(
4924        user.id,
4925        user.github_login.clone(),
4926        session.is_staff(),
4927        billing_preferences,
4928        has_llm_closed_beta_feature_flag,
4929        session.has_llm_subscription(&db).await?,
4930        session.current_plan(&db).await?,
4931        &session.app_state.config,
4932    )?;
4933    response.send(proto::GetLlmTokenResponse { token })?;
4934    Ok(())
4935}
4936
4937fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4938    let message = match message {
4939        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4940        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4941        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4942        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4943        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4944            code: frame.code.into(),
4945            reason: frame.reason,
4946        })),
4947        // We should never receive a frame while reading the message, according
4948        // to the `tungstenite` maintainers:
4949        //
4950        // > It cannot occur when you read messages from the WebSocket, but it
4951        // > can be used when you want to send the raw frames (e.g. you want to
4952        // > send the frames to the WebSocket without composing the full message first).
4953        // >
4954        // > — https://github.com/snapview/tungstenite-rs/issues/268
4955        TungsteniteMessage::Frame(_) => {
4956            bail!("received an unexpected frame while reading the message")
4957        }
4958    };
4959
4960    Ok(message)
4961}
4962
4963fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4964    match message {
4965        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4966        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4967        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4968        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4969        AxumMessage::Close(frame) => {
4970            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4971                code: frame.code.into(),
4972                reason: frame.reason,
4973            }))
4974        }
4975    }
4976}
4977
4978fn notify_membership_updated(
4979    connection_pool: &mut ConnectionPool,
4980    result: MembershipUpdated,
4981    user_id: UserId,
4982    peer: &Peer,
4983) {
4984    for membership in &result.new_channels.channel_memberships {
4985        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4986    }
4987    for channel_id in &result.removed_channels {
4988        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4989    }
4990
4991    let user_channels_update = proto::UpdateUserChannels {
4992        channel_memberships: result
4993            .new_channels
4994            .channel_memberships
4995            .iter()
4996            .map(|cm| proto::ChannelMembership {
4997                channel_id: cm.channel_id.to_proto(),
4998                role: cm.role.into(),
4999            })
5000            .collect(),
5001        ..Default::default()
5002    };
5003
5004    let mut update = build_channels_update(result.new_channels);
5005    update.delete_channels = result
5006        .removed_channels
5007        .into_iter()
5008        .map(|id| id.to_proto())
5009        .collect();
5010    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5011
5012    for connection_id in connection_pool.user_connection_ids(user_id) {
5013        peer.send(connection_id, user_channels_update.clone())
5014            .trace_err();
5015        peer.send(connection_id, update.clone()).trace_err();
5016    }
5017}
5018
5019fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5020    proto::UpdateUserChannels {
5021        channel_memberships: channels
5022            .channel_memberships
5023            .iter()
5024            .map(|m| proto::ChannelMembership {
5025                channel_id: m.channel_id.to_proto(),
5026                role: m.role.into(),
5027            })
5028            .collect(),
5029        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5030        observed_channel_message_id: channels.observed_channel_messages.clone(),
5031    }
5032}
5033
5034fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5035    let mut update = proto::UpdateChannels::default();
5036
5037    for channel in channels.channels {
5038        update.channels.push(channel.to_proto());
5039    }
5040
5041    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5042    update.latest_channel_message_ids = channels.latest_channel_messages;
5043
5044    for (channel_id, participants) in channels.channel_participants {
5045        update
5046            .channel_participants
5047            .push(proto::ChannelParticipants {
5048                channel_id: channel_id.to_proto(),
5049                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5050            });
5051    }
5052
5053    for channel in channels.invited_channels {
5054        update.channel_invitations.push(channel.to_proto());
5055    }
5056
5057    update.hosted_projects = channels.hosted_projects;
5058    update
5059}
5060
5061fn build_initial_contacts_update(
5062    contacts: Vec<db::Contact>,
5063    pool: &ConnectionPool,
5064) -> proto::UpdateContacts {
5065    let mut update = proto::UpdateContacts::default();
5066
5067    for contact in contacts {
5068        match contact {
5069            db::Contact::Accepted { user_id, busy } => {
5070                update.contacts.push(contact_for_user(user_id, busy, pool));
5071            }
5072            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5073            db::Contact::Incoming { user_id } => {
5074                update
5075                    .incoming_requests
5076                    .push(proto::IncomingContactRequest {
5077                        requester_id: user_id.to_proto(),
5078                    })
5079            }
5080        }
5081    }
5082
5083    update
5084}
5085
5086fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5087    proto::Contact {
5088        user_id: user_id.to_proto(),
5089        online: pool.is_user_online(user_id),
5090        busy,
5091    }
5092}
5093
5094fn room_updated(room: &proto::Room, peer: &Peer) {
5095    broadcast(
5096        None,
5097        room.participants
5098            .iter()
5099            .filter_map(|participant| Some(participant.peer_id?.into())),
5100        |peer_id| {
5101            peer.send(
5102                peer_id,
5103                proto::RoomUpdated {
5104                    room: Some(room.clone()),
5105                },
5106            )
5107        },
5108    );
5109}
5110
5111fn channel_updated(
5112    channel: &db::channel::Model,
5113    room: &proto::Room,
5114    peer: &Peer,
5115    pool: &ConnectionPool,
5116) {
5117    let participants = room
5118        .participants
5119        .iter()
5120        .map(|p| p.user_id)
5121        .collect::<Vec<_>>();
5122
5123    broadcast(
5124        None,
5125        pool.channel_connection_ids(channel.root_id())
5126            .filter_map(|(channel_id, role)| {
5127                role.can_see_channel(channel.visibility)
5128                    .then_some(channel_id)
5129            }),
5130        |peer_id| {
5131            peer.send(
5132                peer_id,
5133                proto::UpdateChannels {
5134                    channel_participants: vec![proto::ChannelParticipants {
5135                        channel_id: channel.id.to_proto(),
5136                        participant_user_ids: participants.clone(),
5137                    }],
5138                    ..Default::default()
5139                },
5140            )
5141        },
5142    );
5143}
5144
5145async fn send_dev_server_projects_update(
5146    user_id: UserId,
5147    mut status: proto::DevServerProjectsUpdate,
5148    session: &Session,
5149) {
5150    let pool = session.connection_pool().await;
5151    for dev_server in &mut status.dev_servers {
5152        dev_server.status =
5153            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5154    }
5155    let connections = pool.user_connection_ids(user_id);
5156    for connection_id in connections {
5157        session.peer.send(connection_id, status.clone()).trace_err();
5158    }
5159}
5160
5161async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5162    let db = session.db().await;
5163
5164    let contacts = db.get_contacts(user_id).await?;
5165    let busy = db.is_user_busy(user_id).await?;
5166
5167    let pool = session.connection_pool().await;
5168    let updated_contact = contact_for_user(user_id, busy, &pool);
5169    for contact in contacts {
5170        if let db::Contact::Accepted {
5171            user_id: contact_user_id,
5172            ..
5173        } = contact
5174        {
5175            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5176                session
5177                    .peer
5178                    .send(
5179                        contact_conn_id,
5180                        proto::UpdateContacts {
5181                            contacts: vec![updated_contact.clone()],
5182                            remove_contacts: Default::default(),
5183                            incoming_requests: Default::default(),
5184                            remove_incoming_requests: Default::default(),
5185                            outgoing_requests: Default::default(),
5186                            remove_outgoing_requests: Default::default(),
5187                        },
5188                    )
5189                    .trace_err();
5190            }
5191        }
5192    }
5193    Ok(())
5194}
5195
5196async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5197    log::info!("lost dev server connection, unsharing projects");
5198    let project_ids = session
5199        .db()
5200        .await
5201        .get_stale_dev_server_projects(session.connection_id)
5202        .await?;
5203
5204    for project_id in project_ids {
5205        // not unshare re-checks the connection ids match, so we get away with no transaction
5206        unshare_project_internal(project_id, session.connection_id, None, session).await?;
5207    }
5208
5209    let user_id = session.dev_server().user_id;
5210    let update = session
5211        .db()
5212        .await
5213        .dev_server_projects_update(user_id)
5214        .await?;
5215
5216    send_dev_server_projects_update(user_id, update, session).await;
5217
5218    Ok(())
5219}
5220
5221async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5222    let mut contacts_to_update = HashSet::default();
5223
5224    let room_id;
5225    let canceled_calls_to_user_ids;
5226    let live_kit_room;
5227    let delete_live_kit_room;
5228    let room;
5229    let channel;
5230
5231    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5232        contacts_to_update.insert(session.user_id());
5233
5234        for project in left_room.left_projects.values() {
5235            project_left(project, session);
5236        }
5237
5238        room_id = RoomId::from_proto(left_room.room.id);
5239        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5240        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5241        delete_live_kit_room = left_room.deleted;
5242        room = mem::take(&mut left_room.room);
5243        channel = mem::take(&mut left_room.channel);
5244
5245        room_updated(&room, &session.peer);
5246    } else {
5247        return Ok(());
5248    }
5249
5250    if let Some(channel) = channel {
5251        channel_updated(
5252            &channel,
5253            &room,
5254            &session.peer,
5255            &*session.connection_pool().await,
5256        );
5257    }
5258
5259    {
5260        let pool = session.connection_pool().await;
5261        for canceled_user_id in canceled_calls_to_user_ids {
5262            for connection_id in pool.user_connection_ids(canceled_user_id) {
5263                session
5264                    .peer
5265                    .send(
5266                        connection_id,
5267                        proto::CallCanceled {
5268                            room_id: room_id.to_proto(),
5269                        },
5270                    )
5271                    .trace_err();
5272            }
5273            contacts_to_update.insert(canceled_user_id);
5274        }
5275    }
5276
5277    for contact_user_id in contacts_to_update {
5278        update_user_contacts(contact_user_id, session).await?;
5279    }
5280
5281    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5282        live_kit
5283            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5284            .await
5285            .trace_err();
5286
5287        if delete_live_kit_room {
5288            live_kit.delete_room(live_kit_room).await.trace_err();
5289        }
5290    }
5291
5292    Ok(())
5293}
5294
5295async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5296    let left_channel_buffers = session
5297        .db()
5298        .await
5299        .leave_channel_buffers(session.connection_id)
5300        .await?;
5301
5302    for left_buffer in left_channel_buffers {
5303        channel_buffer_updated(
5304            session.connection_id,
5305            left_buffer.connections,
5306            &proto::UpdateChannelBufferCollaborators {
5307                channel_id: left_buffer.channel_id.to_proto(),
5308                collaborators: left_buffer.collaborators,
5309            },
5310            &session.peer,
5311        );
5312    }
5313
5314    Ok(())
5315}
5316
5317fn project_left(project: &db::LeftProject, session: &UserSession) {
5318    for connection_id in &project.connection_ids {
5319        if project.should_unshare {
5320            session
5321                .peer
5322                .send(
5323                    *connection_id,
5324                    proto::UnshareProject {
5325                        project_id: project.id.to_proto(),
5326                    },
5327                )
5328                .trace_err();
5329        } else {
5330            session
5331                .peer
5332                .send(
5333                    *connection_id,
5334                    proto::RemoveProjectCollaborator {
5335                        project_id: project.id.to_proto(),
5336                        peer_id: Some(session.connection_id.into()),
5337                    },
5338                )
5339                .trace_err();
5340        }
5341    }
5342}
5343
5344pub trait ResultExt {
5345    type Ok;
5346
5347    fn trace_err(self) -> Option<Self::Ok>;
5348}
5349
5350impl<T, E> ResultExt for Result<T, E>
5351where
5352    E: std::fmt::Debug,
5353{
5354    type Ok = T;
5355
5356    #[track_caller]
5357    fn trace_err(self) -> Option<T> {
5358        match self {
5359            Ok(value) => Some(value),
5360            Err(error) => {
5361                tracing::error!("{:?}", error);
5362                None
5363            }
5364        }
5365    }
5366}