rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use sha2::Digest;
  40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  41
  42use futures::{
  43    channel::oneshot,
  44    future::{self, BoxFuture},
  45    stream::FuturesUnordered,
  46    FutureExt, SinkExt, StreamExt, TryStreamExt,
  47};
  48use http_client::IsahcHttpClient;
  49use prometheus::{register_int_gauge, IntGauge};
  50use rpc::{
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  56};
  57use semantic_version::SemanticVersion;
  58use serde::{Serialize, Serializer};
  59use std::{
  60    any::TypeId,
  61    future::Future,
  62    marker::PhantomData,
  63    mem,
  64    net::SocketAddr,
  65    ops::{Deref, DerefMut},
  66    rc::Rc,
  67    sync::{
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69        Arc, OnceLock,
  70    },
  71    time::{Duration, Instant},
  72};
  73use time::OffsetDateTime;
  74use tokio::sync::{watch, MutexGuard, Semaphore};
  75use tower::ServiceBuilder;
  76use tracing::{
  77    field::{self},
  78    info_span, instrument, Instrument,
  79};
  80
  81use self::connection_pool::VersionedMessage;
  82
  83pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  84
  85// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  86pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  87
  88const MESSAGE_COUNT_PER_PAGE: usize = 100;
  89const MAX_MESSAGE_LEN: usize = 1024;
  90const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  91
  92type MessageHandler =
  93    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  94
  95struct Response<R> {
  96    peer: Arc<Peer>,
  97    receipt: Receipt<R>,
  98    responded: Arc<AtomicBool>,
  99}
 100
 101impl<R: RequestMessage> Response<R> {
 102    fn send(self, payload: R::Response) -> Result<()> {
 103        self.responded.store(true, SeqCst);
 104        self.peer.respond(self.receipt, payload)?;
 105        Ok(())
 106    }
 107}
 108
 109#[derive(Clone, Debug)]
 110pub enum Principal {
 111    User(User),
 112    Impersonated { user: User, admin: User },
 113    DevServer(dev_server::Model),
 114}
 115
 116impl Principal {
 117    fn update_span(&self, span: &tracing::Span) {
 118        match &self {
 119            Principal::User(user) => {
 120                span.record("user_id", &user.id.0);
 121                span.record("login", &user.github_login);
 122            }
 123            Principal::Impersonated { user, admin } => {
 124                span.record("user_id", &user.id.0);
 125                span.record("login", &user.github_login);
 126                span.record("impersonator", &admin.github_login);
 127            }
 128            Principal::DevServer(dev_server) => {
 129                span.record("dev_server_id", &dev_server.id.0);
 130            }
 131        }
 132    }
 133}
 134
 135#[derive(Clone)]
 136struct Session {
 137    principal: Principal,
 138    connection_id: ConnectionId,
 139    db: Arc<tokio::sync::Mutex<DbHandle>>,
 140    peer: Arc<Peer>,
 141    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 142    app_state: Arc<AppState>,
 143    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 144    http_client: Arc<IsahcHttpClient>,
 145    /// The GeoIP country code for the user.
 146    #[allow(unused)]
 147    geoip_country_code: Option<String>,
 148    _executor: Executor,
 149}
 150
 151impl Session {
 152    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.db.lock().await;
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        guard
 159    }
 160
 161    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 162        #[cfg(test)]
 163        tokio::task::yield_now().await;
 164        let guard = self.connection_pool.lock();
 165        ConnectionPoolGuard {
 166            guard,
 167            _not_send: PhantomData,
 168        }
 169    }
 170
 171    fn for_user(self) -> Option<UserSession> {
 172        UserSession::new(self)
 173    }
 174
 175    fn for_dev_server(self) -> Option<DevServerSession> {
 176        DevServerSession::new(self)
 177    }
 178
 179    fn user_id(&self) -> Option<UserId> {
 180        match &self.principal {
 181            Principal::User(user) => Some(user.id),
 182            Principal::Impersonated { user, .. } => Some(user.id),
 183            Principal::DevServer(_) => None,
 184        }
 185    }
 186
 187    fn is_staff(&self) -> bool {
 188        match &self.principal {
 189            Principal::User(user) => user.admin,
 190            Principal::Impersonated { .. } => true,
 191            Principal::DevServer(_) => false,
 192        }
 193    }
 194
 195    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 196        if self.is_staff() {
 197            return Ok(proto::Plan::ZedPro);
 198        }
 199
 200        let Some(user_id) = self.user_id() else {
 201            return Ok(proto::Plan::Free);
 202        };
 203
 204        if db.has_active_billing_subscription(user_id).await? {
 205            Ok(proto::Plan::ZedPro)
 206        } else {
 207            Ok(proto::Plan::Free)
 208        }
 209    }
 210
 211    fn dev_server_id(&self) -> Option<DevServerId> {
 212        match &self.principal {
 213            Principal::User(_) | Principal::Impersonated { .. } => None,
 214            Principal::DevServer(dev_server) => Some(dev_server.id),
 215        }
 216    }
 217
 218    fn principal_id(&self) -> PrincipalId {
 219        match &self.principal {
 220            Principal::User(user) => PrincipalId::UserId(user.id),
 221            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 222            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 223        }
 224    }
 225}
 226
 227impl Debug for Session {
 228    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 229        let mut result = f.debug_struct("Session");
 230        match &self.principal {
 231            Principal::User(user) => {
 232                result.field("user", &user.github_login);
 233            }
 234            Principal::Impersonated { user, admin } => {
 235                result.field("user", &user.github_login);
 236                result.field("impersonator", &admin.github_login);
 237            }
 238            Principal::DevServer(dev_server) => {
 239                result.field("dev_server", &dev_server.id);
 240            }
 241        }
 242        result.field("connection_id", &self.connection_id).finish()
 243    }
 244}
 245
 246struct UserSession(Session);
 247
 248impl UserSession {
 249    pub fn new(s: Session) -> Option<Self> {
 250        s.user_id().map(|_| UserSession(s))
 251    }
 252    pub fn user_id(&self) -> UserId {
 253        self.0.user_id().unwrap()
 254    }
 255
 256    pub fn email(&self) -> Option<String> {
 257        match &self.0.principal {
 258            Principal::User(user) => user.email_address.clone(),
 259            Principal::Impersonated { user, .. } => user.email_address.clone(),
 260            Principal::DevServer(..) => None,
 261        }
 262    }
 263}
 264
 265impl Deref for UserSession {
 266    type Target = Session;
 267
 268    fn deref(&self) -> &Self::Target {
 269        &self.0
 270    }
 271}
 272impl DerefMut for UserSession {
 273    fn deref_mut(&mut self) -> &mut Self::Target {
 274        &mut self.0
 275    }
 276}
 277
 278struct DevServerSession(Session);
 279
 280impl DevServerSession {
 281    pub fn new(s: Session) -> Option<Self> {
 282        s.dev_server_id().map(|_| DevServerSession(s))
 283    }
 284    pub fn dev_server_id(&self) -> DevServerId {
 285        self.0.dev_server_id().unwrap()
 286    }
 287
 288    fn dev_server(&self) -> &dev_server::Model {
 289        match &self.0.principal {
 290            Principal::DevServer(dev_server) => dev_server,
 291            _ => unreachable!(),
 292        }
 293    }
 294}
 295
 296impl Deref for DevServerSession {
 297    type Target = Session;
 298
 299    fn deref(&self) -> &Self::Target {
 300        &self.0
 301    }
 302}
 303impl DerefMut for DevServerSession {
 304    fn deref_mut(&mut self) -> &mut Self::Target {
 305        &mut self.0
 306    }
 307}
 308
 309fn user_handler<M: RequestMessage, Fut>(
 310    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 311) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 312where
 313    Fut: Send + Future<Output = Result<()>>,
 314{
 315    let handler = Arc::new(handler);
 316    move |message, response, session| {
 317        let handler = handler.clone();
 318        Box::pin(async move {
 319            if let Some(user_session) = session.for_user() {
 320                Ok(handler(message, response, user_session).await?)
 321            } else {
 322                Err(Error::Internal(anyhow!(
 323                    "must be a user to call {}",
 324                    M::NAME
 325                )))
 326            }
 327        })
 328    }
 329}
 330
 331fn dev_server_handler<M: RequestMessage, Fut>(
 332    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 333) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 334where
 335    Fut: Send + Future<Output = Result<()>>,
 336{
 337    let handler = Arc::new(handler);
 338    move |message, response, session| {
 339        let handler = handler.clone();
 340        Box::pin(async move {
 341            if let Some(dev_server_session) = session.for_dev_server() {
 342                Ok(handler(message, response, dev_server_session).await?)
 343            } else {
 344                Err(Error::Internal(anyhow!(
 345                    "must be a dev server to call {}",
 346                    M::NAME
 347                )))
 348            }
 349        })
 350    }
 351}
 352
 353fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 354    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 355) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 356where
 357    InnertRetFut: Send + Future<Output = Result<()>>,
 358{
 359    let handler = Arc::new(handler);
 360    move |message, session| {
 361        let handler = handler.clone();
 362        Box::pin(async move {
 363            if let Some(user_session) = session.for_user() {
 364                Ok(handler(message, user_session).await?)
 365            } else {
 366                Err(Error::Internal(anyhow!(
 367                    "must be a user to call {}",
 368                    M::NAME
 369                )))
 370            }
 371        })
 372    }
 373}
 374
 375struct DbHandle(Arc<Database>);
 376
 377impl Deref for DbHandle {
 378    type Target = Database;
 379
 380    fn deref(&self) -> &Self::Target {
 381        self.0.as_ref()
 382    }
 383}
 384
 385pub struct Server {
 386    id: parking_lot::Mutex<ServerId>,
 387    peer: Arc<Peer>,
 388    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 389    app_state: Arc<AppState>,
 390    handlers: HashMap<TypeId, MessageHandler>,
 391    teardown: watch::Sender<bool>,
 392}
 393
 394pub(crate) struct ConnectionPoolGuard<'a> {
 395    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 396    _not_send: PhantomData<Rc<()>>,
 397}
 398
 399#[derive(Serialize)]
 400pub struct ServerSnapshot<'a> {
 401    peer: &'a Peer,
 402    #[serde(serialize_with = "serialize_deref")]
 403    connection_pool: ConnectionPoolGuard<'a>,
 404}
 405
 406pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 407where
 408    S: Serializer,
 409    T: Deref<Target = U>,
 410    U: Serialize,
 411{
 412    Serialize::serialize(value.deref(), serializer)
 413}
 414
 415impl Server {
 416    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 417        let mut server = Self {
 418            id: parking_lot::Mutex::new(id),
 419            peer: Peer::new(id.0 as u32),
 420            app_state: app_state.clone(),
 421            connection_pool: Default::default(),
 422            handlers: Default::default(),
 423            teardown: watch::channel(false).0,
 424        };
 425
 426        server
 427            .add_request_handler(ping)
 428            .add_request_handler(user_handler(create_room))
 429            .add_request_handler(user_handler(join_room))
 430            .add_request_handler(user_handler(rejoin_room))
 431            .add_request_handler(user_handler(leave_room))
 432            .add_request_handler(user_handler(set_room_participant_role))
 433            .add_request_handler(user_handler(call))
 434            .add_request_handler(user_handler(cancel_call))
 435            .add_message_handler(user_message_handler(decline_call))
 436            .add_request_handler(user_handler(update_participant_location))
 437            .add_request_handler(user_handler(share_project))
 438            .add_message_handler(unshare_project)
 439            .add_request_handler(user_handler(join_project))
 440            .add_request_handler(user_handler(join_hosted_project))
 441            .add_request_handler(user_handler(rejoin_dev_server_projects))
 442            .add_request_handler(user_handler(create_dev_server_project))
 443            .add_request_handler(user_handler(update_dev_server_project))
 444            .add_request_handler(user_handler(delete_dev_server_project))
 445            .add_request_handler(user_handler(create_dev_server))
 446            .add_request_handler(user_handler(regenerate_dev_server_token))
 447            .add_request_handler(user_handler(rename_dev_server))
 448            .add_request_handler(user_handler(delete_dev_server))
 449            .add_request_handler(user_handler(list_remote_directory))
 450            .add_request_handler(dev_server_handler(share_dev_server_project))
 451            .add_request_handler(dev_server_handler(shutdown_dev_server))
 452            .add_request_handler(dev_server_handler(reconnect_dev_server))
 453            .add_message_handler(user_message_handler(leave_project))
 454            .add_request_handler(update_project)
 455            .add_request_handler(update_worktree)
 456            .add_message_handler(start_language_server)
 457            .add_message_handler(update_language_server)
 458            .add_message_handler(update_diagnostic_summary)
 459            .add_message_handler(update_worktree_settings)
 460            .add_request_handler(user_handler(
 461                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 462            ))
 463            .add_request_handler(user_handler(
 464                forward_project_request_for_owner::<proto::TaskTemplates>,
 465            ))
 466            .add_request_handler(user_handler(
 467                forward_read_only_project_request::<proto::GetHover>,
 468            ))
 469            .add_request_handler(user_handler(
 470                forward_read_only_project_request::<proto::GetDefinition>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetTypeDefinition>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetReferences>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::SearchProject>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::GetProjectSymbols>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::OpenBufferById>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_read_only_project_request::<proto::InlayHints>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_read_only_project_request::<proto::OpenBufferByPath>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::GetCompletions>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::GetCodeActions>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ApplyCodeAction>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::PrepareRename>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::PerformRename>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::ReloadBuffers>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::FormatBuffers>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::CreateProjectEntry>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::RenameProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::CopyProjectEntry>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::OnTypeFormatting>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::BlameBuffer>,
 555            ))
 556            .add_request_handler(user_handler(
 557                forward_mutating_project_request::<proto::MultiLspQuery>,
 558            ))
 559            .add_request_handler(user_handler(
 560                forward_mutating_project_request::<proto::RestartLanguageServers>,
 561            ))
 562            .add_request_handler(user_handler(
 563                forward_mutating_project_request::<proto::LinkedEditingRange>,
 564            ))
 565            .add_message_handler(create_buffer_for_peer)
 566            .add_request_handler(update_buffer)
 567            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 568            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 569            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 571            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 572            .add_request_handler(get_users)
 573            .add_request_handler(user_handler(fuzzy_search_users))
 574            .add_request_handler(user_handler(request_contact))
 575            .add_request_handler(user_handler(remove_contact))
 576            .add_request_handler(user_handler(respond_to_contact_request))
 577            .add_message_handler(subscribe_to_channels)
 578            .add_request_handler(user_handler(create_channel))
 579            .add_request_handler(user_handler(delete_channel))
 580            .add_request_handler(user_handler(invite_channel_member))
 581            .add_request_handler(user_handler(remove_channel_member))
 582            .add_request_handler(user_handler(set_channel_member_role))
 583            .add_request_handler(user_handler(set_channel_visibility))
 584            .add_request_handler(user_handler(rename_channel))
 585            .add_request_handler(user_handler(join_channel_buffer))
 586            .add_request_handler(user_handler(leave_channel_buffer))
 587            .add_message_handler(user_message_handler(update_channel_buffer))
 588            .add_request_handler(user_handler(rejoin_channel_buffers))
 589            .add_request_handler(user_handler(get_channel_members))
 590            .add_request_handler(user_handler(respond_to_channel_invite))
 591            .add_request_handler(user_handler(join_channel))
 592            .add_request_handler(user_handler(join_channel_chat))
 593            .add_message_handler(user_message_handler(leave_channel_chat))
 594            .add_request_handler(user_handler(send_channel_message))
 595            .add_request_handler(user_handler(remove_channel_message))
 596            .add_request_handler(user_handler(update_channel_message))
 597            .add_request_handler(user_handler(get_channel_messages))
 598            .add_request_handler(user_handler(get_channel_messages_by_id))
 599            .add_request_handler(user_handler(get_notifications))
 600            .add_request_handler(user_handler(mark_notification_as_read))
 601            .add_request_handler(user_handler(move_channel))
 602            .add_request_handler(user_handler(follow))
 603            .add_message_handler(user_message_handler(unfollow))
 604            .add_message_handler(user_message_handler(update_followers))
 605            .add_request_handler(user_handler(get_private_user_info))
 606            .add_request_handler(user_handler(get_llm_api_token))
 607            .add_request_handler(user_handler(accept_terms_of_service))
 608            .add_message_handler(user_message_handler(acknowledge_channel_message))
 609            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 610            .add_request_handler(user_handler(get_supermaven_api_key))
 611            .add_request_handler(user_handler(
 612                forward_mutating_project_request::<proto::OpenContext>,
 613            ))
 614            .add_request_handler(user_handler(
 615                forward_mutating_project_request::<proto::CreateContext>,
 616            ))
 617            .add_request_handler(user_handler(
 618                forward_mutating_project_request::<proto::SynchronizeContexts>,
 619            ))
 620            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 621            .add_message_handler(update_context)
 622            .add_request_handler({
 623                let app_state = app_state.clone();
 624                move |request, response, session| {
 625                    let app_state = app_state.clone();
 626                    async move {
 627                        count_language_model_tokens(request, response, session, &app_state.config)
 628                            .await
 629                    }
 630                }
 631            })
 632            .add_request_handler({
 633                user_handler(move |request, response, session| {
 634                    get_cached_embeddings(request, response, session)
 635                })
 636            })
 637            .add_request_handler({
 638                let app_state = app_state.clone();
 639                user_handler(move |request, response, session| {
 640                    compute_embeddings(
 641                        request,
 642                        response,
 643                        session,
 644                        app_state.config.openai_api_key.clone(),
 645                    )
 646                })
 647            });
 648
 649        Arc::new(server)
 650    }
 651
 652    pub async fn start(&self) -> Result<()> {
 653        let server_id = *self.id.lock();
 654        let app_state = self.app_state.clone();
 655        let peer = self.peer.clone();
 656        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 657        let pool = self.connection_pool.clone();
 658        let live_kit_client = self.app_state.live_kit_client.clone();
 659
 660        let span = info_span!("start server");
 661        self.app_state.executor.spawn_detached(
 662            async move {
 663                tracing::info!("waiting for cleanup timeout");
 664                timeout.await;
 665                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 666                if let Some((room_ids, channel_ids)) = app_state
 667                    .db
 668                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 669                    .await
 670                    .trace_err()
 671                {
 672                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 673                    tracing::info!(
 674                        stale_channel_buffer_count = channel_ids.len(),
 675                        "retrieved stale channel buffers"
 676                    );
 677
 678                    for channel_id in channel_ids {
 679                        if let Some(refreshed_channel_buffer) = app_state
 680                            .db
 681                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 682                            .await
 683                            .trace_err()
 684                        {
 685                            for connection_id in refreshed_channel_buffer.connection_ids {
 686                                peer.send(
 687                                    connection_id,
 688                                    proto::UpdateChannelBufferCollaborators {
 689                                        channel_id: channel_id.to_proto(),
 690                                        collaborators: refreshed_channel_buffer
 691                                            .collaborators
 692                                            .clone(),
 693                                    },
 694                                )
 695                                .trace_err();
 696                            }
 697                        }
 698                    }
 699
 700                    for room_id in room_ids {
 701                        let mut contacts_to_update = HashSet::default();
 702                        let mut canceled_calls_to_user_ids = Vec::new();
 703                        let mut live_kit_room = String::new();
 704                        let mut delete_live_kit_room = false;
 705
 706                        if let Some(mut refreshed_room) = app_state
 707                            .db
 708                            .clear_stale_room_participants(room_id, server_id)
 709                            .await
 710                            .trace_err()
 711                        {
 712                            tracing::info!(
 713                                room_id = room_id.0,
 714                                new_participant_count = refreshed_room.room.participants.len(),
 715                                "refreshed room"
 716                            );
 717                            room_updated(&refreshed_room.room, &peer);
 718                            if let Some(channel) = refreshed_room.channel.as_ref() {
 719                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 720                            }
 721                            contacts_to_update
 722                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 723                            contacts_to_update
 724                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 725                            canceled_calls_to_user_ids =
 726                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 727                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 728                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 729                        }
 730
 731                        {
 732                            let pool = pool.lock();
 733                            for canceled_user_id in canceled_calls_to_user_ids {
 734                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 735                                    peer.send(
 736                                        connection_id,
 737                                        proto::CallCanceled {
 738                                            room_id: room_id.to_proto(),
 739                                        },
 740                                    )
 741                                    .trace_err();
 742                                }
 743                            }
 744                        }
 745
 746                        for user_id in contacts_to_update {
 747                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 748                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 749                            if let Some((busy, contacts)) = busy.zip(contacts) {
 750                                let pool = pool.lock();
 751                                let updated_contact = contact_for_user(user_id, busy, &pool);
 752                                for contact in contacts {
 753                                    if let db::Contact::Accepted {
 754                                        user_id: contact_user_id,
 755                                        ..
 756                                    } = contact
 757                                    {
 758                                        for contact_conn_id in
 759                                            pool.user_connection_ids(contact_user_id)
 760                                        {
 761                                            peer.send(
 762                                                contact_conn_id,
 763                                                proto::UpdateContacts {
 764                                                    contacts: vec![updated_contact.clone()],
 765                                                    remove_contacts: Default::default(),
 766                                                    incoming_requests: Default::default(),
 767                                                    remove_incoming_requests: Default::default(),
 768                                                    outgoing_requests: Default::default(),
 769                                                    remove_outgoing_requests: Default::default(),
 770                                                },
 771                                            )
 772                                            .trace_err();
 773                                        }
 774                                    }
 775                                }
 776                            }
 777                        }
 778
 779                        if let Some(live_kit) = live_kit_client.as_ref() {
 780                            if delete_live_kit_room {
 781                                live_kit.delete_room(live_kit_room).await.trace_err();
 782                            }
 783                        }
 784                    }
 785                }
 786
 787                app_state
 788                    .db
 789                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 790                    .await
 791                    .trace_err();
 792            }
 793            .instrument(span),
 794        );
 795        Ok(())
 796    }
 797
 798    pub fn teardown(&self) {
 799        self.peer.teardown();
 800        self.connection_pool.lock().reset();
 801        let _ = self.teardown.send(true);
 802    }
 803
 804    #[cfg(test)]
 805    pub fn reset(&self, id: ServerId) {
 806        self.teardown();
 807        *self.id.lock() = id;
 808        self.peer.reset(id.0 as u32);
 809        let _ = self.teardown.send(false);
 810    }
 811
 812    #[cfg(test)]
 813    pub fn id(&self) -> ServerId {
 814        *self.id.lock()
 815    }
 816
 817    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 818    where
 819        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 820        Fut: 'static + Send + Future<Output = Result<()>>,
 821        M: EnvelopedMessage,
 822    {
 823        let prev_handler = self.handlers.insert(
 824            TypeId::of::<M>(),
 825            Box::new(move |envelope, session| {
 826                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 827                let received_at = envelope.received_at;
 828                tracing::info!("message received");
 829                let start_time = Instant::now();
 830                let future = (handler)(*envelope, session);
 831                async move {
 832                    let result = future.await;
 833                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 834                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 835                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 836                    let payload_type = M::NAME;
 837
 838                    match result {
 839                        Err(error) => {
 840                            tracing::error!(
 841                                ?error,
 842                                total_duration_ms,
 843                                processing_duration_ms,
 844                                queue_duration_ms,
 845                                payload_type,
 846                                "error handling message"
 847                            )
 848                        }
 849                        Ok(()) => tracing::info!(
 850                            total_duration_ms,
 851                            processing_duration_ms,
 852                            queue_duration_ms,
 853                            "finished handling message"
 854                        ),
 855                    }
 856                }
 857                .boxed()
 858            }),
 859        );
 860        if prev_handler.is_some() {
 861            panic!("registered a handler for the same message twice");
 862        }
 863        self
 864    }
 865
 866    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 867    where
 868        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 869        Fut: 'static + Send + Future<Output = Result<()>>,
 870        M: EnvelopedMessage,
 871    {
 872        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 873        self
 874    }
 875
 876    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 877    where
 878        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 879        Fut: Send + Future<Output = Result<()>>,
 880        M: RequestMessage,
 881    {
 882        let handler = Arc::new(handler);
 883        self.add_handler(move |envelope, session| {
 884            let receipt = envelope.receipt();
 885            let handler = handler.clone();
 886            async move {
 887                let peer = session.peer.clone();
 888                let responded = Arc::new(AtomicBool::default());
 889                let response = Response {
 890                    peer: peer.clone(),
 891                    responded: responded.clone(),
 892                    receipt,
 893                };
 894                match (handler)(envelope.payload, response, session).await {
 895                    Ok(()) => {
 896                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 897                            Ok(())
 898                        } else {
 899                            Err(anyhow!("handler did not send a response"))?
 900                        }
 901                    }
 902                    Err(error) => {
 903                        let proto_err = match &error {
 904                            Error::Internal(err) => err.to_proto(),
 905                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 906                        };
 907                        peer.respond_with_error(receipt, proto_err)?;
 908                        Err(error)
 909                    }
 910                }
 911            }
 912        })
 913    }
 914
 915    #[allow(clippy::too_many_arguments)]
 916    pub fn handle_connection(
 917        self: &Arc<Self>,
 918        connection: Connection,
 919        address: String,
 920        principal: Principal,
 921        zed_version: ZedVersion,
 922        geoip_country_code: Option<String>,
 923        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 924        executor: Executor,
 925    ) -> impl Future<Output = ()> {
 926        let this = self.clone();
 927        let span = info_span!("handle connection", %address,
 928            connection_id=field::Empty,
 929            user_id=field::Empty,
 930            login=field::Empty,
 931            impersonator=field::Empty,
 932            dev_server_id=field::Empty,
 933            geoip_country_code=field::Empty
 934        );
 935        principal.update_span(&span);
 936        if let Some(country_code) = geoip_country_code.as_ref() {
 937            span.record("geoip_country_code", country_code);
 938        }
 939
 940        let mut teardown = self.teardown.subscribe();
 941        async move {
 942            if *teardown.borrow() {
 943                tracing::error!("server is tearing down");
 944                return
 945            }
 946            let (connection_id, handle_io, mut incoming_rx) = this
 947                .peer
 948                .add_connection(connection, {
 949                    let executor = executor.clone();
 950                    move |duration| executor.sleep(duration)
 951                });
 952            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 953
 954            tracing::info!("connection opened");
 955
 956            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 957            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 958                Ok(http_client) => Arc::new(http_client),
 959                Err(error) => {
 960                    tracing::error!(?error, "failed to create HTTP client");
 961                    return;
 962                }
 963            };
 964
 965            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 966                Some(Arc::new(SupermavenAdminApi::new(
 967                    supermaven_admin_api_key.to_string(),
 968                    http_client.clone(),
 969                )))
 970            } else {
 971                None
 972            };
 973
 974            let session = Session {
 975                principal: principal.clone(),
 976                connection_id,
 977                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 978                peer: this.peer.clone(),
 979                connection_pool: this.connection_pool.clone(),
 980                app_state: this.app_state.clone(),
 981                http_client,
 982                geoip_country_code,
 983                _executor: executor.clone(),
 984                supermaven_client,
 985            };
 986
 987            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 988                tracing::error!(?error, "failed to send initial client update");
 989                return;
 990            }
 991
 992            let handle_io = handle_io.fuse();
 993            futures::pin_mut!(handle_io);
 994
 995            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 996            // This prevents deadlocks when e.g., client A performs a request to client B and
 997            // client B performs a request to client A. If both clients stop processing further
 998            // messages until their respective request completes, they won't have a chance to
 999            // respond to the other client's request and cause a deadlock.
