rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use http_client::HttpClient;
  39use isahc_http_client::IsahcHttpClient;
  40use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  41use sha2::Digest;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    channel::oneshot,
  46    future::{self, BoxFuture},
  47    stream::FuturesUnordered,
  48    FutureExt, SinkExt, StreamExt, TryStreamExt,
  49};
  50use prometheus::{register_int_gauge, IntGauge};
  51use rpc::{
  52    proto::{
  53        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  54        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  55    },
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57};
  58use semantic_version::SemanticVersion;
  59use serde::{Serialize, Serializer};
  60use std::{
  61    any::TypeId,
  62    future::Future,
  63    marker::PhantomData,
  64    mem,
  65    net::SocketAddr,
  66    ops::{Deref, DerefMut},
  67    rc::Rc,
  68    sync::{
  69        atomic::{AtomicBool, Ordering::SeqCst},
  70        Arc, OnceLock,
  71    },
  72    time::{Duration, Instant},
  73};
  74use time::OffsetDateTime;
  75use tokio::sync::{watch, MutexGuard, Semaphore};
  76use tower::ServiceBuilder;
  77use tracing::{
  78    field::{self},
  79    info_span, instrument, Instrument,
  80};
  81
  82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  83
  84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  86
  87const MESSAGE_COUNT_PER_PAGE: usize = 100;
  88const MAX_MESSAGE_LEN: usize = 1024;
  89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  93
  94struct Response<R> {
  95    peer: Arc<Peer>,
  96    receipt: Receipt<R>,
  97    responded: Arc<AtomicBool>,
  98}
  99
 100impl<R: RequestMessage> Response<R> {
 101    fn send(self, payload: R::Response) -> Result<()> {
 102        self.responded.store(true, SeqCst);
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108#[derive(Clone, Debug)]
 109pub enum Principal {
 110    User(User),
 111    Impersonated { user: User, admin: User },
 112    DevServer(dev_server::Model),
 113}
 114
 115impl Principal {
 116    fn update_span(&self, span: &tracing::Span) {
 117        match &self {
 118            Principal::User(user) => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121            }
 122            Principal::Impersonated { user, admin } => {
 123                span.record("user_id", user.id.0);
 124                span.record("login", &user.github_login);
 125                span.record("impersonator", &admin.github_login);
 126            }
 127            Principal::DevServer(dev_server) => {
 128                span.record("dev_server_id", dev_server.id.0);
 129            }
 130        }
 131    }
 132}
 133
 134#[derive(Clone)]
 135struct Session {
 136    principal: Principal,
 137    connection_id: ConnectionId,
 138    db: Arc<tokio::sync::Mutex<DbHandle>>,
 139    peer: Arc<Peer>,
 140    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 141    app_state: Arc<AppState>,
 142    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 143    http_client: Arc<dyn HttpClient>,
 144    /// The GeoIP country code for the user.
 145    #[allow(unused)]
 146    geoip_country_code: Option<String>,
 147    _executor: Executor,
 148}
 149
 150impl Session {
 151    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.db.lock().await;
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        guard
 158    }
 159
 160    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        let guard = self.connection_pool.lock();
 164        ConnectionPoolGuard {
 165            guard,
 166            _not_send: PhantomData,
 167        }
 168    }
 169
 170    fn for_user(self) -> Option<UserSession> {
 171        UserSession::new(self)
 172    }
 173
 174    fn for_dev_server(self) -> Option<DevServerSession> {
 175        DevServerSession::new(self)
 176    }
 177
 178    fn user_id(&self) -> Option<UserId> {
 179        match &self.principal {
 180            Principal::User(user) => Some(user.id),
 181            Principal::Impersonated { user, .. } => Some(user.id),
 182            Principal::DevServer(_) => None,
 183        }
 184    }
 185
 186    fn is_staff(&self) -> bool {
 187        match &self.principal {
 188            Principal::User(user) => user.admin,
 189            Principal::Impersonated { .. } => true,
 190            Principal::DevServer(_) => false,
 191        }
 192    }
 193
 194    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 195        if self.is_staff() {
 196            return Ok(proto::Plan::ZedPro);
 197        }
 198
 199        let Some(user_id) = self.user_id() else {
 200            return Ok(proto::Plan::Free);
 201        };
 202
 203        if db.has_active_billing_subscription(user_id).await? {
 204            Ok(proto::Plan::ZedPro)
 205        } else {
 206            Ok(proto::Plan::Free)
 207        }
 208    }
 209
 210    fn dev_server_id(&self) -> Option<DevServerId> {
 211        match &self.principal {
 212            Principal::User(_) | Principal::Impersonated { .. } => None,
 213            Principal::DevServer(dev_server) => Some(dev_server.id),
 214        }
 215    }
 216
 217    fn principal_id(&self) -> PrincipalId {
 218        match &self.principal {
 219            Principal::User(user) => PrincipalId::UserId(user.id),
 220            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 221            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 222        }
 223    }
 224}
 225
 226impl Debug for Session {
 227    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 228        let mut result = f.debug_struct("Session");
 229        match &self.principal {
 230            Principal::User(user) => {
 231                result.field("user", &user.github_login);
 232            }
 233            Principal::Impersonated { user, admin } => {
 234                result.field("user", &user.github_login);
 235                result.field("impersonator", &admin.github_login);
 236            }
 237            Principal::DevServer(dev_server) => {
 238                result.field("dev_server", &dev_server.id);
 239            }
 240        }
 241        result.field("connection_id", &self.connection_id).finish()
 242    }
 243}
 244
 245struct UserSession(Session);
 246
 247impl UserSession {
 248    pub fn new(s: Session) -> Option<Self> {
 249        s.user_id().map(|_| UserSession(s))
 250    }
 251    pub fn user_id(&self) -> UserId {
 252        self.0.user_id().unwrap()
 253    }
 254
 255    pub fn email(&self) -> Option<String> {
 256        match &self.0.principal {
 257            Principal::User(user) => user.email_address.clone(),
 258            Principal::Impersonated { user, .. } => user.email_address.clone(),
 259            Principal::DevServer(..) => None,
 260        }
 261    }
 262}
 263
 264impl Deref for UserSession {
 265    type Target = Session;
 266
 267    fn deref(&self) -> &Self::Target {
 268        &self.0
 269    }
 270}
 271impl DerefMut for UserSession {
 272    fn deref_mut(&mut self) -> &mut Self::Target {
 273        &mut self.0
 274    }
 275}
 276
 277struct DevServerSession(Session);
 278
 279impl DevServerSession {
 280    pub fn new(s: Session) -> Option<Self> {
 281        s.dev_server_id().map(|_| DevServerSession(s))
 282    }
 283    pub fn dev_server_id(&self) -> DevServerId {
 284        self.0.dev_server_id().unwrap()
 285    }
 286
 287    fn dev_server(&self) -> &dev_server::Model {
 288        match &self.0.principal {
 289            Principal::DevServer(dev_server) => dev_server,
 290            _ => unreachable!(),
 291        }
 292    }
 293}
 294
 295impl Deref for DevServerSession {
 296    type Target = Session;
 297
 298    fn deref(&self) -> &Self::Target {
 299        &self.0
 300    }
 301}
 302impl DerefMut for DevServerSession {
 303    fn deref_mut(&mut self) -> &mut Self::Target {
 304        &mut self.0
 305    }
 306}
 307
 308fn user_handler<M: RequestMessage, Fut>(
 309    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 310) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 311where
 312    Fut: Send + Future<Output = Result<()>>,
 313{
 314    let handler = Arc::new(handler);
 315    move |message, response, session| {
 316        let handler = handler.clone();
 317        Box::pin(async move {
 318            if let Some(user_session) = session.for_user() {
 319                Ok(handler(message, response, user_session).await?)
 320            } else {
 321                Err(Error::Internal(anyhow!(
 322                    "must be a user to call {}",
 323                    M::NAME
 324                )))
 325            }
 326        })
 327    }
 328}
 329
 330fn dev_server_handler<M: RequestMessage, Fut>(
 331    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 332) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 333where
 334    Fut: Send + Future<Output = Result<()>>,
 335{
 336    let handler = Arc::new(handler);
 337    move |message, response, session| {
 338        let handler = handler.clone();
 339        Box::pin(async move {
 340            if let Some(dev_server_session) = session.for_dev_server() {
 341                Ok(handler(message, response, dev_server_session).await?)
 342            } else {
 343                Err(Error::Internal(anyhow!(
 344                    "must be a dev server to call {}",
 345                    M::NAME
 346                )))
 347            }
 348        })
 349    }
 350}
 351
 352fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 353    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 354) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 355where
 356    InnertRetFut: Send + Future<Output = Result<()>>,
 357{
 358    let handler = Arc::new(handler);
 359    move |message, session| {
 360        let handler = handler.clone();
 361        Box::pin(async move {
 362            if let Some(user_session) = session.for_user() {
 363                Ok(handler(message, user_session).await?)
 364            } else {
 365                Err(Error::Internal(anyhow!(
 366                    "must be a user to call {}",
 367                    M::NAME
 368                )))
 369            }
 370        })
 371    }
 372}
 373
 374struct DbHandle(Arc<Database>);
 375
 376impl Deref for DbHandle {
 377    type Target = Database;
 378
 379    fn deref(&self) -> &Self::Target {
 380        self.0.as_ref()
 381    }
 382}
 383
 384pub struct Server {
 385    id: parking_lot::Mutex<ServerId>,
 386    peer: Arc<Peer>,
 387    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 388    app_state: Arc<AppState>,
 389    handlers: HashMap<TypeId, MessageHandler>,
 390    teardown: watch::Sender<bool>,
 391}
 392
 393pub(crate) struct ConnectionPoolGuard<'a> {
 394    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 395    _not_send: PhantomData<Rc<()>>,
 396}
 397
 398#[derive(Serialize)]
 399pub struct ServerSnapshot<'a> {
 400    peer: &'a Peer,
 401    #[serde(serialize_with = "serialize_deref")]
 402    connection_pool: ConnectionPoolGuard<'a>,
 403}
 404
 405pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 406where
 407    S: Serializer,
 408    T: Deref<Target = U>,
 409    U: Serialize,
 410{
 411    Serialize::serialize(value.deref(), serializer)
 412}
 413
 414impl Server {
 415    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 416        let mut server = Self {
 417            id: parking_lot::Mutex::new(id),
 418            peer: Peer::new(id.0 as u32),
 419            app_state: app_state.clone(),
 420            connection_pool: Default::default(),
 421            handlers: Default::default(),
 422            teardown: watch::channel(false).0,
 423        };
 424
 425        server
 426            .add_request_handler(ping)
 427            .add_request_handler(user_handler(create_room))
 428            .add_request_handler(user_handler(join_room))
 429            .add_request_handler(user_handler(rejoin_room))
 430            .add_request_handler(user_handler(leave_room))
 431            .add_request_handler(user_handler(set_room_participant_role))
 432            .add_request_handler(user_handler(call))
 433            .add_request_handler(user_handler(cancel_call))
 434            .add_message_handler(user_message_handler(decline_call))
 435            .add_request_handler(user_handler(update_participant_location))
 436            .add_request_handler(user_handler(share_project))
 437            .add_message_handler(unshare_project)
 438            .add_request_handler(user_handler(join_project))
 439            .add_request_handler(user_handler(join_hosted_project))
 440            .add_request_handler(user_handler(rejoin_dev_server_projects))
 441            .add_request_handler(user_handler(create_dev_server_project))
 442            .add_request_handler(user_handler(update_dev_server_project))
 443            .add_request_handler(user_handler(delete_dev_server_project))
 444            .add_request_handler(user_handler(create_dev_server))
 445            .add_request_handler(user_handler(regenerate_dev_server_token))
 446            .add_request_handler(user_handler(rename_dev_server))
 447            .add_request_handler(user_handler(delete_dev_server))
 448            .add_request_handler(user_handler(list_remote_directory))
 449            .add_request_handler(dev_server_handler(share_dev_server_project))
 450            .add_request_handler(dev_server_handler(shutdown_dev_server))
 451            .add_request_handler(dev_server_handler(reconnect_dev_server))
 452            .add_message_handler(user_message_handler(leave_project))
 453            .add_request_handler(update_project)
 454            .add_request_handler(update_worktree)
 455            .add_message_handler(start_language_server)
 456            .add_message_handler(update_language_server)
 457            .add_message_handler(update_diagnostic_summary)
 458            .add_message_handler(update_worktree_settings)
 459            .add_request_handler(user_handler(
 460                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 461            ))
 462            .add_request_handler(user_handler(
 463                forward_project_request_for_owner::<proto::TaskTemplates>,
 464            ))
 465            .add_request_handler(user_handler(
 466                forward_read_only_project_request::<proto::GetHover>,
 467            ))
 468            .add_request_handler(user_handler(
 469                forward_read_only_project_request::<proto::GetDefinition>,
 470            ))
 471            .add_request_handler(user_handler(
 472                forward_read_only_project_request::<proto::GetTypeDefinition>,
 473            ))
 474            .add_request_handler(user_handler(
 475                forward_read_only_project_request::<proto::GetReferences>,
 476            ))
 477            .add_request_handler(user_handler(
 478                forward_read_only_project_request::<proto::SearchProject>,
 479            ))
 480            .add_request_handler(user_handler(forward_find_search_candidates_request))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::GetProjectSymbols>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::OpenBufferById>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_read_only_project_request::<proto::InlayHints>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_read_only_project_request::<proto::ResolveInlayHint>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_read_only_project_request::<proto::OpenBufferByPath>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::GetCompletions>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::OpenNewBuffer>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::GetCodeActions>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::ApplyCodeAction>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::PrepareRename>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::PerformRename>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::ReloadBuffers>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::FormatBuffers>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::CreateProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::RenameProjectEntry>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_mutating_project_request::<proto::CopyProjectEntry>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::OnTypeFormatting>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::SaveBuffer>,
 555            ))
 556            .add_request_handler(user_handler(
 557                forward_mutating_project_request::<proto::BlameBuffer>,
 558            ))
 559            .add_request_handler(user_handler(
 560                forward_mutating_project_request::<proto::MultiLspQuery>,
 561            ))
 562            .add_request_handler(user_handler(
 563                forward_mutating_project_request::<proto::RestartLanguageServers>,
 564            ))
 565            .add_request_handler(user_handler(
 566                forward_mutating_project_request::<proto::LinkedEditingRange>,
 567            ))
 568            .add_message_handler(create_buffer_for_peer)
 569            .add_request_handler(update_buffer)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 571            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 572            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 573            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 574            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 575            .add_request_handler(get_users)
 576            .add_request_handler(user_handler(fuzzy_search_users))
 577            .add_request_handler(user_handler(request_contact))
 578            .add_request_handler(user_handler(remove_contact))
 579            .add_request_handler(user_handler(respond_to_contact_request))
 580            .add_message_handler(subscribe_to_channels)
 581            .add_request_handler(user_handler(create_channel))
 582            .add_request_handler(user_handler(delete_channel))
 583            .add_request_handler(user_handler(invite_channel_member))
 584            .add_request_handler(user_handler(remove_channel_member))
 585            .add_request_handler(user_handler(set_channel_member_role))
 586            .add_request_handler(user_handler(set_channel_visibility))
 587            .add_request_handler(user_handler(rename_channel))
 588            .add_request_handler(user_handler(join_channel_buffer))
 589            .add_request_handler(user_handler(leave_channel_buffer))
 590            .add_message_handler(user_message_handler(update_channel_buffer))
 591            .add_request_handler(user_handler(rejoin_channel_buffers))
 592            .add_request_handler(user_handler(get_channel_members))
 593            .add_request_handler(user_handler(respond_to_channel_invite))
 594            .add_request_handler(user_handler(join_channel))
 595            .add_request_handler(user_handler(join_channel_chat))
 596            .add_message_handler(user_message_handler(leave_channel_chat))
 597            .add_request_handler(user_handler(send_channel_message))
 598            .add_request_handler(user_handler(remove_channel_message))
 599            .add_request_handler(user_handler(update_channel_message))
 600            .add_request_handler(user_handler(get_channel_messages))
 601            .add_request_handler(user_handler(get_channel_messages_by_id))
 602            .add_request_handler(user_handler(get_notifications))
 603            .add_request_handler(user_handler(mark_notification_as_read))
 604            .add_request_handler(user_handler(move_channel))
 605            .add_request_handler(user_handler(follow))
 606            .add_message_handler(user_message_handler(unfollow))
 607            .add_message_handler(user_message_handler(update_followers))
 608            .add_request_handler(user_handler(get_private_user_info))
 609            .add_request_handler(user_handler(get_llm_api_token))
 610            .add_request_handler(user_handler(accept_terms_of_service))
 611            .add_message_handler(user_message_handler(acknowledge_channel_message))
 612            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 613            .add_request_handler(user_handler(get_supermaven_api_key))
 614            .add_request_handler(user_handler(
 615                forward_mutating_project_request::<proto::OpenContext>,
 616            ))
 617            .add_request_handler(user_handler(
 618                forward_mutating_project_request::<proto::CreateContext>,
 619            ))
 620            .add_request_handler(user_handler(
 621                forward_mutating_project_request::<proto::SynchronizeContexts>,
 622            ))
 623            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 624            .add_message_handler(update_context)
 625            .add_request_handler({
 626                let app_state = app_state.clone();
 627                move |request, response, session| {
 628                    let app_state = app_state.clone();
 629                    async move {
 630                        count_language_model_tokens(request, response, session, &app_state.config)
 631                            .await
 632                    }
 633                }
 634            })
 635            .add_request_handler({
 636                user_handler(move |request, response, session| {
 637                    get_cached_embeddings(request, response, session)
 638                })
 639            })
 640            .add_request_handler({
 641                let app_state = app_state.clone();
 642                user_handler(move |request, response, session| {
 643                    compute_embeddings(
 644                        request,
 645                        response,
 646                        session,
 647                        app_state.config.openai_api_key.clone(),
 648                    )
 649                })
 650            });
 651
 652        Arc::new(server)
 653    }
 654
 655    pub async fn start(&self) -> Result<()> {
 656        let server_id = *self.id.lock();
 657        let app_state = self.app_state.clone();
 658        let peer = self.peer.clone();
 659        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 660        let pool = self.connection_pool.clone();
 661        let live_kit_client = self.app_state.live_kit_client.clone();
 662
 663        let span = info_span!("start server");
 664        self.app_state.executor.spawn_detached(
 665            async move {
 666                tracing::info!("waiting for cleanup timeout");
 667                timeout.await;
 668                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 669                if let Some((room_ids, channel_ids)) = app_state
 670                    .db
 671                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 672                    .await
 673                    .trace_err()
 674                {
 675                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 676                    tracing::info!(
 677                        stale_channel_buffer_count = channel_ids.len(),
 678                        "retrieved stale channel buffers"
 679                    );
 680
 681                    for channel_id in channel_ids {
 682                        if let Some(refreshed_channel_buffer) = app_state
 683                            .db
 684                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 685                            .await
 686                            .trace_err()
 687                        {
 688                            for connection_id in refreshed_channel_buffer.connection_ids {
 689                                peer.send(
 690                                    connection_id,
 691                                    proto::UpdateChannelBufferCollaborators {
 692                                        channel_id: channel_id.to_proto(),
 693                                        collaborators: refreshed_channel_buffer
 694                                            .collaborators
 695                                            .clone(),
 696                                    },
 697                                )
 698                                .trace_err();
 699                            }
 700                        }
 701                    }
 702
 703                    for room_id in room_ids {
 704                        let mut contacts_to_update = HashSet::default();
 705                        let mut canceled_calls_to_user_ids = Vec::new();
 706                        let mut live_kit_room = String::new();
 707                        let mut delete_live_kit_room = false;
 708
 709                        if let Some(mut refreshed_room) = app_state
 710                            .db
 711                            .clear_stale_room_participants(room_id, server_id)
 712                            .await
 713                            .trace_err()
 714                        {
 715                            tracing::info!(
 716                                room_id = room_id.0,
 717                                new_participant_count = refreshed_room.room.participants.len(),
 718                                "refreshed room"
 719                            );
 720                            room_updated(&refreshed_room.room, &peer);
 721                            if let Some(channel) = refreshed_room.channel.as_ref() {
 722                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 723                            }
 724                            contacts_to_update
 725                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 726                            contacts_to_update
 727                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 728                            canceled_calls_to_user_ids =
 729                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 730                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 731                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 732                        }
 733
 734                        {
 735                            let pool = pool.lock();
 736                            for canceled_user_id in canceled_calls_to_user_ids {
 737                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 738                                    peer.send(
 739                                        connection_id,
 740                                        proto::CallCanceled {
 741                                            room_id: room_id.to_proto(),
 742                                        },
 743                                    )
 744                                    .trace_err();
 745                                }
 746                            }
 747                        }
 748
 749                        for user_id in contacts_to_update {
 750                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 751                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 752                            if let Some((busy, contacts)) = busy.zip(contacts) {
 753                                let pool = pool.lock();
 754                                let updated_contact = contact_for_user(user_id, busy, &pool);
 755                                for contact in contacts {
 756                                    if let db::Contact::Accepted {
 757                                        user_id: contact_user_id,
 758                                        ..
 759                                    } = contact
 760                                    {
 761                                        for contact_conn_id in
 762                                            pool.user_connection_ids(contact_user_id)
 763                                        {
 764                                            peer.send(
 765                                                contact_conn_id,
 766                                                proto::UpdateContacts {
 767                                                    contacts: vec![updated_contact.clone()],
 768                                                    remove_contacts: Default::default(),
 769                                                    incoming_requests: Default::default(),
 770                                                    remove_incoming_requests: Default::default(),
 771                                                    outgoing_requests: Default::default(),
 772                                                    remove_outgoing_requests: Default::default(),
 773                                                },
 774                                            )
 775                                            .trace_err();
 776                                        }
 777                                    }
 778                                }
 779                            }
 780                        }
 781
 782                        if let Some(live_kit) = live_kit_client.as_ref() {
 783                            if delete_live_kit_room {
 784                                live_kit.delete_room(live_kit_room).await.trace_err();
 785                            }
 786                        }
 787                    }
 788                }
 789
 790                app_state
 791                    .db
 792                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 793                    .await
 794                    .trace_err();
 795            }
 796            .instrument(span),
 797        );
 798        Ok(())
 799    }
 800
 801    pub fn teardown(&self) {
 802        self.peer.teardown();
 803        self.connection_pool.lock().reset();
 804        let _ = self.teardown.send(true);
 805    }
 806
 807    #[cfg(test)]
 808    pub fn reset(&self, id: ServerId) {
 809        self.teardown();
 810        *self.id.lock() = id;
 811        self.peer.reset(id.0 as u32);
 812        let _ = self.teardown.send(false);
 813    }
 814
 815    #[cfg(test)]
 816    pub fn id(&self) -> ServerId {
 817        *self.id.lock()
 818    }
 819
 820    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 821    where
 822        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 823        Fut: 'static + Send + Future<Output = Result<()>>,
 824        M: EnvelopedMessage,
 825    {
 826        let prev_handler = self.handlers.insert(
 827            TypeId::of::<M>(),
 828            Box::new(move |envelope, session| {
 829                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 830                let received_at = envelope.received_at;
 831                tracing::info!("message received");
 832                let start_time = Instant::now();
 833                let future = (handler)(*envelope, session);
 834                async move {
 835                    let result = future.await;
 836                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 837                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 838                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 839                    let payload_type = M::NAME;
 840
 841                    match result {
 842                        Err(error) => {
 843                            tracing::error!(
 844                                ?error,
 845                                total_duration_ms,
 846                                processing_duration_ms,
 847                                queue_duration_ms,
 848                                payload_type,
 849                                "error handling message"
 850                            )
 851                        }
 852                        Ok(()) => tracing::info!(
 853                            total_duration_ms,
 854                            processing_duration_ms,
 855                            queue_duration_ms,
 856                            "finished handling message"
 857                        ),
 858                    }
 859                }
 860                .boxed()
 861            }),
 862        );
 863        if prev_handler.is_some() {
 864            panic!("registered a handler for the same message twice");
 865        }
 866        self
 867    }
 868
 869    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 870    where
 871        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 872        Fut: 'static + Send + Future<Output = Result<()>>,
 873        M: EnvelopedMessage,
 874    {
 875        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 876        self
 877    }
 878
 879    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 880    where
 881        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 882        Fut: Send + Future<Output = Result<()>>,
 883        M: RequestMessage,
 884    {
 885        let handler = Arc::new(handler);
 886        self.add_handler(move |envelope, session| {
 887            let receipt = envelope.receipt();
 888            let handler = handler.clone();
 889            async move {
 890                let peer = session.peer.clone();
 891                let responded = Arc::new(AtomicBool::default());
 892                let response = Response {
 893                    peer: peer.clone(),
 894                    responded: responded.clone(),
 895                    receipt,
 896                };
 897                match (handler)(envelope.payload, response, session).await {
 898                    Ok(()) => {
 899                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 900                            Ok(())
 901                        } else {
 902                            Err(anyhow!("handler did not send a response"))?
 903                        }
 904                    }
 905                    Err(error) => {
 906                        let proto_err = match &error {
 907                            Error::Internal(err) => err.to_proto(),
 908                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 909                        };
 910                        peer.respond_with_error(receipt, proto_err)?;
 911                        Err(error)
 912                    }
 913                }
 914            }
 915        })
 916    }
 917
 918    #[allow(clippy::too_many_arguments)]
 919    pub fn handle_connection(
 920        self: &Arc<Self>,
 921        connection: Connection,
 922        address: String,
 923        principal: Principal,
 924        zed_version: ZedVersion,
 925        geoip_country_code: Option<String>,
 926        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 927        executor: Executor,
 928    ) -> impl Future<Output = ()> {
 929        let this = self.clone();
 930        let span = info_span!("handle connection", %address,
 931            connection_id=field::Empty,
 932            user_id=field::Empty,
 933            login=field::Empty,
 934            impersonator=field::Empty,
 935            dev_server_id=field::Empty,
 936            geoip_country_code=field::Empty
 937        );
 938        principal.update_span(&span);
 939        if let Some(country_code) = geoip_country_code.as_ref() {
 940            span.record("geoip_country_code", country_code);
 941        }
 942
 943        let mut teardown = self.teardown.subscribe();
 944        async move {
 945            if *teardown.borrow() {
 946                tracing::error!("server is tearing down");
 947                return
 948            }
 949            let (connection_id, handle_io, mut incoming_rx) = this
 950                .peer
 951                .add_connection(connection, {
 952                    let executor = executor.clone();
 953                    move |duration| executor.sleep(duration)
 954                });
 955            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 956
 957            tracing::info!("connection opened");
 958
 959            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 960            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 961                Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
 962                Err(error) => {
 963                    tracing::error!(?error, "failed to create HTTP client");
 964                    return;
 965                }
 966            };
 967
 968            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 969                    supermaven_admin_api_key.to_string(),
 970                    http_client.clone(),
 971                )));
 972
 973            let session = Session {
 974                principal: principal.clone(),
 975                connection_id,
 976                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 977                peer: this.peer.clone(),
 978                connection_pool: this.connection_pool.clone(),
 979                app_state: this.app_state.clone(),
 980                http_client,
 981                geoip_country_code,
 982                _executor: executor.clone(),
 983                supermaven_client,
 984            };
 985
 986            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 987                tracing::error!(?error, "failed to send initial client update");
 988                return;
 989            }
 990
 991            let handle_io = handle_io.fuse();
 992            futures::pin_mut!(handle_io);
 993
 994            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 995            // This prevents deadlocks when e.g., client A performs a request to client B and
 996            // client B performs a request to client A. If both clients stop processing further
 997            // messages until their respective request completes, they won't have a chance to
 998            // respond to the other client's request and cause a deadlock.
 999            //
1000            // This arrangement ensures we will attempt to process earlier messages first, but fall
1001            // back to processing messages arrived later in the spirit of making progress.
1002            let mut foreground_message_handlers = FuturesUnordered::new();
1003            let concurrent_handlers = Arc::new(Semaphore::new(256));
1004            loop {
1005                let next_message = async {
1006                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1007                    let message = incoming_rx.next().await;
1008                    (permit, message)
1009                }.fuse();
1010                futures::pin_mut!(next_message);
1011                futures::select_biased! {
1012                    _ = teardown.changed().fuse() => return,
1013                    result = handle_io => {
1014                        if let Err(error) = result {
1015                            tracing::error!(?error, "error handling I/O");
1016                        }
1017                        break;
1018                    }
1019                    _ = foreground_message_handlers.next() => {}
1020                    next_message = next_message => {
1021                        let (permit, message) = next_message;
1022                        if let Some(message) = message {
1023                            let type_name = message.payload_type_name();
1024                            // note: we copy all the fields from the parent span so we can query them in the logs.
1025                            // (https://github.com/tokio-rs/tracing/issues/2670).
1026                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1027                                user_id=field::Empty,
1028                                login=field::Empty,
1029                                impersonator=field::Empty,
1030                                dev_server_id=field::Empty
1031                            );
1032                            principal.update_span(&span);
1033                            let span_enter = span.enter();
1034                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1035                                let is_background = message.is_background();
1036                                let handle_message = (handler)(message, session.clone());
1037                                drop(span_enter);
1038
1039                                let handle_message = async move {
1040                                    handle_message.await;
1041                                    drop(permit);
1042                                }.instrument(span);
1043                                if is_background {
1044                                    executor.spawn_detached(handle_message);
1045                                } else {
1046                                    foreground_message_handlers.push(handle_message);
1047                                }
1048                            } else {
1049                                tracing::error!("no message handler");
1050                            }
1051                        } else {
1052                            tracing::info!("connection closed");
1053                            break;
1054                        }
1055                    }
1056                }
1057            }
1058
1059            drop(foreground_message_handlers);
1060            tracing::info!("signing out");
1061            if let Err(error) = connection_lost(session, teardown, executor).await {
1062                tracing::error!(?error, "error signing out");
1063            }
1064
1065        }.instrument(span)
1066    }
1067
1068    async fn send_initial_client_update(
1069        &self,
1070        connection_id: ConnectionId,
1071        principal: &Principal,
1072        zed_version: ZedVersion,
1073        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1074        session: &Session,
1075    ) -> Result<()> {
1076        self.peer.send(
1077            connection_id,
1078            proto::Hello {
1079                peer_id: Some(connection_id.into()),
1080            },
1081        )?;
1082        tracing::info!("sent hello message");
1083        if let Some(send_connection_id) = send_connection_id.take() {
1084            let _ = send_connection_id.send(connection_id);
1085        }
1086
1087        match principal {
1088            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1089                if !user.connected_once {
1090                    self.peer.send(connection_id, proto::ShowContacts {})?;
1091                    self.app_state
1092                        .db
1093                        .set_user_connected_once(user.id, true)
1094                        .await?;
1095                }
1096
1097                update_user_plan(user.id, session).await?;
1098
1099                let (contacts, dev_server_projects) = future::try_join(
1100                    self.app_state.db.get_contacts(user.id),
1101                    self.app_state.db.dev_server_projects_update(user.id),
1102                )
1103                .await?;
1104
1105                {
1106                    let mut pool = self.connection_pool.lock();
1107                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1108                    self.peer.send(
1109                        connection_id,
1110                        build_initial_contacts_update(contacts, &pool),
1111                    )?;
1112                }
1113
1114                if should_auto_subscribe_to_channels(zed_version) {
1115                    subscribe_user_to_channels(user.id, session).await?;
1116                }
1117
1118                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1119
1120                if let Some(incoming_call) =
1121                    self.app_state.db.incoming_call_for_user(user.id).await?
1122                {
1123                    self.peer.send(connection_id, incoming_call)?;
1124                }
1125
1126                update_user_contacts(user.id, session).await?;
1127            }
1128            Principal::DevServer(dev_server) => {
1129                {
1130                    let mut pool = self.connection_pool.lock();
1131                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1132                    {
1133                        self.peer.send(
1134                            stale_connection_id,
1135                            proto::ShutdownDevServer {
1136                                reason: Some(
1137                                    "another dev server connected with the same token".to_string(),
1138                                ),
1139                            },
1140                        )?;
1141                        pool.remove_connection(stale_connection_id)?;
1142                    };
1143                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1144                }
1145
1146                let projects = self
1147                    .app_state
1148                    .db
1149                    .get_projects_for_dev_server(dev_server.id)
1150                    .await?;
1151                self.peer
1152                    .send(connection_id, proto::DevServerInstructions { projects })?;
1153
1154                let status = self
1155                    .app_state
1156                    .db
1157                    .dev_server_projects_update(dev_server.user_id)
1158                    .await?;
1159                send_dev_server_projects_update(dev_server.user_id, status, session).await;
1160            }
1161        }
1162
1163        Ok(())
1164    }
1165
1166    pub async fn invite_code_redeemed(
1167        self: &Arc<Self>,
1168        inviter_id: UserId,
1169        invitee_id: UserId,
1170    ) -> Result<()> {
1171        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1172            if let Some(code) = &user.invite_code {
1173                let pool = self.connection_pool.lock();
1174                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1175                for connection_id in pool.user_connection_ids(inviter_id) {
1176                    self.peer.send(
1177                        connection_id,
1178                        proto::UpdateContacts {
1179                            contacts: vec![invitee_contact.clone()],
1180                            ..Default::default()
1181                        },
1182                    )?;
1183                    self.peer.send(
1184                        connection_id,
1185                        proto::UpdateInviteInfo {
1186                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1187                            count: user.invite_count as u32,
1188                        },
1189                    )?;
1190                }
1191            }
1192        }
1193        Ok(())
1194    }
1195
1196    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1197        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1198            if let Some(invite_code) = &user.invite_code {
1199                let pool = self.connection_pool.lock();
1200                for connection_id in pool.user_connection_ids(user_id) {
1201                    self.peer.send(
1202                        connection_id,
1203                        proto::UpdateInviteInfo {
1204                            url: format!(
1205                                "{}{}",
1206                                self.app_state.config.invite_link_prefix, invite_code
1207                            ),
1208                            count: user.invite_count as u32,
1209                        },
1210                    )?;
1211                }
1212            }
1213        }
1214        Ok(())
1215    }
1216
1217    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1218        ServerSnapshot {
1219            connection_pool: ConnectionPoolGuard {
1220                guard: self.connection_pool.lock(),
1221                _not_send: PhantomData,
1222            },
1223            peer: &self.peer,
1224        }
1225    }
1226}
1227
1228impl<'a> Deref for ConnectionPoolGuard<'a> {
1229    type Target = ConnectionPool;
1230
1231    fn deref(&self) -> &Self::Target {
1232        &self.guard
1233    }
1234}
1235
1236impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1237    fn deref_mut(&mut self) -> &mut Self::Target {
1238        &mut self.guard
1239    }
1240}
1241
1242impl<'a> Drop for ConnectionPoolGuard<'a> {
1243    fn drop(&mut self) {
1244        #[cfg(test)]
1245        self.check_invariants();
1246    }
1247}
1248
1249fn broadcast<F>(
1250    sender_id: Option<ConnectionId>,
1251    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1252    mut f: F,
1253) where
1254    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1255{
1256    for receiver_id in receiver_ids {
1257        if Some(receiver_id) != sender_id {
1258            if let Err(error) = f(receiver_id) {
1259                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1260            }
1261        }
1262    }
1263}
1264
1265pub struct ProtocolVersion(u32);
1266
1267impl Header for ProtocolVersion {
1268    fn name() -> &'static HeaderName {
1269        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1270        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1271    }
1272
1273    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1274    where
1275        Self: Sized,
1276        I: Iterator<Item = &'i axum::http::HeaderValue>,
1277    {
1278        let version = values
1279            .next()
1280            .ok_or_else(axum::headers::Error::invalid)?
1281            .to_str()
1282            .map_err(|_| axum::headers::Error::invalid())?
1283            .parse()
1284            .map_err(|_| axum::headers::Error::invalid())?;
1285        Ok(Self(version))
1286    }
1287
1288    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1289        values.extend([self.0.to_string().parse().unwrap()]);
1290    }
1291}
1292
1293pub struct AppVersionHeader(SemanticVersion);
1294impl Header for AppVersionHeader {
1295    fn name() -> &'static HeaderName {
1296        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1297        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1298    }
1299
1300    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1301    where
1302        Self: Sized,
1303        I: Iterator<Item = &'i axum::http::HeaderValue>,
1304    {
1305        let version = values
1306            .next()
1307            .ok_or_else(axum::headers::Error::invalid)?
1308            .to_str()
1309            .map_err(|_| axum::headers::Error::invalid())?
1310            .parse()
1311            .map_err(|_| axum::headers::Error::invalid())?;
1312        Ok(Self(version))
1313    }
1314
1315    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1316        values.extend([self.0.to_string().parse().unwrap()]);
1317    }
1318}
1319
1320pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1321    Router::new()
1322        .route("/rpc", get(handle_websocket_request))
1323        .layer(
1324            ServiceBuilder::new()
1325                .layer(Extension(server.app_state.clone()))
1326                .layer(middleware::from_fn(auth::validate_header)),
1327        )
1328        .route("/metrics", get(handle_metrics))
1329        .layer(Extension(server))
1330}
1331
1332pub async fn handle_websocket_request(
1333    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1334    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1335    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1336    Extension(server): Extension<Arc<Server>>,
1337    Extension(principal): Extension<Principal>,
1338    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1339    ws: WebSocketUpgrade,
1340) -> axum::response::Response {
1341    if protocol_version != rpc::PROTOCOL_VERSION {
1342        return (
1343            StatusCode::UPGRADE_REQUIRED,
1344            "client must be upgraded".to_string(),
1345        )
1346            .into_response();
1347    }
1348
1349    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1350        return (
1351            StatusCode::UPGRADE_REQUIRED,
1352            "no version header found".to_string(),
1353        )
1354            .into_response();
1355    };
1356
1357    if !version.can_collaborate() {
1358        return (
1359            StatusCode::UPGRADE_REQUIRED,
1360            "client must be upgraded".to_string(),
1361        )
1362            .into_response();
1363    }
1364
1365    let socket_address = socket_address.to_string();
1366    ws.on_upgrade(move |socket| {
1367        let socket = socket
1368            .map_ok(to_tungstenite_message)
1369            .err_into()
1370            .with(|message| async move { to_axum_message(message) });
1371        let connection = Connection::new(Box::pin(socket));
1372        async move {
1373            server
1374                .handle_connection(
1375                    connection,
1376                    socket_address,
1377                    principal,
1378                    version,
1379                    country_code_header.map(|header| header.to_string()),
1380                    None,
1381                    Executor::Production,
1382                )
1383                .await;
1384        }
1385    })
1386}
1387
1388pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1389    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1390    let connections_metric = CONNECTIONS_METRIC
1391        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1392
1393    let connections = server
1394        .connection_pool
1395        .lock()
1396        .connections()
1397        .filter(|connection| !connection.admin)
1398        .count();
1399    connections_metric.set(connections as _);
1400
1401    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1402    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1403        register_int_gauge!(
1404            "shared_projects",
1405            "number of open projects with one or more guests"
1406        )
1407        .unwrap()
1408    });
1409
1410    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1411    shared_projects_metric.set(shared_projects as _);
1412
1413    let encoder = prometheus::TextEncoder::new();
1414    let metric_families = prometheus::gather();
1415    let encoded_metrics = encoder
1416        .encode_to_string(&metric_families)
1417        .map_err(|err| anyhow!("{}", err))?;
1418    Ok(encoded_metrics)
1419}
1420
1421#[instrument(err, skip(executor))]
1422async fn connection_lost(
1423    session: Session,
1424    mut teardown: watch::Receiver<bool>,
1425    executor: Executor,
1426) -> Result<()> {
1427    session.peer.disconnect(session.connection_id);
1428    session
1429        .connection_pool()
1430        .await
1431        .remove_connection(session.connection_id)?;
1432
1433    session
1434        .db()
1435        .await
1436        .connection_lost(session.connection_id)
1437        .await
1438        .trace_err();
1439
1440    futures::select_biased! {
1441        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1442            match &session.principal {
1443                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1444                    let session = session.for_user().unwrap();
1445
1446                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1447                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1448                    leave_channel_buffers_for_session(&session)
1449                        .await
1450                        .trace_err();
1451
1452                    if !session
1453                        .connection_pool()
1454                        .await
1455                        .is_user_online(session.user_id())
1456                    {
1457                        let db = session.db().await;
1458                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1459                            room_updated(&room, &session.peer);
1460                        }
1461                    }
1462
1463                    update_user_contacts(session.user_id(), &session).await?;
1464                },
1465            Principal::DevServer(_) => {
1466                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1467            },
1468        }
1469        },
1470        _ = teardown.changed().fuse() => {}
1471    }
1472
1473    Ok(())
1474}
1475
1476/// Acknowledges a ping from a client, used to keep the connection alive.
1477async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1478    response.send(proto::Ack {})?;
1479    Ok(())
1480}
1481
1482/// Creates a new room for calling (outside of channels)
1483async fn create_room(
1484    _request: proto::CreateRoom,
1485    response: Response<proto::CreateRoom>,
1486    session: UserSession,
1487) -> Result<()> {
1488    let live_kit_room = nanoid::nanoid!(30);
1489
1490    let live_kit_connection_info = util::maybe!(async {
1491        let live_kit = session.app_state.live_kit_client.as_ref();
1492        let live_kit = live_kit?;
1493        let user_id = session.user_id().to_string();
1494
1495        let token = live_kit
1496            .room_token(&live_kit_room, &user_id.to_string())
1497            .trace_err()?;
1498
1499        Some(proto::LiveKitConnectionInfo {
1500            server_url: live_kit.url().into(),
1501            token,
1502            can_publish: true,
1503        })
1504    })
1505    .await;
1506
1507    let room = session
1508        .db()
1509        .await
1510        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1511        .await?;
1512
1513    response.send(proto::CreateRoomResponse {
1514        room: Some(room.clone()),
1515        live_kit_connection_info,
1516    })?;
1517
1518    update_user_contacts(session.user_id(), &session).await?;
1519    Ok(())
1520}
1521
1522/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1523async fn join_room(
1524    request: proto::JoinRoom,
1525    response: Response<proto::JoinRoom>,
1526    session: UserSession,
1527) -> Result<()> {
1528    let room_id = RoomId::from_proto(request.id);
1529
1530    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1531
1532    if let Some(channel_id) = channel_id {
1533        return join_channel_internal(channel_id, Box::new(response), session).await;
1534    }
1535
1536    let joined_room = {
1537        let room = session
1538            .db()
1539            .await
1540            .join_room(room_id, session.user_id(), session.connection_id)
1541            .await?;
1542        room_updated(&room.room, &session.peer);
1543        room.into_inner()
1544    };
1545
1546    for connection_id in session
1547        .connection_pool()
1548        .await
1549        .user_connection_ids(session.user_id())
1550    {
1551        session
1552            .peer
1553            .send(
1554                connection_id,
1555                proto::CallCanceled {
1556                    room_id: room_id.to_proto(),
1557                },
1558            )
1559            .trace_err();
1560    }
1561
1562    let live_kit_connection_info =
1563        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1564            live_kit
1565                .room_token(
1566                    &joined_room.room.live_kit_room,
1567                    &session.user_id().to_string(),
1568                )
1569                .trace_err()
1570                .map(|token| proto::LiveKitConnectionInfo {
1571                    server_url: live_kit.url().into(),
1572                    token,
1573                    can_publish: true,
1574                })
1575        } else {
1576            None
1577        };
1578
1579    response.send(proto::JoinRoomResponse {
1580        room: Some(joined_room.room),
1581        channel_id: None,
1582        live_kit_connection_info,
1583    })?;
1584
1585    update_user_contacts(session.user_id(), &session).await?;
1586    Ok(())
1587}
1588
1589/// Rejoin room is used to reconnect to a room after connection errors.
1590async fn rejoin_room(
1591    request: proto::RejoinRoom,
1592    response: Response<proto::RejoinRoom>,
1593    session: UserSession,
1594) -> Result<()> {
1595    let room;
1596    let channel;
1597    {
1598        let mut rejoined_room = session
1599            .db()
1600            .await
1601            .rejoin_room(request, session.user_id(), session.connection_id)
1602            .await?;
1603
1604        response.send(proto::RejoinRoomResponse {
1605            room: Some(rejoined_room.room.clone()),
1606            reshared_projects: rejoined_room
1607                .reshared_projects
1608                .iter()
1609                .map(|project| proto::ResharedProject {
1610                    id: project.id.to_proto(),
1611                    collaborators: project
1612                        .collaborators
1613                        .iter()
1614                        .map(|collaborator| collaborator.to_proto())
1615                        .collect(),
1616                })
1617                .collect(),
1618            rejoined_projects: rejoined_room
1619                .rejoined_projects
1620                .iter()
1621                .map(|rejoined_project| rejoined_project.to_proto())
1622                .collect(),
1623        })?;
1624        room_updated(&rejoined_room.room, &session.peer);
1625
1626        for project in &rejoined_room.reshared_projects {
1627            for collaborator in &project.collaborators {
1628                session
1629                    .peer
1630                    .send(
1631                        collaborator.connection_id,
1632                        proto::UpdateProjectCollaborator {
1633                            project_id: project.id.to_proto(),
1634                            old_peer_id: Some(project.old_connection_id.into()),
1635                            new_peer_id: Some(session.connection_id.into()),
1636                        },
1637                    )
1638                    .trace_err();
1639            }
1640
1641            broadcast(
1642                Some(session.connection_id),
1643                project
1644                    .collaborators
1645                    .iter()
1646                    .map(|collaborator| collaborator.connection_id),
1647                |connection_id| {
1648                    session.peer.forward_send(
1649                        session.connection_id,
1650                        connection_id,
1651                        proto::UpdateProject {
1652                            project_id: project.id.to_proto(),
1653                            worktrees: project.worktrees.clone(),
1654                        },
1655                    )
1656                },
1657            );
1658        }
1659
1660        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1661
1662        let rejoined_room = rejoined_room.into_inner();
1663
1664        room = rejoined_room.room;
1665        channel = rejoined_room.channel;
1666    }
1667
1668    if let Some(channel) = channel {
1669        channel_updated(
1670            &channel,
1671            &room,
1672            &session.peer,
1673            &*session.connection_pool().await,
1674        );
1675    }
1676
1677    update_user_contacts(session.user_id(), &session).await?;
1678    Ok(())
1679}
1680
1681fn notify_rejoined_projects(
1682    rejoined_projects: &mut Vec<RejoinedProject>,
1683    session: &UserSession,
1684) -> Result<()> {
1685    for project in rejoined_projects.iter() {
1686        for collaborator in &project.collaborators {
1687            session
1688                .peer
1689                .send(
1690                    collaborator.connection_id,
1691                    proto::UpdateProjectCollaborator {
1692                        project_id: project.id.to_proto(),
1693                        old_peer_id: Some(project.old_connection_id.into()),
1694                        new_peer_id: Some(session.connection_id.into()),
1695                    },
1696                )
1697                .trace_err();
1698        }
1699    }
1700
1701    for project in rejoined_projects {
1702        for worktree in mem::take(&mut project.worktrees) {
1703            #[cfg(any(test, feature = "test-support"))]
1704            const MAX_CHUNK_SIZE: usize = 2;
1705            #[cfg(not(any(test, feature = "test-support")))]
1706            const MAX_CHUNK_SIZE: usize = 256;
1707
1708            // Stream this worktree's entries.
1709            let message = proto::UpdateWorktree {
1710                project_id: project.id.to_proto(),
1711                worktree_id: worktree.id,
1712                abs_path: worktree.abs_path.clone(),
1713                root_name: worktree.root_name,
1714                updated_entries: worktree.updated_entries,
1715                removed_entries: worktree.removed_entries,
1716                scan_id: worktree.scan_id,
1717                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1718                updated_repositories: worktree.updated_repositories,
1719                removed_repositories: worktree.removed_repositories,
1720            };
1721            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1722                session.peer.send(session.connection_id, update.clone())?;
1723            }
1724
1725            // Stream this worktree's diagnostics.
1726            for summary in worktree.diagnostic_summaries {
1727                session.peer.send(
1728                    session.connection_id,
1729                    proto::UpdateDiagnosticSummary {
1730                        project_id: project.id.to_proto(),
1731                        worktree_id: worktree.id,
1732                        summary: Some(summary),
1733                    },
1734                )?;
1735            }
1736
1737            for settings_file in worktree.settings_files {
1738                session.peer.send(
1739                    session.connection_id,
1740                    proto::UpdateWorktreeSettings {
1741                        project_id: project.id.to_proto(),
1742                        worktree_id: worktree.id,
1743                        path: settings_file.path,
1744                        content: Some(settings_file.content),
1745                    },
1746                )?;
1747            }
1748        }
1749
1750        for language_server in &project.language_servers {
1751            session.peer.send(
1752                session.connection_id,
1753                proto::UpdateLanguageServer {
1754                    project_id: project.id.to_proto(),
1755                    language_server_id: language_server.id,
1756                    variant: Some(
1757                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1758                            proto::LspDiskBasedDiagnosticsUpdated {},
1759                        ),
1760                    ),
1761                },
1762            )?;
1763        }
1764    }
1765    Ok(())
1766}
1767
1768/// leave room disconnects from the room.
1769async fn leave_room(
1770    _: proto::LeaveRoom,
1771    response: Response<proto::LeaveRoom>,
1772    session: UserSession,
1773) -> Result<()> {
1774    leave_room_for_session(&session, session.connection_id).await?;
1775    response.send(proto::Ack {})?;
1776    Ok(())
1777}
1778
1779/// Updates the permissions of someone else in the room.
1780async fn set_room_participant_role(
1781    request: proto::SetRoomParticipantRole,
1782    response: Response<proto::SetRoomParticipantRole>,
1783    session: UserSession,
1784) -> Result<()> {
1785    let user_id = UserId::from_proto(request.user_id);
1786    let role = ChannelRole::from(request.role());
1787
1788    let (live_kit_room, can_publish) = {
1789        let room = session
1790            .db()
1791            .await
1792            .set_room_participant_role(
1793                session.user_id(),
1794                RoomId::from_proto(request.room_id),
1795                user_id,
1796                role,
1797            )
1798            .await?;
1799
1800        let live_kit_room = room.live_kit_room.clone();
1801        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1802        room_updated(&room, &session.peer);
1803        (live_kit_room, can_publish)
1804    };
1805
1806    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1807        live_kit
1808            .update_participant(
1809                live_kit_room.clone(),
1810                request.user_id.to_string(),
1811                live_kit_server::proto::ParticipantPermission {
1812                    can_subscribe: true,
1813                    can_publish,
1814                    can_publish_data: can_publish,
1815                    hidden: false,
1816                    recorder: false,
1817                },
1818            )
1819            .await
1820            .trace_err();
1821    }
1822
1823    response.send(proto::Ack {})?;
1824    Ok(())
1825}
1826
1827/// Call someone else into the current room
1828async fn call(
1829    request: proto::Call,
1830    response: Response<proto::Call>,
1831    session: UserSession,
1832) -> Result<()> {
1833    let room_id = RoomId::from_proto(request.room_id);
1834    let calling_user_id = session.user_id();
1835    let calling_connection_id = session.connection_id;
1836    let called_user_id = UserId::from_proto(request.called_user_id);
1837    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1838    if !session
1839        .db()
1840        .await
1841        .has_contact(calling_user_id, called_user_id)
1842        .await?
1843    {
1844        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1845    }
1846
1847    let incoming_call = {
1848        let (room, incoming_call) = &mut *session
1849            .db()
1850            .await
1851            .call(
1852                room_id,
1853                calling_user_id,
1854                calling_connection_id,
1855                called_user_id,
1856                initial_project_id,
1857            )
1858            .await?;
1859        room_updated(room, &session.peer);
1860        mem::take(incoming_call)
1861    };
1862    update_user_contacts(called_user_id, &session).await?;
1863
1864    let mut calls = session
1865        .connection_pool()
1866        .await
1867        .user_connection_ids(called_user_id)
1868        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1869        .collect::<FuturesUnordered<_>>();
1870
1871    while let Some(call_response) = calls.next().await {
1872        match call_response.as_ref() {
1873            Ok(_) => {
1874                response.send(proto::Ack {})?;
1875                return Ok(());
1876            }
1877            Err(_) => {
1878                call_response.trace_err();
1879            }
1880        }
1881    }
1882
1883    {
1884        let room = session
1885            .db()
1886            .await
1887            .call_failed(room_id, called_user_id)
1888            .await?;
1889        room_updated(&room, &session.peer);
1890    }
1891    update_user_contacts(called_user_id, &session).await?;
1892
1893    Err(anyhow!("failed to ring user"))?
1894}
1895
1896/// Cancel an outgoing call.
1897async fn cancel_call(
1898    request: proto::CancelCall,
1899    response: Response<proto::CancelCall>,
1900    session: UserSession,
1901) -> Result<()> {
1902    let called_user_id = UserId::from_proto(request.called_user_id);
1903    let room_id = RoomId::from_proto(request.room_id);
1904    {
1905        let room = session
1906            .db()
1907            .await
1908            .cancel_call(room_id, session.connection_id, called_user_id)
1909            .await?;
1910        room_updated(&room, &session.peer);
1911    }
1912
1913    for connection_id in session
1914        .connection_pool()
1915        .await
1916        .user_connection_ids(called_user_id)
1917    {
1918        session
1919            .peer
1920            .send(
1921                connection_id,
1922                proto::CallCanceled {
1923                    room_id: room_id.to_proto(),
1924                },
1925            )
1926            .trace_err();
1927    }
1928    response.send(proto::Ack {})?;
1929
1930    update_user_contacts(called_user_id, &session).await?;
1931    Ok(())
1932}
1933
1934/// Decline an incoming call.
1935async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1936    let room_id = RoomId::from_proto(message.room_id);
1937    {
1938        let room = session
1939            .db()
1940            .await
1941            .decline_call(Some(room_id), session.user_id())
1942            .await?
1943            .ok_or_else(|| anyhow!("failed to decline call"))?;
1944        room_updated(&room, &session.peer);
1945    }
1946
1947    for connection_id in session
1948        .connection_pool()
1949        .await
1950        .user_connection_ids(session.user_id())
1951    {
1952        session
1953            .peer
1954            .send(
1955                connection_id,
1956                proto::CallCanceled {
1957                    room_id: room_id.to_proto(),
1958                },
1959            )
1960            .trace_err();
1961    }
1962    update_user_contacts(session.user_id(), &session).await?;
1963    Ok(())
1964}
1965
1966/// Updates other participants in the room with your current location.
1967async fn update_participant_location(
1968    request: proto::UpdateParticipantLocation,
1969    response: Response<proto::UpdateParticipantLocation>,
1970    session: UserSession,
1971) -> Result<()> {
1972    let room_id = RoomId::from_proto(request.room_id);
1973    let location = request
1974        .location
1975        .ok_or_else(|| anyhow!("invalid location"))?;
1976
1977    let db = session.db().await;
1978    let room = db
1979        .update_room_participant_location(room_id, session.connection_id, location)
1980        .await?;
1981
1982    room_updated(&room, &session.peer);
1983    response.send(proto::Ack {})?;
1984    Ok(())
1985}
1986
1987/// Share a project into the room.
1988async fn share_project(
1989    request: proto::ShareProject,
1990    response: Response<proto::ShareProject>,
1991    session: UserSession,
1992) -> Result<()> {
1993    let (project_id, room) = &*session
1994        .db()
1995        .await
1996        .share_project(
1997            RoomId::from_proto(request.room_id),
1998            session.connection_id,
1999            &request.worktrees,
2000            request.is_ssh_project,
2001            request
2002                .dev_server_project_id
2003                .map(DevServerProjectId::from_proto),
2004        )
2005        .await?;
2006    response.send(proto::ShareProjectResponse {
2007        project_id: project_id.to_proto(),
2008    })?;
2009    room_updated(room, &session.peer);
2010
2011    Ok(())
2012}
2013
2014/// Unshare a project from the room.
2015async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2016    let project_id = ProjectId::from_proto(message.project_id);
2017    unshare_project_internal(
2018        project_id,
2019        session.connection_id,
2020        session.user_id(),
2021        &session,
2022    )
2023    .await
2024}
2025
2026async fn unshare_project_internal(
2027    project_id: ProjectId,
2028    connection_id: ConnectionId,
2029    user_id: Option<UserId>,
2030    session: &Session,
2031) -> Result<()> {
2032    let delete = {
2033        let room_guard = session
2034            .db()
2035            .await
2036            .unshare_project(project_id, connection_id, user_id)
2037            .await?;
2038
2039        let (delete, room, guest_connection_ids) = &*room_guard;
2040
2041        let message = proto::UnshareProject {
2042            project_id: project_id.to_proto(),
2043        };
2044
2045        broadcast(
2046            Some(connection_id),
2047            guest_connection_ids.iter().copied(),
2048            |conn_id| session.peer.send(conn_id, message.clone()),
2049        );
2050        if let Some(room) = room {
2051            room_updated(room, &session.peer);
2052        }
2053
2054        *delete
2055    };
2056
2057    if delete {
2058        let db = session.db().await;
2059        db.delete_project(project_id).await?;
2060    }
2061
2062    Ok(())
2063}
2064
2065/// DevServer makes a project available online
2066async fn share_dev_server_project(
2067    request: proto::ShareDevServerProject,
2068    response: Response<proto::ShareDevServerProject>,
2069    session: DevServerSession,
2070) -> Result<()> {
2071    let (dev_server_project, user_id, status) = session
2072        .db()
2073        .await
2074        .share_dev_server_project(
2075            DevServerProjectId::from_proto(request.dev_server_project_id),
2076            session.dev_server_id(),
2077            session.connection_id,
2078            &request.worktrees,
2079        )
2080        .await?;
2081    let Some(project_id) = dev_server_project.project_id else {
2082        return Err(anyhow!("failed to share remote project"))?;
2083    };
2084
2085    send_dev_server_projects_update(user_id, status, &session).await;
2086
2087    response.send(proto::ShareProjectResponse { project_id })?;
2088
2089    Ok(())
2090}
2091
2092/// Join someone elses shared project.
2093async fn join_project(
2094    request: proto::JoinProject,
2095    response: Response<proto::JoinProject>,
2096    session: UserSession,
2097) -> Result<()> {
2098    let project_id = ProjectId::from_proto(request.project_id);
2099
2100    tracing::info!(%project_id, "join project");
2101
2102    let db = session.db().await;
2103    let (project, replica_id) = &mut *db
2104        .join_project(project_id, session.connection_id, session.user_id())
2105        .await?;
2106    drop(db);
2107    tracing::info!(%project_id, "join remote project");
2108    join_project_internal(response, session, project, replica_id)
2109}
2110
2111trait JoinProjectInternalResponse {
2112    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2113}
2114impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2115    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2116        Response::<proto::JoinProject>::send(self, result)
2117    }
2118}
2119impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2120    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2121        Response::<proto::JoinHostedProject>::send(self, result)
2122    }
2123}
2124
2125fn join_project_internal(
2126    response: impl JoinProjectInternalResponse,
2127    session: UserSession,
2128    project: &mut Project,
2129    replica_id: &ReplicaId,
2130) -> Result<()> {
2131    let collaborators = project
2132        .collaborators
2133        .iter()
2134        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2135        .map(|collaborator| collaborator.to_proto())
2136        .collect::<Vec<_>>();
2137    let project_id = project.id;
2138    let guest_user_id = session.user_id();
2139
2140    let worktrees = project
2141        .worktrees
2142        .iter()
2143        .map(|(id, worktree)| proto::WorktreeMetadata {
2144            id: *id,
2145            root_name: worktree.root_name.clone(),
2146            visible: worktree.visible,
2147            abs_path: worktree.abs_path.clone(),
2148        })
2149        .collect::<Vec<_>>();
2150
2151    let add_project_collaborator = proto::AddProjectCollaborator {
2152        project_id: project_id.to_proto(),
2153        collaborator: Some(proto::Collaborator {
2154            peer_id: Some(session.connection_id.into()),
2155            replica_id: replica_id.0 as u32,
2156            user_id: guest_user_id.to_proto(),
2157        }),
2158    };
2159
2160    for collaborator in &collaborators {
2161        session
2162            .peer
2163            .send(
2164                collaborator.peer_id.unwrap().into(),
2165                add_project_collaborator.clone(),
2166            )
2167            .trace_err();
2168    }
2169
2170    // First, we send the metadata associated with each worktree.
2171    response.send(proto::JoinProjectResponse {
2172        project_id: project.id.0 as u64,
2173        worktrees: worktrees.clone(),
2174        replica_id: replica_id.0 as u32,
2175        collaborators: collaborators.clone(),
2176        language_servers: project.language_servers.clone(),
2177        role: project.role.into(),
2178        dev_server_project_id: project
2179            .dev_server_project_id
2180            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2181    })?;
2182
2183    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2184        #[cfg(any(test, feature = "test-support"))]
2185        const MAX_CHUNK_SIZE: usize = 2;
2186        #[cfg(not(any(test, feature = "test-support")))]
2187        const MAX_CHUNK_SIZE: usize = 256;
2188
2189        // Stream this worktree's entries.
2190        let message = proto::UpdateWorktree {
2191            project_id: project_id.to_proto(),
2192            worktree_id,
2193            abs_path: worktree.abs_path.clone(),
2194            root_name: worktree.root_name,
2195            updated_entries: worktree.entries,
2196            removed_entries: Default::default(),
2197            scan_id: worktree.scan_id,
2198            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2199            updated_repositories: worktree.repository_entries.into_values().collect(),
2200            removed_repositories: Default::default(),
2201        };
2202        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2203            session.peer.send(session.connection_id, update.clone())?;
2204        }
2205
2206        // Stream this worktree's diagnostics.
2207        for summary in worktree.diagnostic_summaries {
2208            session.peer.send(
2209                session.connection_id,
2210                proto::UpdateDiagnosticSummary {
2211                    project_id: project_id.to_proto(),
2212                    worktree_id: worktree.id,
2213                    summary: Some(summary),
2214                },
2215            )?;
2216        }
2217
2218        for settings_file in worktree.settings_files {
2219            session.peer.send(
2220                session.connection_id,
2221                proto::UpdateWorktreeSettings {
2222                    project_id: project_id.to_proto(),
2223                    worktree_id: worktree.id,
2224                    path: settings_file.path,
2225                    content: Some(settings_file.content),
2226                },
2227            )?;
2228        }
2229    }
2230
2231    for language_server in &project.language_servers {
2232        session.peer.send(
2233            session.connection_id,
2234            proto::UpdateLanguageServer {
2235                project_id: project_id.to_proto(),
2236                language_server_id: language_server.id,
2237                variant: Some(
2238                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2239                        proto::LspDiskBasedDiagnosticsUpdated {},
2240                    ),
2241                ),
2242            },
2243        )?;
2244    }
2245
2246    Ok(())
2247}
2248
2249/// Leave someone elses shared project.
2250async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2251    let sender_id = session.connection_id;
2252    let project_id = ProjectId::from_proto(request.project_id);
2253    let db = session.db().await;
2254    if db.is_hosted_project(project_id).await? {
2255        let project = db.leave_hosted_project(project_id, sender_id).await?;
2256        project_left(&project, &session);
2257        return Ok(());
2258    }
2259
2260    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2261    tracing::info!(
2262        %project_id,
2263        "leave project"
2264    );
2265
2266    project_left(project, &session);
2267    if let Some(room) = room {
2268        room_updated(room, &session.peer);
2269    }
2270
2271    Ok(())
2272}
2273
2274async fn join_hosted_project(
2275    request: proto::JoinHostedProject,
2276    response: Response<proto::JoinHostedProject>,
2277    session: UserSession,
2278) -> Result<()> {
2279    let (mut project, replica_id) = session
2280        .db()
2281        .await
2282        .join_hosted_project(
2283            ProjectId(request.project_id as i32),
2284            session.user_id(),
2285            session.connection_id,
2286        )
2287        .await?;
2288
2289    join_project_internal(response, session, &mut project, &replica_id)
2290}
2291
2292async fn list_remote_directory(
2293    request: proto::ListRemoteDirectory,
2294    response: Response<proto::ListRemoteDirectory>,
2295    session: UserSession,
2296) -> Result<()> {
2297    let dev_server_id = DevServerId(request.dev_server_id as i32);
2298    let dev_server_connection_id = session
2299        .connection_pool()
2300        .await
2301        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2302
2303    session
2304        .db()
2305        .await
2306        .get_dev_server_for_user(dev_server_id, session.user_id())
2307        .await?;
2308
2309    response.send(
2310        session
2311            .peer
2312            .forward_request(session.connection_id, dev_server_connection_id, request)
2313            .await?,
2314    )?;
2315    Ok(())
2316}
2317
2318async fn update_dev_server_project(
2319    request: proto::UpdateDevServerProject,
2320    response: Response<proto::UpdateDevServerProject>,
2321    session: UserSession,
2322) -> Result<()> {
2323    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2324
2325    let (dev_server_project, update) = session
2326        .db()
2327        .await
2328        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2329        .await?;
2330
2331    let projects = session
2332        .db()
2333        .await
2334        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2335        .await?;
2336
2337    let dev_server_connection_id = session
2338        .connection_pool()
2339        .await
2340        .dev_server_connection_id_supporting(
2341            dev_server_project.dev_server_id,
2342            ZedVersion::with_list_directory(),
2343        )?;
2344
2345    session.peer.send(
2346        dev_server_connection_id,
2347        proto::DevServerInstructions { projects },
2348    )?;
2349
2350    send_dev_server_projects_update(session.user_id(), update, &session).await;
2351
2352    response.send(proto::Ack {})
2353}
2354
2355async fn create_dev_server_project(
2356    request: proto::CreateDevServerProject,
2357    response: Response<proto::CreateDevServerProject>,
2358    session: UserSession,
2359) -> Result<()> {
2360    let dev_server_id = DevServerId(request.dev_server_id as i32);
2361    let dev_server_connection_id = session
2362        .connection_pool()
2363        .await
2364        .dev_server_connection_id(dev_server_id);
2365    let Some(dev_server_connection_id) = dev_server_connection_id else {
2366        Err(ErrorCode::DevServerOffline
2367            .message("Cannot create a remote project when the dev server is offline".to_string())
2368            .anyhow())?
2369    };
2370
2371    let path = request.path.clone();
2372    //Check that the path exists on the dev server
2373    session
2374        .peer
2375        .forward_request(
2376            session.connection_id,
2377            dev_server_connection_id,
2378            proto::ValidateDevServerProjectRequest { path: path.clone() },
2379        )
2380        .await?;
2381
2382    let (dev_server_project, update) = session
2383        .db()
2384        .await
2385        .create_dev_server_project(
2386            DevServerId(request.dev_server_id as i32),
2387            &request.path,
2388            session.user_id(),
2389        )
2390        .await?;
2391
2392    let projects = session
2393        .db()
2394        .await
2395        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2396        .await?;
2397
2398    session.peer.send(
2399        dev_server_connection_id,
2400        proto::DevServerInstructions { projects },
2401    )?;
2402
2403    send_dev_server_projects_update(session.user_id(), update, &session).await;
2404
2405    response.send(proto::CreateDevServerProjectResponse {
2406        dev_server_project: Some(dev_server_project.to_proto(None)),
2407    })?;
2408    Ok(())
2409}
2410
2411async fn create_dev_server(
2412    request: proto::CreateDevServer,
2413    response: Response<proto::CreateDevServer>,
2414    session: UserSession,
2415) -> Result<()> {
2416    let access_token = auth::random_token();
2417    let hashed_access_token = auth::hash_access_token(&access_token);
2418
2419    if request.name.is_empty() {
2420        return Err(proto::ErrorCode::Forbidden
2421            .message("Dev server name cannot be empty".to_string())
2422            .anyhow())?;
2423    }
2424
2425    let (dev_server, status) = session
2426        .db()
2427        .await
2428        .create_dev_server(
2429            &request.name,
2430            request.ssh_connection_string.as_deref(),
2431            &hashed_access_token,
2432            session.user_id(),
2433        )
2434        .await?;
2435
2436    send_dev_server_projects_update(session.user_id(), status, &session).await;
2437
2438    response.send(proto::CreateDevServerResponse {
2439        dev_server_id: dev_server.id.0 as u64,
2440        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2441        name: request.name,
2442    })?;
2443    Ok(())
2444}
2445
2446async fn regenerate_dev_server_token(
2447    request: proto::RegenerateDevServerToken,
2448    response: Response<proto::RegenerateDevServerToken>,
2449    session: UserSession,
2450) -> Result<()> {
2451    let dev_server_id = DevServerId(request.dev_server_id as i32);
2452    let access_token = auth::random_token();
2453    let hashed_access_token = auth::hash_access_token(&access_token);
2454
2455    let connection_id = session
2456        .connection_pool()
2457        .await
2458        .dev_server_connection_id(dev_server_id);
2459    if let Some(connection_id) = connection_id {
2460        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2461        session.peer.send(
2462            connection_id,
2463            proto::ShutdownDevServer {
2464                reason: Some("dev server token was regenerated".to_string()),
2465            },
2466        )?;
2467        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2468    }
2469
2470    let status = session
2471        .db()
2472        .await
2473        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2474        .await?;
2475
2476    send_dev_server_projects_update(session.user_id(), status, &session).await;
2477
2478    response.send(proto::RegenerateDevServerTokenResponse {
2479        dev_server_id: dev_server_id.to_proto(),
2480        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2481    })?;
2482    Ok(())
2483}
2484
2485async fn rename_dev_server(
2486    request: proto::RenameDevServer,
2487    response: Response<proto::RenameDevServer>,
2488    session: UserSession,
2489) -> Result<()> {
2490    if request.name.trim().is_empty() {
2491        return Err(proto::ErrorCode::Forbidden
2492            .message("Dev server name cannot be empty".to_string())
2493            .anyhow())?;
2494    }
2495
2496    let dev_server_id = DevServerId(request.dev_server_id as i32);
2497    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2498    if dev_server.user_id != session.user_id() {
2499        return Err(anyhow!(ErrorCode::Forbidden))?;
2500    }
2501
2502    let status = session
2503        .db()
2504        .await
2505        .rename_dev_server(
2506            dev_server_id,
2507            &request.name,
2508            request.ssh_connection_string.as_deref(),
2509            session.user_id(),
2510        )
2511        .await?;
2512
2513    send_dev_server_projects_update(session.user_id(), status, &session).await;
2514
2515    response.send(proto::Ack {})?;
2516    Ok(())
2517}
2518
2519async fn delete_dev_server(
2520    request: proto::DeleteDevServer,
2521    response: Response<proto::DeleteDevServer>,
2522    session: UserSession,
2523) -> Result<()> {
2524    let dev_server_id = DevServerId(request.dev_server_id as i32);
2525    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2526    if dev_server.user_id != session.user_id() {
2527        return Err(anyhow!(ErrorCode::Forbidden))?;
2528    }
2529
2530    let connection_id = session
2531        .connection_pool()
2532        .await
2533        .dev_server_connection_id(dev_server_id);
2534    if let Some(connection_id) = connection_id {
2535        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2536        session.peer.send(
2537            connection_id,
2538            proto::ShutdownDevServer {
2539                reason: Some("dev server was deleted".to_string()),
2540            },
2541        )?;
2542        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2543    }
2544
2545    let status = session
2546        .db()
2547        .await
2548        .delete_dev_server(dev_server_id, session.user_id())
2549        .await?;
2550
2551    send_dev_server_projects_update(session.user_id(), status, &session).await;
2552
2553    response.send(proto::Ack {})?;
2554    Ok(())
2555}
2556
2557async fn delete_dev_server_project(
2558    request: proto::DeleteDevServerProject,
2559    response: Response<proto::DeleteDevServerProject>,
2560    session: UserSession,
2561) -> Result<()> {
2562    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2563    let dev_server_project = session
2564        .db()
2565        .await
2566        .get_dev_server_project(dev_server_project_id)
2567        .await?;
2568
2569    let dev_server = session
2570        .db()
2571        .await
2572        .get_dev_server(dev_server_project.dev_server_id)
2573        .await?;
2574    if dev_server.user_id != session.user_id() {
2575        return Err(anyhow!(ErrorCode::Forbidden))?;
2576    }
2577
2578    let dev_server_connection_id = session
2579        .connection_pool()
2580        .await
2581        .dev_server_connection_id(dev_server.id);
2582
2583    if let Some(dev_server_connection_id) = dev_server_connection_id {
2584        let project = session
2585            .db()
2586            .await
2587            .find_dev_server_project(dev_server_project_id)
2588            .await;
2589        if let Ok(project) = project {
2590            unshare_project_internal(
2591                project.id,
2592                dev_server_connection_id,
2593                Some(session.user_id()),
2594                &session,
2595            )
2596            .await?;
2597        }
2598    }
2599
2600    let (projects, status) = session
2601        .db()
2602        .await
2603        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2604        .await?;
2605
2606    if let Some(dev_server_connection_id) = dev_server_connection_id {
2607        session.peer.send(
2608            dev_server_connection_id,
2609            proto::DevServerInstructions { projects },
2610        )?;
2611    }
2612
2613    send_dev_server_projects_update(session.user_id(), status, &session).await;
2614
2615    response.send(proto::Ack {})?;
2616    Ok(())
2617}
2618
2619async fn rejoin_dev_server_projects(
2620    request: proto::RejoinRemoteProjects,
2621    response: Response<proto::RejoinRemoteProjects>,
2622    session: UserSession,
2623) -> Result<()> {
2624    let mut rejoined_projects = {
2625        let db = session.db().await;
2626        db.rejoin_dev_server_projects(
2627            &request.rejoined_projects,
2628            session.user_id(),
2629            session.0.connection_id,
2630        )
2631        .await?
2632    };
2633    response.send(proto::RejoinRemoteProjectsResponse {
2634        rejoined_projects: rejoined_projects
2635            .iter()
2636            .map(|project| project.to_proto())
2637            .collect(),
2638    })?;
2639    notify_rejoined_projects(&mut rejoined_projects, &session)
2640}
2641
2642async fn reconnect_dev_server(
2643    request: proto::ReconnectDevServer,
2644    response: Response<proto::ReconnectDevServer>,
2645    session: DevServerSession,
2646) -> Result<()> {
2647    let reshared_projects = {
2648        let db = session.db().await;
2649        db.reshare_dev_server_projects(
2650            &request.reshared_projects,
2651            session.dev_server_id(),
2652            session.0.connection_id,
2653        )
2654        .await?
2655    };
2656
2657    for project in &reshared_projects {
2658        for collaborator in &project.collaborators {
2659            session
2660                .peer
2661                .send(
2662                    collaborator.connection_id,
2663                    proto::UpdateProjectCollaborator {
2664                        project_id: project.id.to_proto(),
2665                        old_peer_id: Some(project.old_connection_id.into()),
2666                        new_peer_id: Some(session.connection_id.into()),
2667                    },
2668                )
2669                .trace_err();
2670        }
2671
2672        broadcast(
2673            Some(session.connection_id),
2674            project
2675                .collaborators
2676                .iter()
2677                .map(|collaborator| collaborator.connection_id),
2678            |connection_id| {
2679                session.peer.forward_send(
2680                    session.connection_id,
2681                    connection_id,
2682                    proto::UpdateProject {
2683                        project_id: project.id.to_proto(),
2684                        worktrees: project.worktrees.clone(),
2685                    },
2686                )
2687            },
2688        );
2689    }
2690
2691    response.send(proto::ReconnectDevServerResponse {
2692        reshared_projects: reshared_projects
2693            .iter()
2694            .map(|project| proto::ResharedProject {
2695                id: project.id.to_proto(),
2696                collaborators: project
2697                    .collaborators
2698                    .iter()
2699                    .map(|collaborator| collaborator.to_proto())
2700                    .collect(),
2701            })
2702            .collect(),
2703    })?;
2704
2705    Ok(())
2706}
2707
2708async fn shutdown_dev_server(
2709    _: proto::ShutdownDevServer,
2710    response: Response<proto::ShutdownDevServer>,
2711    session: DevServerSession,
2712) -> Result<()> {
2713    response.send(proto::Ack {})?;
2714    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2715    remove_dev_server_connection(session.dev_server_id(), &session).await
2716}
2717
2718async fn shutdown_dev_server_internal(
2719    dev_server_id: DevServerId,
2720    connection_id: ConnectionId,
2721    session: &Session,
2722) -> Result<()> {
2723    let (dev_server_projects, dev_server) = {
2724        let db = session.db().await;
2725        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2726        let dev_server = db.get_dev_server(dev_server_id).await?;
2727        (dev_server_projects, dev_server)
2728    };
2729
2730    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2731        unshare_project_internal(
2732            ProjectId::from_proto(project_id),
2733            connection_id,
2734            None,
2735            session,
2736        )
2737        .await?;
2738    }
2739
2740    session
2741        .connection_pool()
2742        .await
2743        .set_dev_server_offline(dev_server_id);
2744
2745    let status = session
2746        .db()
2747        .await
2748        .dev_server_projects_update(dev_server.user_id)
2749        .await?;
2750    send_dev_server_projects_update(dev_server.user_id, status, session).await;
2751
2752    Ok(())
2753}
2754
2755async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2756    let dev_server_connection = session
2757        .connection_pool()
2758        .await
2759        .dev_server_connection_id(dev_server_id);
2760
2761    if let Some(dev_server_connection) = dev_server_connection {
2762        session
2763            .connection_pool()
2764            .await
2765            .remove_connection(dev_server_connection)?;
2766    }
2767    Ok(())
2768}
2769
2770/// Updates other participants with changes to the project
2771async fn update_project(
2772    request: proto::UpdateProject,
2773    response: Response<proto::UpdateProject>,
2774    session: Session,
2775) -> Result<()> {
2776    let project_id = ProjectId::from_proto(request.project_id);
2777    let (room, guest_connection_ids) = &*session
2778        .db()
2779        .await
2780        .update_project(project_id, session.connection_id, &request.worktrees)
2781        .await?;
2782    broadcast(
2783        Some(session.connection_id),
2784        guest_connection_ids.iter().copied(),
2785        |connection_id| {
2786            session
2787                .peer
2788                .forward_send(session.connection_id, connection_id, request.clone())
2789        },
2790    );
2791    if let Some(room) = room {
2792        room_updated(room, &session.peer);
2793    }
2794    response.send(proto::Ack {})?;
2795
2796    Ok(())
2797}
2798
2799/// Updates other participants with changes to the worktree
2800async fn update_worktree(
2801    request: proto::UpdateWorktree,
2802    response: Response<proto::UpdateWorktree>,
2803    session: Session,
2804) -> Result<()> {
2805    let guest_connection_ids = session
2806        .db()
2807        .await
2808        .update_worktree(&request, session.connection_id)
2809        .await?;
2810
2811    broadcast(
2812        Some(session.connection_id),
2813        guest_connection_ids.iter().copied(),
2814        |connection_id| {
2815            session
2816                .peer
2817                .forward_send(session.connection_id, connection_id, request.clone())
2818        },
2819    );
2820    response.send(proto::Ack {})?;
2821    Ok(())
2822}
2823
2824/// Updates other participants with changes to the diagnostics
2825async fn update_diagnostic_summary(
2826    message: proto::UpdateDiagnosticSummary,
2827    session: Session,
2828) -> Result<()> {
2829    let guest_connection_ids = session
2830        .db()
2831        .await
2832        .update_diagnostic_summary(&message, session.connection_id)
2833        .await?;
2834
2835    broadcast(
2836        Some(session.connection_id),
2837        guest_connection_ids.iter().copied(),
2838        |connection_id| {
2839            session
2840                .peer
2841                .forward_send(session.connection_id, connection_id, message.clone())
2842        },
2843    );
2844
2845    Ok(())
2846}
2847
2848/// Updates other participants with changes to the worktree settings
2849async fn update_worktree_settings(
2850    message: proto::UpdateWorktreeSettings,
2851    session: Session,
2852) -> Result<()> {
2853    let guest_connection_ids = session
2854        .db()
2855        .await
2856        .update_worktree_settings(&message, session.connection_id)
2857        .await?;
2858
2859    broadcast(
2860        Some(session.connection_id),
2861        guest_connection_ids.iter().copied(),
2862        |connection_id| {
2863            session
2864                .peer
2865                .forward_send(session.connection_id, connection_id, message.clone())
2866        },
2867    );
2868
2869    Ok(())
2870}
2871
2872/// Notify other participants that a  language server has started.
2873async fn start_language_server(
2874    request: proto::StartLanguageServer,
2875    session: Session,
2876) -> Result<()> {
2877    let guest_connection_ids = session
2878        .db()
2879        .await
2880        .start_language_server(&request, session.connection_id)
2881        .await?;
2882
2883    broadcast(
2884        Some(session.connection_id),
2885        guest_connection_ids.iter().copied(),
2886        |connection_id| {
2887            session
2888                .peer
2889                .forward_send(session.connection_id, connection_id, request.clone())
2890        },
2891    );
2892    Ok(())
2893}
2894
2895/// Notify other participants that a language server has changed.
2896async fn update_language_server(
2897    request: proto::UpdateLanguageServer,
2898    session: Session,
2899) -> Result<()> {
2900    let project_id = ProjectId::from_proto(request.project_id);
2901    let project_connection_ids = session
2902        .db()
2903        .await
2904        .project_connection_ids(project_id, session.connection_id, true)
2905        .await?;
2906    broadcast(
2907        Some(session.connection_id),
2908        project_connection_ids.iter().copied(),
2909        |connection_id| {
2910            session
2911                .peer
2912                .forward_send(session.connection_id, connection_id, request.clone())
2913        },
2914    );
2915    Ok(())
2916}
2917
2918/// forward a project request to the host. These requests should be read only
2919/// as guests are allowed to send them.
2920async fn forward_read_only_project_request<T>(
2921    request: T,
2922    response: Response<T>,
2923    session: UserSession,
2924) -> Result<()>
2925where
2926    T: EntityMessage + RequestMessage,
2927{
2928    let project_id = ProjectId::from_proto(request.remote_entity_id());
2929    let host_connection_id = session
2930        .db()
2931        .await
2932        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2933        .await?;
2934    let payload = session
2935        .peer
2936        .forward_request(session.connection_id, host_connection_id, request)
2937        .await?;
2938    response.send(payload)?;
2939    Ok(())
2940}
2941
2942async fn forward_find_search_candidates_request(
2943    request: proto::FindSearchCandidates,
2944    response: Response<proto::FindSearchCandidates>,
2945    session: UserSession,
2946) -> Result<()> {
2947    let project_id = ProjectId::from_proto(request.remote_entity_id());
2948    let host_connection_id = session
2949        .db()
2950        .await
2951        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2952        .await?;
2953
2954    let host_version = session
2955        .connection_pool()
2956        .await
2957        .connection(host_connection_id)
2958        .map(|c| c.zed_version);
2959
2960    if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2961    {
2962        let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2963        let search = proto::SearchProject {
2964            project_id: project_id.to_proto(),
2965            query: query.query,
2966            regex: query.regex,
2967            whole_word: query.whole_word,
2968            case_sensitive: query.case_sensitive,
2969            files_to_include: query.files_to_include,
2970            files_to_exclude: query.files_to_exclude,
2971            include_ignored: query.include_ignored,
2972        };
2973
2974        let payload = session
2975            .peer
2976            .forward_request(session.connection_id, host_connection_id, search)
2977            .await?;
2978        return response.send(proto::FindSearchCandidatesResponse {
2979            buffer_ids: payload
2980                .locations
2981                .into_iter()
2982                .map(|loc| loc.buffer_id)
2983                .collect(),
2984        });
2985    }
2986
2987    let payload = session
2988        .peer
2989        .forward_request(session.connection_id, host_connection_id, request)
2990        .await?;
2991    response.send(payload)?;
2992    Ok(())
2993}
2994
2995/// forward a project request to the dev server. Only allowed
2996/// if it's your dev server.
2997async fn forward_project_request_for_owner<T>(
2998    request: T,
2999    response: Response<T>,
3000    session: UserSession,
3001) -> Result<()>
3002where
3003    T: EntityMessage + RequestMessage,
3004{
3005    let project_id = ProjectId::from_proto(request.remote_entity_id());
3006
3007    let host_connection_id = session
3008        .db()
3009        .await
3010        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3011        .await?;
3012    let payload = session
3013        .peer
3014        .forward_request(session.connection_id, host_connection_id, request)
3015        .await?;
3016    response.send(payload)?;
3017    Ok(())
3018}
3019
3020/// forward a project request to the host. These requests are disallowed
3021/// for guests.
3022async fn forward_mutating_project_request<T>(
3023    request: T,
3024    response: Response<T>,
3025    session: UserSession,
3026) -> Result<()>
3027where
3028    T: EntityMessage + RequestMessage,
3029{
3030    let project_id = ProjectId::from_proto(request.remote_entity_id());
3031
3032    let host_connection_id = session
3033        .db()
3034        .await
3035        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3036        .await?;
3037    let payload = session
3038        .peer
3039        .forward_request(session.connection_id, host_connection_id, request)
3040        .await?;
3041    response.send(payload)?;
3042    Ok(())
3043}
3044
3045/// Notify other participants that a new buffer has been created
3046async fn create_buffer_for_peer(
3047    request: proto::CreateBufferForPeer,
3048    session: Session,
3049) -> Result<()> {
3050    session
3051        .db()
3052        .await
3053        .check_user_is_project_host(
3054            ProjectId::from_proto(request.project_id),
3055            session.connection_id,
3056        )
3057        .await?;
3058    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3059    session
3060        .peer
3061        .forward_send(session.connection_id, peer_id.into(), request)?;
3062    Ok(())
3063}
3064
3065/// Notify other participants that a buffer has been updated. This is
3066/// allowed for guests as long as the update is limited to selections.
3067async fn update_buffer(
3068    request: proto::UpdateBuffer,
3069    response: Response<proto::UpdateBuffer>,
3070    session: Session,
3071) -> Result<()> {
3072    let project_id = ProjectId::from_proto(request.project_id);
3073    let mut capability = Capability::ReadOnly;
3074
3075    for op in request.operations.iter() {
3076        match op.variant {
3077            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3078            Some(_) => capability = Capability::ReadWrite,
3079        }
3080    }
3081
3082    let host = {
3083        let guard = session
3084            .db()
3085            .await
3086            .connections_for_buffer_update(
3087                project_id,
3088                session.principal_id(),
3089                session.connection_id,
3090                capability,
3091            )
3092            .await?;
3093
3094        let (host, guests) = &*guard;
3095
3096        broadcast(
3097            Some(session.connection_id),
3098            guests.clone(),
3099            |connection_id| {
3100                session
3101                    .peer
3102                    .forward_send(session.connection_id, connection_id, request.clone())
3103            },
3104        );
3105
3106        *host
3107    };
3108
3109    if host != session.connection_id {
3110        session
3111            .peer
3112            .forward_request(session.connection_id, host, request.clone())
3113            .await?;
3114    }
3115
3116    response.send(proto::Ack {})?;
3117    Ok(())
3118}
3119
3120async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3121    let project_id = ProjectId::from_proto(message.project_id);
3122
3123    let operation = message.operation.as_ref().context("invalid operation")?;
3124    let capability = match operation.variant.as_ref() {
3125        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3126            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3127                match buffer_op.variant {
3128                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3129                        Capability::ReadOnly
3130                    }
3131                    _ => Capability::ReadWrite,
3132                }
3133            } else {
3134                Capability::ReadWrite
3135            }
3136        }
3137        Some(_) => Capability::ReadWrite,
3138        None => Capability::ReadOnly,
3139    };
3140
3141    let guard = session
3142        .db()
3143        .await
3144        .connections_for_buffer_update(
3145            project_id,
3146            session.principal_id(),
3147            session.connection_id,
3148            capability,
3149        )
3150        .await?;
3151
3152    let (host, guests) = &*guard;
3153
3154    broadcast(
3155        Some(session.connection_id),
3156        guests.iter().chain([host]).copied(),
3157        |connection_id| {
3158            session
3159                .peer
3160                .forward_send(session.connection_id, connection_id, message.clone())
3161        },
3162    );
3163
3164    Ok(())
3165}
3166
3167/// Notify other participants that a project has been updated.
3168async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3169    request: T,
3170    session: Session,
3171) -> Result<()> {
3172    let project_id = ProjectId::from_proto(request.remote_entity_id());
3173    let project_connection_ids = session
3174        .db()
3175        .await
3176        .project_connection_ids(project_id, session.connection_id, false)
3177        .await?;
3178
3179    broadcast(
3180        Some(session.connection_id),
3181        project_connection_ids.iter().copied(),
3182        |connection_id| {
3183            session
3184                .peer
3185                .forward_send(session.connection_id, connection_id, request.clone())
3186        },
3187    );
3188    Ok(())
3189}
3190
3191/// Start following another user in a call.
3192async fn follow(
3193    request: proto::Follow,
3194    response: Response<proto::Follow>,
3195    session: UserSession,
3196) -> Result<()> {
3197    let room_id = RoomId::from_proto(request.room_id);
3198    let project_id = request.project_id.map(ProjectId::from_proto);
3199    let leader_id = request
3200        .leader_id
3201        .ok_or_else(|| anyhow!("invalid leader id"))?
3202        .into();
3203    let follower_id = session.connection_id;
3204
3205    session
3206        .db()
3207        .await
3208        .check_room_participants(room_id, leader_id, session.connection_id)
3209        .await?;
3210
3211    let response_payload = session
3212        .peer
3213        .forward_request(session.connection_id, leader_id, request)
3214        .await?;
3215    response.send(response_payload)?;
3216
3217    if let Some(project_id) = project_id {
3218        let room = session
3219            .db()
3220            .await
3221            .follow(room_id, project_id, leader_id, follower_id)
3222            .await?;
3223        room_updated(&room, &session.peer);
3224    }
3225
3226    Ok(())
3227}
3228
3229/// Stop following another user in a call.
3230async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3231    let room_id = RoomId::from_proto(request.room_id);
3232    let project_id = request.project_id.map(ProjectId::from_proto);
3233    let leader_id = request
3234        .leader_id
3235        .ok_or_else(|| anyhow!("invalid leader id"))?
3236        .into();
3237    let follower_id = session.connection_id;
3238
3239    session
3240        .db()
3241        .await
3242        .check_room_participants(room_id, leader_id, session.connection_id)
3243        .await?;
3244
3245    session
3246        .peer
3247        .forward_send(session.connection_id, leader_id, request)?;
3248
3249    if let Some(project_id) = project_id {
3250        let room = session
3251            .db()
3252            .await
3253            .unfollow(room_id, project_id, leader_id, follower_id)
3254            .await?;
3255        room_updated(&room, &session.peer);
3256    }
3257
3258    Ok(())
3259}
3260
3261/// Notify everyone following you of your current location.
3262async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3263    let room_id = RoomId::from_proto(request.room_id);
3264    let database = session.db.lock().await;
3265
3266    let connection_ids = if let Some(project_id) = request.project_id {
3267        let project_id = ProjectId::from_proto(project_id);
3268        database
3269            .project_connection_ids(project_id, session.connection_id, true)
3270            .await?
3271    } else {
3272        database
3273            .room_connection_ids(room_id, session.connection_id)
3274            .await?
3275    };
3276
3277    // For now, don't send view update messages back to that view's current leader.
3278    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3279        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3280        _ => None,
3281    });
3282
3283    for connection_id in connection_ids.iter().cloned() {
3284        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3285            session
3286                .peer
3287                .forward_send(session.connection_id, connection_id, request.clone())?;
3288        }
3289    }
3290    Ok(())
3291}
3292
3293/// Get public data about users.
3294async fn get_users(
3295    request: proto::GetUsers,
3296    response: Response<proto::GetUsers>,
3297    session: Session,
3298) -> Result<()> {
3299    let user_ids = request
3300        .user_ids
3301        .into_iter()
3302        .map(UserId::from_proto)
3303        .collect();
3304    let users = session
3305        .db()
3306        .await
3307        .get_users_by_ids(user_ids)
3308        .await?
3309        .into_iter()
3310        .map(|user| proto::User {
3311            id: user.id.to_proto(),
3312            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3313            github_login: user.github_login,
3314        })
3315        .collect();
3316    response.send(proto::UsersResponse { users })?;
3317    Ok(())
3318}
3319
3320/// Search for users (to invite) buy Github login
3321async fn fuzzy_search_users(
3322    request: proto::FuzzySearchUsers,
3323    response: Response<proto::FuzzySearchUsers>,
3324    session: UserSession,
3325) -> Result<()> {
3326    let query = request.query;
3327    let users = match query.len() {
3328        0 => vec![],
3329        1 | 2 => session
3330            .db()
3331            .await
3332            .get_user_by_github_login(&query)
3333            .await?
3334            .into_iter()
3335            .collect(),
3336        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3337    };
3338    let users = users
3339        .into_iter()
3340        .filter(|user| user.id != session.user_id())
3341        .map(|user| proto::User {
3342            id: user.id.to_proto(),
3343            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3344            github_login: user.github_login,
3345        })
3346        .collect();
3347    response.send(proto::UsersResponse { users })?;
3348    Ok(())
3349}
3350
3351/// Send a contact request to another user.
3352async fn request_contact(
3353    request: proto::RequestContact,
3354    response: Response<proto::RequestContact>,
3355    session: UserSession,
3356) -> Result<()> {
3357    let requester_id = session.user_id();
3358    let responder_id = UserId::from_proto(request.responder_id);
3359    if requester_id == responder_id {
3360        return Err(anyhow!("cannot add yourself as a contact"))?;
3361    }
3362
3363    let notifications = session
3364        .db()
3365        .await
3366        .send_contact_request(requester_id, responder_id)
3367        .await?;
3368
3369    // Update outgoing contact requests of requester
3370    let mut update = proto::UpdateContacts::default();
3371    update.outgoing_requests.push(responder_id.to_proto());
3372    for connection_id in session
3373        .connection_pool()
3374        .await
3375        .user_connection_ids(requester_id)
3376    {
3377        session.peer.send(connection_id, update.clone())?;
3378    }
3379
3380    // Update incoming contact requests of responder
3381    let mut update = proto::UpdateContacts::default();
3382    update
3383        .incoming_requests
3384        .push(proto::IncomingContactRequest {
3385            requester_id: requester_id.to_proto(),
3386        });
3387    let connection_pool = session.connection_pool().await;
3388    for connection_id in connection_pool.user_connection_ids(responder_id) {
3389        session.peer.send(connection_id, update.clone())?;
3390    }
3391
3392    send_notifications(&connection_pool, &session.peer, notifications);
3393
3394    response.send(proto::Ack {})?;
3395    Ok(())
3396}
3397
3398/// Accept or decline a contact request
3399async fn respond_to_contact_request(
3400    request: proto::RespondToContactRequest,
3401    response: Response<proto::RespondToContactRequest>,
3402    session: UserSession,
3403) -> Result<()> {
3404    let responder_id = session.user_id();
3405    let requester_id = UserId::from_proto(request.requester_id);
3406    let db = session.db().await;
3407    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3408        db.dismiss_contact_notification(responder_id, requester_id)
3409            .await?;
3410    } else {
3411        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3412
3413        let notifications = db
3414            .respond_to_contact_request(responder_id, requester_id, accept)
3415            .await?;
3416        let requester_busy = db.is_user_busy(requester_id).await?;
3417        let responder_busy = db.is_user_busy(responder_id).await?;
3418
3419        let pool = session.connection_pool().await;
3420        // Update responder with new contact
3421        let mut update = proto::UpdateContacts::default();
3422        if accept {
3423            update
3424                .contacts
3425                .push(contact_for_user(requester_id, requester_busy, &pool));
3426        }
3427        update
3428            .remove_incoming_requests
3429            .push(requester_id.to_proto());
3430        for connection_id in pool.user_connection_ids(responder_id) {
3431            session.peer.send(connection_id, update.clone())?;
3432        }
3433
3434        // Update requester with new contact
3435        let mut update = proto::UpdateContacts::default();
3436        if accept {
3437            update
3438                .contacts
3439                .push(contact_for_user(responder_id, responder_busy, &pool));
3440        }
3441        update
3442            .remove_outgoing_requests
3443            .push(responder_id.to_proto());
3444
3445        for connection_id in pool.user_connection_ids(requester_id) {
3446            session.peer.send(connection_id, update.clone())?;
3447        }
3448
3449        send_notifications(&pool, &session.peer, notifications);
3450    }
3451
3452    response.send(proto::Ack {})?;
3453    Ok(())
3454}
3455
3456/// Remove a contact.
3457async fn remove_contact(
3458    request: proto::RemoveContact,
3459    response: Response<proto::RemoveContact>,
3460    session: UserSession,
3461) -> Result<()> {
3462    let requester_id = session.user_id();
3463    let responder_id = UserId::from_proto(request.user_id);
3464    let db = session.db().await;
3465    let (contact_accepted, deleted_notification_id) =
3466        db.remove_contact(requester_id, responder_id).await?;
3467
3468    let pool = session.connection_pool().await;
3469    // Update outgoing contact requests of requester
3470    let mut update = proto::UpdateContacts::default();
3471    if contact_accepted {
3472        update.remove_contacts.push(responder_id.to_proto());
3473    } else {
3474        update
3475            .remove_outgoing_requests
3476            .push(responder_id.to_proto());
3477    }
3478    for connection_id in pool.user_connection_ids(requester_id) {
3479        session.peer.send(connection_id, update.clone())?;
3480    }
3481
3482    // Update incoming contact requests of responder
3483    let mut update = proto::UpdateContacts::default();
3484    if contact_accepted {
3485        update.remove_contacts.push(requester_id.to_proto());
3486    } else {
3487        update
3488            .remove_incoming_requests
3489            .push(requester_id.to_proto());
3490    }
3491    for connection_id in pool.user_connection_ids(responder_id) {
3492        session.peer.send(connection_id, update.clone())?;
3493        if let Some(notification_id) = deleted_notification_id {
3494            session.peer.send(
3495                connection_id,
3496                proto::DeleteNotification {
3497                    notification_id: notification_id.to_proto(),
3498                },
3499            )?;
3500        }
3501    }
3502
3503    response.send(proto::Ack {})?;
3504    Ok(())
3505}
3506
3507fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3508    version.0.minor() < 139
3509}
3510
3511async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3512    let plan = session.current_plan(session.db().await).await?;
3513
3514    session
3515        .peer
3516        .send(
3517            session.connection_id,
3518            proto::UpdateUserPlan { plan: plan.into() },
3519        )
3520        .trace_err();
3521
3522    Ok(())
3523}
3524
3525async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3526    subscribe_user_to_channels(
3527        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3528        &session,
3529    )
3530    .await?;
3531    Ok(())
3532}
3533
3534async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3535    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3536    let mut pool = session.connection_pool().await;
3537    for membership in &channels_for_user.channel_memberships {
3538        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3539    }
3540    session.peer.send(
3541        session.connection_id,
3542        build_update_user_channels(&channels_for_user),
3543    )?;
3544    session.peer.send(
3545        session.connection_id,
3546        build_channels_update(channels_for_user),
3547    )?;
3548    Ok(())
3549}
3550
3551/// Creates a new channel.
3552async fn create_channel(
3553    request: proto::CreateChannel,
3554    response: Response<proto::CreateChannel>,
3555    session: UserSession,
3556) -> Result<()> {
3557    let db = session.db().await;
3558
3559    let parent_id = request.parent_id.map(ChannelId::from_proto);
3560    let (channel, membership) = db
3561        .create_channel(&request.name, parent_id, session.user_id())
3562        .await?;
3563
3564    let root_id = channel.root_id();
3565    let channel = Channel::from_model(channel);
3566
3567    response.send(proto::CreateChannelResponse {
3568        channel: Some(channel.to_proto()),
3569        parent_id: request.parent_id,
3570    })?;
3571
3572    let mut connection_pool = session.connection_pool().await;
3573    if let Some(membership) = membership {
3574        connection_pool.subscribe_to_channel(
3575            membership.user_id,
3576            membership.channel_id,
3577            membership.role,
3578        );
3579        let update = proto::UpdateUserChannels {
3580            channel_memberships: vec![proto::ChannelMembership {
3581                channel_id: membership.channel_id.to_proto(),
3582                role: membership.role.into(),
3583            }],
3584            ..Default::default()
3585        };
3586        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3587            session.peer.send(connection_id, update.clone())?;
3588        }
3589    }
3590
3591    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3592        if !role.can_see_channel(channel.visibility) {
3593            continue;
3594        }
3595
3596        let update = proto::UpdateChannels {
3597            channels: vec![channel.to_proto()],
3598            ..Default::default()
3599        };
3600        session.peer.send(connection_id, update.clone())?;
3601    }
3602
3603    Ok(())
3604}
3605
3606/// Delete a channel
3607async fn delete_channel(
3608    request: proto::DeleteChannel,
3609    response: Response<proto::DeleteChannel>,
3610    session: UserSession,
3611) -> Result<()> {
3612    let db = session.db().await;
3613
3614    let channel_id = request.channel_id;
3615    let (root_channel, removed_channels) = db
3616        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3617        .await?;
3618    response.send(proto::Ack {})?;
3619
3620    // Notify members of removed channels
3621    let mut update = proto::UpdateChannels::default();
3622    update
3623        .delete_channels
3624        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3625
3626    let connection_pool = session.connection_pool().await;
3627    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3628        session.peer.send(connection_id, update.clone())?;
3629    }
3630
3631    Ok(())
3632}
3633
3634/// Invite someone to join a channel.
3635async fn invite_channel_member(
3636    request: proto::InviteChannelMember,
3637    response: Response<proto::InviteChannelMember>,
3638    session: UserSession,
3639) -> Result<()> {
3640    let db = session.db().await;
3641    let channel_id = ChannelId::from_proto(request.channel_id);
3642    let invitee_id = UserId::from_proto(request.user_id);
3643    let InviteMemberResult {
3644        channel,
3645        notifications,
3646    } = db
3647        .invite_channel_member(
3648            channel_id,
3649            invitee_id,
3650            session.user_id(),
3651            request.role().into(),
3652        )
3653        .await?;
3654
3655    let update = proto::UpdateChannels {
3656        channel_invitations: vec![channel.to_proto()],
3657        ..Default::default()
3658    };
3659
3660    let connection_pool = session.connection_pool().await;
3661    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3662        session.peer.send(connection_id, update.clone())?;
3663    }
3664
3665    send_notifications(&connection_pool, &session.peer, notifications);
3666
3667    response.send(proto::Ack {})?;
3668    Ok(())
3669}
3670
3671/// remove someone from a channel
3672async fn remove_channel_member(
3673    request: proto::RemoveChannelMember,
3674    response: Response<proto::RemoveChannelMember>,
3675    session: UserSession,
3676) -> Result<()> {
3677    let db = session.db().await;
3678    let channel_id = ChannelId::from_proto(request.channel_id);
3679    let member_id = UserId::from_proto(request.user_id);
3680
3681    let RemoveChannelMemberResult {
3682        membership_update,
3683        notification_id,
3684    } = db
3685        .remove_channel_member(channel_id, member_id, session.user_id())
3686        .await?;
3687
3688    let mut connection_pool = session.connection_pool().await;
3689    notify_membership_updated(
3690        &mut connection_pool,
3691        membership_update,
3692        member_id,
3693        &session.peer,
3694    );
3695    for connection_id in connection_pool.user_connection_ids(member_id) {
3696        if let Some(notification_id) = notification_id {
3697            session
3698                .peer
3699                .send(
3700                    connection_id,
3701                    proto::DeleteNotification {
3702                        notification_id: notification_id.to_proto(),
3703                    },
3704                )
3705                .trace_err();
3706        }
3707    }
3708
3709    response.send(proto::Ack {})?;
3710    Ok(())
3711}
3712
3713/// Toggle the channel between public and private.
3714/// Care is taken to maintain the invariant that public channels only descend from public channels,
3715/// (though members-only channels can appear at any point in the hierarchy).
3716async fn set_channel_visibility(
3717    request: proto::SetChannelVisibility,
3718    response: Response<proto::SetChannelVisibility>,
3719    session: UserSession,
3720) -> Result<()> {
3721    let db = session.db().await;
3722    let channel_id = ChannelId::from_proto(request.channel_id);
3723    let visibility = request.visibility().into();
3724
3725    let channel_model = db
3726        .set_channel_visibility(channel_id, visibility, session.user_id())
3727        .await?;
3728    let root_id = channel_model.root_id();
3729    let channel = Channel::from_model(channel_model);
3730
3731    let mut connection_pool = session.connection_pool().await;
3732    for (user_id, role) in connection_pool
3733        .channel_user_ids(root_id)
3734        .collect::<Vec<_>>()
3735        .into_iter()
3736    {
3737        let update = if role.can_see_channel(channel.visibility) {
3738            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3739            proto::UpdateChannels {
3740                channels: vec![channel.to_proto()],
3741                ..Default::default()
3742            }
3743        } else {
3744            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3745            proto::UpdateChannels {
3746                delete_channels: vec![channel.id.to_proto()],
3747                ..Default::default()
3748            }
3749        };
3750
3751        for connection_id in connection_pool.user_connection_ids(user_id) {
3752            session.peer.send(connection_id, update.clone())?;
3753        }
3754    }
3755
3756    response.send(proto::Ack {})?;
3757    Ok(())
3758}
3759
3760/// Alter the role for a user in the channel.
3761async fn set_channel_member_role(
3762    request: proto::SetChannelMemberRole,
3763    response: Response<proto::SetChannelMemberRole>,
3764    session: UserSession,
3765) -> Result<()> {
3766    let db = session.db().await;
3767    let channel_id = ChannelId::from_proto(request.channel_id);
3768    let member_id = UserId::from_proto(request.user_id);
3769    let result = db
3770        .set_channel_member_role(
3771            channel_id,
3772            session.user_id(),
3773            member_id,
3774            request.role().into(),
3775        )
3776        .await?;
3777
3778    match result {
3779        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3780            let mut connection_pool = session.connection_pool().await;
3781            notify_membership_updated(
3782                &mut connection_pool,
3783                membership_update,
3784                member_id,
3785                &session.peer,
3786            )
3787        }
3788        db::SetMemberRoleResult::InviteUpdated(channel) => {
3789            let update = proto::UpdateChannels {
3790                channel_invitations: vec![channel.to_proto()],
3791                ..Default::default()
3792            };
3793
3794            for connection_id in session
3795                .connection_pool()
3796                .await
3797                .user_connection_ids(member_id)
3798            {
3799                session.peer.send(connection_id, update.clone())?;
3800            }
3801        }
3802    }
3803
3804    response.send(proto::Ack {})?;
3805    Ok(())
3806}
3807
3808/// Change the name of a channel
3809async fn rename_channel(
3810    request: proto::RenameChannel,
3811    response: Response<proto::RenameChannel>,
3812    session: UserSession,
3813) -> Result<()> {
3814    let db = session.db().await;
3815    let channel_id = ChannelId::from_proto(request.channel_id);
3816    let channel_model = db
3817        .rename_channel(channel_id, session.user_id(), &request.name)
3818        .await?;
3819    let root_id = channel_model.root_id();
3820    let channel = Channel::from_model(channel_model);
3821
3822    response.send(proto::RenameChannelResponse {
3823        channel: Some(channel.to_proto()),
3824    })?;
3825
3826    let connection_pool = session.connection_pool().await;
3827    let update = proto::UpdateChannels {
3828        channels: vec![channel.to_proto()],
3829        ..Default::default()
3830    };
3831    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3832        if role.can_see_channel(channel.visibility) {
3833            session.peer.send(connection_id, update.clone())?;
3834        }
3835    }
3836
3837    Ok(())
3838}
3839
3840/// Move a channel to a new parent.
3841async fn move_channel(
3842    request: proto::MoveChannel,
3843    response: Response<proto::MoveChannel>,
3844    session: UserSession,
3845) -> Result<()> {
3846    let channel_id = ChannelId::from_proto(request.channel_id);
3847    let to = ChannelId::from_proto(request.to);
3848
3849    let (root_id, channels) = session
3850        .db()
3851        .await
3852        .move_channel(channel_id, to, session.user_id())
3853        .await?;
3854
3855    let connection_pool = session.connection_pool().await;
3856    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3857        let channels = channels
3858            .iter()
3859            .filter_map(|channel| {
3860                if role.can_see_channel(channel.visibility) {
3861                    Some(channel.to_proto())
3862                } else {
3863                    None
3864                }
3865            })
3866            .collect::<Vec<_>>();
3867        if channels.is_empty() {
3868            continue;
3869        }
3870
3871        let update = proto::UpdateChannels {
3872            channels,
3873            ..Default::default()
3874        };
3875
3876        session.peer.send(connection_id, update.clone())?;
3877    }
3878
3879    response.send(Ack {})?;
3880    Ok(())
3881}
3882
3883/// Get the list of channel members
3884async fn get_channel_members(
3885    request: proto::GetChannelMembers,
3886    response: Response<proto::GetChannelMembers>,
3887    session: UserSession,
3888) -> Result<()> {
3889    let db = session.db().await;
3890    let channel_id = ChannelId::from_proto(request.channel_id);
3891    let limit = if request.limit == 0 {
3892        u16::MAX as u64
3893    } else {
3894        request.limit
3895    };
3896    let (members, users) = db
3897        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3898        .await?;
3899    response.send(proto::GetChannelMembersResponse { members, users })?;
3900    Ok(())
3901}
3902
3903/// Accept or decline a channel invitation.
3904async fn respond_to_channel_invite(
3905    request: proto::RespondToChannelInvite,
3906    response: Response<proto::RespondToChannelInvite>,
3907    session: UserSession,
3908) -> Result<()> {
3909    let db = session.db().await;
3910    let channel_id = ChannelId::from_proto(request.channel_id);
3911    let RespondToChannelInvite {
3912        membership_update,
3913        notifications,
3914    } = db
3915        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3916        .await?;
3917
3918    let mut connection_pool = session.connection_pool().await;
3919    if let Some(membership_update) = membership_update {
3920        notify_membership_updated(
3921            &mut connection_pool,
3922            membership_update,
3923            session.user_id(),
3924            &session.peer,
3925        );
3926    } else {
3927        let update = proto::UpdateChannels {
3928            remove_channel_invitations: vec![channel_id.to_proto()],
3929            ..Default::default()
3930        };
3931
3932        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3933            session.peer.send(connection_id, update.clone())?;
3934        }
3935    };
3936
3937    send_notifications(&connection_pool, &session.peer, notifications);
3938
3939    response.send(proto::Ack {})?;
3940
3941    Ok(())
3942}
3943
3944/// Join the channels' room
3945async fn join_channel(
3946    request: proto::JoinChannel,
3947    response: Response<proto::JoinChannel>,
3948    session: UserSession,
3949) -> Result<()> {
3950    let channel_id = ChannelId::from_proto(request.channel_id);
3951    join_channel_internal(channel_id, Box::new(response), session).await
3952}
3953
3954trait JoinChannelInternalResponse {
3955    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3956}
3957impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3958    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3959        Response::<proto::JoinChannel>::send(self, result)
3960    }
3961}
3962impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3963    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3964        Response::<proto::JoinRoom>::send(self, result)
3965    }
3966}
3967
3968async fn join_channel_internal(
3969    channel_id: ChannelId,
3970    response: Box<impl JoinChannelInternalResponse>,
3971    session: UserSession,
3972) -> Result<()> {
3973    let joined_room = {
3974        let mut db = session.db().await;
3975        // If zed quits without leaving the room, and the user re-opens zed before the
3976        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3977        // room they were in.
3978        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3979            tracing::info!(
3980                stale_connection_id = %connection,
3981                "cleaning up stale connection",
3982            );
3983            drop(db);
3984            leave_room_for_session(&session, connection).await?;
3985            db = session.db().await;
3986        }
3987
3988        let (joined_room, membership_updated, role) = db
3989            .join_channel(channel_id, session.user_id(), session.connection_id)
3990            .await?;
3991
3992        let live_kit_connection_info =
3993            session
3994                .app_state
3995                .live_kit_client
3996                .as_ref()
3997                .and_then(|live_kit| {
3998                    let (can_publish, token) = if role == ChannelRole::Guest {
3999                        (
4000                            false,
4001                            live_kit
4002                                .guest_token(
4003                                    &joined_room.room.live_kit_room,
4004                                    &session.user_id().to_string(),
4005                                )
4006                                .trace_err()?,
4007                        )
4008                    } else {
4009                        (
4010                            true,
4011                            live_kit
4012                                .room_token(
4013                                    &joined_room.room.live_kit_room,
4014                                    &session.user_id().to_string(),
4015                                )
4016                                .trace_err()?,
4017                        )
4018                    };
4019
4020                    Some(LiveKitConnectionInfo {
4021                        server_url: live_kit.url().into(),
4022                        token,
4023                        can_publish,
4024                    })
4025                });
4026
4027        response.send(proto::JoinRoomResponse {
4028            room: Some(joined_room.room.clone()),
4029            channel_id: joined_room
4030                .channel
4031                .as_ref()
4032                .map(|channel| channel.id.to_proto()),
4033            live_kit_connection_info,
4034        })?;
4035
4036        let mut connection_pool = session.connection_pool().await;
4037        if let Some(membership_updated) = membership_updated {
4038            notify_membership_updated(
4039                &mut connection_pool,
4040                membership_updated,
4041                session.user_id(),
4042                &session.peer,
4043            );
4044        }
4045
4046        room_updated(&joined_room.room, &session.peer);
4047
4048        joined_room
4049    };
4050
4051    channel_updated(
4052        &joined_room
4053            .channel
4054            .ok_or_else(|| anyhow!("channel not returned"))?,
4055        &joined_room.room,
4056        &session.peer,
4057        &*session.connection_pool().await,
4058    );
4059
4060    update_user_contacts(session.user_id(), &session).await?;
4061    Ok(())
4062}
4063
4064/// Start editing the channel notes
4065async fn join_channel_buffer(
4066    request: proto::JoinChannelBuffer,
4067    response: Response<proto::JoinChannelBuffer>,
4068    session: UserSession,
4069) -> Result<()> {
4070    let db = session.db().await;
4071    let channel_id = ChannelId::from_proto(request.channel_id);
4072
4073    let open_response = db
4074        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4075        .await?;
4076
4077    let collaborators = open_response.collaborators.clone();
4078    response.send(open_response)?;
4079
4080    let update = UpdateChannelBufferCollaborators {
4081        channel_id: channel_id.to_proto(),
4082        collaborators: collaborators.clone(),
4083    };
4084    channel_buffer_updated(
4085        session.connection_id,
4086        collaborators
4087            .iter()
4088            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4089        &update,
4090        &session.peer,
4091    );
4092
4093    Ok(())
4094}
4095
4096/// Edit the channel notes
4097async fn update_channel_buffer(
4098    request: proto::UpdateChannelBuffer,
4099    session: UserSession,
4100) -> Result<()> {
4101    let db = session.db().await;
4102    let channel_id = ChannelId::from_proto(request.channel_id);
4103
4104    let (collaborators, epoch, version) = db
4105        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4106        .await?;
4107
4108    channel_buffer_updated(
4109        session.connection_id,
4110        collaborators.clone(),
4111        &proto::UpdateChannelBuffer {
4112            channel_id: channel_id.to_proto(),
4113            operations: request.operations,
4114        },
4115        &session.peer,
4116    );
4117
4118    let pool = &*session.connection_pool().await;
4119
4120    let non_collaborators =
4121        pool.channel_connection_ids(channel_id)
4122            .filter_map(|(connection_id, _)| {
4123                if collaborators.contains(&connection_id) {
4124                    None
4125                } else {
4126                    Some(connection_id)
4127                }
4128            });
4129
4130    broadcast(None, non_collaborators, |peer_id| {
4131        session.peer.send(
4132            peer_id,
4133            proto::UpdateChannels {
4134                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4135                    channel_id: channel_id.to_proto(),
4136                    epoch: epoch as u64,
4137                    version: version.clone(),
4138                }],
4139                ..Default::default()
4140            },
4141        )
4142    });
4143
4144    Ok(())
4145}
4146
4147/// Rejoin the channel notes after a connection blip
4148async fn rejoin_channel_buffers(
4149    request: proto::RejoinChannelBuffers,
4150    response: Response<proto::RejoinChannelBuffers>,
4151    session: UserSession,
4152) -> Result<()> {
4153    let db = session.db().await;
4154    let buffers = db
4155        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4156        .await?;
4157
4158    for rejoined_buffer in &buffers {
4159        let collaborators_to_notify = rejoined_buffer
4160            .buffer
4161            .collaborators
4162            .iter()
4163            .filter_map(|c| Some(c.peer_id?.into()));
4164        channel_buffer_updated(
4165            session.connection_id,
4166            collaborators_to_notify,
4167            &proto::UpdateChannelBufferCollaborators {
4168                channel_id: rejoined_buffer.buffer.channel_id,
4169                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4170            },
4171            &session.peer,
4172        );
4173    }
4174
4175    response.send(proto::RejoinChannelBuffersResponse {
4176        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4177    })?;
4178
4179    Ok(())
4180}
4181
4182/// Stop editing the channel notes
4183async fn leave_channel_buffer(
4184    request: proto::LeaveChannelBuffer,
4185    response: Response<proto::LeaveChannelBuffer>,
4186    session: UserSession,
4187) -> Result<()> {
4188    let db = session.db().await;
4189    let channel_id = ChannelId::from_proto(request.channel_id);
4190
4191    let left_buffer = db
4192        .leave_channel_buffer(channel_id, session.connection_id)
4193        .await?;
4194
4195    response.send(Ack {})?;
4196
4197    channel_buffer_updated(
4198        session.connection_id,
4199        left_buffer.connections,
4200        &proto::UpdateChannelBufferCollaborators {
4201            channel_id: channel_id.to_proto(),
4202            collaborators: left_buffer.collaborators,
4203        },
4204        &session.peer,
4205    );
4206
4207    Ok(())
4208}
4209
4210fn channel_buffer_updated<T: EnvelopedMessage>(
4211    sender_id: ConnectionId,
4212    collaborators: impl IntoIterator<Item = ConnectionId>,
4213    message: &T,
4214    peer: &Peer,
4215) {
4216    broadcast(Some(sender_id), collaborators, |peer_id| {
4217        peer.send(peer_id, message.clone())
4218    });
4219}
4220
4221fn send_notifications(
4222    connection_pool: &ConnectionPool,
4223    peer: &Peer,
4224    notifications: db::NotificationBatch,
4225) {
4226    for (user_id, notification) in notifications {
4227        for connection_id in connection_pool.user_connection_ids(user_id) {
4228            if let Err(error) = peer.send(
4229                connection_id,
4230                proto::AddNotification {
4231                    notification: Some(notification.clone()),
4232                },
4233            ) {
4234                tracing::error!(
4235                    "failed to send notification to {:?} {}",
4236                    connection_id,
4237                    error
4238                );
4239            }
4240        }
4241    }
4242}
4243
4244/// Send a message to the channel
4245async fn send_channel_message(
4246    request: proto::SendChannelMessage,
4247    response: Response<proto::SendChannelMessage>,
4248    session: UserSession,
4249) -> Result<()> {
4250    // Validate the message body.
4251    let body = request.body.trim().to_string();
4252    if body.len() > MAX_MESSAGE_LEN {
4253        return Err(anyhow!("message is too long"))?;
4254    }
4255    if body.is_empty() {
4256        return Err(anyhow!("message can't be blank"))?;
4257    }
4258
4259    // TODO: adjust mentions if body is trimmed
4260
4261    let timestamp = OffsetDateTime::now_utc();
4262    let nonce = request
4263        .nonce
4264        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4265
4266    let channel_id = ChannelId::from_proto(request.channel_id);
4267    let CreatedChannelMessage {
4268        message_id,
4269        participant_connection_ids,
4270        notifications,
4271    } = session
4272        .db()
4273        .await
4274        .create_channel_message(
4275            channel_id,
4276            session.user_id(),
4277            &body,
4278            &request.mentions,
4279            timestamp,
4280            nonce.clone().into(),
4281            request.reply_to_message_id.map(MessageId::from_proto),
4282        )
4283        .await?;
4284
4285    let message = proto::ChannelMessage {
4286        sender_id: session.user_id().to_proto(),
4287        id: message_id.to_proto(),
4288        body,
4289        mentions: request.mentions,
4290        timestamp: timestamp.unix_timestamp() as u64,
4291        nonce: Some(nonce),
4292        reply_to_message_id: request.reply_to_message_id,
4293        edited_at: None,
4294    };
4295    broadcast(
4296        Some(session.connection_id),
4297        participant_connection_ids.clone(),
4298        |connection| {
4299            session.peer.send(
4300                connection,
4301                proto::ChannelMessageSent {
4302                    channel_id: channel_id.to_proto(),
4303                    message: Some(message.clone()),
4304                },
4305            )
4306        },
4307    );
4308    response.send(proto::SendChannelMessageResponse {
4309        message: Some(message),
4310    })?;
4311
4312    let pool = &*session.connection_pool().await;
4313    let non_participants =
4314        pool.channel_connection_ids(channel_id)
4315            .filter_map(|(connection_id, _)| {
4316                if participant_connection_ids.contains(&connection_id) {
4317                    None
4318                } else {
4319                    Some(connection_id)
4320                }
4321            });
4322    broadcast(None, non_participants, |peer_id| {
4323        session.peer.send(
4324            peer_id,
4325            proto::UpdateChannels {
4326                latest_channel_message_ids: vec![proto::ChannelMessageId {
4327                    channel_id: channel_id.to_proto(),
4328                    message_id: message_id.to_proto(),
4329                }],
4330                ..Default::default()
4331            },
4332        )
4333    });
4334    send_notifications(pool, &session.peer, notifications);
4335
4336    Ok(())
4337}
4338
4339/// Delete a channel message
4340async fn remove_channel_message(
4341    request: proto::RemoveChannelMessage,
4342    response: Response<proto::RemoveChannelMessage>,
4343    session: UserSession,
4344) -> Result<()> {
4345    let channel_id = ChannelId::from_proto(request.channel_id);
4346    let message_id = MessageId::from_proto(request.message_id);
4347    let (connection_ids, existing_notification_ids) = session
4348        .db()
4349        .await
4350        .remove_channel_message(channel_id, message_id, session.user_id())
4351        .await?;
4352
4353    broadcast(
4354        Some(session.connection_id),
4355        connection_ids,
4356        move |connection| {
4357            session.peer.send(connection, request.clone())?;
4358
4359            for notification_id in &existing_notification_ids {
4360                session.peer.send(
4361                    connection,
4362                    proto::DeleteNotification {
4363                        notification_id: (*notification_id).to_proto(),
4364                    },
4365                )?;
4366            }
4367
4368            Ok(())
4369        },
4370    );
4371    response.send(proto::Ack {})?;
4372    Ok(())
4373}
4374
4375async fn update_channel_message(
4376    request: proto::UpdateChannelMessage,
4377    response: Response<proto::UpdateChannelMessage>,
4378    session: UserSession,
4379) -> Result<()> {
4380    let channel_id = ChannelId::from_proto(request.channel_id);
4381    let message_id = MessageId::from_proto(request.message_id);
4382    let updated_at = OffsetDateTime::now_utc();
4383    let UpdatedChannelMessage {
4384        message_id,
4385        participant_connection_ids,
4386        notifications,
4387        reply_to_message_id,
4388        timestamp,
4389        deleted_mention_notification_ids,
4390        updated_mention_notifications,
4391    } = session
4392        .db()
4393        .await
4394        .update_channel_message(
4395            channel_id,
4396            message_id,
4397            session.user_id(),
4398            request.body.as_str(),
4399            &request.mentions,
4400            updated_at,
4401        )
4402        .await?;
4403
4404    let nonce = request
4405        .nonce
4406        .clone()
4407        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4408
4409    let message = proto::ChannelMessage {
4410        sender_id: session.user_id().to_proto(),
4411        id: message_id.to_proto(),
4412        body: request.body.clone(),
4413        mentions: request.mentions.clone(),
4414        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4415        nonce: Some(nonce),
4416        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4417        edited_at: Some(updated_at.unix_timestamp() as u64),
4418    };
4419
4420    response.send(proto::Ack {})?;
4421
4422    let pool = &*session.connection_pool().await;
4423    broadcast(
4424        Some(session.connection_id),
4425        participant_connection_ids,
4426        |connection| {
4427            session.peer.send(
4428                connection,
4429                proto::ChannelMessageUpdate {
4430                    channel_id: channel_id.to_proto(),
4431                    message: Some(message.clone()),
4432                },
4433            )?;
4434
4435            for notification_id in &deleted_mention_notification_ids {
4436                session.peer.send(
4437                    connection,
4438                    proto::DeleteNotification {
4439                        notification_id: (*notification_id).to_proto(),
4440                    },
4441                )?;
4442            }
4443
4444            for notification in &updated_mention_notifications {
4445                session.peer.send(
4446                    connection,
4447                    proto::UpdateNotification {
4448                        notification: Some(notification.clone()),
4449                    },
4450                )?;
4451            }
4452
4453            Ok(())
4454        },
4455    );
4456
4457    send_notifications(pool, &session.peer, notifications);
4458
4459    Ok(())
4460}
4461
4462/// Mark a channel message as read
4463async fn acknowledge_channel_message(
4464    request: proto::AckChannelMessage,
4465    session: UserSession,
4466) -> Result<()> {
4467    let channel_id = ChannelId::from_proto(request.channel_id);
4468    let message_id = MessageId::from_proto(request.message_id);
4469    let notifications = session
4470        .db()
4471        .await
4472        .observe_channel_message(channel_id, session.user_id(), message_id)
4473        .await?;
4474    send_notifications(
4475        &*session.connection_pool().await,
4476        &session.peer,
4477        notifications,
4478    );
4479    Ok(())
4480}
4481
4482/// Mark a buffer version as synced
4483async fn acknowledge_buffer_version(
4484    request: proto::AckBufferOperation,
4485    session: UserSession,
4486) -> Result<()> {
4487    let buffer_id = BufferId::from_proto(request.buffer_id);
4488    session
4489        .db()
4490        .await
4491        .observe_buffer_version(
4492            buffer_id,
4493            session.user_id(),
4494            request.epoch as i32,
4495            &request.version,
4496        )
4497        .await?;
4498    Ok(())
4499}
4500
4501async fn count_language_model_tokens(
4502    request: proto::CountLanguageModelTokens,
4503    response: Response<proto::CountLanguageModelTokens>,
4504    session: Session,
4505    config: &Config,
4506) -> Result<()> {
4507    let Some(session) = session.for_user() else {
4508        return Err(anyhow!("user not found"))?;
4509    };
4510    authorize_access_to_legacy_llm_endpoints(&session).await?;
4511
4512    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4513        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4514        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4515    };
4516
4517    session
4518        .app_state
4519        .rate_limiter
4520        .check(&*rate_limit, session.user_id())
4521        .await?;
4522
4523    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4524        Some(proto::LanguageModelProvider::Google) => {
4525            let api_key = config
4526                .google_ai_api_key
4527                .as_ref()
4528                .context("no Google AI API key configured on the server")?;
4529            google_ai::count_tokens(
4530                session.http_client.as_ref(),
4531                google_ai::API_URL,
4532                api_key,
4533                serde_json::from_str(&request.request)?,
4534                None,
4535            )
4536            .await?
4537        }
4538        _ => return Err(anyhow!("unsupported provider"))?,
4539    };
4540
4541    response.send(proto::CountLanguageModelTokensResponse {
4542        token_count: result.total_tokens as u32,
4543    })?;
4544
4545    Ok(())
4546}
4547
4548struct ZedProCountLanguageModelTokensRateLimit;
4549
4550impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4551    fn capacity(&self) -> usize {
4552        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4553            .ok()
4554            .and_then(|v| v.parse().ok())
4555            .unwrap_or(600) // Picked arbitrarily
4556    }
4557
4558    fn refill_duration(&self) -> chrono::Duration {
4559        chrono::Duration::hours(1)
4560    }
4561
4562    fn db_name(&self) -> &'static str {
4563        "zed-pro:count-language-model-tokens"
4564    }
4565}
4566
4567struct FreeCountLanguageModelTokensRateLimit;
4568
4569impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4570    fn capacity(&self) -> usize {
4571        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4572            .ok()
4573            .and_then(|v| v.parse().ok())
4574            .unwrap_or(600 / 10) // Picked arbitrarily
4575    }
4576
4577    fn refill_duration(&self) -> chrono::Duration {
4578        chrono::Duration::hours(1)
4579    }
4580
4581    fn db_name(&self) -> &'static str {
4582        "free:count-language-model-tokens"
4583    }
4584}
4585
4586struct ZedProComputeEmbeddingsRateLimit;
4587
4588impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4589    fn capacity(&self) -> usize {
4590        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4591            .ok()
4592            .and_then(|v| v.parse().ok())
4593            .unwrap_or(5000) // Picked arbitrarily
4594    }
4595
4596    fn refill_duration(&self) -> chrono::Duration {
4597        chrono::Duration::hours(1)
4598    }
4599
4600    fn db_name(&self) -> &'static str {
4601        "zed-pro:compute-embeddings"
4602    }
4603}
4604
4605struct FreeComputeEmbeddingsRateLimit;
4606
4607impl RateLimit for FreeComputeEmbeddingsRateLimit {
4608    fn capacity(&self) -> usize {
4609        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4610            .ok()
4611            .and_then(|v| v.parse().ok())
4612            .unwrap_or(5000 / 10) // Picked arbitrarily
4613    }
4614
4615    fn refill_duration(&self) -> chrono::Duration {
4616        chrono::Duration::hours(1)
4617    }
4618
4619    fn db_name(&self) -> &'static str {
4620        "free:compute-embeddings"
4621    }
4622}
4623
4624async fn compute_embeddings(
4625    request: proto::ComputeEmbeddings,
4626    response: Response<proto::ComputeEmbeddings>,
4627    session: UserSession,
4628    api_key: Option<Arc<str>>,
4629) -> Result<()> {
4630    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4631    authorize_access_to_legacy_llm_endpoints(&session).await?;
4632
4633    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4634        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4635        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4636    };
4637
4638    session
4639        .app_state
4640        .rate_limiter
4641        .check(&*rate_limit, session.user_id())
4642        .await?;
4643
4644    let embeddings = match request.model.as_str() {
4645        "openai/text-embedding-3-small" => {
4646            open_ai::embed(
4647                session.http_client.as_ref(),
4648                OPEN_AI_API_URL,
4649                &api_key,
4650                OpenAiEmbeddingModel::TextEmbedding3Small,
4651                request.texts.iter().map(|text| text.as_str()),
4652            )
4653            .await?
4654        }
4655        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4656    };
4657
4658    let embeddings = request
4659        .texts
4660        .iter()
4661        .map(|text| {
4662            let mut hasher = sha2::Sha256::new();
4663            hasher.update(text.as_bytes());
4664            let result = hasher.finalize();
4665            result.to_vec()
4666        })
4667        .zip(
4668            embeddings
4669                .data
4670                .into_iter()
4671                .map(|embedding| embedding.embedding),
4672        )
4673        .collect::<HashMap<_, _>>();
4674
4675    let db = session.db().await;
4676    db.save_embeddings(&request.model, &embeddings)
4677        .await
4678        .context("failed to save embeddings")
4679        .trace_err();
4680
4681    response.send(proto::ComputeEmbeddingsResponse {
4682        embeddings: embeddings
4683            .into_iter()
4684            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4685            .collect(),
4686    })?;
4687    Ok(())
4688}
4689
4690async fn get_cached_embeddings(
4691    request: proto::GetCachedEmbeddings,
4692    response: Response<proto::GetCachedEmbeddings>,
4693    session: UserSession,
4694) -> Result<()> {
4695    authorize_access_to_legacy_llm_endpoints(&session).await?;
4696
4697    let db = session.db().await;
4698    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4699
4700    response.send(proto::GetCachedEmbeddingsResponse {
4701        embeddings: embeddings
4702            .into_iter()
4703            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4704            .collect(),
4705    })?;
4706    Ok(())
4707}
4708
4709/// This is leftover from before the LLM service.
4710///
4711/// The endpoints protected by this check will be moved there eventually.
4712async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4713    if session.is_staff() {
4714        Ok(())
4715    } else {
4716        Err(anyhow!("permission denied"))?
4717    }
4718}
4719
4720/// Get a Supermaven API key for the user
4721async fn get_supermaven_api_key(
4722    _request: proto::GetSupermavenApiKey,
4723    response: Response<proto::GetSupermavenApiKey>,
4724    session: UserSession,
4725) -> Result<()> {
4726    let user_id: String = session.user_id().to_string();
4727    if !session.is_staff() {
4728        return Err(anyhow!("supermaven not enabled for this account"))?;
4729    }
4730
4731    let email = session
4732        .email()
4733        .ok_or_else(|| anyhow!("user must have an email"))?;
4734
4735    let supermaven_admin_api = session
4736        .supermaven_client
4737        .as_ref()
4738        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4739
4740    let result = supermaven_admin_api
4741        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4742        .await?;
4743
4744    response.send(proto::GetSupermavenApiKeyResponse {
4745        api_key: result.api_key,
4746    })?;
4747
4748    Ok(())
4749}
4750
4751/// Start receiving chat updates for a channel
4752async fn join_channel_chat(
4753    request: proto::JoinChannelChat,
4754    response: Response<proto::JoinChannelChat>,
4755    session: UserSession,
4756) -> Result<()> {
4757    let channel_id = ChannelId::from_proto(request.channel_id);
4758
4759    let db = session.db().await;
4760    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4761        .await?;
4762    let messages = db
4763        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4764        .await?;
4765    response.send(proto::JoinChannelChatResponse {
4766        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4767        messages,
4768    })?;
4769    Ok(())
4770}
4771
4772/// Stop receiving chat updates for a channel
4773async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4774    let channel_id = ChannelId::from_proto(request.channel_id);
4775    session
4776        .db()
4777        .await
4778        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4779        .await?;
4780    Ok(())
4781}
4782
4783/// Retrieve the chat history for a channel
4784async fn get_channel_messages(
4785    request: proto::GetChannelMessages,
4786    response: Response<proto::GetChannelMessages>,
4787    session: UserSession,
4788) -> Result<()> {
4789    let channel_id = ChannelId::from_proto(request.channel_id);
4790    let messages = session
4791        .db()
4792        .await
4793        .get_channel_messages(
4794            channel_id,
4795            session.user_id(),
4796            MESSAGE_COUNT_PER_PAGE,
4797            Some(MessageId::from_proto(request.before_message_id)),
4798        )
4799        .await?;
4800    response.send(proto::GetChannelMessagesResponse {
4801        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4802        messages,
4803    })?;
4804    Ok(())
4805}
4806
4807/// Retrieve specific chat messages
4808async fn get_channel_messages_by_id(
4809    request: proto::GetChannelMessagesById,
4810    response: Response<proto::GetChannelMessagesById>,
4811    session: UserSession,
4812) -> Result<()> {
4813    let message_ids = request
4814        .message_ids
4815        .iter()
4816        .map(|id| MessageId::from_proto(*id))
4817        .collect::<Vec<_>>();
4818    let messages = session
4819        .db()
4820        .await
4821        .get_channel_messages_by_id(session.user_id(), &message_ids)
4822        .await?;
4823    response.send(proto::GetChannelMessagesResponse {
4824        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4825        messages,
4826    })?;
4827    Ok(())
4828}
4829
4830/// Retrieve the current users notifications
4831async fn get_notifications(
4832    request: proto::GetNotifications,
4833    response: Response<proto::GetNotifications>,
4834    session: UserSession,
4835) -> Result<()> {
4836    let notifications = session
4837        .db()
4838        .await
4839        .get_notifications(
4840            session.user_id(),
4841            NOTIFICATION_COUNT_PER_PAGE,
4842            request.before_id.map(db::NotificationId::from_proto),
4843        )
4844        .await?;
4845    response.send(proto::GetNotificationsResponse {
4846        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4847        notifications,
4848    })?;
4849    Ok(())
4850}
4851
4852/// Mark notifications as read
4853async fn mark_notification_as_read(
4854    request: proto::MarkNotificationRead,
4855    response: Response<proto::MarkNotificationRead>,
4856    session: UserSession,
4857) -> Result<()> {
4858    let database = &session.db().await;
4859    let notifications = database
4860        .mark_notification_as_read_by_id(
4861            session.user_id(),
4862            NotificationId::from_proto(request.notification_id),
4863        )
4864        .await?;
4865    send_notifications(
4866        &*session.connection_pool().await,
4867        &session.peer,
4868        notifications,
4869    );
4870    response.send(proto::Ack {})?;
4871    Ok(())
4872}
4873
4874/// Get the current users information
4875async fn get_private_user_info(
4876    _request: proto::GetPrivateUserInfo,
4877    response: Response<proto::GetPrivateUserInfo>,
4878    session: UserSession,
4879) -> Result<()> {
4880    let db = session.db().await;
4881
4882    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4883    let user = db
4884        .get_user_by_id(session.user_id())
4885        .await?
4886        .ok_or_else(|| anyhow!("user not found"))?;
4887    let flags = db.get_user_flags(session.user_id()).await?;
4888
4889    response.send(proto::GetPrivateUserInfoResponse {
4890        metrics_id,
4891        staff: user.admin,
4892        flags,
4893        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4894    })?;
4895    Ok(())
4896}
4897
4898/// Accept the terms of service (tos) on behalf of the current user
4899async fn accept_terms_of_service(
4900    _request: proto::AcceptTermsOfService,
4901    response: Response<proto::AcceptTermsOfService>,
4902    session: UserSession,
4903) -> Result<()> {
4904    let db = session.db().await;
4905
4906    let accepted_tos_at = Utc::now();
4907    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4908        .await?;
4909
4910    response.send(proto::AcceptTermsOfServiceResponse {
4911        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4912    })?;
4913    Ok(())
4914}
4915
4916/// The minimum account age an account must have in order to use the LLM service.
4917const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4918
4919async fn get_llm_api_token(
4920    _request: proto::GetLlmToken,
4921    response: Response<proto::GetLlmToken>,
4922    session: UserSession,
4923) -> Result<()> {
4924    let db = session.db().await;
4925
4926    let flags = db.get_user_flags(session.user_id()).await?;
4927    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4928    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4929
4930    if !session.is_staff() && !has_language_models_feature_flag {
4931        Err(anyhow!("permission denied"))?
4932    }
4933
4934    let user_id = session.user_id();
4935    let user = db
4936        .get_user_by_id(user_id)
4937        .await?
4938        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4939
4940    if user.accepted_tos_at.is_none() {
4941        Err(anyhow!("terms of service not accepted"))?
4942    }
4943
4944    let mut account_created_at = user.created_at;
4945    if let Some(github_created_at) = user.github_user_created_at {
4946        account_created_at = account_created_at.min(github_created_at);
4947    }
4948    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4949        Err(anyhow!("account too young"))?
4950    }
4951    let token = LlmTokenClaims::create(
4952        user.id,
4953        user.github_login.clone(),
4954        session.is_staff(),
4955        has_llm_closed_beta_feature_flag,
4956        session.current_plan(db).await?,
4957        &session.app_state.config,
4958    )?;
4959    response.send(proto::GetLlmTokenResponse { token })?;
4960    Ok(())
4961}
4962
4963fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4964    let message = match message {
4965        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4966        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4967        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4968        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4969        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4970            code: frame.code.into(),
4971            reason: frame.reason,
4972        })),
4973        // We should never receive a frame while reading the message, according
4974        // to the `tungstenite` maintainers:
4975        //
4976        // > It cannot occur when you read messages from the WebSocket, but it
4977        // > can be used when you want to send the raw frames (e.g. you want to
4978        // > send the frames to the WebSocket without composing the full message first).
4979        // >
4980        // > — https://github.com/snapview/tungstenite-rs/issues/268
4981        TungsteniteMessage::Frame(_) => {
4982            bail!("received an unexpected frame while reading the message")
4983        }
4984    };
4985
4986    Ok(message)
4987}
4988
4989fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4990    match message {
4991        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4992        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4993        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4994        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4995        AxumMessage::Close(frame) => {
4996            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4997                code: frame.code.into(),
4998                reason: frame.reason,
4999            }))
5000        }
5001    }
5002}
5003
5004fn notify_membership_updated(
5005    connection_pool: &mut ConnectionPool,
5006    result: MembershipUpdated,
5007    user_id: UserId,
5008    peer: &Peer,
5009) {
5010    for membership in &result.new_channels.channel_memberships {
5011        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5012    }
5013    for channel_id in &result.removed_channels {
5014        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5015    }
5016
5017    let user_channels_update = proto::UpdateUserChannels {
5018        channel_memberships: result
5019            .new_channels
5020            .channel_memberships
5021            .iter()
5022            .map(|cm| proto::ChannelMembership {
5023                channel_id: cm.channel_id.to_proto(),
5024                role: cm.role.into(),
5025            })
5026            .collect(),
5027        ..Default::default()
5028    };
5029
5030    let mut update = build_channels_update(result.new_channels);
5031    update.delete_channels = result
5032        .removed_channels
5033        .into_iter()
5034        .map(|id| id.to_proto())
5035        .collect();
5036    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5037
5038    for connection_id in connection_pool.user_connection_ids(user_id) {
5039        peer.send(connection_id, user_channels_update.clone())
5040            .trace_err();
5041        peer.send(connection_id, update.clone()).trace_err();
5042    }
5043}
5044
5045fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5046    proto::UpdateUserChannels {
5047        channel_memberships: channels
5048            .channel_memberships
5049            .iter()
5050            .map(|m| proto::ChannelMembership {
5051                channel_id: m.channel_id.to_proto(),
5052                role: m.role.into(),
5053            })
5054            .collect(),
5055        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5056        observed_channel_message_id: channels.observed_channel_messages.clone(),
5057    }
5058}
5059
5060fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5061    let mut update = proto::UpdateChannels::default();
5062
5063    for channel in channels.channels {
5064        update.channels.push(channel.to_proto());
5065    }
5066
5067    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5068    update.latest_channel_message_ids = channels.latest_channel_messages;
5069
5070    for (channel_id, participants) in channels.channel_participants {
5071        update
5072            .channel_participants
5073            .push(proto::ChannelParticipants {
5074                channel_id: channel_id.to_proto(),
5075                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5076            });
5077    }
5078
5079    for channel in channels.invited_channels {
5080        update.channel_invitations.push(channel.to_proto());
5081    }
5082
5083    update.hosted_projects = channels.hosted_projects;
5084    update
5085}
5086
5087fn build_initial_contacts_update(
5088    contacts: Vec<db::Contact>,
5089    pool: &ConnectionPool,
5090) -> proto::UpdateContacts {
5091    let mut update = proto::UpdateContacts::default();
5092
5093    for contact in contacts {
5094        match contact {
5095            db::Contact::Accepted { user_id, busy } => {
5096                update.contacts.push(contact_for_user(user_id, busy, pool));
5097            }
5098            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5099            db::Contact::Incoming { user_id } => {
5100                update
5101                    .incoming_requests
5102                    .push(proto::IncomingContactRequest {
5103                        requester_id: user_id.to_proto(),
5104                    })
5105            }
5106        }
5107    }
5108
5109    update
5110}
5111
5112fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5113    proto::Contact {
5114        user_id: user_id.to_proto(),
5115        online: pool.is_user_online(user_id),
5116        busy,
5117    }
5118}
5119
5120fn room_updated(room: &proto::Room, peer: &Peer) {
5121    broadcast(
5122        None,
5123        room.participants
5124            .iter()
5125            .filter_map(|participant| Some(participant.peer_id?.into())),
5126        |peer_id| {
5127            peer.send(
5128                peer_id,
5129                proto::RoomUpdated {
5130                    room: Some(room.clone()),
5131                },
5132            )
5133        },
5134    );
5135}
5136
5137fn channel_updated(
5138    channel: &db::channel::Model,
5139    room: &proto::Room,
5140    peer: &Peer,
5141    pool: &ConnectionPool,
5142) {
5143    let participants = room
5144        .participants
5145        .iter()
5146        .map(|p| p.user_id)
5147        .collect::<Vec<_>>();
5148
5149    broadcast(
5150        None,
5151        pool.channel_connection_ids(channel.root_id())
5152            .filter_map(|(channel_id, role)| {
5153                role.can_see_channel(channel.visibility)
5154                    .then_some(channel_id)
5155            }),
5156        |peer_id| {
5157            peer.send(
5158                peer_id,
5159                proto::UpdateChannels {
5160                    channel_participants: vec![proto::ChannelParticipants {
5161                        channel_id: channel.id.to_proto(),
5162                        participant_user_ids: participants.clone(),
5163                    }],
5164                    ..Default::default()
5165                },
5166            )
5167        },
5168    );
5169}
5170
5171async fn send_dev_server_projects_update(
5172    user_id: UserId,
5173    mut status: proto::DevServerProjectsUpdate,
5174    session: &Session,
5175) {
5176    let pool = session.connection_pool().await;
5177    for dev_server in &mut status.dev_servers {
5178        dev_server.status =
5179            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5180    }
5181    let connections = pool.user_connection_ids(user_id);
5182    for connection_id in connections {
5183        session.peer.send(connection_id, status.clone()).trace_err();
5184    }
5185}
5186
5187async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5188    let db = session.db().await;
5189
5190    let contacts = db.get_contacts(user_id).await?;
5191    let busy = db.is_user_busy(user_id).await?;
5192
5193    let pool = session.connection_pool().await;
5194    let updated_contact = contact_for_user(user_id, busy, &pool);
5195    for contact in contacts {
5196        if let db::Contact::Accepted {
5197            user_id: contact_user_id,
5198            ..
5199        } = contact
5200        {
5201            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5202                session
5203                    .peer
5204                    .send(
5205                        contact_conn_id,
5206                        proto::UpdateContacts {
5207                            contacts: vec![updated_contact.clone()],
5208                            remove_contacts: Default::default(),
5209                            incoming_requests: Default::default(),
5210                            remove_incoming_requests: Default::default(),
5211                            outgoing_requests: Default::default(),
5212                            remove_outgoing_requests: Default::default(),
5213                        },
5214                    )
5215                    .trace_err();
5216            }
5217        }
5218    }
5219    Ok(())
5220}
5221
5222async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5223    log::info!("lost dev server connection, unsharing projects");
5224    let project_ids = session
5225        .db()
5226        .await
5227        .get_stale_dev_server_projects(session.connection_id)
5228        .await?;
5229
5230    for project_id in project_ids {
5231        // not unshare re-checks the connection ids match, so we get away with no transaction
5232        unshare_project_internal(project_id, session.connection_id, None, session).await?;
5233    }
5234
5235    let user_id = session.dev_server().user_id;
5236    let update = session
5237        .db()
5238        .await
5239        .dev_server_projects_update(user_id)
5240        .await?;
5241
5242    send_dev_server_projects_update(user_id, update, session).await;
5243
5244    Ok(())
5245}
5246
5247async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5248    let mut contacts_to_update = HashSet::default();
5249
5250    let room_id;
5251    let canceled_calls_to_user_ids;
5252    let live_kit_room;
5253    let delete_live_kit_room;
5254    let room;
5255    let channel;
5256
5257    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5258        contacts_to_update.insert(session.user_id());
5259
5260        for project in left_room.left_projects.values() {
5261            project_left(project, session);
5262        }
5263
5264        room_id = RoomId::from_proto(left_room.room.id);
5265        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5266        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5267        delete_live_kit_room = left_room.deleted;
5268        room = mem::take(&mut left_room.room);
5269        channel = mem::take(&mut left_room.channel);
5270
5271        room_updated(&room, &session.peer);
5272    } else {
5273        return Ok(());
5274    }
5275
5276    if let Some(channel) = channel {
5277        channel_updated(
5278            &channel,
5279            &room,
5280            &session.peer,
5281            &*session.connection_pool().await,
5282        );
5283    }
5284
5285    {
5286        let pool = session.connection_pool().await;
5287        for canceled_user_id in canceled_calls_to_user_ids {
5288            for connection_id in pool.user_connection_ids(canceled_user_id) {
5289                session
5290                    .peer
5291                    .send(
5292                        connection_id,
5293                        proto::CallCanceled {
5294                            room_id: room_id.to_proto(),
5295                        },
5296                    )
5297                    .trace_err();
5298            }
5299            contacts_to_update.insert(canceled_user_id);
5300        }
5301    }
5302
5303    for contact_user_id in contacts_to_update {
5304        update_user_contacts(contact_user_id, session).await?;
5305    }
5306
5307    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5308        live_kit
5309            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5310            .await
5311            .trace_err();
5312
5313        if delete_live_kit_room {
5314            live_kit.delete_room(live_kit_room).await.trace_err();
5315        }
5316    }
5317
5318    Ok(())
5319}
5320
5321async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5322    let left_channel_buffers = session
5323        .db()
5324        .await
5325        .leave_channel_buffers(session.connection_id)
5326        .await?;
5327
5328    for left_buffer in left_channel_buffers {
5329        channel_buffer_updated(
5330            session.connection_id,
5331            left_buffer.connections,
5332            &proto::UpdateChannelBufferCollaborators {
5333                channel_id: left_buffer.channel_id.to_proto(),
5334                collaborators: left_buffer.collaborators,
5335            },
5336            &session.peer,
5337        );
5338    }
5339
5340    Ok(())
5341}
5342
5343fn project_left(project: &db::LeftProject, session: &UserSession) {
5344    for connection_id in &project.connection_ids {
5345        if project.should_unshare {
5346            session
5347                .peer
5348                .send(
5349                    *connection_id,
5350                    proto::UnshareProject {
5351                        project_id: project.id.to_proto(),
5352                    },
5353                )
5354                .trace_err();
5355        } else {
5356            session
5357                .peer
5358                .send(
5359                    *connection_id,
5360                    proto::RemoveProjectCollaborator {
5361                        project_id: project.id.to_proto(),
5362                        peer_id: Some(session.connection_id.into()),
5363                    },
5364                )
5365                .trace_err();
5366        }
5367    }
5368}
5369
5370pub trait ResultExt {
5371    type Ok;
5372
5373    fn trace_err(self) -> Option<Self::Ok>;
5374}
5375
5376impl<T, E> ResultExt for Result<T, E>
5377where
5378    E: std::fmt::Debug,
5379{
5380    type Ok = T;
5381
5382    #[track_caller]
5383    fn trace_err(self) -> Option<T> {
5384        match self {
5385            Ok(value) => Some(value),
5386            Err(error) => {
5387                tracing::error!("{:?}", error);
5388                None
5389            }
5390        }
5391    }
5392}