1000            //
1001            // This arrangement ensures we will attempt to process earlier messages first, but fall
1002            // back to processing messages arrived later in the spirit of making progress.
1003            let mut foreground_message_handlers = FuturesUnordered::new();
1004            let concurrent_handlers = Arc::new(Semaphore::new(256));
1005            loop {
1006                let next_message = async {
1007                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1008                    let message = incoming_rx.next().await;
1009                    (permit, message)
1010                }.fuse();
1011                futures::pin_mut!(next_message);
1012                futures::select_biased! {
1013                    _ = teardown.changed().fuse() => return,
1014                    result = handle_io => {
1015                        if let Err(error) = result {
1016                            tracing::error!(?error, "error handling I/O");
1017                        }
1018                        break;
1019                    }
1020                    _ = foreground_message_handlers.next() => {}
1021                    next_message = next_message => {
1022                        let (permit, message) = next_message;
1023                        if let Some(message) = message {
1024                            let type_name = message.payload_type_name();
1025                            // note: we copy all the fields from the parent span so we can query them in the logs.
1026                            // (https://github.com/tokio-rs/tracing/issues/2670).
1027                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1028                                user_id=field::Empty,
1029                                login=field::Empty,
1030                                impersonator=field::Empty,
1031                                dev_server_id=field::Empty
1032                            );
1033                            principal.update_span(&span);
1034                            let span_enter = span.enter();
1035                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1036                                let is_background = message.is_background();
1037                                let handle_message = (handler)(message, session.clone());
1038                                drop(span_enter);
1039
1040                                let handle_message = async move {
1041                                    handle_message.await;
1042                                    drop(permit);
1043                                }.instrument(span);
1044                                if is_background {
1045                                    executor.spawn_detached(handle_message);
1046                                } else {
1047                                    foreground_message_handlers.push(handle_message);
1048                                }
1049                            } else {
1050                                tracing::error!("no message handler");
1051                            }
1052                        } else {
1053                            tracing::info!("connection closed");
1054                            break;
1055                        }
1056                    }
1057                }
1058            }
1059
1060            drop(foreground_message_handlers);
1061            tracing::info!("signing out");
1062            if let Err(error) = connection_lost(session, teardown, executor).await {
1063                tracing::error!(?error, "error signing out");
1064            }
1065
1066        }.instrument(span)
1067    }
1068
1069    async fn send_initial_client_update(
1070        &self,
1071        connection_id: ConnectionId,
1072        principal: &Principal,
1073        zed_version: ZedVersion,
1074        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1075        session: &Session,
1076    ) -> Result<()> {
1077        self.peer.send(
1078            connection_id,
1079            proto::Hello {
1080                peer_id: Some(connection_id.into()),
1081            },
1082        )?;
1083        tracing::info!("sent hello message");
1084        if let Some(send_connection_id) = send_connection_id.take() {
1085            let _ = send_connection_id.send(connection_id);
1086        }
1087
1088        match principal {
1089            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1090                if !user.connected_once {
1091                    self.peer.send(connection_id, proto::ShowContacts {})?;
1092                    self.app_state
1093                        .db
1094                        .set_user_connected_once(user.id, true)
1095                        .await?;
1096                }
1097
1098                update_user_plan(user.id, session).await?;
1099
1100                let (contacts, dev_server_projects) = future::try_join(
1101                    self.app_state.db.get_contacts(user.id),
1102                    self.app_state.db.dev_server_projects_update(user.id),
1103                )
1104                .await?;
1105
1106                {
1107                    let mut pool = self.connection_pool.lock();
1108                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1109                    self.peer.send(
1110                        connection_id,
1111                        build_initial_contacts_update(contacts, &pool),
1112                    )?;
1113                }
1114
1115                if should_auto_subscribe_to_channels(zed_version) {
1116                    subscribe_user_to_channels(user.id, session).await?;
1117                }
1118
1119                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1120
1121                if let Some(incoming_call) =
1122                    self.app_state.db.incoming_call_for_user(user.id).await?
1123                {
1124                    self.peer.send(connection_id, incoming_call)?;
1125                }
1126
1127                update_user_contacts(user.id, &session).await?;
1128            }
1129            Principal::DevServer(dev_server) => {
1130                {
1131                    let mut pool = self.connection_pool.lock();
1132                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1133                    {
1134                        self.peer.send(
1135                            stale_connection_id,
1136                            proto::ShutdownDevServer {
1137                                reason: Some(
1138                                    "another dev server connected with the same token".to_string(),
1139                                ),
1140                            },
1141                        )?;
1142                        pool.remove_connection(stale_connection_id)?;
1143                    };
1144                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1145                }
1146
1147                let projects = self
1148                    .app_state
1149                    .db
1150                    .get_projects_for_dev_server(dev_server.id)
1151                    .await?;
1152                self.peer
1153                    .send(connection_id, proto::DevServerInstructions { projects })?;
1154
1155                let status = self
1156                    .app_state
1157                    .db
1158                    .dev_server_projects_update(dev_server.user_id)
1159                    .await?;
1160                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1161            }
1162        }
1163
1164        Ok(())
1165    }
1166
1167    pub async fn invite_code_redeemed(
1168        self: &Arc<Self>,
1169        inviter_id: UserId,
1170        invitee_id: UserId,
1171    ) -> Result<()> {
1172        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1173            if let Some(code) = &user.invite_code {
1174                let pool = self.connection_pool.lock();
1175                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1176                for connection_id in pool.user_connection_ids(inviter_id) {
1177                    self.peer.send(
1178                        connection_id,
1179                        proto::UpdateContacts {
1180                            contacts: vec![invitee_contact.clone()],
1181                            ..Default::default()
1182                        },
1183                    )?;
1184                    self.peer.send(
1185                        connection_id,
1186                        proto::UpdateInviteInfo {
1187                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1188                            count: user.invite_count as u32,
1189                        },
1190                    )?;
1191                }
1192            }
1193        }
1194        Ok(())
1195    }
1196
1197    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1198        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1199            if let Some(invite_code) = &user.invite_code {
1200                let pool = self.connection_pool.lock();
1201                for connection_id in pool.user_connection_ids(user_id) {
1202                    self.peer.send(
1203                        connection_id,
1204                        proto::UpdateInviteInfo {
1205                            url: format!(
1206                                "{}{}",
1207                                self.app_state.config.invite_link_prefix, invite_code
1208                            ),
1209                            count: user.invite_count as u32,
1210                        },
1211                    )?;
1212                }
1213            }
1214        }
1215        Ok(())
1216    }
1217
1218    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1219        ServerSnapshot {
1220            connection_pool: ConnectionPoolGuard {
1221                guard: self.connection_pool.lock(),
1222                _not_send: PhantomData,
1223            },
1224            peer: &self.peer,
1225        }
1226    }
1227}
1228
1229impl<'a> Deref for ConnectionPoolGuard<'a> {
1230    type Target = ConnectionPool;
1231
1232    fn deref(&self) -> &Self::Target {
1233        &self.guard
1234    }
1235}
1236
1237impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1238    fn deref_mut(&mut self) -> &mut Self::Target {
1239        &mut self.guard
1240    }
1241}
1242
1243impl<'a> Drop for ConnectionPoolGuard<'a> {
1244    fn drop(&mut self) {
1245        #[cfg(test)]
1246        self.check_invariants();
1247    }
1248}
1249
1250fn broadcast<F>(
1251    sender_id: Option<ConnectionId>,
1252    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1253    mut f: F,
1254) where
1255    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1256{
1257    for receiver_id in receiver_ids {
1258        if Some(receiver_id) != sender_id {
1259            if let Err(error) = f(receiver_id) {
1260                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1261            }
1262        }
1263    }
1264}
1265
1266pub struct ProtocolVersion(u32);
1267
1268impl Header for ProtocolVersion {
1269    fn name() -> &'static HeaderName {
1270        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1271        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1272    }
1273
1274    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1275    where
1276        Self: Sized,
1277        I: Iterator<Item = &'i axum::http::HeaderValue>,
1278    {
1279        let version = values
1280            .next()
1281            .ok_or_else(axum::headers::Error::invalid)?
1282            .to_str()
1283            .map_err(|_| axum::headers::Error::invalid())?
1284            .parse()
1285            .map_err(|_| axum::headers::Error::invalid())?;
1286        Ok(Self(version))
1287    }
1288
1289    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1290        values.extend([self.0.to_string().parse().unwrap()]);
1291    }
1292}
1293
1294pub struct AppVersionHeader(SemanticVersion);
1295impl Header for AppVersionHeader {
1296    fn name() -> &'static HeaderName {
1297        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1298        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1299    }
1300
1301    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1302    where
1303        Self: Sized,
1304        I: Iterator<Item = &'i axum::http::HeaderValue>,
1305    {
1306        let version = values
1307            .next()
1308            .ok_or_else(axum::headers::Error::invalid)?
1309            .to_str()
1310            .map_err(|_| axum::headers::Error::invalid())?
1311            .parse()
1312            .map_err(|_| axum::headers::Error::invalid())?;
1313        Ok(Self(version))
1314    }
1315
1316    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1317        values.extend([self.0.to_string().parse().unwrap()]);
1318    }
1319}
1320
1321pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1322    Router::new()
1323        .route("/rpc", get(handle_websocket_request))
1324        .layer(
1325            ServiceBuilder::new()
1326                .layer(Extension(server.app_state.clone()))
1327                .layer(middleware::from_fn(auth::validate_header)),
1328        )
1329        .route("/metrics", get(handle_metrics))
1330        .layer(Extension(server))
1331}
1332
1333pub async fn handle_websocket_request(
1334    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1335    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1336    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1337    Extension(server): Extension<Arc<Server>>,
1338    Extension(principal): Extension<Principal>,
1339    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1340    ws: WebSocketUpgrade,
1341) -> axum::response::Response {
1342    if protocol_version != rpc::PROTOCOL_VERSION {
1343        return (
1344            StatusCode::UPGRADE_REQUIRED,
1345            "client must be upgraded".to_string(),
1346        )
1347            .into_response();
1348    }
1349
1350    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1351        return (
1352            StatusCode::UPGRADE_REQUIRED,
1353            "no version header found".to_string(),
1354        )
1355            .into_response();
1356    };
1357
1358    if !version.can_collaborate() {
1359        return (
1360            StatusCode::UPGRADE_REQUIRED,
1361            "client must be upgraded".to_string(),
1362        )
1363            .into_response();
1364    }
1365
1366    let socket_address = socket_address.to_string();
1367    ws.on_upgrade(move |socket| {
1368        let socket = socket
1369            .map_ok(to_tungstenite_message)
1370            .err_into()
1371            .with(|message| async move { to_axum_message(message) });
1372        let connection = Connection::new(Box::pin(socket));
1373        async move {
1374            server
1375                .handle_connection(
1376                    connection,
1377                    socket_address,
1378                    principal,
1379                    version,
1380                    country_code_header.map(|header| header.to_string()),
1381                    None,
1382                    Executor::Production,
1383                )
1384                .await;
1385        }
1386    })
1387}
1388
1389pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1390    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1391    let connections_metric = CONNECTIONS_METRIC
1392        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1393
1394    let connections = server
1395        .connection_pool
1396        .lock()
1397        .connections()
1398        .filter(|connection| !connection.admin)
1399        .count();
1400    connections_metric.set(connections as _);
1401
1402    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1404        register_int_gauge!(
1405            "shared_projects",
1406            "number of open projects with one or more guests"
1407        )
1408        .unwrap()
1409    });
1410
1411    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1412    shared_projects_metric.set(shared_projects as _);
1413
1414    let encoder = prometheus::TextEncoder::new();
1415    let metric_families = prometheus::gather();
1416    let encoded_metrics = encoder
1417        .encode_to_string(&metric_families)
1418        .map_err(|err| anyhow!("{}", err))?;
1419    Ok(encoded_metrics)
1420}
1421
1422#[instrument(err, skip(executor))]
1423async fn connection_lost(
1424    session: Session,
1425    mut teardown: watch::Receiver<bool>,
1426    executor: Executor,
1427) -> Result<()> {
1428    session.peer.disconnect(session.connection_id);
1429    session
1430        .connection_pool()
1431        .await
1432        .remove_connection(session.connection_id)?;
1433
1434    session
1435        .db()
1436        .await
1437        .connection_lost(session.connection_id)
1438        .await
1439        .trace_err();
1440
1441    futures::select_biased! {
1442        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1443            match &session.principal {
1444                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1445                    let session = session.for_user().unwrap();
1446
1447                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1448                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1449                    leave_channel_buffers_for_session(&session)
1450                        .await
1451                        .trace_err();
1452
1453                    if !session
1454                        .connection_pool()
1455                        .await
1456                        .is_user_online(session.user_id())
1457                    {
1458                        let db = session.db().await;
1459                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1460                            room_updated(&room, &session.peer);
1461                        }
1462                    }
1463
1464                    update_user_contacts(session.user_id(), &session).await?;
1465                },
1466            Principal::DevServer(_) => {
1467                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1468            },
1469        }
1470        },
1471        _ = teardown.changed().fuse() => {}
1472    }
1473
1474    Ok(())
1475}
1476
1477/// Acknowledges a ping from a client, used to keep the connection alive.
1478async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1479    response.send(proto::Ack {})?;
1480    Ok(())
1481}
1482
1483/// Creates a new room for calling (outside of channels)
1484async fn create_room(
1485    _request: proto::CreateRoom,
1486    response: Response<proto::CreateRoom>,
1487    session: UserSession,
1488) -> Result<()> {
1489    let live_kit_room = nanoid::nanoid!(30);
1490
1491    let live_kit_connection_info = util::maybe!(async {
1492        let live_kit = session.app_state.live_kit_client.as_ref();
1493        let live_kit = live_kit?;
1494        let user_id = session.user_id().to_string();
1495
1496        let token = live_kit
1497            .room_token(&live_kit_room, &user_id.to_string())
1498            .trace_err()?;
1499
1500        Some(proto::LiveKitConnectionInfo {
1501            server_url: live_kit.url().into(),
1502            token,
1503            can_publish: true,
1504        })
1505    })
1506    .await;
1507
1508    let room = session
1509        .db()
1510        .await
1511        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1512        .await?;
1513
1514    response.send(proto::CreateRoomResponse {
1515        room: Some(room.clone()),
1516        live_kit_connection_info,
1517    })?;
1518
1519    update_user_contacts(session.user_id(), &session).await?;
1520    Ok(())
1521}
1522
1523/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1524async fn join_room(
1525    request: proto::JoinRoom,
1526    response: Response<proto::JoinRoom>,
1527    session: UserSession,
1528) -> Result<()> {
1529    let room_id = RoomId::from_proto(request.id);
1530
1531    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1532
1533    if let Some(channel_id) = channel_id {
1534        return join_channel_internal(channel_id, Box::new(response), session).await;
1535    }
1536
1537    let joined_room = {
1538        let room = session
1539            .db()
1540            .await
1541            .join_room(room_id, session.user_id(), session.connection_id)
1542            .await?;
1543        room_updated(&room.room, &session.peer);
1544        room.into_inner()
1545    };
1546
1547    for connection_id in session
1548        .connection_pool()
1549        .await
1550        .user_connection_ids(session.user_id())
1551    {
1552        session
1553            .peer
1554            .send(
1555                connection_id,
1556                proto::CallCanceled {
1557                    room_id: room_id.to_proto(),
1558                },
1559            )
1560            .trace_err();
1561    }
1562
1563    let live_kit_connection_info =
1564        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1565            if let Some(token) = live_kit
1566                .room_token(
1567                    &joined_room.room.live_kit_room,
1568                    &session.user_id().to_string(),
1569                )
1570                .trace_err()
1571            {
1572                Some(proto::LiveKitConnectionInfo {
1573                    server_url: live_kit.url().into(),
1574                    token,
1575                    can_publish: true,
1576                })
1577            } else {
1578                None
1579            }
1580        } else {
1581            None
1582        };
1583
1584    response.send(proto::JoinRoomResponse {
1585        room: Some(joined_room.room),
1586        channel_id: None,
1587        live_kit_connection_info,
1588    })?;
1589
1590    update_user_contacts(session.user_id(), &session).await?;
1591    Ok(())
1592}
1593
1594/// Rejoin room is used to reconnect to a room after connection errors.
1595async fn rejoin_room(
1596    request: proto::RejoinRoom,
1597    response: Response<proto::RejoinRoom>,
1598    session: UserSession,
1599) -> Result<()> {
1600    let room;
1601    let channel;
1602    {
1603        let mut rejoined_room = session
1604            .db()
1605            .await
1606            .rejoin_room(request, session.user_id(), session.connection_id)
1607            .await?;
1608
1609        response.send(proto::RejoinRoomResponse {
1610            room: Some(rejoined_room.room.clone()),
1611            reshared_projects: rejoined_room
1612                .reshared_projects
1613                .iter()
1614                .map(|project| proto::ResharedProject {
1615                    id: project.id.to_proto(),
1616                    collaborators: project
1617                        .collaborators
1618                        .iter()
1619                        .map(|collaborator| collaborator.to_proto())
1620                        .collect(),
1621                })
1622                .collect(),
1623            rejoined_projects: rejoined_room
1624                .rejoined_projects
1625                .iter()
1626                .map(|rejoined_project| rejoined_project.to_proto())
1627                .collect(),
1628        })?;
1629        room_updated(&rejoined_room.room, &session.peer);
1630
1631        for project in &rejoined_room.reshared_projects {
1632            for collaborator in &project.collaborators {
1633                session
1634                    .peer
1635                    .send(
1636                        collaborator.connection_id,
1637                        proto::UpdateProjectCollaborator {
1638                            project_id: project.id.to_proto(),
1639                            old_peer_id: Some(project.old_connection_id.into()),
1640                            new_peer_id: Some(session.connection_id.into()),
1641                        },
1642                    )
1643                    .trace_err();
1644            }
1645
1646            broadcast(
1647                Some(session.connection_id),
1648                project
1649                    .collaborators
1650                    .iter()
1651                    .map(|collaborator| collaborator.connection_id),
1652                |connection_id| {
1653                    session.peer.forward_send(
1654                        session.connection_id,
1655                        connection_id,
1656                        proto::UpdateProject {
1657                            project_id: project.id.to_proto(),
1658                            worktrees: project.worktrees.clone(),
1659                        },
1660                    )
1661                },
1662            );
1663        }
1664
1665        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1666
1667        let rejoined_room = rejoined_room.into_inner();
1668
1669        room = rejoined_room.room;
1670        channel = rejoined_room.channel;
1671    }
1672
1673    if let Some(channel) = channel {
1674        channel_updated(
1675            &channel,
1676            &room,
1677            &session.peer,
1678            &*session.connection_pool().await,
1679        );
1680    }
1681
1682    update_user_contacts(session.user_id(), &session).await?;
1683    Ok(())
1684}
1685
1686fn notify_rejoined_projects(
1687    rejoined_projects: &mut Vec<RejoinedProject>,
1688    session: &UserSession,
1689) -> Result<()> {
1690    for project in rejoined_projects.iter() {
1691        for collaborator in &project.collaborators {
1692            session
1693                .peer
1694                .send(
1695                    collaborator.connection_id,
1696                    proto::UpdateProjectCollaborator {
1697                        project_id: project.id.to_proto(),
1698                        old_peer_id: Some(project.old_connection_id.into()),
1699                        new_peer_id: Some(session.connection_id.into()),
1700                    },
1701                )
1702                .trace_err();
1703        }
1704    }
1705
1706    for project in rejoined_projects {
1707        for worktree in mem::take(&mut project.worktrees) {
1708            #[cfg(any(test, feature = "test-support"))]
1709            const MAX_CHUNK_SIZE: usize = 2;
1710            #[cfg(not(any(test, feature = "test-support")))]
1711            const MAX_CHUNK_SIZE: usize = 256;
1712
1713            // Stream this worktree's entries.
1714            let message = proto::UpdateWorktree {
1715                project_id: project.id.to_proto(),
1716                worktree_id: worktree.id,
1717                abs_path: worktree.abs_path.clone(),
1718                root_name: worktree.root_name,
1719                updated_entries: worktree.updated_entries,
1720                removed_entries: worktree.removed_entries,
1721                scan_id: worktree.scan_id,
1722                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1723                updated_repositories: worktree.updated_repositories,
1724                removed_repositories: worktree.removed_repositories,
1725            };
1726            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1727                session.peer.send(session.connection_id, update.clone())?;
1728            }
1729
1730            // Stream this worktree's diagnostics.
1731            for summary in worktree.diagnostic_summaries {
1732                session.peer.send(
1733                    session.connection_id,
1734                    proto::UpdateDiagnosticSummary {
1735                        project_id: project.id.to_proto(),
1736                        worktree_id: worktree.id,
1737                        summary: Some(summary),
1738                    },
1739                )?;
1740            }
1741
1742            for settings_file in worktree.settings_files {
1743                session.peer.send(
1744                    session.connection_id,
1745                    proto::UpdateWorktreeSettings {
1746                        project_id: project.id.to_proto(),
1747                        worktree_id: worktree.id,
1748                        path: settings_file.path,
1749                        content: Some(settings_file.content),
1750                    },
1751                )?;
1752            }
1753        }
1754
1755        for language_server in &project.language_servers {
1756            session.peer.send(
1757                session.connection_id,
1758                proto::UpdateLanguageServer {
1759                    project_id: project.id.to_proto(),
1760                    language_server_id: language_server.id,
1761                    variant: Some(
1762                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1763                            proto::LspDiskBasedDiagnosticsUpdated {},
1764                        ),
1765                    ),
1766                },
1767            )?;
1768        }
1769    }
1770    Ok(())
1771}
1772
1773/// leave room disconnects from the room.
1774async fn leave_room(
1775    _: proto::LeaveRoom,
1776    response: Response<proto::LeaveRoom>,
1777    session: UserSession,
1778) -> Result<()> {
1779    leave_room_for_session(&session, session.connection_id).await?;
1780    response.send(proto::Ack {})?;
1781    Ok(())
1782}
1783
1784/// Updates the permissions of someone else in the room.
1785async fn set_room_participant_role(
1786    request: proto::SetRoomParticipantRole,
1787    response: Response<proto::SetRoomParticipantRole>,
1788    session: UserSession,
1789) -> Result<()> {
1790    let user_id = UserId::from_proto(request.user_id);
1791    let role = ChannelRole::from(request.role());
1792
1793    let (live_kit_room, can_publish) = {
1794        let room = session
1795            .db()
1796            .await
1797            .set_room_participant_role(
1798                session.user_id(),
1799                RoomId::from_proto(request.room_id),
1800                user_id,
1801                role,
1802            )
1803            .await?;
1804
1805        let live_kit_room = room.live_kit_room.clone();
1806        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1807        room_updated(&room, &session.peer);
1808        (live_kit_room, can_publish)
1809    };
1810
1811    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1812        live_kit
1813            .update_participant(
1814                live_kit_room.clone(),
1815                request.user_id.to_string(),
1816                live_kit_server::proto::ParticipantPermission {
1817                    can_subscribe: true,
1818                    can_publish,
1819                    can_publish_data: can_publish,
1820                    hidden: false,
1821                    recorder: false,
1822                },
1823            )
1824            .await
1825            .trace_err();
1826    }
1827
1828    response.send(proto::Ack {})?;
1829    Ok(())
1830}
1831
1832/// Call someone else into the current room
1833async fn call(
1834    request: proto::Call,
1835    response: Response<proto::Call>,
1836    session: UserSession,
1837) -> Result<()> {
1838    let room_id = RoomId::from_proto(request.room_id);
1839    let calling_user_id = session.user_id();
1840    let calling_connection_id = session.connection_id;
1841    let called_user_id = UserId::from_proto(request.called_user_id);
1842    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1843    if !session
1844        .db()
1845        .await
1846        .has_contact(calling_user_id, called_user_id)
1847        .await?
1848    {
1849        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1850    }
1851
1852    let incoming_call = {
1853        let (room, incoming_call) = &mut *session
1854            .db()
1855            .await
1856            .call(
1857                room_id,
1858                calling_user_id,
1859                calling_connection_id,
1860                called_user_id,
1861                initial_project_id,
1862            )
1863            .await?;
1864        room_updated(&room, &session.peer);
1865        mem::take(incoming_call)
1866    };
1867    update_user_contacts(called_user_id, &session).await?;
1868
1869    let mut calls = session
1870        .connection_pool()
1871        .await
1872        .user_connection_ids(called_user_id)
1873        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1874        .collect::<FuturesUnordered<_>>();
1875
1876    while let Some(call_response) = calls.next().await {
1877        match call_response.as_ref() {
1878            Ok(_) => {
1879                response.send(proto::Ack {})?;
1880                return Ok(());
1881            }
1882            Err(_) => {
1883                call_response.trace_err();
1884            }
1885        }
1886    }
1887
1888    {
1889        let room = session
1890            .db()
1891            .await
1892            .call_failed(room_id, called_user_id)
1893            .await?;
1894        room_updated(&room, &session.peer);
1895    }
1896    update_user_contacts(called_user_id, &session).await?;
1897
1898    Err(anyhow!("failed to ring user"))?
1899}
1900
1901/// Cancel an outgoing call.
1902async fn cancel_call(
1903    request: proto::CancelCall,
1904    response: Response<proto::CancelCall>,
1905    session: UserSession,
1906) -> Result<()> {
1907    let called_user_id = UserId::from_proto(request.called_user_id);
1908    let room_id = RoomId::from_proto(request.room_id);
1909    {
1910        let room = session
1911            .db()
1912            .await
1913            .cancel_call(room_id, session.connection_id, called_user_id)
1914            .await?;
1915        room_updated(&room, &session.peer);
1916    }
1917
1918    for connection_id in session
1919        .connection_pool()
1920        .await
1921        .user_connection_ids(called_user_id)
1922    {
1923        session
1924            .peer
1925            .send(
1926                connection_id,
1927                proto::CallCanceled {
1928                    room_id: room_id.to_proto(),
1929                },
1930            )
1931            .trace_err();
1932    }
1933    response.send(proto::Ack {})?;
1934
1935    update_user_contacts(called_user_id, &session).await?;
1936    Ok(())
1937}
1938
1939/// Decline an incoming call.
1940async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1941    let room_id = RoomId::from_proto(message.room_id);
1942    {
1943        let room = session
1944            .db()
1945            .await
1946            .decline_call(Some(room_id), session.user_id())
1947            .await?
1948            .ok_or_else(|| anyhow!("failed to decline call"))?;
1949        room_updated(&room, &session.peer);
1950    }
1951
1952    for connection_id in session
1953        .connection_pool()
1954        .await
1955        .user_connection_ids(session.user_id())
1956    {
1957        session
1958            .peer
1959            .send(
1960                connection_id,
1961                proto::CallCanceled {
1962                    room_id: room_id.to_proto(),
1963                },
1964            )
1965            .trace_err();
1966    }
1967    update_user_contacts(session.user_id(), &session).await?;
1968    Ok(())
1969}
1970
1971/// Updates other participants in the room with your current location.
1972async fn update_participant_location(
1973    request: proto::UpdateParticipantLocation,
1974    response: Response<proto::UpdateParticipantLocation>,
1975    session: UserSession,
1976) -> Result<()> {
1977    let room_id = RoomId::from_proto(request.room_id);
1978    let location = request
1979        .location
1980        .ok_or_else(|| anyhow!("invalid location"))?;
1981
1982    let db = session.db().await;
1983    let room = db
1984        .update_room_participant_location(room_id, session.connection_id, location)
1985        .await?;
1986
1987    room_updated(&room, &session.peer);
1988    response.send(proto::Ack {})?;
1989    Ok(())
1990}
1991
1992/// Share a project into the room.
1993async fn share_project(
1994    request: proto::ShareProject,
1995    response: Response<proto::ShareProject>,
1996    session: UserSession,
1997) -> Result<()> {
1998    let (project_id, room) = &*session
1999        .db()
2000        .await
2001        .share_project(
2002            RoomId::from_proto(request.room_id),
2003            session.connection_id,
2004            &request.worktrees,
2005            request
2006                .dev_server_project_id
2007                .map(|id| DevServerProjectId::from_proto(id)),
2008        )
2009        .await?;
2010    response.send(proto::ShareProjectResponse {
2011        project_id: project_id.to_proto(),
2012    })?;
2013    room_updated(&room, &session.peer);
2014
2015    Ok(())
2016}
2017
2018/// Unshare a project from the room.
2019async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2020    let project_id = ProjectId::from_proto(message.project_id);
2021    unshare_project_internal(
2022        project_id,
2023        session.connection_id,
2024        session.user_id(),
2025        &session,
2026    )
2027    .await
2028}
2029
2030async fn unshare_project_internal(
2031    project_id: ProjectId,
2032    connection_id: ConnectionId,
2033    user_id: Option<UserId>,
2034    session: &Session,
2035) -> Result<()> {
2036    let delete = {
2037        let room_guard = session
2038            .db()
2039            .await
2040            .unshare_project(project_id, connection_id, user_id)
2041            .await?;
2042
2043        let (delete, room, guest_connection_ids) = &*room_guard;
2044
2045        let message = proto::UnshareProject {
2046            project_id: project_id.to_proto(),
2047        };
2048
2049        broadcast(
2050            Some(connection_id),
2051            guest_connection_ids.iter().copied(),
2052            |conn_id| session.peer.send(conn_id, message.clone()),
2053        );
2054        if let Some(room) = room {
2055            room_updated(room, &session.peer);
2056        }
2057
2058        *delete
2059    };
2060
2061    if delete {
2062        let db = session.db().await;
2063        db.delete_project(project_id).await?;
2064    }
2065
2066    Ok(())
2067}
2068
2069/// DevServer makes a project available online
2070async fn share_dev_server_project(
2071    request: proto::ShareDevServerProject,
2072    response: Response<proto::ShareDevServerProject>,
2073    session: DevServerSession,
2074) -> Result<()> {
2075    let (dev_server_project, user_id, status) = session
2076        .db()
2077        .await
2078        .share_dev_server_project(
2079            DevServerProjectId::from_proto(request.dev_server_project_id),
2080            session.dev_server_id(),
2081            session.connection_id,
2082            &request.worktrees,
2083        )
2084        .await?;
2085    let Some(project_id) = dev_server_project.project_id else {
2086        return Err(anyhow!("failed to share remote project"))?;
2087    };
2088
2089    send_dev_server_projects_update(user_id, status, &session).await;
2090
2091    response.send(proto::ShareProjectResponse { project_id })?;
2092
2093    Ok(())
2094}
2095
2096/// Join someone elses shared project.
2097async fn join_project(
2098    request: proto::JoinProject,
2099    response: Response<proto::JoinProject>,
2100    session: UserSession,
2101) -> Result<()> {
2102    let project_id = ProjectId::from_proto(request.project_id);
2103
2104    tracing::info!(%project_id, "join project");
2105
2106    let db = session.db().await;
2107    let (project, replica_id) = &mut *db
2108        .join_project(project_id, session.connection_id, session.user_id())
2109        .await?;
2110    drop(db);
2111    tracing::info!(%project_id, "join remote project");
2112    join_project_internal(response, session, project, replica_id)
2113}
2114
2115trait JoinProjectInternalResponse {
2116    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2117}
2118impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2119    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2120        Response::<proto::JoinProject>::send(self, result)
2121    }
2122}
2123impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2124    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2125        Response::<proto::JoinHostedProject>::send(self, result)
2126    }
2127}
2128
2129fn join_project_internal(
2130    response: impl JoinProjectInternalResponse,
2131    session: UserSession,
2132    project: &mut Project,
2133    replica_id: &ReplicaId,
2134) -> Result<()> {
2135    let collaborators = project
2136        .collaborators
2137        .iter()
2138        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2139        .map(|collaborator| collaborator.to_proto())
2140        .collect::<Vec<_>>();
2141    let project_id = project.id;
2142    let guest_user_id = session.user_id();
2143
2144    let worktrees = project
2145        .worktrees
2146        .iter()
2147        .map(|(id, worktree)| proto::WorktreeMetadata {
2148            id: *id,
2149            root_name: worktree.root_name.clone(),
2150            visible: worktree.visible,
2151            abs_path: worktree.abs_path.clone(),
2152        })
2153        .collect::<Vec<_>>();
2154
2155    let add_project_collaborator = proto::AddProjectCollaborator {
2156        project_id: project_id.to_proto(),
2157        collaborator: Some(proto::Collaborator {
2158            peer_id: Some(session.connection_id.into()),
2159            replica_id: replica_id.0 as u32,
2160            user_id: guest_user_id.to_proto(),
2161        }),
2162    };
2163
2164    for collaborator in &collaborators {
2165        session
2166            .peer
2167            .send(
2168                collaborator.peer_id.unwrap().into(),
2169                add_project_collaborator.clone(),
2170            )
2171            .trace_err();
2172    }
2173
2174    // First, we send the metadata associated with each worktree.
2175    response.send(proto::JoinProjectResponse {
2176        project_id: project.id.0 as u64,
2177        worktrees: worktrees.clone(),
2178        replica_id: replica_id.0 as u32,
2179        collaborators: collaborators.clone(),
2180        language_servers: project.language_servers.clone(),
2181        role: project.role.into(),
2182        dev_server_project_id: project
2183            .dev_server_project_id
2184            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2185    })?;
2186
2187    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2188        #[cfg(any(test, feature = "test-support"))]
2189        const MAX_CHUNK_SIZE: usize = 2;
2190        #[cfg(not(any(test, feature = "test-support")))]
2191        const MAX_CHUNK_SIZE: usize = 256;
2192
2193        // Stream this worktree's entries.
2194        let message = proto::UpdateWorktree {
2195            project_id: project_id.to_proto(),
2196            worktree_id,
2197            abs_path: worktree.abs_path.clone(),
2198            root_name: worktree.root_name,
2199            updated_entries: worktree.entries,
2200            removed_entries: Default::default(),
2201            scan_id: worktree.scan_id,
2202            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2203            updated_repositories: worktree.repository_entries.into_values().collect(),
2204            removed_repositories: Default::default(),
2205        };
2206        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2207            session.peer.send(session.connection_id, update.clone())?;
2208        }
2209
2210        // Stream this worktree's diagnostics.
2211        for summary in worktree.diagnostic_summaries {
2212            session.peer.send(
2213                session.connection_id,
2214                proto::UpdateDiagnosticSummary {
2215                    project_id: project_id.to_proto(),
2216                    worktree_id: worktree.id,
2217                    summary: Some(summary),
2218                },
2219            )?;
2220        }
2221
2222        for settings_file in worktree.settings_files {
2223            session.peer.send(
2224                session.connection_id,
2225                proto::UpdateWorktreeSettings {
2226                    project_id: project_id.to_proto(),
2227                    worktree_id: worktree.id,
2228                    path: settings_file.path,
2229                    content: Some(settings_file.content),
2230                },
2231            )?;
2232        }
2233    }
2234
2235    for language_server in &project.language_servers {
2236        session.peer.send(
2237            session.connection_id,
2238            proto::UpdateLanguageServer {
2239                project_id: project_id.to_proto(),
2240                language_server_id: language_server.id,
2241                variant: Some(
2242                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2243                        proto::LspDiskBasedDiagnosticsUpdated {},
2244                    ),
2245                ),
2246            },
2247        )?;
2248    }
2249
2250    Ok(())
2251}
2252
2253/// Leave someone elses shared project.
2254async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2255    let sender_id = session.connection_id;
2256    let project_id = ProjectId::from_proto(request.project_id);
2257    let db = session.db().await;
2258    if db.is_hosted_project(project_id).await? {
2259        let project = db.leave_hosted_project(project_id, sender_id).await?;
2260        project_left(&project, &session);
2261        return Ok(());
2262    }
2263
2264    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2265    tracing::info!(
2266        %project_id,
2267        "leave project"
2268    );
2269
2270    project_left(&project, &session);
2271    if let Some(room) = room {
2272        room_updated(&room, &session.peer);
2273    }
2274
2275    Ok(())
2276}
2277
2278async fn join_hosted_project(
2279    request: proto::JoinHostedProject,
2280    response: Response<proto::JoinHostedProject>,
2281    session: UserSession,
2282) -> Result<()> {
2283    let (mut project, replica_id) = session
2284        .db()
2285        .await
2286        .join_hosted_project(
2287            ProjectId(request.project_id as i32),
2288            session.user_id(),
2289            session.connection_id,
2290        )
2291        .await?;
2292
2293    join_project_internal(response, session, &mut project, &replica_id)
2294}
2295
2296async fn list_remote_directory(
2297    request: proto::ListRemoteDirectory,
2298    response: Response<proto::ListRemoteDirectory>,
2299    session: UserSession,
2300) -> Result<()> {
2301    let dev_server_id = DevServerId(request.dev_server_id as i32);
2302    let dev_server_connection_id = session
2303        .connection_pool()
2304        .await
2305        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2306
2307    session
2308        .db()
2309        .await
2310        .get_dev_server_for_user(dev_server_id, session.user_id())
2311        .await?;
2312
2313    response.send(
2314        session
2315            .peer
2316            .forward_request(session.connection_id, dev_server_connection_id, request)
2317            .await?,
2318    )?;
2319    Ok(())
2320}
2321
2322async fn update_dev_server_project(
2323    request: proto::UpdateDevServerProject,
2324    response: Response<proto::UpdateDevServerProject>,
2325    session: UserSession,
2326) -> Result<()> {
2327    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2328
2329    let (dev_server_project, update) = session
2330        .db()
2331        .await
2332        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2333        .await?;
2334
2335    let projects = session
2336        .db()
2337        .await
2338        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2339        .await?;
2340
2341    let dev_server_connection_id = session
2342        .connection_pool()
2343        .await
2344        .dev_server_connection_id_supporting(
2345            dev_server_project.dev_server_id,
2346            ZedVersion::with_list_directory(),
2347        )?;
2348
2349    session.peer.send(
2350        dev_server_connection_id,
2351        proto::DevServerInstructions { projects },
2352    )?;
2353
2354    send_dev_server_projects_update(session.user_id(), update, &session).await;
2355
2356    response.send(proto::Ack {})
2357}
2358
2359async fn create_dev_server_project(
2360    request: proto::CreateDevServerProject,
2361    response: Response<proto::CreateDevServerProject>,
2362    session: UserSession,
2363) -> Result<()> {
2364    let dev_server_id = DevServerId(request.dev_server_id as i32);
2365    let dev_server_connection_id = session
2366        .connection_pool()
2367        .await
2368        .dev_server_connection_id(dev_server_id);
2369    let Some(dev_server_connection_id) = dev_server_connection_id else {
2370        Err(ErrorCode::DevServerOffline
2371            .message("Cannot create a remote project when the dev server is offline".to_string())
2372            .anyhow())?
2373    };
2374
2375    let path = request.path.clone();
2376    //Check that the path exists on the dev server
2377    session
2378        .peer
2379        .forward_request(
2380            session.connection_id,
2381            dev_server_connection_id,
2382            proto::ValidateDevServerProjectRequest { path: path.clone() },
2383        )
2384        .await?;
2385
2386    let (dev_server_project, update) = session
2387        .db()
2388        .await
2389        .create_dev_server_project(
2390            DevServerId(request.dev_server_id as i32),
2391            &request.path,
2392            session.user_id(),
2393        )
2394        .await?;
2395
2396    let projects = session
2397        .db()
2398        .await
2399        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2400        .await?;
2401
2402    session.peer.send(
2403        dev_server_connection_id,
2404        proto::DevServerInstructions { projects },
2405    )?;
2406
2407    send_dev_server_projects_update(session.user_id(), update, &session).await;
2408
2409    response.send(proto::CreateDevServerProjectResponse {
2410        dev_server_project: Some(dev_server_project.to_proto(None)),
2411    })?;
2412    Ok(())
2413}
2414
2415async fn create_dev_server(
2416    request: proto::CreateDevServer,
2417    response: Response<proto::CreateDevServer>,
2418    session: UserSession,
2419) -> Result<()> {
2420    let access_token = auth::random_token();
2421    let hashed_access_token = auth::hash_access_token(&access_token);
2422
2423    if request.name.is_empty() {
2424        return Err(proto::ErrorCode::Forbidden
2425            .message("Dev server name cannot be empty".to_string())
2426            .anyhow())?;
2427    }
2428
2429    let (dev_server, status) = session
2430        .db()
2431        .await
2432        .create_dev_server(
2433            &request.name,
2434            request.ssh_connection_string.as_deref(),
2435            &hashed_access_token,
2436            session.user_id(),
2437        )
2438        .await?;
2439
2440    send_dev_server_projects_update(session.user_id(), status, &session).await;
2441
2442    response.send(proto::CreateDevServerResponse {
2443        dev_server_id: dev_server.id.0 as u64,
2444        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2445        name: request.name,
2446    })?;
2447    Ok(())
2448}
2449
2450async fn regenerate_dev_server_token(
2451    request: proto::RegenerateDevServerToken,
2452    response: Response<proto::RegenerateDevServerToken>,
2453    session: UserSession,
2454) -> Result<()> {
2455    let dev_server_id = DevServerId(request.dev_server_id as i32);
2456    let access_token = auth::random_token();
2457    let hashed_access_token = auth::hash_access_token(&access_token);
2458
2459    let connection_id = session
2460        .connection_pool()
2461        .await
2462        .dev_server_connection_id(dev_server_id);
2463    if let Some(connection_id) = connection_id {
2464        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2465        session.peer.send(
2466            connection_id,
2467            proto::ShutdownDevServer {
2468                reason: Some("dev server token was regenerated".to_string()),
2469            },
2470        )?;
2471        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2472    }
2473
2474    let status = session
2475        .db()
2476        .await
2477        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2478        .await?;
2479
2480    send_dev_server_projects_update(session.user_id(), status, &session).await;
2481
2482    response.send(proto::RegenerateDevServerTokenResponse {
2483        dev_server_id: dev_server_id.to_proto(),
2484        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2485    })?;
2486    Ok(())
2487}
2488
2489async fn rename_dev_server(
2490    request: proto::RenameDevServer,
2491    response: Response<proto::RenameDevServer>,
2492    session: UserSession,
2493) -> Result<()> {
2494    if request.name.trim().is_empty() {
2495        return Err(proto::ErrorCode::Forbidden
2496            .message("Dev server name cannot be empty".to_string())
2497            .anyhow())?;
2498    }
2499
2500    let dev_server_id = DevServerId(request.dev_server_id as i32);
2501    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2502    if dev_server.user_id != session.user_id() {
2503        return Err(anyhow!(ErrorCode::Forbidden))?;
2504    }
2505
2506    let status = session
2507        .db()
2508        .await
2509        .rename_dev_server(
2510            dev_server_id,
2511            &request.name,
2512            request.ssh_connection_string.as_deref(),
2513            session.user_id(),
2514        )
2515        .await?;
2516
2517    send_dev_server_projects_update(session.user_id(), status, &session).await;
2518
2519    response.send(proto::Ack {})?;
2520    Ok(())
2521}
2522
2523async fn delete_dev_server(
2524    request: proto::DeleteDevServer,
2525    response: Response<proto::DeleteDevServer>,
2526    session: UserSession,
2527) -> Result<()> {
2528    let dev_server_id = DevServerId(request.dev_server_id as i32);
2529    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2530    if dev_server.user_id != session.user_id() {
2531        return Err(anyhow!(ErrorCode::Forbidden))?;
2532    }
2533
2534    let connection_id = session
2535        .connection_pool()
2536        .await
2537        .dev_server_connection_id(dev_server_id);
2538    if let Some(connection_id) = connection_id {
2539        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2540        session.peer.send(
2541            connection_id,
2542            proto::ShutdownDevServer {
2543                reason: Some("dev server was deleted".to_string()),
2544            },
2545        )?;
2546        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2547    }
2548
2549    let status = session
2550        .db()
2551        .await
2552        .delete_dev_server(dev_server_id, session.user_id())
2553        .await?;
2554
2555    send_dev_server_projects_update(session.user_id(), status, &session).await;
2556
2557    response.send(proto::Ack {})?;
2558    Ok(())
2559}
2560
2561async fn delete_dev_server_project(
2562    request: proto::DeleteDevServerProject,
2563    response: Response<proto::DeleteDevServerProject>,
2564    session: UserSession,
2565) -> Result<()> {
2566    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2567    let dev_server_project = session
2568        .db()
2569        .await
2570        .get_dev_server_project(dev_server_project_id)
2571        .await?;
2572
2573    let dev_server = session
2574        .db()
2575        .await
2576        .get_dev_server(dev_server_project.dev_server_id)
2577        .await?;
2578    if dev_server.user_id != session.user_id() {
2579        return Err(anyhow!(ErrorCode::Forbidden))?;
2580    }
2581
2582    let dev_server_connection_id = session
2583        .connection_pool()
2584        .await
2585        .dev_server_connection_id(dev_server.id);
2586
2587    if let Some(dev_server_connection_id) = dev_server_connection_id {
2588        let project = session
2589            .db()
2590            .await
2591            .find_dev_server_project(dev_server_project_id)
2592            .await;
2593        if let Ok(project) = project {
2594            unshare_project_internal(
2595                project.id,
2596                dev_server_connection_id,
2597                Some(session.user_id()),
2598                &session,
2599            )
2600            .await?;
2601        }
2602    }
2603
2604    let (projects, status) = session
2605        .db()
2606        .await
2607        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2608        .await?;
2609
2610    if let Some(dev_server_connection_id) = dev_server_connection_id {
2611        session.peer.send(
2612            dev_server_connection_id,
2613            proto::DevServerInstructions { projects },
2614        )?;
2615    }
2616
2617    send_dev_server_projects_update(session.user_id(), status, &session).await;
2618
2619    response.send(proto::Ack {})?;
2620    Ok(())
2621}
2622
2623async fn rejoin_dev_server_projects(
2624    request: proto::RejoinRemoteProjects,
2625    response: Response<proto::RejoinRemoteProjects>,
2626    session: UserSession,
2627) -> Result<()> {
2628    let mut rejoined_projects = {
2629        let db = session.db().await;
2630        db.rejoin_dev_server_projects(
2631            &request.rejoined_projects,
2632            session.user_id(),
2633            session.0.connection_id,
2634        )
2635        .await?
2636    };
2637    response.send(proto::RejoinRemoteProjectsResponse {
2638        rejoined_projects: rejoined_projects
2639            .iter()
2640            .map(|project| project.to_proto())
2641            .collect(),
2642    })?;
2643    notify_rejoined_projects(&mut rejoined_projects, &session)
2644}
2645
2646async fn reconnect_dev_server(
2647    request: proto::ReconnectDevServer,
2648    response: Response<proto::ReconnectDevServer>,
2649    session: DevServerSession,
2650) -> Result<()> {
2651    let reshared_projects = {
2652        let db = session.db().await;
2653        db.reshare_dev_server_projects(
2654            &request.reshared_projects,
2655            session.dev_server_id(),
2656            session.0.connection_id,
2657        )
2658        .await?
2659    };
2660
2661    for project in &reshared_projects {
2662        for collaborator in &project.collaborators {
2663            session
2664                .peer
2665                .send(
2666                    collaborator.connection_id,
2667                    proto::UpdateProjectCollaborator {
2668                        project_id: project.id.to_proto(),
2669                        old_peer_id: Some(project.old_connection_id.into()),
2670                        new_peer_id: Some(session.connection_id.into()),
2671                    },
2672                )
2673                .trace_err();
2674        }
2675
2676        broadcast(
2677            Some(session.connection_id),
2678            project
2679                .collaborators
2680                .iter()
2681                .map(|collaborator| collaborator.connection_id),
2682            |connection_id| {
2683                session.peer.forward_send(
2684                    session.connection_id,
2685                    connection_id,
2686                    proto::UpdateProject {
2687                        project_id: project.id.to_proto(),
2688                        worktrees: project.worktrees.clone(),
2689                    },
2690                )
2691            },
2692        );
2693    }
2694
2695    response.send(proto::ReconnectDevServerResponse {
2696        reshared_projects: reshared_projects
2697            .iter()
2698            .map(|project| proto::ResharedProject {
2699                id: project.id.to_proto(),
2700                collaborators: project
2701                    .collaborators
2702                    .iter()
2703                    .map(|collaborator| collaborator.to_proto())
2704                    .collect(),
2705            })
2706            .collect(),
2707    })?;
2708
2709    Ok(())
2710}
2711
2712async fn shutdown_dev_server(
2713    _: proto::ShutdownDevServer,
2714    response: Response<proto::ShutdownDevServer>,
2715    session: DevServerSession,
2716) -> Result<()> {
2717    response.send(proto::Ack {})?;
2718    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2719    remove_dev_server_connection(session.dev_server_id(), &session).await
2720}
2721
2722async fn shutdown_dev_server_internal(
2723    dev_server_id: DevServerId,
2724    connection_id: ConnectionId,
2725    session: &Session,
2726) -> Result<()> {
2727    let (dev_server_projects, dev_server) = {
2728        let db = session.db().await;
2729        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2730        let dev_server = db.get_dev_server(dev_server_id).await?;
2731        (dev_server_projects, dev_server)
2732    };
2733
2734    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2735        unshare_project_internal(
2736            ProjectId::from_proto(project_id),
2737            connection_id,
2738            None,
2739            session,
2740        )
2741        .await?;
2742    }
2743
2744    session
2745        .connection_pool()
2746        .await
2747        .set_dev_server_offline(dev_server_id);
2748
2749    let status = session
2750        .db()
2751        .await
2752        .dev_server_projects_update(dev_server.user_id)
2753        .await?;
2754    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2755
2756    Ok(())
2757}
2758
2759async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2760    let dev_server_connection = session
2761        .connection_pool()
2762        .await
2763        .dev_server_connection_id(dev_server_id);
2764
2765    if let Some(dev_server_connection) = dev_server_connection {
2766        session
2767            .connection_pool()
2768            .await
2769            .remove_connection(dev_server_connection)?;
2770    }
2771    Ok(())
2772}
2773
2774/// Updates other participants with changes to the project
2775async fn update_project(
2776    request: proto::UpdateProject,
2777    response: Response<proto::UpdateProject>,
2778    session: Session,
2779) -> Result<()> {
2780    let project_id = ProjectId::from_proto(request.project_id);
2781    let (room, guest_connection_ids) = &*session
2782        .db()
2783        .await
2784        .update_project(project_id, session.connection_id, &request.worktrees)
2785        .await?;
2786    broadcast(
2787        Some(session.connection_id),
2788        guest_connection_ids.iter().copied(),
2789        |connection_id| {
2790            session
2791                .peer
2792                .forward_send(session.connection_id, connection_id, request.clone())
2793        },
2794    );
2795    if let Some(room) = room {
2796        room_updated(&room, &session.peer);
2797    }
2798    response.send(proto::Ack {})?;
2799
2800    Ok(())
2801}
2802
2803/// Updates other participants with changes to the worktree
2804async fn update_worktree(
2805    request: proto::UpdateWorktree,
2806    response: Response<proto::UpdateWorktree>,
2807    session: Session,
2808) -> Result<()> {
2809    let guest_connection_ids = session
2810        .db()
2811        .await
2812        .update_worktree(&request, session.connection_id)
2813        .await?;
2814
2815    broadcast(
2816        Some(session.connection_id),
2817        guest_connection_ids.iter().copied(),
2818        |connection_id| {
2819            session
2820                .peer
2821                .forward_send(session.connection_id, connection_id, request.clone())
2822        },
2823    );
2824    response.send(proto::Ack {})?;
2825    Ok(())
2826}
2827
2828/// Updates other participants with changes to the diagnostics
2829async fn update_diagnostic_summary(
2830    message: proto::UpdateDiagnosticSummary,
2831    session: Session,
2832) -> Result<()> {
2833    let guest_connection_ids = session
2834        .db()
2835        .await
2836        .update_diagnostic_summary(&message, session.connection_id)
2837        .await?;
2838
2839    broadcast(
2840        Some(session.connection_id),
2841        guest_connection_ids.iter().copied(),
2842        |connection_id| {
2843            session
2844                .peer
2845                .forward_send(session.connection_id, connection_id, message.clone())
2846        },
2847    );
2848
2849    Ok(())
2850}
2851
2852/// Updates other participants with changes to the worktree settings
2853async fn update_worktree_settings(
2854    message: proto::UpdateWorktreeSettings,
2855    session: Session,
2856) -> Result<()> {
2857    let guest_connection_ids = session
2858        .db()
2859        .await
2860        .update_worktree_settings(&message, session.connection_id)
2861        .await?;
2862
2863    broadcast(
2864        Some(session.connection_id),
2865        guest_connection_ids.iter().copied(),
2866        |connection_id| {
2867            session
2868                .peer
2869                .forward_send(session.connection_id, connection_id, message.clone())
2870        },
2871    );
2872
2873    Ok(())
2874}
2875
2876/// Notify other participants that a  language server has started.
2877async fn start_language_server(
2878    request: proto::StartLanguageServer,
2879    session: Session,
2880) -> Result<()> {
2881    let guest_connection_ids = session
2882        .db()
2883        .await
2884        .start_language_server(&request, session.connection_id)
2885        .await?;
2886
2887    broadcast(
2888        Some(session.connection_id),
2889        guest_connection_ids.iter().copied(),
2890        |connection_id| {
2891            session
2892                .peer
2893                .forward_send(session.connection_id, connection_id, request.clone())
2894        },
2895    );
2896    Ok(())
2897}
2898
2899/// Notify other participants that a language server has changed.
2900async fn update_language_server(
2901    request: proto::UpdateLanguageServer,
2902    session: Session,
2903) -> Result<()> {
2904    let project_id = ProjectId::from_proto(request.project_id);
2905    let project_connection_ids = session
2906        .db()
2907        .await
2908        .project_connection_ids(project_id, session.connection_id, true)
2909        .await?;
2910    broadcast(
2911        Some(session.connection_id),
2912        project_connection_ids.iter().copied(),
2913        |connection_id| {
2914            session
2915                .peer
2916                .forward_send(session.connection_id, connection_id, request.clone())
2917        },
2918    );
2919    Ok(())
2920}
2921
2922/// forward a project request to the host. These requests should be read only
2923/// as guests are allowed to send them.
2924async fn forward_read_only_project_request<T>(
2925    request: T,
2926    response: Response<T>,
2927    session: UserSession,
2928) -> Result<()>
2929where
2930    T: EntityMessage + RequestMessage,
2931{
2932    let project_id = ProjectId::from_proto(request.remote_entity_id());
2933    let host_connection_id = session
2934        .db()
2935        .await
2936        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2937        .await?;
2938    let payload = session
2939        .peer
2940        .forward_request(session.connection_id, host_connection_id, request)
2941        .await?;
2942    response.send(payload)?;
2943    Ok(())
2944}
2945
2946/// forward a project request to the dev server. Only allowed
2947/// if it's your dev server.
2948async fn forward_project_request_for_owner<T>(
2949    request: T,
2950    response: Response<T>,
2951    session: UserSession,
2952) -> Result<()>
2953where
2954    T: EntityMessage + RequestMessage,
2955{
2956    let project_id = ProjectId::from_proto(request.remote_entity_id());
2957
2958    let host_connection_id = session
2959        .db()
2960        .await
2961        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2962        .await?;
2963    let payload = session
2964        .peer
2965        .forward_request(session.connection_id, host_connection_id, request)
2966        .await?;
2967    response.send(payload)?;
2968    Ok(())
2969}
2970
2971/// forward a project request to the host. These requests are disallowed
2972/// for guests.
2973async fn forward_mutating_project_request<T>(
2974    request: T,
2975    response: Response<T>,
2976    session: UserSession,
2977) -> Result<()>
2978where
2979    T: EntityMessage + RequestMessage,
2980{
2981    let project_id = ProjectId::from_proto(request.remote_entity_id());
2982
2983    let host_connection_id = session
2984        .db()
2985        .await
2986        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2987        .await?;
2988    let payload = session
2989        .peer
2990        .forward_request(session.connection_id, host_connection_id, request)
2991        .await?;
2992    response.send(payload)?;
2993    Ok(())
2994}
2995
2996/// forward a project request to the host. These requests are disallowed
2997/// for guests.
2998async fn forward_versioned_mutating_project_request<T>(
2999    request: T,
3000    response: Response<T>,
3001    session: UserSession,
3002) -> Result<()>
3003where
3004    T: EntityMessage + RequestMessage + VersionedMessage,
3005{
3006    let project_id = ProjectId::from_proto(request.remote_entity_id());
3007
3008    let host_connection_id = session
3009        .db()
3010        .await
3011        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3012        .await?;
3013    if let Some(host_version) = session
3014        .connection_pool()
3015        .await
3016        .connection(host_connection_id)
3017        .map(|c| c.zed_version)
3018    {
3019        if let Some(min_required_version) = request.required_host_version() {
3020            if min_required_version > host_version {
3021                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3022                    .with_tag("required", &min_required_version.to_string())))?;
3023            }
3024        }
3025    }
3026
3027    let payload = session
3028        .peer
3029        .forward_request(session.connection_id, host_connection_id, request)
3030        .await?;
3031    response.send(payload)?;
3032    Ok(())
3033}
3034
3035/// Notify other participants that a new buffer has been created
3036async fn create_buffer_for_peer(
3037    request: proto::CreateBufferForPeer,
3038    session: Session,
3039) -> Result<()> {
3040    session
3041        .db()
3042        .await
3043        .check_user_is_project_host(
3044            ProjectId::from_proto(request.project_id),
3045            session.connection_id,
3046        )
3047        .await?;
3048    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3049    session
3050        .peer
3051        .forward_send(session.connection_id, peer_id.into(), request)?;
3052    Ok(())
3053}
3054
3055/// Notify other participants that a buffer has been updated. This is
3056/// allowed for guests as long as the update is limited to selections.
3057async fn update_buffer(
3058    request: proto::UpdateBuffer,
3059    response: Response<proto::UpdateBuffer>,
3060    session: Session,
3061) -> Result<()> {
3062    let project_id = ProjectId::from_proto(request.project_id);
3063    let mut capability = Capability::ReadOnly;
3064
3065    for op in request.operations.iter() {
3066        match op.variant {
3067            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3068            Some(_) => capability = Capability::ReadWrite,
3069        }
3070    }
3071
3072    let host = {
3073        let guard = session
3074            .db()
3075            .await
3076            .connections_for_buffer_update(
3077                project_id,
3078                session.principal_id(),
3079                session.connection_id,
3080                capability,
3081            )
3082            .await?;
3083
3084        let (host, guests) = &*guard;
3085
3086        broadcast(
3087            Some(session.connection_id),
3088            guests.clone(),
3089            |connection_id| {
3090                session
3091                    .peer
3092                    .forward_send(session.connection_id, connection_id, request.clone())
3093            },
3094        );
3095
3096        *host
3097    };
3098
3099    if host != session.connection_id {
3100        session
3101            .peer
3102            .forward_request(session.connection_id, host, request.clone())
3103            .await?;
3104    }
3105
3106    response.send(proto::Ack {})?;
3107    Ok(())
3108}
3109
3110async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3111    let project_id = ProjectId::from_proto(message.project_id);
3112
3113    let operation = message.operation.as_ref().context("invalid operation")?;
3114    let capability = match operation.variant.as_ref() {
3115        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3116            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3117                match buffer_op.variant {
3118                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3119                        Capability::ReadOnly
3120                    }
3121                    _ => Capability::ReadWrite,
3122                }
3123            } else {
3124                Capability::ReadWrite
3125            }
3126        }
3127        Some(_) => Capability::ReadWrite,
3128        None => Capability::ReadOnly,
3129    };
3130
3131    let guard = session
3132        .db()
3133        .await
3134        .connections_for_buffer_update(
3135            project_id,
3136            session.principal_id(),
3137            session.connection_id,
3138            capability,
3139        )
3140        .await?;
3141
3142    let (host, guests) = &*guard;
3143
3144    broadcast(
3145        Some(session.connection_id),
3146        guests.iter().chain([host]).copied(),
3147        |connection_id| {
3148            session
3149                .peer
3150                .forward_send(session.connection_id, connection_id, message.clone())
3151        },
3152    );
3153
3154    Ok(())
3155}
3156
3157/// Notify other participants that a project has been updated.
3158async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3159    request: T,
3160    session: Session,
3161) -> Result<()> {
3162    let project_id = ProjectId::from_proto(request.remote_entity_id());
3163    let project_connection_ids = session
3164        .db()
3165        .await
3166        .project_connection_ids(project_id, session.connection_id, false)
3167        .await?;
3168
3169    broadcast(
3170        Some(session.connection_id),
3171        project_connection_ids.iter().copied(),
3172        |connection_id| {
3173            session
3174                .peer
3175                .forward_send(session.connection_id, connection_id, request.clone())
3176        },
3177    );
3178    Ok(())
3179}
3180
3181/// Start following another user in a call.
3182async fn follow(
3183    request: proto::Follow,
3184    response: Response<proto::Follow>,
3185    session: UserSession,
3186) -> Result<()> {
3187    let room_id = RoomId::from_proto(request.room_id);
3188    let project_id = request.project_id.map(ProjectId::from_proto);
3189    let leader_id = request
3190        .leader_id
3191        .ok_or_else(|| anyhow!("invalid leader id"))?
3192        .into();
3193    let follower_id = session.connection_id;
3194
3195    session
3196        .db()
3197        .await
3198        .check_room_participants(room_id, leader_id, session.connection_id)
3199        .await?;
3200
3201    let response_payload = session
3202        .peer
3203        .forward_request(session.connection_id, leader_id, request)
3204        .await?;
3205    response.send(response_payload)?;
3206
3207    if let Some(project_id) = project_id {
3208        let room = session
3209            .db()
3210            .await
3211            .follow(room_id, project_id, leader_id, follower_id)
3212            .await?;
3213        room_updated(&room, &session.peer);
3214    }
3215
3216    Ok(())
3217}
3218
3219/// Stop following another user in a call.
3220async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3221    let room_id = RoomId::from_proto(request.room_id);
3222    let project_id = request.project_id.map(ProjectId::from_proto);
3223    let leader_id = request
3224        .leader_id
3225        .ok_or_else(|| anyhow!("invalid leader id"))?
3226        .into();
3227    let follower_id = session.connection_id;
3228
3229    session
3230        .db()
3231        .await
3232        .check_room_participants(room_id, leader_id, session.connection_id)
3233        .await?;
3234
3235    session
3236        .peer
3237        .forward_send(session.connection_id, leader_id, request)?;
3238
3239    if let Some(project_id) = project_id {
3240        let room = session
3241            .db()
3242            .await
3243            .unfollow(room_id, project_id, leader_id, follower_id)
3244            .await?;
3245        room_updated(&room, &session.peer);
3246    }
3247
3248    Ok(())
3249}
3250
3251/// Notify everyone following you of your current location.
3252async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3253    let room_id = RoomId::from_proto(request.room_id);
3254    let database = session.db.lock().await;
3255
3256    let connection_ids = if let Some(project_id) = request.project_id {
3257        let project_id = ProjectId::from_proto(project_id);
3258        database
3259            .project_connection_ids(project_id, session.connection_id, true)
3260            .await?
3261    } else {
3262        database
3263            .room_connection_ids(room_id, session.connection_id)
3264            .await?
3265    };
3266
3267    // For now, don't send view update messages back to that view's current leader.
3268    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3269        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3270        _ => None,
3271    });
3272
3273    for connection_id in connection_ids.iter().cloned() {
3274        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3275            session
3276                .peer
3277                .forward_send(session.connection_id, connection_id, request.clone())?;
3278        }
3279    }
3280    Ok(())
3281}
3282
3283/// Get public data about users.
3284async fn get_users(
3285    request: proto::GetUsers,
3286    response: Response<proto::GetUsers>,
3287    session: Session,
3288) -> Result<()> {
3289    let user_ids = request
3290        .user_ids
3291        .into_iter()
3292        .map(UserId::from_proto)
3293        .collect();
3294    let users = session
3295        .db()
3296        .await
3297        .get_users_by_ids(user_ids)
3298        .await?
3299        .into_iter()
3300        .map(|user| proto::User {
3301            id: user.id.to_proto(),
3302            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3303            github_login: user.github_login,
3304        })
3305        .collect();
3306    response.send(proto::UsersResponse { users })?;
3307    Ok(())
3308}
3309
3310/// Search for users (to invite) buy Github login
3311async fn fuzzy_search_users(
3312    request: proto::FuzzySearchUsers,
3313    response: Response<proto::FuzzySearchUsers>,
3314    session: UserSession,
3315) -> Result<()> {
3316    let query = request.query;
3317    let users = match query.len() {
3318        0 => vec![],
3319        1 | 2 => session
3320            .db()
3321            .await
3322            .get_user_by_github_login(&query)
3323            .await?
3324            .into_iter()
3325            .collect(),
3326        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3327    };
3328    let users = users
3329        .into_iter()
3330        .filter(|user| user.id != session.user_id())
3331        .map(|user| proto::User {
3332            id: user.id.to_proto(),
3333            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3334            github_login: user.github_login,
3335        })
3336        .collect();
3337    response.send(proto::UsersResponse { users })?;
3338    Ok(())
3339}
3340
3341/// Send a contact request to another user.
3342async fn request_contact(
3343    request: proto::RequestContact,
3344    response: Response<proto::RequestContact>,
3345    session: UserSession,
3346) -> Result<()> {
3347    let requester_id = session.user_id();
3348    let responder_id = UserId::from_proto(request.responder_id);
3349    if requester_id == responder_id {
3350        return Err(anyhow!("cannot add yourself as a contact"))?;
3351    }
3352
3353    let notifications = session
3354        .db()
3355        .await
3356        .send_contact_request(requester_id, responder_id)
3357        .await?;
3358
3359    // Update outgoing contact requests of requester
3360    let mut update = proto::UpdateContacts::default();
3361    update.outgoing_requests.push(responder_id.to_proto());
3362    for connection_id in session
3363        .connection_pool()
3364        .await
3365        .user_connection_ids(requester_id)
3366    {
3367        session.peer.send(connection_id, update.clone())?;
3368    }
3369
3370    // Update incoming contact requests of responder
3371    let mut update = proto::UpdateContacts::default();
3372    update
3373        .incoming_requests
3374        .push(proto::IncomingContactRequest {
3375            requester_id: requester_id.to_proto(),
3376        });
3377    let connection_pool = session.connection_pool().await;
3378    for connection_id in connection_pool.user_connection_ids(responder_id) {
3379        session.peer.send(connection_id, update.clone())?;
3380    }
3381
3382    send_notifications(&connection_pool, &session.peer, notifications);
3383
3384    response.send(proto::Ack {})?;
3385    Ok(())
3386}
3387
3388/// Accept or decline a contact request
3389async fn respond_to_contact_request(
3390    request: proto::RespondToContactRequest,
3391    response: Response<proto::RespondToContactRequest>,
3392    session: UserSession,
3393) -> Result<()> {
3394    let responder_id = session.user_id();
3395    let requester_id = UserId::from_proto(request.requester_id);
3396    let db = session.db().await;
3397    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3398        db.dismiss_contact_notification(responder_id, requester_id)
3399            .await?;
3400    } else {
3401        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3402
3403        let notifications = db
3404            .respond_to_contact_request(responder_id, requester_id, accept)
3405            .await?;
3406        let requester_busy = db.is_user_busy(requester_id).await?;
3407        let responder_busy = db.is_user_busy(responder_id).await?;
3408
3409        let pool = session.connection_pool().await;
3410        // Update responder with new contact
3411        let mut update = proto::UpdateContacts::default();
3412        if accept {
3413            update
3414                .contacts
3415                .push(contact_for_user(requester_id, requester_busy, &pool));
3416        }
3417        update
3418            .remove_incoming_requests
3419            .push(requester_id.to_proto());
3420        for connection_id in pool.user_connection_ids(responder_id) {
3421            session.peer.send(connection_id, update.clone())?;
3422        }
3423
3424        // Update requester with new contact
3425        let mut update = proto::UpdateContacts::default();
3426        if accept {
3427            update
3428                .contacts
3429                .push(contact_for_user(responder_id, responder_busy, &pool));
3430        }
3431        update
3432            .remove_outgoing_requests
3433            .push(responder_id.to_proto());
3434
3435        for connection_id in pool.user_connection_ids(requester_id) {
3436            session.peer.send(connection_id, update.clone())?;
3437        }
3438
3439        send_notifications(&pool, &session.peer, notifications);
3440    }
3441
3442    response.send(proto::Ack {})?;
3443    Ok(())
3444}
3445
3446/// Remove a contact.
3447async fn remove_contact(
3448    request: proto::RemoveContact,
3449    response: Response<proto::RemoveContact>,
3450    session: UserSession,
3451) -> Result<()> {
3452    let requester_id = session.user_id();
3453    let responder_id = UserId::from_proto(request.user_id);
3454    let db = session.db().await;
3455    let (contact_accepted, deleted_notification_id) =
3456        db.remove_contact(requester_id, responder_id).await?;
3457
3458    let pool = session.connection_pool().await;
3459    // Update outgoing contact requests of requester
3460    let mut update = proto::UpdateContacts::default();
3461    if contact_accepted {
3462        update.remove_contacts.push(responder_id.to_proto());
3463    } else {
3464        update
3465            .remove_outgoing_requests
3466            .push(responder_id.to_proto());
3467    }
3468    for connection_id in pool.user_connection_ids(requester_id) {
3469        session.peer.send(connection_id, update.clone())?;
3470    }
3471
3472    // Update incoming contact requests of responder
3473    let mut update = proto::UpdateContacts::default();
3474    if contact_accepted {
3475        update.remove_contacts.push(requester_id.to_proto());
3476    } else {
3477        update
3478            .remove_incoming_requests
3479            .push(requester_id.to_proto());
3480    }
3481    for connection_id in pool.user_connection_ids(responder_id) {
3482        session.peer.send(connection_id, update.clone())?;
3483        if let Some(notification_id) = deleted_notification_id {
3484            session.peer.send(
3485                connection_id,
3486                proto::DeleteNotification {
3487                    notification_id: notification_id.to_proto(),
3488                },
3489            )?;
3490        }
3491    }
3492
3493    response.send(proto::Ack {})?;
3494    Ok(())
3495}
3496
3497fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3498    version.0.minor() < 139
3499}
3500
3501async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3502    let plan = session.current_plan(session.db().await).await?;
3503
3504    session
3505        .peer
3506        .send(
3507            session.connection_id,
3508            proto::UpdateUserPlan { plan: plan.into() },
3509        )
3510        .trace_err();
3511
3512    Ok(())
3513}
3514
3515async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3516    subscribe_user_to_channels(
3517        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3518        &session,
3519    )
3520    .await?;
3521    Ok(())
3522}
3523
3524async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3525    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3526    let mut pool = session.connection_pool().await;
3527    for membership in &channels_for_user.channel_memberships {
3528        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3529    }
3530    session.peer.send(
3531        session.connection_id,
3532        build_update_user_channels(&channels_for_user),
3533    )?;
3534    session.peer.send(
3535        session.connection_id,
3536        build_channels_update(channels_for_user),
3537    )?;
3538    Ok(())
3539}
3540
3541/// Creates a new channel.
3542async fn create_channel(
3543    request: proto::CreateChannel,
3544    response: Response<proto::CreateChannel>,
3545    session: UserSession,
3546) -> Result<()> {
3547    let db = session.db().await;
3548
3549    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3550    let (channel, membership) = db
3551        .create_channel(&request.name, parent_id, session.user_id())
3552        .await?;
3553
3554    let root_id = channel.root_id();
3555    let channel = Channel::from_model(channel);
3556
3557    response.send(proto::CreateChannelResponse {
3558        channel: Some(channel.to_proto()),
3559        parent_id: request.parent_id,
3560    })?;
3561
3562    let mut connection_pool = session.connection_pool().await;
3563    if let Some(membership) = membership {
3564        connection_pool.subscribe_to_channel(
3565            membership.user_id,
3566            membership.channel_id,
3567            membership.role,
3568        );
3569        let update = proto::UpdateUserChannels {
3570            channel_memberships: vec![proto::ChannelMembership {
3571                channel_id: membership.channel_id.to_proto(),
3572                role: membership.role.into(),
3573            }],
3574            ..Default::default()
3575        };
3576        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3577            session.peer.send(connection_id, update.clone())?;
3578        }
3579    }
3580
3581    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3582        if !role.can_see_channel(channel.visibility) {
3583            continue;
3584        }
3585
3586        let update = proto::UpdateChannels {
3587            channels: vec![channel.to_proto()],
3588            ..Default::default()
3589        };
3590        session.peer.send(connection_id, update.clone())?;
3591    }
3592
3593    Ok(())
3594}
3595
3596/// Delete a channel
3597async fn delete_channel(
3598    request: proto::DeleteChannel,
3599    response: Response<proto::DeleteChannel>,
3600    session: UserSession,
3601) -> Result<()> {
3602    let db = session.db().await;
3603
3604    let channel_id = request.channel_id;
3605    let (root_channel, removed_channels) = db
3606        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3607        .await?;
3608    response.send(proto::Ack {})?;
3609
3610    // Notify members of removed channels
3611    let mut update = proto::UpdateChannels::default();
3612    update
3613        .delete_channels
3614        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3615
3616    let connection_pool = session.connection_pool().await;
3617    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3618        session.peer.send(connection_id, update.clone())?;
3619    }
3620
3621    Ok(())
3622}
3623
3624/// Invite someone to join a channel.
3625async fn invite_channel_member(
3626    request: proto::InviteChannelMember,
3627    response: Response<proto::InviteChannelMember>,
3628    session: UserSession,
3629) -> Result<()> {
3630    let db = session.db().await;
3631    let channel_id = ChannelId::from_proto(request.channel_id);
3632    let invitee_id = UserId::from_proto(request.user_id);
3633    let InviteMemberResult {
3634        channel,
3635        notifications,
3636    } = db
3637        .invite_channel_member(
3638            channel_id,
3639            invitee_id,
3640            session.user_id(),
3641            request.role().into(),
3642        )
3643        .await?;
3644
3645    let update = proto::UpdateChannels {
3646        channel_invitations: vec![channel.to_proto()],
3647        ..Default::default()
3648    };
3649
3650    let connection_pool = session.connection_pool().await;
3651    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3652        session.peer.send(connection_id, update.clone())?;
3653    }
3654
3655    send_notifications(&connection_pool, &session.peer, notifications);
3656
3657    response.send(proto::Ack {})?;
3658    Ok(())
3659}
3660
3661/// remove someone from a channel
3662async fn remove_channel_member(
3663    request: proto::RemoveChannelMember,
3664    response: Response<proto::RemoveChannelMember>,
3665    session: UserSession,
3666) -> Result<()> {
3667    let db = session.db().await;
3668    let channel_id = ChannelId::from_proto(request.channel_id);
3669    let member_id = UserId::from_proto(request.user_id);
3670
3671    let RemoveChannelMemberResult {
3672        membership_update,
3673        notification_id,
3674    } = db
3675        .remove_channel_member(channel_id, member_id, session.user_id())
3676        .await?;
3677
3678    let mut connection_pool = session.connection_pool().await;
3679    notify_membership_updated(
3680        &mut connection_pool,
3681        membership_update,
3682        member_id,
3683        &session.peer,
3684    );
3685    for connection_id in connection_pool.user_connection_ids(member_id) {
3686        if let Some(notification_id) = notification_id {
3687            session
3688                .peer
3689                .send(
3690                    connection_id,
3691                    proto::DeleteNotification {
3692                        notification_id: notification_id.to_proto(),
3693                    },
3694                )
3695                .trace_err();
3696        }
3697    }
3698
3699    response.send(proto::Ack {})?;
3700    Ok(())
3701}
3702
3703/// Toggle the channel between public and private.
3704/// Care is taken to maintain the invariant that public channels only descend from public channels,
3705/// (though members-only channels can appear at any point in the hierarchy).
3706async fn set_channel_visibility(
3707    request: proto::SetChannelVisibility,
3708    response: Response<proto::SetChannelVisibility>,
3709    session: UserSession,
3710) -> Result<()> {
3711    let db = session.db().await;
3712    let channel_id = ChannelId::from_proto(request.channel_id);
3713    let visibility = request.visibility().into();
3714
3715    let channel_model = db
3716        .set_channel_visibility(channel_id, visibility, session.user_id())
3717        .await?;
3718    let root_id = channel_model.root_id();
3719    let channel = Channel::from_model(channel_model);
3720
3721    let mut connection_pool = session.connection_pool().await;
3722    for (user_id, role) in connection_pool
3723        .channel_user_ids(root_id)
3724        .collect::<Vec<_>>()
3725        .into_iter()
3726    {
3727        let update = if role.can_see_channel(channel.visibility) {
3728            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3729            proto::UpdateChannels {
3730                channels: vec![channel.to_proto()],
3731                ..Default::default()
3732            }
3733        } else {
3734            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3735            proto::UpdateChannels {
3736                delete_channels: vec![channel.id.to_proto()],
3737                ..Default::default()
3738            }
3739        };
3740
3741        for connection_id in connection_pool.user_connection_ids(user_id) {
3742            session.peer.send(connection_id, update.clone())?;
3743        }
3744    }
3745
3746    response.send(proto::Ack {})?;
3747    Ok(())
3748}
3749
3750/// Alter the role for a user in the channel.
3751async fn set_channel_member_role(
3752    request: proto::SetChannelMemberRole,
3753    response: Response<proto::SetChannelMemberRole>,
3754    session: UserSession,
3755) -> Result<()> {
3756    let db = session.db().await;
3757    let channel_id = ChannelId::from_proto(request.channel_id);
3758    let member_id = UserId::from_proto(request.user_id);
3759    let result = db
3760        .set_channel_member_role(
3761            channel_id,
3762            session.user_id(),
3763            member_id,
3764            request.role().into(),
3765        )
3766        .await?;
3767
3768    match result {
3769        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3770            let mut connection_pool = session.connection_pool().await;
3771            notify_membership_updated(
3772                &mut connection_pool,
3773                membership_update,
3774                member_id,
3775                &session.peer,
3776            )
3777        }
3778        db::SetMemberRoleResult::InviteUpdated(channel) => {
3779            let update = proto::UpdateChannels {
3780                channel_invitations: vec![channel.to_proto()],
3781                ..Default::default()
3782            };
3783
3784            for connection_id in session
3785                .connection_pool()
3786                .await
3787                .user_connection_ids(member_id)
3788            {
3789                session.peer.send(connection_id, update.clone())?;
3790            }
3791        }
3792    }
3793
3794    response.send(proto::Ack {})?;
3795    Ok(())
3796}
3797
3798/// Change the name of a channel
3799async fn rename_channel(
3800    request: proto::RenameChannel,
3801    response: Response<proto::RenameChannel>,
3802    session: UserSession,
3803) -> Result<()> {
3804    let db = session.db().await;
3805    let channel_id = ChannelId::from_proto(request.channel_id);
3806    let channel_model = db
3807        .rename_channel(channel_id, session.user_id(), &request.name)
3808        .await?;
3809    let root_id = channel_model.root_id();
3810    let channel = Channel::from_model(channel_model);
3811
3812    response.send(proto::RenameChannelResponse {
3813        channel: Some(channel.to_proto()),
3814    })?;
3815
3816    let connection_pool = session.connection_pool().await;
3817    let update = proto::UpdateChannels {
3818        channels: vec![channel.to_proto()],
3819        ..Default::default()
3820    };
3821    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3822        if role.can_see_channel(channel.visibility) {
3823            session.peer.send(connection_id, update.clone())?;
3824        }
3825    }
3826
3827    Ok(())
3828}
3829
3830/// Move a channel to a new parent.
3831async fn move_channel(
3832    request: proto::MoveChannel,
3833    response: Response<proto::MoveChannel>,
3834    session: UserSession,
3835) -> Result<()> {
3836    let channel_id = ChannelId::from_proto(request.channel_id);
3837    let to = ChannelId::from_proto(request.to);
3838
3839    let (root_id, channels) = session
3840        .db()
3841        .await
3842        .move_channel(channel_id, to, session.user_id())
3843        .await?;
3844
3845    let connection_pool = session.connection_pool().await;
3846    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3847        let channels = channels
3848            .iter()
3849            .filter_map(|channel| {
3850                if role.can_see_channel(channel.visibility) {
3851                    Some(channel.to_proto())
3852                } else {
3853                    None
3854                }
3855            })
3856            .collect::<Vec<_>>();
3857        if channels.is_empty() {
3858            continue;
3859        }
3860
3861        let update = proto::UpdateChannels {
3862            channels,
3863            ..Default::default()
3864        };
3865
3866        session.peer.send(connection_id, update.clone())?;
3867    }
3868
3869    response.send(Ack {})?;
3870    Ok(())
3871}
3872
3873/// Get the list of channel members
3874async fn get_channel_members(
3875    request: proto::GetChannelMembers,
3876    response: Response<proto::GetChannelMembers>,
3877    session: UserSession,
3878) -> Result<()> {
3879    let db = session.db().await;
3880    let channel_id = ChannelId::from_proto(request.channel_id);
3881    let limit = if request.limit == 0 {
3882        u16::MAX as u64
3883    } else {
3884        request.limit
3885    };
3886    let (members, users) = db
3887        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3888        .await?;
3889    response.send(proto::GetChannelMembersResponse { members, users })?;
3890    Ok(())
3891}
3892
3893/// Accept or decline a channel invitation.
3894async fn respond_to_channel_invite(
3895    request: proto::RespondToChannelInvite,
3896    response: Response<proto::RespondToChannelInvite>,
3897    session: UserSession,
3898) -> Result<()> {
3899    let db = session.db().await;
3900    let channel_id = ChannelId::from_proto(request.channel_id);
3901    let RespondToChannelInvite {
3902        membership_update,
3903        notifications,
3904    } = db
3905        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3906        .await?;
3907
3908    let mut connection_pool = session.connection_pool().await;
3909    if let Some(membership_update) = membership_update {
3910        notify_membership_updated(
3911            &mut connection_pool,
3912            membership_update,
3913            session.user_id(),
3914            &session.peer,
3915        );
3916    } else {
3917        let update = proto::UpdateChannels {
3918            remove_channel_invitations: vec![channel_id.to_proto()],
3919            ..Default::default()
3920        };
3921
3922        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3923            session.peer.send(connection_id, update.clone())?;
3924        }
3925    };
3926
3927    send_notifications(&connection_pool, &session.peer, notifications);
3928
3929    response.send(proto::Ack {})?;
3930
3931    Ok(())
3932}
3933
3934/// Join the channels' room
3935async fn join_channel(
3936    request: proto::JoinChannel,
3937    response: Response<proto::JoinChannel>,
3938    session: UserSession,
3939) -> Result<()> {
3940    let channel_id = ChannelId::from_proto(request.channel_id);
3941    join_channel_internal(channel_id, Box::new(response), session).await
3942}
3943
3944trait JoinChannelInternalResponse {
3945    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3946}
3947impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3948    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3949        Response::<proto::JoinChannel>::send(self, result)
3950    }
3951}
3952impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3953    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3954        Response::<proto::JoinRoom>::send(self, result)
3955    }
3956}
3957
3958async fn join_channel_internal(
3959    channel_id: ChannelId,
3960    response: Box<impl JoinChannelInternalResponse>,
3961    session: UserSession,
3962) -> Result<()> {
3963    let joined_room = {
3964        let mut db = session.db().await;
3965        // If zed quits without leaving the room, and the user re-opens zed before the
3966        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3967        // room they were in.
3968        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3969            tracing::info!(
3970                stale_connection_id = %connection,
3971                "cleaning up stale connection",
3972            );
3973            drop(db);
3974            leave_room_for_session(&session, connection).await?;
3975            db = session.db().await;
3976        }
3977
3978        let (joined_room, membership_updated, role) = db
3979            .join_channel(channel_id, session.user_id(), session.connection_id)
3980            .await?;
3981
3982        let live_kit_connection_info =
3983            session
3984                .app_state
3985                .live_kit_client
3986                .as_ref()
3987                .and_then(|live_kit| {
3988                    let (can_publish, token) = if role == ChannelRole::Guest {
3989                        (
3990                            false,
3991                            live_kit
3992                                .guest_token(
3993                                    &joined_room.room.live_kit_room,
3994                                    &session.user_id().to_string(),
3995                                )
3996                                .trace_err()?,
3997                        )
3998                    } else {
3999                        (
4000                            true,
4001                            live_kit
4002                                .room_token(
4003                                    &joined_room.room.live_kit_room,
4004                                    &session.user_id().to_string(),
4005                                )
4006                                .trace_err()?,
4007                        )
4008                    };
4009
4010                    Some(LiveKitConnectionInfo {
4011                        server_url: live_kit.url().into(),
4012                        token,
4013                        can_publish,
4014                    })
4015                });
4016
4017        response.send(proto::JoinRoomResponse {
4018            room: Some(joined_room.room.clone()),
4019            channel_id: joined_room
4020                .channel
4021                .as_ref()
4022                .map(|channel| channel.id.to_proto()),
4023            live_kit_connection_info,
4024        })?;
4025
4026        let mut connection_pool = session.connection_pool().await;
4027        if let Some(membership_updated) = membership_updated {
4028            notify_membership_updated(
4029                &mut connection_pool,
4030                membership_updated,
4031                session.user_id(),
4032                &session.peer,
4033            );
4034        }
4035
4036        room_updated(&joined_room.room, &session.peer);
4037
4038        joined_room
4039    };
4040
4041    channel_updated(
4042        &joined_room
4043            .channel
4044            .ok_or_else(|| anyhow!("channel not returned"))?,
4045        &joined_room.room,
4046        &session.peer,
4047        &*session.connection_pool().await,
4048    );
4049
4050    update_user_contacts(session.user_id(), &session).await?;
4051    Ok(())
4052}
4053
4054/// Start editing the channel notes
4055async fn join_channel_buffer(
4056    request: proto::JoinChannelBuffer,
4057    response: Response<proto::JoinChannelBuffer>,
4058    session: UserSession,
4059) -> Result<()> {
4060    let db = session.db().await;
4061    let channel_id = ChannelId::from_proto(request.channel_id);
4062
4063    let open_response = db
4064        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4065        .await?;
4066
4067    let collaborators = open_response.collaborators.clone();
4068    response.send(open_response)?;
4069
4070    let update = UpdateChannelBufferCollaborators {
4071        channel_id: channel_id.to_proto(),
4072        collaborators: collaborators.clone(),
4073    };
4074    channel_buffer_updated(
4075        session.connection_id,
4076        collaborators
4077            .iter()
4078            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4079        &update,
4080        &session.peer,
4081    );
4082
4083    Ok(())
4084}
4085
4086/// Edit the channel notes
4087async fn update_channel_buffer(
4088    request: proto::UpdateChannelBuffer,
4089    session: UserSession,
4090) -> Result<()> {
4091    let db = session.db().await;
4092    let channel_id = ChannelId::from_proto(request.channel_id);
4093
4094    let (collaborators, epoch, version) = db
4095        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4096        .await?;
4097
4098    channel_buffer_updated(
4099        session.connection_id,
4100        collaborators.clone(),
4101        &proto::UpdateChannelBuffer {
4102            channel_id: channel_id.to_proto(),
4103            operations: request.operations,
4104        },
4105        &session.peer,
4106    );
4107
4108    let pool = &*session.connection_pool().await;
4109
4110    let non_collaborators =
4111        pool.channel_connection_ids(channel_id)
4112            .filter_map(|(connection_id, _)| {
4113                if collaborators.contains(&connection_id) {
4114                    None
4115                } else {
4116                    Some(connection_id)
4117                }
4118            });
4119
4120    broadcast(None, non_collaborators, |peer_id| {
4121        session.peer.send(
4122            peer_id,
4123            proto::UpdateChannels {
4124                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4125                    channel_id: channel_id.to_proto(),
4126                    epoch: epoch as u64,
4127                    version: version.clone(),
4128                }],
4129                ..Default::default()
4130            },
4131        )
4132    });
4133
4134    Ok(())
4135}
4136
4137/// Rejoin the channel notes after a connection blip
4138async fn rejoin_channel_buffers(
4139    request: proto::RejoinChannelBuffers,
4140    response: Response<proto::RejoinChannelBuffers>,
4141    session: UserSession,
4142) -> Result<()> {
4143    let db = session.db().await;
4144    let buffers = db
4145        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4146        .await?;
4147
4148    for rejoined_buffer in &buffers {
4149        let collaborators_to_notify = rejoined_buffer
4150            .buffer
4151            .collaborators
4152            .iter()
4153            .filter_map(|c| Some(c.peer_id?.into()));
4154        channel_buffer_updated(
4155            session.connection_id,
4156            collaborators_to_notify,
4157            &proto::UpdateChannelBufferCollaborators {
4158                channel_id: rejoined_buffer.buffer.channel_id,
4159                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4160            },
4161            &session.peer,
4162        );
4163    }
4164
4165    response.send(proto::RejoinChannelBuffersResponse {
4166        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4167    })?;
4168
4169    Ok(())
4170}
4171
4172/// Stop editing the channel notes
4173async fn leave_channel_buffer(
4174    request: proto::LeaveChannelBuffer,
4175    response: Response<proto::LeaveChannelBuffer>,
4176    session: UserSession,
4177) -> Result<()> {
4178    let db = session.db().await;
4179    let channel_id = ChannelId::from_proto(request.channel_id);
4180
4181    let left_buffer = db
4182        .leave_channel_buffer(channel_id, session.connection_id)
4183        .await?;
4184
4185    response.send(Ack {})?;
4186
4187    channel_buffer_updated(
4188        session.connection_id,
4189        left_buffer.connections,
4190        &proto::UpdateChannelBufferCollaborators {
4191            channel_id: channel_id.to_proto(),
4192            collaborators: left_buffer.collaborators,
4193        },
4194        &session.peer,
4195    );
4196
4197    Ok(())
4198}
4199
4200fn channel_buffer_updated<T: EnvelopedMessage>(
4201    sender_id: ConnectionId,
4202    collaborators: impl IntoIterator<Item = ConnectionId>,
4203    message: &T,
4204    peer: &Peer,
4205) {
4206    broadcast(Some(sender_id), collaborators, |peer_id| {
4207        peer.send(peer_id, message.clone())
4208    });
4209}
4210
4211fn send_notifications(
4212    connection_pool: &ConnectionPool,
4213    peer: &Peer,
4214    notifications: db::NotificationBatch,
4215) {
4216    for (user_id, notification) in notifications {
4217        for connection_id in connection_pool.user_connection_ids(user_id) {
4218            if let Err(error) = peer.send(
4219                connection_id,
4220                proto::AddNotification {
4221                    notification: Some(notification.clone()),
4222                },
4223            ) {
4224                tracing::error!(
4225                    "failed to send notification to {:?} {}",
4226                    connection_id,
4227                    error
4228                );
4229            }
4230        }
4231    }
4232}
4233
4234/// Send a message to the channel
4235async fn send_channel_message(
4236    request: proto::SendChannelMessage,
4237    response: Response<proto::SendChannelMessage>,
4238    session: UserSession,
4239) -> Result<()> {
4240    // Validate the message body.
4241    let body = request.body.trim().to_string();
4242    if body.len() > MAX_MESSAGE_LEN {
4243        return Err(anyhow!("message is too long"))?;
4244    }
4245    if body.is_empty() {
4246        return Err(anyhow!("message can't be blank"))?;
4247    }
4248
4249    // TODO: adjust mentions if body is trimmed
4250
4251    let timestamp = OffsetDateTime::now_utc();
4252    let nonce = request
4253        .nonce
4254        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4255
4256    let channel_id = ChannelId::from_proto(request.channel_id);
4257    let CreatedChannelMessage {
4258        message_id,
4259        participant_connection_ids,
4260        notifications,
4261    } = session
4262        .db()
4263        .await
4264        .create_channel_message(
4265            channel_id,
4266            session.user_id(),
4267            &body,
4268            &request.mentions,
4269            timestamp,
4270            nonce.clone().into(),
4271            match request.reply_to_message_id {
4272                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4273                None => None,
4274            },
4275        )
4276        .await?;
4277
4278    let message = proto::ChannelMessage {
4279        sender_id: session.user_id().to_proto(),
4280        id: message_id.to_proto(),
4281        body,
4282        mentions: request.mentions,
4283        timestamp: timestamp.unix_timestamp() as u64,
4284        nonce: Some(nonce),
4285        reply_to_message_id: request.reply_to_message_id,
4286        edited_at: None,
4287    };
4288    broadcast(
4289        Some(session.connection_id),
4290        participant_connection_ids.clone(),
4291        |connection| {
4292            session.peer.send(
4293                connection,
4294                proto::ChannelMessageSent {
4295                    channel_id: channel_id.to_proto(),
4296                    message: Some(message.clone()),
4297                },
4298            )
4299        },
4300    );
4301    response.send(proto::SendChannelMessageResponse {
4302        message: Some(message),
4303    })?;
4304
4305    let pool = &*session.connection_pool().await;
4306    let non_participants =
4307        pool.channel_connection_ids(channel_id)
4308            .filter_map(|(connection_id, _)| {
4309                if participant_connection_ids.contains(&connection_id) {
4310                    None
4311                } else {
4312                    Some(connection_id)
4313                }
4314            });
4315    broadcast(None, non_participants, |peer_id| {
4316        session.peer.send(
4317            peer_id,
4318            proto::UpdateChannels {
4319                latest_channel_message_ids: vec![proto::ChannelMessageId {
4320                    channel_id: channel_id.to_proto(),
4321                    message_id: message_id.to_proto(),
4322                }],
4323                ..Default::default()
4324            },
4325        )
4326    });
4327    send_notifications(pool, &session.peer, notifications);
4328
4329    Ok(())
4330}
4331
4332/// Delete a channel message
4333async fn remove_channel_message(
4334    request: proto::RemoveChannelMessage,
4335    response: Response<proto::RemoveChannelMessage>,
4336    session: UserSession,
4337) -> Result<()> {
4338    let channel_id = ChannelId::from_proto(request.channel_id);
4339    let message_id = MessageId::from_proto(request.message_id);
4340    let (connection_ids, existing_notification_ids) = session
4341        .db()
4342        .await
4343        .remove_channel_message(channel_id, message_id, session.user_id())
4344        .await?;
4345
4346    broadcast(
4347        Some(session.connection_id),
4348        connection_ids,
4349        move |connection| {
4350            session.peer.send(connection, request.clone())?;
4351
4352            for notification_id in &existing_notification_ids {
4353                session.peer.send(
4354                    connection,
4355                    proto::DeleteNotification {
4356                        notification_id: (*notification_id).to_proto(),
4357                    },
4358                )?;
4359            }
4360
4361            Ok(())
4362        },
4363    );
4364    response.send(proto::Ack {})?;
4365    Ok(())
4366}
4367
4368async fn update_channel_message(
4369    request: proto::UpdateChannelMessage,
4370    response: Response<proto::UpdateChannelMessage>,
4371    session: UserSession,
4372) -> Result<()> {
4373    let channel_id = ChannelId::from_proto(request.channel_id);
4374    let message_id = MessageId::from_proto(request.message_id);
4375    let updated_at = OffsetDateTime::now_utc();
4376    let UpdatedChannelMessage {
4377        message_id,
4378        participant_connection_ids,
4379        notifications,
4380        reply_to_message_id,
4381        timestamp,
4382        deleted_mention_notification_ids,
4383        updated_mention_notifications,
4384    } = session
4385        .db()
4386        .await
4387        .update_channel_message(
4388            channel_id,
4389            message_id,
4390            session.user_id(),
4391            request.body.as_str(),
4392            &request.mentions,
4393            updated_at,
4394        )
4395        .await?;
4396
4397    let nonce = request
4398        .nonce
4399        .clone()
4400        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4401
4402    let message = proto::ChannelMessage {
4403        sender_id: session.user_id().to_proto(),
4404        id: message_id.to_proto(),
4405        body: request.body.clone(),
4406        mentions: request.mentions.clone(),
4407        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4408        nonce: Some(nonce),
4409        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4410        edited_at: Some(updated_at.unix_timestamp() as u64),
4411    };
4412
4413    response.send(proto::Ack {})?;
4414
4415    let pool = &*session.connection_pool().await;
4416    broadcast(
4417        Some(session.connection_id),
4418        participant_connection_ids,
4419        |connection| {
4420            session.peer.send(
4421                connection,
4422                proto::ChannelMessageUpdate {
4423                    channel_id: channel_id.to_proto(),
4424                    message: Some(message.clone()),
4425                },
4426            )?;
4427
4428            for notification_id in &deleted_mention_notification_ids {
4429                session.peer.send(
4430                    connection,
4431                    proto::DeleteNotification {
4432                        notification_id: (*notification_id).to_proto(),
4433                    },
4434                )?;
4435            }
4436
4437            for notification in &updated_mention_notifications {
4438                session.peer.send(
4439                    connection,
4440                    proto::UpdateNotification {
4441                        notification: Some(notification.clone()),
4442                    },
4443                )?;
4444            }
4445
4446            Ok(())
4447        },
4448    );
4449
4450    send_notifications(pool, &session.peer, notifications);
4451
4452    Ok(())
4453}
4454
4455/// Mark a channel message as read
4456async fn acknowledge_channel_message(
4457    request: proto::AckChannelMessage,
4458    session: UserSession,
4459) -> Result<()> {
4460    let channel_id = ChannelId::from_proto(request.channel_id);
4461    let message_id = MessageId::from_proto(request.message_id);
4462    let notifications = session
4463        .db()
4464        .await
4465        .observe_channel_message(channel_id, session.user_id(), message_id)
4466        .await?;
4467    send_notifications(
4468        &*session.connection_pool().await,
4469        &session.peer,
4470        notifications,
4471    );
4472    Ok(())
4473}
4474
4475/// Mark a buffer version as synced
4476async fn acknowledge_buffer_version(
4477    request: proto::AckBufferOperation,
4478    session: UserSession,
4479) -> Result<()> {
4480    let buffer_id = BufferId::from_proto(request.buffer_id);
4481    session
4482        .db()
4483        .await
4484        .observe_buffer_version(
4485            buffer_id,
4486            session.user_id(),
4487            request.epoch as i32,
4488            &request.version,
4489        )
4490        .await?;
4491    Ok(())
4492}
4493
4494async fn count_language_model_tokens(
4495    request: proto::CountLanguageModelTokens,
4496    response: Response<proto::CountLanguageModelTokens>,
4497    session: Session,
4498    config: &Config,
4499) -> Result<()> {
4500    let Some(session) = session.for_user() else {
4501        return Err(anyhow!("user not found"))?;
4502    };
4503    authorize_access_to_legacy_llm_endpoints(&session).await?;
4504
4505    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4506        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4507        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4508    };
4509
4510    session
4511        .app_state
4512        .rate_limiter
4513        .check(&*rate_limit, session.user_id())
4514        .await?;
4515
4516    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4517        Some(proto::LanguageModelProvider::Google) => {
4518            let api_key = config
4519                .google_ai_api_key
4520                .as_ref()
4521                .context("no Google AI API key configured on the server")?;
4522            google_ai::count_tokens(
4523                session.http_client.as_ref(),
4524                google_ai::API_URL,
4525                api_key,
4526                serde_json::from_str(&request.request)?,
4527            )
4528            .await?
4529        }
4530        _ => return Err(anyhow!("unsupported provider"))?,
4531    };
4532
4533    response.send(proto::CountLanguageModelTokensResponse {
4534        token_count: result.total_tokens as u32,
4535    })?;
4536
4537    Ok(())
4538}
4539
4540struct ZedProCountLanguageModelTokensRateLimit;
4541
4542impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4543    fn capacity(&self) -> usize {
4544        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4545            .ok()
4546            .and_then(|v| v.parse().ok())
4547            .unwrap_or(600) // Picked arbitrarily
4548    }
4549
4550    fn refill_duration(&self) -> chrono::Duration {
4551        chrono::Duration::hours(1)
4552    }
4553
4554    fn db_name(&self) -> &'static str {
4555        "zed-pro:count-language-model-tokens"
4556    }
4557}
4558
4559struct FreeCountLanguageModelTokensRateLimit;
4560
4561impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4562    fn capacity(&self) -> usize {
4563        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4564            .ok()
4565            .and_then(|v| v.parse().ok())
4566            .unwrap_or(600 / 10) // Picked arbitrarily
4567    }
4568
4569    fn refill_duration(&self) -> chrono::Duration {
4570        chrono::Duration::hours(1)
4571    }
4572
4573    fn db_name(&self) -> &'static str {
4574        "free:count-language-model-tokens"
4575    }
4576}
4577
4578struct ZedProComputeEmbeddingsRateLimit;
4579
4580impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4581    fn capacity(&self) -> usize {
4582        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4583            .ok()
4584            .and_then(|v| v.parse().ok())
4585            .unwrap_or(5000) // Picked arbitrarily
4586    }
4587
4588    fn refill_duration(&self) -> chrono::Duration {
4589        chrono::Duration::hours(1)
4590    }
4591
4592    fn db_name(&self) -> &'static str {
4593        "zed-pro:compute-embeddings"
4594    }
4595}
4596
4597struct FreeComputeEmbeddingsRateLimit;
4598
4599impl RateLimit for FreeComputeEmbeddingsRateLimit {
4600    fn capacity(&self) -> usize {
4601        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4602            .ok()
4603            .and_then(|v| v.parse().ok())
4604            .unwrap_or(5000 / 10) // Picked arbitrarily
4605    }
4606
4607    fn refill_duration(&self) -> chrono::Duration {
4608        chrono::Duration::hours(1)
4609    }
4610
4611    fn db_name(&self) -> &'static str {
4612        "free:compute-embeddings"
4613    }
4614}
4615
4616async fn compute_embeddings(
4617    request: proto::ComputeEmbeddings,
4618    response: Response<proto::ComputeEmbeddings>,
4619    session: UserSession,
4620    api_key: Option<Arc<str>>,
4621) -> Result<()> {
4622    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4623    authorize_access_to_legacy_llm_endpoints(&session).await?;
4624
4625    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4626        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4627        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4628    };
4629
4630    session
4631        .app_state
4632        .rate_limiter
4633        .check(&*rate_limit, session.user_id())
4634        .await?;
4635
4636    let embeddings = match request.model.as_str() {
4637        "openai/text-embedding-3-small" => {
4638            open_ai::embed(
4639                session.http_client.as_ref(),
4640                OPEN_AI_API_URL,
4641                &api_key,
4642                OpenAiEmbeddingModel::TextEmbedding3Small,
4643                request.texts.iter().map(|text| text.as_str()),
4644            )
4645            .await?
4646        }
4647        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4648    };
4649
4650    let embeddings = request
4651        .texts
4652        .iter()
4653        .map(|text| {
4654            let mut hasher = sha2::Sha256::new();
4655            hasher.update(text.as_bytes());
4656            let result = hasher.finalize();
4657            result.to_vec()
4658        })
4659        .zip(
4660            embeddings
4661                .data
4662                .into_iter()
4663                .map(|embedding| embedding.embedding),
4664        )
4665        .collect::<HashMap<_, _>>();
4666
4667    let db = session.db().await;
4668    db.save_embeddings(&request.model, &embeddings)
4669        .await
4670        .context("failed to save embeddings")
4671        .trace_err();
4672
4673    response.send(proto::ComputeEmbeddingsResponse {
4674        embeddings: embeddings
4675            .into_iter()
4676            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4677            .collect(),
4678    })?;
4679    Ok(())
4680}
4681
4682async fn get_cached_embeddings(
4683    request: proto::GetCachedEmbeddings,
4684    response: Response<proto::GetCachedEmbeddings>,
4685    session: UserSession,
4686) -> Result<()> {
4687    authorize_access_to_legacy_llm_endpoints(&session).await?;
4688
4689    let db = session.db().await;
4690    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4691
4692    response.send(proto::GetCachedEmbeddingsResponse {
4693        embeddings: embeddings
4694            .into_iter()
4695            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4696            .collect(),
4697    })?;
4698    Ok(())
4699}
4700
4701/// This is leftover from before the LLM service.
4702///
4703/// The endpoints protected by this check will be moved there eventually.
4704async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4705    if session.is_staff() {
4706        Ok(())
4707    } else {
4708        Err(anyhow!("permission denied"))?
4709    }
4710}
4711
4712/// Get a Supermaven API key for the user
4713async fn get_supermaven_api_key(
4714    _request: proto::GetSupermavenApiKey,
4715    response: Response<proto::GetSupermavenApiKey>,
4716    session: UserSession,
4717) -> Result<()> {
4718    let user_id: String = session.user_id().to_string();
4719    if !session.is_staff() {
4720        return Err(anyhow!("supermaven not enabled for this account"))?;
4721    }
4722
4723    let email = session
4724        .email()
4725        .ok_or_else(|| anyhow!("user must have an email"))?;
4726
4727    let supermaven_admin_api = session
4728        .supermaven_client
4729        .as_ref()
4730        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4731
4732    let result = supermaven_admin_api
4733        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4734        .await?;
4735
4736    response.send(proto::GetSupermavenApiKeyResponse {
4737        api_key: result.api_key,
4738    })?;
4739
4740    Ok(())
4741}
4742
4743/// Start receiving chat updates for a channel
4744async fn join_channel_chat(
4745    request: proto::JoinChannelChat,
4746    response: Response<proto::JoinChannelChat>,
4747    session: UserSession,
4748) -> Result<()> {
4749    let channel_id = ChannelId::from_proto(request.channel_id);
4750
4751    let db = session.db().await;
4752    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4753        .await?;
4754    let messages = db
4755        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4756        .await?;
4757    response.send(proto::JoinChannelChatResponse {
4758        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4759        messages,
4760    })?;
4761    Ok(())
4762}
4763
4764/// Stop receiving chat updates for a channel
4765async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4766    let channel_id = ChannelId::from_proto(request.channel_id);
4767    session
4768        .db()
4769        .await
4770        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4771        .await?;
4772    Ok(())
4773}
4774
4775/// Retrieve the chat history for a channel
4776async fn get_channel_messages(
4777    request: proto::GetChannelMessages,
4778    response: Response<proto::GetChannelMessages>,
4779    session: UserSession,
4780) -> Result<()> {
4781    let channel_id = ChannelId::from_proto(request.channel_id);
4782    let messages = session
4783        .db()
4784        .await
4785        .get_channel_messages(
4786            channel_id,
4787            session.user_id(),
4788            MESSAGE_COUNT_PER_PAGE,
4789            Some(MessageId::from_proto(request.before_message_id)),
4790        )
4791        .await?;
4792    response.send(proto::GetChannelMessagesResponse {
4793        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4794        messages,
4795    })?;
4796    Ok(())
4797}
4798
4799/// Retrieve specific chat messages
4800async fn get_channel_messages_by_id(
4801    request: proto::GetChannelMessagesById,
4802    response: Response<proto::GetChannelMessagesById>,
4803    session: UserSession,
4804) -> Result<()> {
4805    let message_ids = request
4806        .message_ids
4807        .iter()
4808        .map(|id| MessageId::from_proto(*id))
4809        .collect::<Vec<_>>();
4810    let messages = session
4811        .db()
4812        .await
4813        .get_channel_messages_by_id(session.user_id(), &message_ids)
4814        .await?;
4815    response.send(proto::GetChannelMessagesResponse {
4816        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4817        messages,
4818    })?;
4819    Ok(())
4820}
4821
4822/// Retrieve the current users notifications
4823async fn get_notifications(
4824    request: proto::GetNotifications,
4825    response: Response<proto::GetNotifications>,
4826    session: UserSession,
4827) -> Result<()> {
4828    let notifications = session
4829        .db()
4830        .await
4831        .get_notifications(
4832            session.user_id(),
4833            NOTIFICATION_COUNT_PER_PAGE,
4834            request
4835                .before_id
4836                .map(|id| db::NotificationId::from_proto(id)),
4837        )
4838        .await?;
4839    response.send(proto::GetNotificationsResponse {
4840        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4841        notifications,
4842    })?;
4843    Ok(())
4844}
4845
4846/// Mark notifications as read
4847async fn mark_notification_as_read(
4848    request: proto::MarkNotificationRead,
4849    response: Response<proto::MarkNotificationRead>,
4850    session: UserSession,
4851) -> Result<()> {
4852    let database = &session.db().await;
4853    let notifications = database
4854        .mark_notification_as_read_by_id(
4855            session.user_id(),
4856            NotificationId::from_proto(request.notification_id),
4857        )
4858        .await?;
4859    send_notifications(
4860        &*session.connection_pool().await,
4861        &session.peer,
4862        notifications,
4863    );
4864    response.send(proto::Ack {})?;
4865    Ok(())
4866}
4867
4868/// Get the current users information
4869async fn get_private_user_info(
4870    _request: proto::GetPrivateUserInfo,
4871    response: Response<proto::GetPrivateUserInfo>,
4872    session: UserSession,
4873) -> Result<()> {
4874    let db = session.db().await;
4875
4876    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4877    let user = db
4878        .get_user_by_id(session.user_id())
4879        .await?
4880        .ok_or_else(|| anyhow!("user not found"))?;
4881    let flags = db.get_user_flags(session.user_id()).await?;
4882
4883    response.send(proto::GetPrivateUserInfoResponse {
4884        metrics_id,
4885        staff: user.admin,
4886        flags,
4887        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4888    })?;
4889    Ok(())
4890}
4891
4892/// Accept the terms of service (tos) on behalf of the current user
4893async fn accept_terms_of_service(
4894    _request: proto::AcceptTermsOfService,
4895    response: Response<proto::AcceptTermsOfService>,
4896    session: UserSession,
4897) -> Result<()> {
4898    let db = session.db().await;
4899
4900    let accepted_tos_at = Utc::now();
4901    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4902        .await?;
4903
4904    response.send(proto::AcceptTermsOfServiceResponse {
4905        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4906    })?;
4907    Ok(())
4908}
4909
4910/// The minimum account age an account must have in order to use the LLM service.
4911const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4912
4913async fn get_llm_api_token(
4914    _request: proto::GetLlmToken,
4915    response: Response<proto::GetLlmToken>,
4916    session: UserSession,
4917) -> Result<()> {
4918    let db = session.db().await;
4919
4920    let flags = db.get_user_flags(session.user_id()).await?;
4921    if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
4922        Err(anyhow!("permission denied"))?
4923    }
4924
4925    let user_id = session.user_id();
4926    let user = db
4927        .get_user_by_id(user_id)
4928        .await?
4929        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4930
4931    if user.accepted_tos_at.is_none() {
4932        Err(anyhow!("terms of service not accepted"))?
4933    }
4934
4935    let mut account_created_at = user.created_at;
4936    if let Some(github_created_at) = user.github_user_created_at {
4937        account_created_at = account_created_at.min(github_created_at);
4938    }
4939    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4940        Err(anyhow!("account too young"))?
4941    }
4942    let token = LlmTokenClaims::create(
4943        user.id,
4944        user.github_login.clone(),
4945        session.is_staff(),
4946        session.current_plan(db).await?,
4947        &session.app_state.config,
4948    )?;
4949    response.send(proto::GetLlmTokenResponse { token })?;
4950    Ok(())
4951}
4952
4953fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4954    let message = match message {
4955        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4956        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4957        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4958        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4959        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4960            code: frame.code.into(),
4961            reason: frame.reason,
4962        })),
4963        // We should never receive a frame while reading the message, according
4964        // to the `tungstenite` maintainers:
4965        //
4966        // > It cannot occur when you read messages from the WebSocket, but it
4967        // > can be used when you want to send the raw frames (e.g. you want to
4968        // > send the frames to the WebSocket without composing the full message first).
4969        // >
4970        // > — https://github.com/snapview/tungstenite-rs/issues/268
4971        TungsteniteMessage::Frame(_) => {
4972            bail!("received an unexpected frame while reading the message")
4973        }
4974    };
4975
4976    Ok(message)
4977}
4978
4979fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4980    match message {
4981        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4982        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4983        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4984        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4985        AxumMessage::Close(frame) => {
4986            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4987                code: frame.code.into(),
4988                reason: frame.reason,
4989            }))
4990        }
4991    }
4992}
4993
4994fn notify_membership_updated(
4995    connection_pool: &mut ConnectionPool,
4996    result: MembershipUpdated,
4997    user_id: UserId,
4998    peer: &Peer,
4999) {
5000    for membership in &result.new_channels.channel_memberships {
5001        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5002    }
5003    for channel_id in &result.removed_channels {
5004        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5005    }
5006
5007    let user_channels_update = proto::UpdateUserChannels {
5008        channel_memberships: result
5009            .new_channels
5010            .channel_memberships
5011            .iter()
5012            .map(|cm| proto::ChannelMembership {
5013                channel_id: cm.channel_id.to_proto(),
5014                role: cm.role.into(),
5015            })
5016            .collect(),
5017        ..Default::default()
5018    };
5019
5020    let mut update = build_channels_update(result.new_channels);
5021    update.delete_channels = result
5022        .removed_channels
5023        .into_iter()
5024        .map(|id| id.to_proto())
5025        .collect();
5026    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5027
5028    for connection_id in connection_pool.user_connection_ids(user_id) {
5029        peer.send(connection_id, user_channels_update.clone())
5030            .trace_err();
5031        peer.send(connection_id, update.clone()).trace_err();
5032    }
5033}
5034
5035fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5036    proto::UpdateUserChannels {
5037        channel_memberships: channels
5038            .channel_memberships
5039            .iter()
5040            .map(|m| proto::ChannelMembership {
5041                channel_id: m.channel_id.to_proto(),
5042                role: m.role.into(),
5043            })
5044            .collect(),
5045        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5046        observed_channel_message_id: channels.observed_channel_messages.clone(),
5047    }
5048}
5049
5050fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5051    let mut update = proto::UpdateChannels::default();
5052
5053    for channel in channels.channels {
5054        update.channels.push(channel.to_proto());
5055    }
5056
5057    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5058    update.latest_channel_message_ids = channels.latest_channel_messages;
5059
5060    for (channel_id, participants) in channels.channel_participants {
5061        update
5062            .channel_participants
5063            .push(proto::ChannelParticipants {
5064                channel_id: channel_id.to_proto(),
5065                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5066            });
5067    }
5068
5069    for channel in channels.invited_channels {
5070        update.channel_invitations.push(channel.to_proto());
5071    }
5072
5073    update.hosted_projects = channels.hosted_projects;
5074    update
5075}
5076
5077fn build_initial_contacts_update(
5078    contacts: Vec<db::Contact>,
5079    pool: &ConnectionPool,
5080) -> proto::UpdateContacts {
5081    let mut update = proto::UpdateContacts::default();
5082
5083    for contact in contacts {
5084        match contact {
5085            db::Contact::Accepted { user_id, busy } => {
5086                update.contacts.push(contact_for_user(user_id, busy, &pool));
5087            }
5088            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5089            db::Contact::Incoming { user_id } => {
5090                update
5091                    .incoming_requests
5092                    .push(proto::IncomingContactRequest {
5093                        requester_id: user_id.to_proto(),
5094                    })
5095            }
5096        }
5097    }
5098
5099    update
5100}
5101
5102fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5103    proto::Contact {
5104        user_id: user_id.to_proto(),
5105        online: pool.is_user_online(user_id),
5106        busy,
5107    }
5108}
5109
5110fn room_updated(room: &proto::Room, peer: &Peer) {
5111    broadcast(
5112        None,
5113        room.participants
5114            .iter()
5115            .filter_map(|participant| Some(participant.peer_id?.into())),
5116        |peer_id| {
5117            peer.send(
5118                peer_id,
5119                proto::RoomUpdated {
5120                    room: Some(room.clone()),
5121                },
5122            )
5123        },
5124    );
5125}
5126
5127fn channel_updated(
5128    channel: &db::channel::Model,
5129    room: &proto::Room,
5130    peer: &Peer,
5131    pool: &ConnectionPool,
5132) {
5133    let participants = room
5134        .participants
5135        .iter()
5136        .map(|p| p.user_id)
5137        .collect::<Vec<_>>();
5138
5139    broadcast(
5140        None,
5141        pool.channel_connection_ids(channel.root_id())
5142            .filter_map(|(channel_id, role)| {
5143                role.can_see_channel(channel.visibility).then(|| channel_id)
5144            }),
5145        |peer_id| {
5146            peer.send(
5147                peer_id,
5148                proto::UpdateChannels {
5149                    channel_participants: vec![proto::ChannelParticipants {
5150                        channel_id: channel.id.to_proto(),
5151                        participant_user_ids: participants.clone(),
5152                    }],
5153                    ..Default::default()
5154                },
5155            )
5156        },
5157    );
5158}
5159
5160async fn send_dev_server_projects_update(
5161    user_id: UserId,
5162    mut status: proto::DevServerProjectsUpdate,
5163    session: &Session,
5164) {
5165    let pool = session.connection_pool().await;
5166    for dev_server in &mut status.dev_servers {
5167        dev_server.status =
5168            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5169    }
5170    let connections = pool.user_connection_ids(user_id);
5171    for connection_id in connections {
5172        session.peer.send(connection_id, status.clone()).trace_err();
5173    }
5174}
5175
5176async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5177    let db = session.db().await;
5178
5179    let contacts = db.get_contacts(user_id).await?;
5180    let busy = db.is_user_busy(user_id).await?;
5181
5182    let pool = session.connection_pool().await;
5183    let updated_contact = contact_for_user(user_id, busy, &pool);
5184    for contact in contacts {
5185        if let db::Contact::Accepted {
5186            user_id: contact_user_id,
5187            ..
5188        } = contact
5189        {
5190            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5191                session
5192                    .peer
5193                    .send(
5194                        contact_conn_id,
5195                        proto::UpdateContacts {
5196                            contacts: vec![updated_contact.clone()],
5197                            remove_contacts: Default::default(),
5198                            incoming_requests: Default::default(),
5199                            remove_incoming_requests: Default::default(),
5200                            outgoing_requests: Default::default(),
5201                            remove_outgoing_requests: Default::default(),
5202                        },
5203                    )
5204                    .trace_err();
5205            }
5206        }
5207    }
5208    Ok(())
5209}
5210
5211async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5212    log::info!("lost dev server connection, unsharing projects");
5213    let project_ids = session
5214        .db()
5215        .await
5216        .get_stale_dev_server_projects(session.connection_id)
5217        .await?;
5218
5219    for project_id in project_ids {
5220        // not unshare re-checks the connection ids match, so we get away with no transaction
5221        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5222    }
5223
5224    let user_id = session.dev_server().user_id;
5225    let update = session
5226        .db()
5227        .await
5228        .dev_server_projects_update(user_id)
5229        .await?;
5230
5231    send_dev_server_projects_update(user_id, update, session).await;
5232
5233    Ok(())
5234}
5235
5236async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5237    let mut contacts_to_update = HashSet::default();
5238
5239    let room_id;
5240    let canceled_calls_to_user_ids;
5241    let live_kit_room;
5242    let delete_live_kit_room;
5243    let room;
5244    let channel;
5245
5246    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5247        contacts_to_update.insert(session.user_id());
5248
5249        for project in left_room.left_projects.values() {
5250            project_left(project, session);
5251        }
5252
5253        room_id = RoomId::from_proto(left_room.room.id);
5254        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5255        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5256        delete_live_kit_room = left_room.deleted;
5257        room = mem::take(&mut left_room.room);
5258        channel = mem::take(&mut left_room.channel);
5259
5260        room_updated(&room, &session.peer);
5261    } else {
5262        return Ok(());
5263    }
5264
5265    if let Some(channel) = channel {
5266        channel_updated(
5267            &channel,
5268            &room,
5269            &session.peer,
5270            &*session.connection_pool().await,
5271        );
5272    }
5273
5274    {
5275        let pool = session.connection_pool().await;
5276        for canceled_user_id in canceled_calls_to_user_ids {
5277            for connection_id in pool.user_connection_ids(canceled_user_id) {
5278                session
5279                    .peer
5280                    .send(
5281                        connection_id,
5282                        proto::CallCanceled {
5283                            room_id: room_id.to_proto(),
5284                        },
5285                    )
5286                    .trace_err();
5287            }
5288            contacts_to_update.insert(canceled_user_id);
5289        }
5290    }
5291
5292    for contact_user_id in contacts_to_update {
5293        update_user_contacts(contact_user_id, &session).await?;
5294    }
5295
5296    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5297        live_kit
5298            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5299            .await
5300            .trace_err();
5301
5302        if delete_live_kit_room {
5303            live_kit.delete_room(live_kit_room).await.trace_err();
5304        }
5305    }
5306
5307    Ok(())
5308}
5309
5310async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5311    let left_channel_buffers = session
5312        .db()
5313        .await
5314        .leave_channel_buffers(session.connection_id)
5315        .await?;
5316
5317    for left_buffer in left_channel_buffers {
5318        channel_buffer_updated(
5319            session.connection_id,
5320            left_buffer.connections,
5321            &proto::UpdateChannelBufferCollaborators {
5322                channel_id: left_buffer.channel_id.to_proto(),
5323                collaborators: left_buffer.collaborators,
5324            },
5325            &session.peer,
5326        );
5327    }
5328
5329    Ok(())
5330}
5331
5332fn project_left(project: &db::LeftProject, session: &UserSession) {
5333    for connection_id in &project.connection_ids {
5334        if project.should_unshare {
5335            session
5336                .peer
5337                .send(
5338                    *connection_id,
5339                    proto::UnshareProject {
5340                        project_id: project.id.to_proto(),
5341                    },
5342                )
5343                .trace_err();
5344        } else {
5345            session
5346                .peer
5347                .send(
5348                    *connection_id,
5349                    proto::RemoveProjectCollaborator {
5350                        project_id: project.id.to_proto(),
5351                        peer_id: Some(session.connection_id.into()),
5352                    },
5353                )
5354                .trace_err();
5355        }
5356    }
5357}
5358
5359pub trait ResultExt {
5360    type Ok;
5361
5362    fn trace_err(self) -> Option<Self::Ok>;
5363}
5364
5365impl<T, E> ResultExt for Result<T, E>
5366where
5367    E: std::fmt::Debug,
5368{
5369    type Ok = T;
5370
5371    #[track_caller]
5372    fn trace_err(self) -> Option<T> {
5373        match self {
5374            Ok(value) => Some(value),
5375            Err(error) => {
5376                tracing::error!("{:?}", error);
5377                None
5378            }
5379        }
5380    }
5381}