rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use sha2::Digest;
  40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  41
  42use futures::{
  43    channel::oneshot,
  44    future::{self, BoxFuture},
  45    stream::FuturesUnordered,
  46    FutureExt, SinkExt, StreamExt, TryStreamExt,
  47};
  48use http_client::IsahcHttpClient;
  49use prometheus::{register_int_gauge, IntGauge};
  50use rpc::{
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  56};
  57use semantic_version::SemanticVersion;
  58use serde::{Serialize, Serializer};
  59use std::{
  60    any::TypeId,
  61    future::Future,
  62    marker::PhantomData,
  63    mem,
  64    net::SocketAddr,
  65    ops::{Deref, DerefMut},
  66    rc::Rc,
  67    sync::{
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69        Arc, OnceLock,
  70    },
  71    time::{Duration, Instant},
  72};
  73use time::OffsetDateTime;
  74use tokio::sync::{watch, MutexGuard, Semaphore};
  75use tower::ServiceBuilder;
  76use tracing::{
  77    field::{self},
  78    info_span, instrument, Instrument,
  79};
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107#[derive(Clone, Debug)]
 108pub enum Principal {
 109    User(User),
 110    Impersonated { user: User, admin: User },
 111    DevServer(dev_server::Model),
 112}
 113
 114impl Principal {
 115    fn update_span(&self, span: &tracing::Span) {
 116        match &self {
 117            Principal::User(user) => {
 118                span.record("user_id", user.id.0);
 119                span.record("login", &user.github_login);
 120            }
 121            Principal::Impersonated { user, admin } => {
 122                span.record("user_id", user.id.0);
 123                span.record("login", &user.github_login);
 124                span.record("impersonator", &admin.github_login);
 125            }
 126            Principal::DevServer(dev_server) => {
 127                span.record("dev_server_id", dev_server.id.0);
 128            }
 129        }
 130    }
 131}
 132
 133#[derive(Clone)]
 134struct Session {
 135    principal: Principal,
 136    connection_id: ConnectionId,
 137    db: Arc<tokio::sync::Mutex<DbHandle>>,
 138    peer: Arc<Peer>,
 139    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 140    app_state: Arc<AppState>,
 141    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 142    http_client: Arc<IsahcHttpClient>,
 143    /// The GeoIP country code for the user.
 144    #[allow(unused)]
 145    geoip_country_code: Option<String>,
 146    _executor: Executor,
 147}
 148
 149impl Session {
 150    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 151        #[cfg(test)]
 152        tokio::task::yield_now().await;
 153        let guard = self.db.lock().await;
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        guard
 157    }
 158
 159    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.connection_pool.lock();
 163        ConnectionPoolGuard {
 164            guard,
 165            _not_send: PhantomData,
 166        }
 167    }
 168
 169    fn for_user(self) -> Option<UserSession> {
 170        UserSession::new(self)
 171    }
 172
 173    fn for_dev_server(self) -> Option<DevServerSession> {
 174        DevServerSession::new(self)
 175    }
 176
 177    fn user_id(&self) -> Option<UserId> {
 178        match &self.principal {
 179            Principal::User(user) => Some(user.id),
 180            Principal::Impersonated { user, .. } => Some(user.id),
 181            Principal::DevServer(_) => None,
 182        }
 183    }
 184
 185    fn is_staff(&self) -> bool {
 186        match &self.principal {
 187            Principal::User(user) => user.admin,
 188            Principal::Impersonated { .. } => true,
 189            Principal::DevServer(_) => false,
 190        }
 191    }
 192
 193    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 194        if self.is_staff() {
 195            return Ok(proto::Plan::ZedPro);
 196        }
 197
 198        let Some(user_id) = self.user_id() else {
 199            return Ok(proto::Plan::Free);
 200        };
 201
 202        if db.has_active_billing_subscription(user_id).await? {
 203            Ok(proto::Plan::ZedPro)
 204        } else {
 205            Ok(proto::Plan::Free)
 206        }
 207    }
 208
 209    fn dev_server_id(&self) -> Option<DevServerId> {
 210        match &self.principal {
 211            Principal::User(_) | Principal::Impersonated { .. } => None,
 212            Principal::DevServer(dev_server) => Some(dev_server.id),
 213        }
 214    }
 215
 216    fn principal_id(&self) -> PrincipalId {
 217        match &self.principal {
 218            Principal::User(user) => PrincipalId::UserId(user.id),
 219            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 220            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 221        }
 222    }
 223}
 224
 225impl Debug for Session {
 226    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 227        let mut result = f.debug_struct("Session");
 228        match &self.principal {
 229            Principal::User(user) => {
 230                result.field("user", &user.github_login);
 231            }
 232            Principal::Impersonated { user, admin } => {
 233                result.field("user", &user.github_login);
 234                result.field("impersonator", &admin.github_login);
 235            }
 236            Principal::DevServer(dev_server) => {
 237                result.field("dev_server", &dev_server.id);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct UserSession(Session);
 245
 246impl UserSession {
 247    pub fn new(s: Session) -> Option<Self> {
 248        s.user_id().map(|_| UserSession(s))
 249    }
 250    pub fn user_id(&self) -> UserId {
 251        self.0.user_id().unwrap()
 252    }
 253
 254    pub fn email(&self) -> Option<String> {
 255        match &self.0.principal {
 256            Principal::User(user) => user.email_address.clone(),
 257            Principal::Impersonated { user, .. } => user.email_address.clone(),
 258            Principal::DevServer(..) => None,
 259        }
 260    }
 261}
 262
 263impl Deref for UserSession {
 264    type Target = Session;
 265
 266    fn deref(&self) -> &Self::Target {
 267        &self.0
 268    }
 269}
 270impl DerefMut for UserSession {
 271    fn deref_mut(&mut self) -> &mut Self::Target {
 272        &mut self.0
 273    }
 274}
 275
 276struct DevServerSession(Session);
 277
 278impl DevServerSession {
 279    pub fn new(s: Session) -> Option<Self> {
 280        s.dev_server_id().map(|_| DevServerSession(s))
 281    }
 282    pub fn dev_server_id(&self) -> DevServerId {
 283        self.0.dev_server_id().unwrap()
 284    }
 285
 286    fn dev_server(&self) -> &dev_server::Model {
 287        match &self.0.principal {
 288            Principal::DevServer(dev_server) => dev_server,
 289            _ => unreachable!(),
 290        }
 291    }
 292}
 293
 294impl Deref for DevServerSession {
 295    type Target = Session;
 296
 297    fn deref(&self) -> &Self::Target {
 298        &self.0
 299    }
 300}
 301impl DerefMut for DevServerSession {
 302    fn deref_mut(&mut self) -> &mut Self::Target {
 303        &mut self.0
 304    }
 305}
 306
 307fn user_handler<M: RequestMessage, Fut>(
 308    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 309) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 310where
 311    Fut: Send + Future<Output = Result<()>>,
 312{
 313    let handler = Arc::new(handler);
 314    move |message, response, session| {
 315        let handler = handler.clone();
 316        Box::pin(async move {
 317            if let Some(user_session) = session.for_user() {
 318                Ok(handler(message, response, user_session).await?)
 319            } else {
 320                Err(Error::Internal(anyhow!(
 321                    "must be a user to call {}",
 322                    M::NAME
 323                )))
 324            }
 325        })
 326    }
 327}
 328
 329fn dev_server_handler<M: RequestMessage, Fut>(
 330    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 331) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 332where
 333    Fut: Send + Future<Output = Result<()>>,
 334{
 335    let handler = Arc::new(handler);
 336    move |message, response, session| {
 337        let handler = handler.clone();
 338        Box::pin(async move {
 339            if let Some(dev_server_session) = session.for_dev_server() {
 340                Ok(handler(message, response, dev_server_session).await?)
 341            } else {
 342                Err(Error::Internal(anyhow!(
 343                    "must be a dev server to call {}",
 344                    M::NAME
 345                )))
 346            }
 347        })
 348    }
 349}
 350
 351fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 352    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 353) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 354where
 355    InnertRetFut: Send + Future<Output = Result<()>>,
 356{
 357    let handler = Arc::new(handler);
 358    move |message, session| {
 359        let handler = handler.clone();
 360        Box::pin(async move {
 361            if let Some(user_session) = session.for_user() {
 362                Ok(handler(message, user_session).await?)
 363            } else {
 364                Err(Error::Internal(anyhow!(
 365                    "must be a user to call {}",
 366                    M::NAME
 367                )))
 368            }
 369        })
 370    }
 371}
 372
 373struct DbHandle(Arc<Database>);
 374
 375impl Deref for DbHandle {
 376    type Target = Database;
 377
 378    fn deref(&self) -> &Self::Target {
 379        self.0.as_ref()
 380    }
 381}
 382
 383pub struct Server {
 384    id: parking_lot::Mutex<ServerId>,
 385    peer: Arc<Peer>,
 386    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 387    app_state: Arc<AppState>,
 388    handlers: HashMap<TypeId, MessageHandler>,
 389    teardown: watch::Sender<bool>,
 390}
 391
 392pub(crate) struct ConnectionPoolGuard<'a> {
 393    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 394    _not_send: PhantomData<Rc<()>>,
 395}
 396
 397#[derive(Serialize)]
 398pub struct ServerSnapshot<'a> {
 399    peer: &'a Peer,
 400    #[serde(serialize_with = "serialize_deref")]
 401    connection_pool: ConnectionPoolGuard<'a>,
 402}
 403
 404pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 405where
 406    S: Serializer,
 407    T: Deref<Target = U>,
 408    U: Serialize,
 409{
 410    Serialize::serialize(value.deref(), serializer)
 411}
 412
 413impl Server {
 414    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 415        let mut server = Self {
 416            id: parking_lot::Mutex::new(id),
 417            peer: Peer::new(id.0 as u32),
 418            app_state: app_state.clone(),
 419            connection_pool: Default::default(),
 420            handlers: Default::default(),
 421            teardown: watch::channel(false).0,
 422        };
 423
 424        server
 425            .add_request_handler(ping)
 426            .add_request_handler(user_handler(create_room))
 427            .add_request_handler(user_handler(join_room))
 428            .add_request_handler(user_handler(rejoin_room))
 429            .add_request_handler(user_handler(leave_room))
 430            .add_request_handler(user_handler(set_room_participant_role))
 431            .add_request_handler(user_handler(call))
 432            .add_request_handler(user_handler(cancel_call))
 433            .add_message_handler(user_message_handler(decline_call))
 434            .add_request_handler(user_handler(update_participant_location))
 435            .add_request_handler(user_handler(share_project))
 436            .add_message_handler(unshare_project)
 437            .add_request_handler(user_handler(join_project))
 438            .add_request_handler(user_handler(join_hosted_project))
 439            .add_request_handler(user_handler(rejoin_dev_server_projects))
 440            .add_request_handler(user_handler(create_dev_server_project))
 441            .add_request_handler(user_handler(update_dev_server_project))
 442            .add_request_handler(user_handler(delete_dev_server_project))
 443            .add_request_handler(user_handler(create_dev_server))
 444            .add_request_handler(user_handler(regenerate_dev_server_token))
 445            .add_request_handler(user_handler(rename_dev_server))
 446            .add_request_handler(user_handler(delete_dev_server))
 447            .add_request_handler(user_handler(list_remote_directory))
 448            .add_request_handler(dev_server_handler(share_dev_server_project))
 449            .add_request_handler(dev_server_handler(shutdown_dev_server))
 450            .add_request_handler(dev_server_handler(reconnect_dev_server))
 451            .add_message_handler(user_message_handler(leave_project))
 452            .add_request_handler(update_project)
 453            .add_request_handler(update_worktree)
 454            .add_message_handler(start_language_server)
 455            .add_message_handler(update_language_server)
 456            .add_message_handler(update_diagnostic_summary)
 457            .add_message_handler(update_worktree_settings)
 458            .add_request_handler(user_handler(
 459                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_project_request_for_owner::<proto::TaskTemplates>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetHover>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::GetDefinition>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetTypeDefinition>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetReferences>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::SearchProject>,
 478            ))
 479            .add_request_handler(user_handler(forward_find_search_candidates_request))
 480            .add_request_handler(user_handler(
 481                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 482            ))
 483            .add_request_handler(user_handler(
 484                forward_read_only_project_request::<proto::GetProjectSymbols>,
 485            ))
 486            .add_request_handler(user_handler(
 487                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 488            ))
 489            .add_request_handler(user_handler(
 490                forward_read_only_project_request::<proto::OpenBufferById>,
 491            ))
 492            .add_request_handler(user_handler(
 493                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 494            ))
 495            .add_request_handler(user_handler(
 496                forward_read_only_project_request::<proto::InlayHints>,
 497            ))
 498            .add_request_handler(user_handler(
 499                forward_read_only_project_request::<proto::ResolveInlayHint>,
 500            ))
 501            .add_request_handler(user_handler(
 502                forward_read_only_project_request::<proto::OpenBufferByPath>,
 503            ))
 504            .add_request_handler(user_handler(
 505                forward_mutating_project_request::<proto::GetCompletions>,
 506            ))
 507            .add_request_handler(user_handler(
 508                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 509            ))
 510            .add_request_handler(user_handler(
 511                forward_mutating_project_request::<proto::OpenNewBuffer>,
 512            ))
 513            .add_request_handler(user_handler(
 514                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 515            ))
 516            .add_request_handler(user_handler(
 517                forward_mutating_project_request::<proto::GetCodeActions>,
 518            ))
 519            .add_request_handler(user_handler(
 520                forward_mutating_project_request::<proto::ApplyCodeAction>,
 521            ))
 522            .add_request_handler(user_handler(
 523                forward_mutating_project_request::<proto::PrepareRename>,
 524            ))
 525            .add_request_handler(user_handler(
 526                forward_mutating_project_request::<proto::PerformRename>,
 527            ))
 528            .add_request_handler(user_handler(
 529                forward_mutating_project_request::<proto::ReloadBuffers>,
 530            ))
 531            .add_request_handler(user_handler(
 532                forward_mutating_project_request::<proto::FormatBuffers>,
 533            ))
 534            .add_request_handler(user_handler(
 535                forward_mutating_project_request::<proto::CreateProjectEntry>,
 536            ))
 537            .add_request_handler(user_handler(
 538                forward_mutating_project_request::<proto::RenameProjectEntry>,
 539            ))
 540            .add_request_handler(user_handler(
 541                forward_mutating_project_request::<proto::CopyProjectEntry>,
 542            ))
 543            .add_request_handler(user_handler(
 544                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 545            ))
 546            .add_request_handler(user_handler(
 547                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 548            ))
 549            .add_request_handler(user_handler(
 550                forward_mutating_project_request::<proto::OnTypeFormatting>,
 551            ))
 552            .add_request_handler(user_handler(
 553                forward_mutating_project_request::<proto::SaveBuffer>,
 554            ))
 555            .add_request_handler(user_handler(
 556                forward_mutating_project_request::<proto::BlameBuffer>,
 557            ))
 558            .add_request_handler(user_handler(
 559                forward_mutating_project_request::<proto::MultiLspQuery>,
 560            ))
 561            .add_request_handler(user_handler(
 562                forward_mutating_project_request::<proto::RestartLanguageServers>,
 563            ))
 564            .add_request_handler(user_handler(
 565                forward_mutating_project_request::<proto::LinkedEditingRange>,
 566            ))
 567            .add_message_handler(create_buffer_for_peer)
 568            .add_request_handler(update_buffer)
 569            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 571            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 572            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 573            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 574            .add_request_handler(get_users)
 575            .add_request_handler(user_handler(fuzzy_search_users))
 576            .add_request_handler(user_handler(request_contact))
 577            .add_request_handler(user_handler(remove_contact))
 578            .add_request_handler(user_handler(respond_to_contact_request))
 579            .add_message_handler(subscribe_to_channels)
 580            .add_request_handler(user_handler(create_channel))
 581            .add_request_handler(user_handler(delete_channel))
 582            .add_request_handler(user_handler(invite_channel_member))
 583            .add_request_handler(user_handler(remove_channel_member))
 584            .add_request_handler(user_handler(set_channel_member_role))
 585            .add_request_handler(user_handler(set_channel_visibility))
 586            .add_request_handler(user_handler(rename_channel))
 587            .add_request_handler(user_handler(join_channel_buffer))
 588            .add_request_handler(user_handler(leave_channel_buffer))
 589            .add_message_handler(user_message_handler(update_channel_buffer))
 590            .add_request_handler(user_handler(rejoin_channel_buffers))
 591            .add_request_handler(user_handler(get_channel_members))
 592            .add_request_handler(user_handler(respond_to_channel_invite))
 593            .add_request_handler(user_handler(join_channel))
 594            .add_request_handler(user_handler(join_channel_chat))
 595            .add_message_handler(user_message_handler(leave_channel_chat))
 596            .add_request_handler(user_handler(send_channel_message))
 597            .add_request_handler(user_handler(remove_channel_message))
 598            .add_request_handler(user_handler(update_channel_message))
 599            .add_request_handler(user_handler(get_channel_messages))
 600            .add_request_handler(user_handler(get_channel_messages_by_id))
 601            .add_request_handler(user_handler(get_notifications))
 602            .add_request_handler(user_handler(mark_notification_as_read))
 603            .add_request_handler(user_handler(move_channel))
 604            .add_request_handler(user_handler(follow))
 605            .add_message_handler(user_message_handler(unfollow))
 606            .add_message_handler(user_message_handler(update_followers))
 607            .add_request_handler(user_handler(get_private_user_info))
 608            .add_request_handler(user_handler(get_llm_api_token))
 609            .add_request_handler(user_handler(accept_terms_of_service))
 610            .add_message_handler(user_message_handler(acknowledge_channel_message))
 611            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 612            .add_request_handler(user_handler(get_supermaven_api_key))
 613            .add_request_handler(user_handler(
 614                forward_mutating_project_request::<proto::OpenContext>,
 615            ))
 616            .add_request_handler(user_handler(
 617                forward_mutating_project_request::<proto::CreateContext>,
 618            ))
 619            .add_request_handler(user_handler(
 620                forward_mutating_project_request::<proto::SynchronizeContexts>,
 621            ))
 622            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 623            .add_message_handler(update_context)
 624            .add_request_handler({
 625                let app_state = app_state.clone();
 626                move |request, response, session| {
 627                    let app_state = app_state.clone();
 628                    async move {
 629                        count_language_model_tokens(request, response, session, &app_state.config)
 630                            .await
 631                    }
 632                }
 633            })
 634            .add_request_handler({
 635                user_handler(move |request, response, session| {
 636                    get_cached_embeddings(request, response, session)
 637                })
 638            })
 639            .add_request_handler({
 640                let app_state = app_state.clone();
 641                user_handler(move |request, response, session| {
 642                    compute_embeddings(
 643                        request,
 644                        response,
 645                        session,
 646                        app_state.config.openai_api_key.clone(),
 647                    )
 648                })
 649            });
 650
 651        Arc::new(server)
 652    }
 653
 654    pub async fn start(&self) -> Result<()> {
 655        let server_id = *self.id.lock();
 656        let app_state = self.app_state.clone();
 657        let peer = self.peer.clone();
 658        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 659        let pool = self.connection_pool.clone();
 660        let live_kit_client = self.app_state.live_kit_client.clone();
 661
 662        let span = info_span!("start server");
 663        self.app_state.executor.spawn_detached(
 664            async move {
 665                tracing::info!("waiting for cleanup timeout");
 666                timeout.await;
 667                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 668                if let Some((room_ids, channel_ids)) = app_state
 669                    .db
 670                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 671                    .await
 672                    .trace_err()
 673                {
 674                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 675                    tracing::info!(
 676                        stale_channel_buffer_count = channel_ids.len(),
 677                        "retrieved stale channel buffers"
 678                    );
 679
 680                    for channel_id in channel_ids {
 681                        if let Some(refreshed_channel_buffer) = app_state
 682                            .db
 683                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 684                            .await
 685                            .trace_err()
 686                        {
 687                            for connection_id in refreshed_channel_buffer.connection_ids {
 688                                peer.send(
 689                                    connection_id,
 690                                    proto::UpdateChannelBufferCollaborators {
 691                                        channel_id: channel_id.to_proto(),
 692                                        collaborators: refreshed_channel_buffer
 693                                            .collaborators
 694                                            .clone(),
 695                                    },
 696                                )
 697                                .trace_err();
 698                            }
 699                        }
 700                    }
 701
 702                    for room_id in room_ids {
 703                        let mut contacts_to_update = HashSet::default();
 704                        let mut canceled_calls_to_user_ids = Vec::new();
 705                        let mut live_kit_room = String::new();
 706                        let mut delete_live_kit_room = false;
 707
 708                        if let Some(mut refreshed_room) = app_state
 709                            .db
 710                            .clear_stale_room_participants(room_id, server_id)
 711                            .await
 712                            .trace_err()
 713                        {
 714                            tracing::info!(
 715                                room_id = room_id.0,
 716                                new_participant_count = refreshed_room.room.participants.len(),
 717                                "refreshed room"
 718                            );
 719                            room_updated(&refreshed_room.room, &peer);
 720                            if let Some(channel) = refreshed_room.channel.as_ref() {
 721                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 722                            }
 723                            contacts_to_update
 724                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 725                            contacts_to_update
 726                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 727                            canceled_calls_to_user_ids =
 728                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 729                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 730                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 731                        }
 732
 733                        {
 734                            let pool = pool.lock();
 735                            for canceled_user_id in canceled_calls_to_user_ids {
 736                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 737                                    peer.send(
 738                                        connection_id,
 739                                        proto::CallCanceled {
 740                                            room_id: room_id.to_proto(),
 741                                        },
 742                                    )
 743                                    .trace_err();
 744                                }
 745                            }
 746                        }
 747
 748                        for user_id in contacts_to_update {
 749                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 750                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 751                            if let Some((busy, contacts)) = busy.zip(contacts) {
 752                                let pool = pool.lock();
 753                                let updated_contact = contact_for_user(user_id, busy, &pool);
 754                                for contact in contacts {
 755                                    if let db::Contact::Accepted {
 756                                        user_id: contact_user_id,
 757                                        ..
 758                                    } = contact
 759                                    {
 760                                        for contact_conn_id in
 761                                            pool.user_connection_ids(contact_user_id)
 762                                        {
 763                                            peer.send(
 764                                                contact_conn_id,
 765                                                proto::UpdateContacts {
 766                                                    contacts: vec![updated_contact.clone()],
 767                                                    remove_contacts: Default::default(),
 768                                                    incoming_requests: Default::default(),
 769                                                    remove_incoming_requests: Default::default(),
 770                                                    outgoing_requests: Default::default(),
 771                                                    remove_outgoing_requests: Default::default(),
 772                                                },
 773                                            )
 774                                            .trace_err();
 775                                        }
 776                                    }
 777                                }
 778                            }
 779                        }
 780
 781                        if let Some(live_kit) = live_kit_client.as_ref() {
 782                            if delete_live_kit_room {
 783                                live_kit.delete_room(live_kit_room).await.trace_err();
 784                            }
 785                        }
 786                    }
 787                }
 788
 789                app_state
 790                    .db
 791                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 792                    .await
 793                    .trace_err();
 794            }
 795            .instrument(span),
 796        );
 797        Ok(())
 798    }
 799
 800    pub fn teardown(&self) {
 801        self.peer.teardown();
 802        self.connection_pool.lock().reset();
 803        let _ = self.teardown.send(true);
 804    }
 805
 806    #[cfg(test)]
 807    pub fn reset(&self, id: ServerId) {
 808        self.teardown();
 809        *self.id.lock() = id;
 810        self.peer.reset(id.0 as u32);
 811        let _ = self.teardown.send(false);
 812    }
 813
 814    #[cfg(test)]
 815    pub fn id(&self) -> ServerId {
 816        *self.id.lock()
 817    }
 818
 819    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 820    where
 821        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 822        Fut: 'static + Send + Future<Output = Result<()>>,
 823        M: EnvelopedMessage,
 824    {
 825        let prev_handler = self.handlers.insert(
 826            TypeId::of::<M>(),
 827            Box::new(move |envelope, session| {
 828                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 829                let received_at = envelope.received_at;
 830                tracing::info!("message received");
 831                let start_time = Instant::now();
 832                let future = (handler)(*envelope, session);
 833                async move {
 834                    let result = future.await;
 835                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 836                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 837                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 838                    let payload_type = M::NAME;
 839
 840                    match result {
 841                        Err(error) => {
 842                            tracing::error!(
 843                                ?error,
 844                                total_duration_ms,
 845                                processing_duration_ms,
 846                                queue_duration_ms,
 847                                payload_type,
 848                                "error handling message"
 849                            )
 850                        }
 851                        Ok(()) => tracing::info!(
 852                            total_duration_ms,
 853                            processing_duration_ms,
 854                            queue_duration_ms,
 855                            "finished handling message"
 856                        ),
 857                    }
 858                }
 859                .boxed()
 860            }),
 861        );
 862        if prev_handler.is_some() {
 863            panic!("registered a handler for the same message twice");
 864        }
 865        self
 866    }
 867
 868    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 869    where
 870        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 871        Fut: 'static + Send + Future<Output = Result<()>>,
 872        M: EnvelopedMessage,
 873    {
 874        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 875        self
 876    }
 877
 878    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 879    where
 880        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 881        Fut: Send + Future<Output = Result<()>>,
 882        M: RequestMessage,
 883    {
 884        let handler = Arc::new(handler);
 885        self.add_handler(move |envelope, session| {
 886            let receipt = envelope.receipt();
 887            let handler = handler.clone();
 888            async move {
 889                let peer = session.peer.clone();
 890                let responded = Arc::new(AtomicBool::default());
 891                let response = Response {
 892                    peer: peer.clone(),
 893                    responded: responded.clone(),
 894                    receipt,
 895                };
 896                match (handler)(envelope.payload, response, session).await {
 897                    Ok(()) => {
 898                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 899                            Ok(())
 900                        } else {
 901                            Err(anyhow!("handler did not send a response"))?
 902                        }
 903                    }
 904                    Err(error) => {
 905                        let proto_err = match &error {
 906                            Error::Internal(err) => err.to_proto(),
 907                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 908                        };
 909                        peer.respond_with_error(receipt, proto_err)?;
 910                        Err(error)
 911                    }
 912                }
 913            }
 914        })
 915    }
 916
 917    #[allow(clippy::too_many_arguments)]
 918    pub fn handle_connection(
 919        self: &Arc<Self>,
 920        connection: Connection,
 921        address: String,
 922        principal: Principal,
 923        zed_version: ZedVersion,
 924        geoip_country_code: Option<String>,
 925        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 926        executor: Executor,
 927    ) -> impl Future<Output = ()> {
 928        let this = self.clone();
 929        let span = info_span!("handle connection", %address,
 930            connection_id=field::Empty,
 931            user_id=field::Empty,
 932            login=field::Empty,
 933            impersonator=field::Empty,
 934            dev_server_id=field::Empty,
 935            geoip_country_code=field::Empty
 936        );
 937        principal.update_span(&span);
 938        if let Some(country_code) = geoip_country_code.as_ref() {
 939            span.record("geoip_country_code", country_code);
 940        }
 941
 942        let mut teardown = self.teardown.subscribe();
 943        async move {
 944            if *teardown.borrow() {
 945                tracing::error!("server is tearing down");
 946                return
 947            }
 948            let (connection_id, handle_io, mut incoming_rx) = this
 949                .peer
 950                .add_connection(connection, {
 951                    let executor = executor.clone();
 952                    move |duration| executor.sleep(duration)
 953                });
 954            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 955
 956            tracing::info!("connection opened");
 957
 958            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 959            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 960                Ok(http_client) => Arc::new(http_client),
 961                Err(error) => {
 962                    tracing::error!(?error, "failed to create HTTP client");
 963                    return;
 964                }
 965            };
 966
 967            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 968                    supermaven_admin_api_key.to_string(),
 969                    http_client.clone(),
 970                )));
 971
 972            let session = Session {
 973                principal: principal.clone(),
 974                connection_id,
 975                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 976                peer: this.peer.clone(),
 977                connection_pool: this.connection_pool.clone(),
 978                app_state: this.app_state.clone(),
 979                http_client,
 980                geoip_country_code,
 981                _executor: executor.clone(),
 982                supermaven_client,
 983            };
 984
 985            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 986                tracing::error!(?error, "failed to send initial client update");
 987                return;
 988            }
 989
 990            let handle_io = handle_io.fuse();
 991            futures::pin_mut!(handle_io);
 992
 993            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 994            // This prevents deadlocks when e.g., client A performs a request to client B and
 995            // client B performs a request to client A. If both clients stop processing further
 996            // messages until their respective request completes, they won't have a chance to
 997            // respond to the other client's request and cause a deadlock.
 998            //
 999            // This arrangement ensures we will attempt to process earlier messages first, but fall
1000            // back to processing messages arrived later in the spirit of making progress.
1001            let mut foreground_message_handlers = FuturesUnordered::new();
1002            let concurrent_handlers = Arc::new(Semaphore::new(256));
1003            loop {
1004                let next_message = async {
1005                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1006                    let message = incoming_rx.next().await;
1007                    (permit, message)
1008                }.fuse();
1009                futures::pin_mut!(next_message);
1010                futures::select_biased! {
1011                    _ = teardown.changed().fuse() => return,
1012                    result = handle_io => {
1013                        if let Err(error) = result {
1014                            tracing::error!(?error, "error handling I/O");
1015                        }
1016                        break;
1017                    }
1018                    _ = foreground_message_handlers.next() => {}
1019                    next_message = next_message => {
1020                        let (permit, message) = next_message;
1021                        if let Some(message) = message {
1022                            let type_name = message.payload_type_name();
1023                            // note: we copy all the fields from the parent span so we can query them in the logs.
1024                            // (https://github.com/tokio-rs/tracing/issues/2670).
1025                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1026                                user_id=field::Empty,
1027                                login=field::Empty,
1028                                impersonator=field::Empty,
1029                                dev_server_id=field::Empty
1030                            );
1031                            principal.update_span(&span);
1032                            let span_enter = span.enter();
1033                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1034                                let is_background = message.is_background();
1035                                let handle_message = (handler)(message, session.clone());
1036                                drop(span_enter);
1037
1038                                let handle_message = async move {
1039                                    handle_message.await;
1040                                    drop(permit);
1041                                }.instrument(span);
1042                                if is_background {
1043                                    executor.spawn_detached(handle_message);
1044                                } else {
1045                                    foreground_message_handlers.push(handle_message);
1046                                }
1047                            } else {
1048                                tracing::error!("no message handler");
1049                            }
1050                        } else {
1051                            tracing::info!("connection closed");
1052                            break;
1053                        }
1054                    }
1055                }
1056            }
1057
1058            drop(foreground_message_handlers);
1059            tracing::info!("signing out");
1060            if let Err(error) = connection_lost(session, teardown, executor).await {
1061                tracing::error!(?error, "error signing out");
1062            }
1063
1064        }.instrument(span)
1065    }
1066
1067    async fn send_initial_client_update(
1068        &self,
1069        connection_id: ConnectionId,
1070        principal: &Principal,
1071        zed_version: ZedVersion,
1072        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1073        session: &Session,
1074    ) -> Result<()> {
1075        self.peer.send(
1076            connection_id,
1077            proto::Hello {
1078                peer_id: Some(connection_id.into()),
1079            },
1080        )?;
1081        tracing::info!("sent hello message");
1082        if let Some(send_connection_id) = send_connection_id.take() {
1083            let _ = send_connection_id.send(connection_id);
1084        }
1085
1086        match principal {
1087            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1088                if !user.connected_once {
1089                    self.peer.send(connection_id, proto::ShowContacts {})?;
1090                    self.app_state
1091                        .db
1092                        .set_user_connected_once(user.id, true)
1093                        .await?;
1094                }
1095
1096                update_user_plan(user.id, session).await?;
1097
1098                let (contacts, dev_server_projects) = future::try_join(
1099                    self.app_state.db.get_contacts(user.id),
1100                    self.app_state.db.dev_server_projects_update(user.id),
1101                )
1102                .await?;
1103
1104                {
1105                    let mut pool = self.connection_pool.lock();
1106                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1107                    self.peer.send(
1108                        connection_id,
1109                        build_initial_contacts_update(contacts, &pool),
1110                    )?;
1111                }
1112
1113                if should_auto_subscribe_to_channels(zed_version) {
1114                    subscribe_user_to_channels(user.id, session).await?;
1115                }
1116
1117                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1118
1119                if let Some(incoming_call) =
1120                    self.app_state.db.incoming_call_for_user(user.id).await?
1121                {
1122                    self.peer.send(connection_id, incoming_call)?;
1123                }
1124
1125                update_user_contacts(user.id, session).await?;
1126            }
1127            Principal::DevServer(dev_server) => {
1128                {
1129                    let mut pool = self.connection_pool.lock();
1130                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1131                    {
1132                        self.peer.send(
1133                            stale_connection_id,
1134                            proto::ShutdownDevServer {
1135                                reason: Some(
1136                                    "another dev server connected with the same token".to_string(),
1137                                ),
1138                            },
1139                        )?;
1140                        pool.remove_connection(stale_connection_id)?;
1141                    };
1142                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1143                }
1144
1145                let projects = self
1146                    .app_state
1147                    .db
1148                    .get_projects_for_dev_server(dev_server.id)
1149                    .await?;
1150                self.peer
1151                    .send(connection_id, proto::DevServerInstructions { projects })?;
1152
1153                let status = self
1154                    .app_state
1155                    .db
1156                    .dev_server_projects_update(dev_server.user_id)
1157                    .await?;
1158                send_dev_server_projects_update(dev_server.user_id, status, session).await;
1159            }
1160        }
1161
1162        Ok(())
1163    }
1164
1165    pub async fn invite_code_redeemed(
1166        self: &Arc<Self>,
1167        inviter_id: UserId,
1168        invitee_id: UserId,
1169    ) -> Result<()> {
1170        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1171            if let Some(code) = &user.invite_code {
1172                let pool = self.connection_pool.lock();
1173                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1174                for connection_id in pool.user_connection_ids(inviter_id) {
1175                    self.peer.send(
1176                        connection_id,
1177                        proto::UpdateContacts {
1178                            contacts: vec![invitee_contact.clone()],
1179                            ..Default::default()
1180                        },
1181                    )?;
1182                    self.peer.send(
1183                        connection_id,
1184                        proto::UpdateInviteInfo {
1185                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1186                            count: user.invite_count as u32,
1187                        },
1188                    )?;
1189                }
1190            }
1191        }
1192        Ok(())
1193    }
1194
1195    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1196        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1197            if let Some(invite_code) = &user.invite_code {
1198                let pool = self.connection_pool.lock();
1199                for connection_id in pool.user_connection_ids(user_id) {
1200                    self.peer.send(
1201                        connection_id,
1202                        proto::UpdateInviteInfo {
1203                            url: format!(
1204                                "{}{}",
1205                                self.app_state.config.invite_link_prefix, invite_code
1206                            ),
1207                            count: user.invite_count as u32,
1208                        },
1209                    )?;
1210                }
1211            }
1212        }
1213        Ok(())
1214    }
1215
1216    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1217        ServerSnapshot {
1218            connection_pool: ConnectionPoolGuard {
1219                guard: self.connection_pool.lock(),
1220                _not_send: PhantomData,
1221            },
1222            peer: &self.peer,
1223        }
1224    }
1225}
1226
1227impl<'a> Deref for ConnectionPoolGuard<'a> {
1228    type Target = ConnectionPool;
1229
1230    fn deref(&self) -> &Self::Target {
1231        &self.guard
1232    }
1233}
1234
1235impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1236    fn deref_mut(&mut self) -> &mut Self::Target {
1237        &mut self.guard
1238    }
1239}
1240
1241impl<'a> Drop for ConnectionPoolGuard<'a> {
1242    fn drop(&mut self) {
1243        #[cfg(test)]
1244        self.check_invariants();
1245    }
1246}
1247
1248fn broadcast<F>(
1249    sender_id: Option<ConnectionId>,
1250    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1251    mut f: F,
1252) where
1253    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1254{
1255    for receiver_id in receiver_ids {
1256        if Some(receiver_id) != sender_id {
1257            if let Err(error) = f(receiver_id) {
1258                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1259            }
1260        }
1261    }
1262}
1263
1264pub struct ProtocolVersion(u32);
1265
1266impl Header for ProtocolVersion {
1267    fn name() -> &'static HeaderName {
1268        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1269        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1270    }
1271
1272    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1273    where
1274        Self: Sized,
1275        I: Iterator<Item = &'i axum::http::HeaderValue>,
1276    {
1277        let version = values
1278            .next()
1279            .ok_or_else(axum::headers::Error::invalid)?
1280            .to_str()
1281            .map_err(|_| axum::headers::Error::invalid())?
1282            .parse()
1283            .map_err(|_| axum::headers::Error::invalid())?;
1284        Ok(Self(version))
1285    }
1286
1287    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1288        values.extend([self.0.to_string().parse().unwrap()]);
1289    }
1290}
1291
1292pub struct AppVersionHeader(SemanticVersion);
1293impl Header for AppVersionHeader {
1294    fn name() -> &'static HeaderName {
1295        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1296        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1297    }
1298
1299    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1300    where
1301        Self: Sized,
1302        I: Iterator<Item = &'i axum::http::HeaderValue>,
1303    {
1304        let version = values
1305            .next()
1306            .ok_or_else(axum::headers::Error::invalid)?
1307            .to_str()
1308            .map_err(|_| axum::headers::Error::invalid())?
1309            .parse()
1310            .map_err(|_| axum::headers::Error::invalid())?;
1311        Ok(Self(version))
1312    }
1313
1314    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1315        values.extend([self.0.to_string().parse().unwrap()]);
1316    }
1317}
1318
1319pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1320    Router::new()
1321        .route("/rpc", get(handle_websocket_request))
1322        .layer(
1323            ServiceBuilder::new()
1324                .layer(Extension(server.app_state.clone()))
1325                .layer(middleware::from_fn(auth::validate_header)),
1326        )
1327        .route("/metrics", get(handle_metrics))
1328        .layer(Extension(server))
1329}
1330
1331pub async fn handle_websocket_request(
1332    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1333    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1334    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1335    Extension(server): Extension<Arc<Server>>,
1336    Extension(principal): Extension<Principal>,
1337    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1338    ws: WebSocketUpgrade,
1339) -> axum::response::Response {
1340    if protocol_version != rpc::PROTOCOL_VERSION {
1341        return (
1342            StatusCode::UPGRADE_REQUIRED,
1343            "client must be upgraded".to_string(),
1344        )
1345            .into_response();
1346    }
1347
1348    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1349        return (
1350            StatusCode::UPGRADE_REQUIRED,
1351            "no version header found".to_string(),
1352        )
1353            .into_response();
1354    };
1355
1356    if !version.can_collaborate() {
1357        return (
1358            StatusCode::UPGRADE_REQUIRED,
1359            "client must be upgraded".to_string(),
1360        )
1361            .into_response();
1362    }
1363
1364    let socket_address = socket_address.to_string();
1365    ws.on_upgrade(move |socket| {
1366        let socket = socket
1367            .map_ok(to_tungstenite_message)
1368            .err_into()
1369            .with(|message| async move { to_axum_message(message) });
1370        let connection = Connection::new(Box::pin(socket));
1371        async move {
1372            server
1373                .handle_connection(
1374                    connection,
1375                    socket_address,
1376                    principal,
1377                    version,
1378                    country_code_header.map(|header| header.to_string()),
1379                    None,
1380                    Executor::Production,
1381                )
1382                .await;
1383        }
1384    })
1385}
1386
1387pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1388    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1389    let connections_metric = CONNECTIONS_METRIC
1390        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1391
1392    let connections = server
1393        .connection_pool
1394        .lock()
1395        .connections()
1396        .filter(|connection| !connection.admin)
1397        .count();
1398    connections_metric.set(connections as _);
1399
1400    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1401    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1402        register_int_gauge!(
1403            "shared_projects",
1404            "number of open projects with one or more guests"
1405        )
1406        .unwrap()
1407    });
1408
1409    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1410    shared_projects_metric.set(shared_projects as _);
1411
1412    let encoder = prometheus::TextEncoder::new();
1413    let metric_families = prometheus::gather();
1414    let encoded_metrics = encoder
1415        .encode_to_string(&metric_families)
1416        .map_err(|err| anyhow!("{}", err))?;
1417    Ok(encoded_metrics)
1418}
1419
1420#[instrument(err, skip(executor))]
1421async fn connection_lost(
1422    session: Session,
1423    mut teardown: watch::Receiver<bool>,
1424    executor: Executor,
1425) -> Result<()> {
1426    session.peer.disconnect(session.connection_id);
1427    session
1428        .connection_pool()
1429        .await
1430        .remove_connection(session.connection_id)?;
1431
1432    session
1433        .db()
1434        .await
1435        .connection_lost(session.connection_id)
1436        .await
1437        .trace_err();
1438
1439    futures::select_biased! {
1440        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1441            match &session.principal {
1442                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1443                    let session = session.for_user().unwrap();
1444
1445                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1446                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1447                    leave_channel_buffers_for_session(&session)
1448                        .await
1449                        .trace_err();
1450
1451                    if !session
1452                        .connection_pool()
1453                        .await
1454                        .is_user_online(session.user_id())
1455                    {
1456                        let db = session.db().await;
1457                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1458                            room_updated(&room, &session.peer);
1459                        }
1460                    }
1461
1462                    update_user_contacts(session.user_id(), &session).await?;
1463                },
1464            Principal::DevServer(_) => {
1465                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1466            },
1467        }
1468        },
1469        _ = teardown.changed().fuse() => {}
1470    }
1471
1472    Ok(())
1473}
1474
1475/// Acknowledges a ping from a client, used to keep the connection alive.
1476async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1477    response.send(proto::Ack {})?;
1478    Ok(())
1479}
1480
1481/// Creates a new room for calling (outside of channels)
1482async fn create_room(
1483    _request: proto::CreateRoom,
1484    response: Response<proto::CreateRoom>,
1485    session: UserSession,
1486) -> Result<()> {
1487    let live_kit_room = nanoid::nanoid!(30);
1488
1489    let live_kit_connection_info = util::maybe!(async {
1490        let live_kit = session.app_state.live_kit_client.as_ref();
1491        let live_kit = live_kit?;
1492        let user_id = session.user_id().to_string();
1493
1494        let token = live_kit
1495            .room_token(&live_kit_room, &user_id.to_string())
1496            .trace_err()?;
1497
1498        Some(proto::LiveKitConnectionInfo {
1499            server_url: live_kit.url().into(),
1500            token,
1501            can_publish: true,
1502        })
1503    })
1504    .await;
1505
1506    let room = session
1507        .db()
1508        .await
1509        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1510        .await?;
1511
1512    response.send(proto::CreateRoomResponse {
1513        room: Some(room.clone()),
1514        live_kit_connection_info,
1515    })?;
1516
1517    update_user_contacts(session.user_id(), &session).await?;
1518    Ok(())
1519}
1520
1521/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1522async fn join_room(
1523    request: proto::JoinRoom,
1524    response: Response<proto::JoinRoom>,
1525    session: UserSession,
1526) -> Result<()> {
1527    let room_id = RoomId::from_proto(request.id);
1528
1529    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1530
1531    if let Some(channel_id) = channel_id {
1532        return join_channel_internal(channel_id, Box::new(response), session).await;
1533    }
1534
1535    let joined_room = {
1536        let room = session
1537            .db()
1538            .await
1539            .join_room(room_id, session.user_id(), session.connection_id)
1540            .await?;
1541        room_updated(&room.room, &session.peer);
1542        room.into_inner()
1543    };
1544
1545    for connection_id in session
1546        .connection_pool()
1547        .await
1548        .user_connection_ids(session.user_id())
1549    {
1550        session
1551            .peer
1552            .send(
1553                connection_id,
1554                proto::CallCanceled {
1555                    room_id: room_id.to_proto(),
1556                },
1557            )
1558            .trace_err();
1559    }
1560
1561    let live_kit_connection_info =
1562        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1563            live_kit
1564                .room_token(
1565                    &joined_room.room.live_kit_room,
1566                    &session.user_id().to_string(),
1567                )
1568                .trace_err()
1569                .map(|token| proto::LiveKitConnectionInfo {
1570                    server_url: live_kit.url().into(),
1571                    token,
1572                    can_publish: true,
1573                })
1574        } else {
1575            None
1576        };
1577
1578    response.send(proto::JoinRoomResponse {
1579        room: Some(joined_room.room),
1580        channel_id: None,
1581        live_kit_connection_info,
1582    })?;
1583
1584    update_user_contacts(session.user_id(), &session).await?;
1585    Ok(())
1586}
1587
1588/// Rejoin room is used to reconnect to a room after connection errors.
1589async fn rejoin_room(
1590    request: proto::RejoinRoom,
1591    response: Response<proto::RejoinRoom>,
1592    session: UserSession,
1593) -> Result<()> {
1594    let room;
1595    let channel;
1596    {
1597        let mut rejoined_room = session
1598            .db()
1599            .await
1600            .rejoin_room(request, session.user_id(), session.connection_id)
1601            .await?;
1602
1603        response.send(proto::RejoinRoomResponse {
1604            room: Some(rejoined_room.room.clone()),
1605            reshared_projects: rejoined_room
1606                .reshared_projects
1607                .iter()
1608                .map(|project| proto::ResharedProject {
1609                    id: project.id.to_proto(),
1610                    collaborators: project
1611                        .collaborators
1612                        .iter()
1613                        .map(|collaborator| collaborator.to_proto())
1614                        .collect(),
1615                })
1616                .collect(),
1617            rejoined_projects: rejoined_room
1618                .rejoined_projects
1619                .iter()
1620                .map(|rejoined_project| rejoined_project.to_proto())
1621                .collect(),
1622        })?;
1623        room_updated(&rejoined_room.room, &session.peer);
1624
1625        for project in &rejoined_room.reshared_projects {
1626            for collaborator in &project.collaborators {
1627                session
1628                    .peer
1629                    .send(
1630                        collaborator.connection_id,
1631                        proto::UpdateProjectCollaborator {
1632                            project_id: project.id.to_proto(),
1633                            old_peer_id: Some(project.old_connection_id.into()),
1634                            new_peer_id: Some(session.connection_id.into()),
1635                        },
1636                    )
1637                    .trace_err();
1638            }
1639
1640            broadcast(
1641                Some(session.connection_id),
1642                project
1643                    .collaborators
1644                    .iter()
1645                    .map(|collaborator| collaborator.connection_id),
1646                |connection_id| {
1647                    session.peer.forward_send(
1648                        session.connection_id,
1649                        connection_id,
1650                        proto::UpdateProject {
1651                            project_id: project.id.to_proto(),
1652                            worktrees: project.worktrees.clone(),
1653                        },
1654                    )
1655                },
1656            );
1657        }
1658
1659        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1660
1661        let rejoined_room = rejoined_room.into_inner();
1662
1663        room = rejoined_room.room;
1664        channel = rejoined_room.channel;
1665    }
1666
1667    if let Some(channel) = channel {
1668        channel_updated(
1669            &channel,
1670            &room,
1671            &session.peer,
1672            &*session.connection_pool().await,
1673        );
1674    }
1675
1676    update_user_contacts(session.user_id(), &session).await?;
1677    Ok(())
1678}
1679
1680fn notify_rejoined_projects(
1681    rejoined_projects: &mut Vec<RejoinedProject>,
1682    session: &UserSession,
1683) -> Result<()> {
1684    for project in rejoined_projects.iter() {
1685        for collaborator in &project.collaborators {
1686            session
1687                .peer
1688                .send(
1689                    collaborator.connection_id,
1690                    proto::UpdateProjectCollaborator {
1691                        project_id: project.id.to_proto(),
1692                        old_peer_id: Some(project.old_connection_id.into()),
1693                        new_peer_id: Some(session.connection_id.into()),
1694                    },
1695                )
1696                .trace_err();
1697        }
1698    }
1699
1700    for project in rejoined_projects {
1701        for worktree in mem::take(&mut project.worktrees) {
1702            #[cfg(any(test, feature = "test-support"))]
1703            const MAX_CHUNK_SIZE: usize = 2;
1704            #[cfg(not(any(test, feature = "test-support")))]
1705            const MAX_CHUNK_SIZE: usize = 256;
1706
1707            // Stream this worktree's entries.
1708            let message = proto::UpdateWorktree {
1709                project_id: project.id.to_proto(),
1710                worktree_id: worktree.id,
1711                abs_path: worktree.abs_path.clone(),
1712                root_name: worktree.root_name,
1713                updated_entries: worktree.updated_entries,
1714                removed_entries: worktree.removed_entries,
1715                scan_id: worktree.scan_id,
1716                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1717                updated_repositories: worktree.updated_repositories,
1718                removed_repositories: worktree.removed_repositories,
1719            };
1720            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1721                session.peer.send(session.connection_id, update.clone())?;
1722            }
1723
1724            // Stream this worktree's diagnostics.
1725            for summary in worktree.diagnostic_summaries {
1726                session.peer.send(
1727                    session.connection_id,
1728                    proto::UpdateDiagnosticSummary {
1729                        project_id: project.id.to_proto(),
1730                        worktree_id: worktree.id,
1731                        summary: Some(summary),
1732                    },
1733                )?;
1734            }
1735
1736            for settings_file in worktree.settings_files {
1737                session.peer.send(
1738                    session.connection_id,
1739                    proto::UpdateWorktreeSettings {
1740                        project_id: project.id.to_proto(),
1741                        worktree_id: worktree.id,
1742                        path: settings_file.path,
1743                        content: Some(settings_file.content),
1744                    },
1745                )?;
1746            }
1747        }
1748
1749        for language_server in &project.language_servers {
1750            session.peer.send(
1751                session.connection_id,
1752                proto::UpdateLanguageServer {
1753                    project_id: project.id.to_proto(),
1754                    language_server_id: language_server.id,
1755                    variant: Some(
1756                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1757                            proto::LspDiskBasedDiagnosticsUpdated {},
1758                        ),
1759                    ),
1760                },
1761            )?;
1762        }
1763    }
1764    Ok(())
1765}
1766
1767/// leave room disconnects from the room.
1768async fn leave_room(
1769    _: proto::LeaveRoom,
1770    response: Response<proto::LeaveRoom>,
1771    session: UserSession,
1772) -> Result<()> {
1773    leave_room_for_session(&session, session.connection_id).await?;
1774    response.send(proto::Ack {})?;
1775    Ok(())
1776}
1777
1778/// Updates the permissions of someone else in the room.
1779async fn set_room_participant_role(
1780    request: proto::SetRoomParticipantRole,
1781    response: Response<proto::SetRoomParticipantRole>,
1782    session: UserSession,
1783) -> Result<()> {
1784    let user_id = UserId::from_proto(request.user_id);
1785    let role = ChannelRole::from(request.role());
1786
1787    let (live_kit_room, can_publish) = {
1788        let room = session
1789            .db()
1790            .await
1791            .set_room_participant_role(
1792                session.user_id(),
1793                RoomId::from_proto(request.room_id),
1794                user_id,
1795                role,
1796            )
1797            .await?;
1798
1799        let live_kit_room = room.live_kit_room.clone();
1800        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1801        room_updated(&room, &session.peer);
1802        (live_kit_room, can_publish)
1803    };
1804
1805    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1806        live_kit
1807            .update_participant(
1808                live_kit_room.clone(),
1809                request.user_id.to_string(),
1810                live_kit_server::proto::ParticipantPermission {
1811                    can_subscribe: true,
1812                    can_publish,
1813                    can_publish_data: can_publish,
1814                    hidden: false,
1815                    recorder: false,
1816                },
1817            )
1818            .await
1819            .trace_err();
1820    }
1821
1822    response.send(proto::Ack {})?;
1823    Ok(())
1824}
1825
1826/// Call someone else into the current room
1827async fn call(
1828    request: proto::Call,
1829    response: Response<proto::Call>,
1830    session: UserSession,
1831) -> Result<()> {
1832    let room_id = RoomId::from_proto(request.room_id);
1833    let calling_user_id = session.user_id();
1834    let calling_connection_id = session.connection_id;
1835    let called_user_id = UserId::from_proto(request.called_user_id);
1836    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1837    if !session
1838        .db()
1839        .await
1840        .has_contact(calling_user_id, called_user_id)
1841        .await?
1842    {
1843        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1844    }
1845
1846    let incoming_call = {
1847        let (room, incoming_call) = &mut *session
1848            .db()
1849            .await
1850            .call(
1851                room_id,
1852                calling_user_id,
1853                calling_connection_id,
1854                called_user_id,
1855                initial_project_id,
1856            )
1857            .await?;
1858        room_updated(room, &session.peer);
1859        mem::take(incoming_call)
1860    };
1861    update_user_contacts(called_user_id, &session).await?;
1862
1863    let mut calls = session
1864        .connection_pool()
1865        .await
1866        .user_connection_ids(called_user_id)
1867        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1868        .collect::<FuturesUnordered<_>>();
1869
1870    while let Some(call_response) = calls.next().await {
1871        match call_response.as_ref() {
1872            Ok(_) => {
1873                response.send(proto::Ack {})?;
1874                return Ok(());
1875            }
1876            Err(_) => {
1877                call_response.trace_err();
1878            }
1879        }
1880    }
1881
1882    {
1883        let room = session
1884            .db()
1885            .await
1886            .call_failed(room_id, called_user_id)
1887            .await?;
1888        room_updated(&room, &session.peer);
1889    }
1890    update_user_contacts(called_user_id, &session).await?;
1891
1892    Err(anyhow!("failed to ring user"))?
1893}
1894
1895/// Cancel an outgoing call.
1896async fn cancel_call(
1897    request: proto::CancelCall,
1898    response: Response<proto::CancelCall>,
1899    session: UserSession,
1900) -> Result<()> {
1901    let called_user_id = UserId::from_proto(request.called_user_id);
1902    let room_id = RoomId::from_proto(request.room_id);
1903    {
1904        let room = session
1905            .db()
1906            .await
1907            .cancel_call(room_id, session.connection_id, called_user_id)
1908            .await?;
1909        room_updated(&room, &session.peer);
1910    }
1911
1912    for connection_id in session
1913        .connection_pool()
1914        .await
1915        .user_connection_ids(called_user_id)
1916    {
1917        session
1918            .peer
1919            .send(
1920                connection_id,
1921                proto::CallCanceled {
1922                    room_id: room_id.to_proto(),
1923                },
1924            )
1925            .trace_err();
1926    }
1927    response.send(proto::Ack {})?;
1928
1929    update_user_contacts(called_user_id, &session).await?;
1930    Ok(())
1931}
1932
1933/// Decline an incoming call.
1934async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1935    let room_id = RoomId::from_proto(message.room_id);
1936    {
1937        let room = session
1938            .db()
1939            .await
1940            .decline_call(Some(room_id), session.user_id())
1941            .await?
1942            .ok_or_else(|| anyhow!("failed to decline call"))?;
1943        room_updated(&room, &session.peer);
1944    }
1945
1946    for connection_id in session
1947        .connection_pool()
1948        .await
1949        .user_connection_ids(session.user_id())
1950    {
1951        session
1952            .peer
1953            .send(
1954                connection_id,
1955                proto::CallCanceled {
1956                    room_id: room_id.to_proto(),
1957                },
1958            )
1959            .trace_err();
1960    }
1961    update_user_contacts(session.user_id(), &session).await?;
1962    Ok(())
1963}
1964
1965/// Updates other participants in the room with your current location.
1966async fn update_participant_location(
1967    request: proto::UpdateParticipantLocation,
1968    response: Response<proto::UpdateParticipantLocation>,
1969    session: UserSession,
1970) -> Result<()> {
1971    let room_id = RoomId::from_proto(request.room_id);
1972    let location = request
1973        .location
1974        .ok_or_else(|| anyhow!("invalid location"))?;
1975
1976    let db = session.db().await;
1977    let room = db
1978        .update_room_participant_location(room_id, session.connection_id, location)
1979        .await?;
1980
1981    room_updated(&room, &session.peer);
1982    response.send(proto::Ack {})?;
1983    Ok(())
1984}
1985
1986/// Share a project into the room.
1987async fn share_project(
1988    request: proto::ShareProject,
1989    response: Response<proto::ShareProject>,
1990    session: UserSession,
1991) -> Result<()> {
1992    let (project_id, room) = &*session
1993        .db()
1994        .await
1995        .share_project(
1996            RoomId::from_proto(request.room_id),
1997            session.connection_id,
1998            &request.worktrees,
1999            request
2000                .dev_server_project_id
2001                .map(DevServerProjectId::from_proto),
2002        )
2003        .await?;
2004    response.send(proto::ShareProjectResponse {
2005        project_id: project_id.to_proto(),
2006    })?;
2007    room_updated(room, &session.peer);
2008
2009    Ok(())
2010}
2011
2012/// Unshare a project from the room.
2013async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2014    let project_id = ProjectId::from_proto(message.project_id);
2015    unshare_project_internal(
2016        project_id,
2017        session.connection_id,
2018        session.user_id(),
2019        &session,
2020    )
2021    .await
2022}
2023
2024async fn unshare_project_internal(
2025    project_id: ProjectId,
2026    connection_id: ConnectionId,
2027    user_id: Option<UserId>,
2028    session: &Session,
2029) -> Result<()> {
2030    let delete = {
2031        let room_guard = session
2032            .db()
2033            .await
2034            .unshare_project(project_id, connection_id, user_id)
2035            .await?;
2036
2037        let (delete, room, guest_connection_ids) = &*room_guard;
2038
2039        let message = proto::UnshareProject {
2040            project_id: project_id.to_proto(),
2041        };
2042
2043        broadcast(
2044            Some(connection_id),
2045            guest_connection_ids.iter().copied(),
2046            |conn_id| session.peer.send(conn_id, message.clone()),
2047        );
2048        if let Some(room) = room {
2049            room_updated(room, &session.peer);
2050        }
2051
2052        *delete
2053    };
2054
2055    if delete {
2056        let db = session.db().await;
2057        db.delete_project(project_id).await?;
2058    }
2059
2060    Ok(())
2061}
2062
2063/// DevServer makes a project available online
2064async fn share_dev_server_project(
2065    request: proto::ShareDevServerProject,
2066    response: Response<proto::ShareDevServerProject>,
2067    session: DevServerSession,
2068) -> Result<()> {
2069    let (dev_server_project, user_id, status) = session
2070        .db()
2071        .await
2072        .share_dev_server_project(
2073            DevServerProjectId::from_proto(request.dev_server_project_id),
2074            session.dev_server_id(),
2075            session.connection_id,
2076            &request.worktrees,
2077        )
2078        .await?;
2079    let Some(project_id) = dev_server_project.project_id else {
2080        return Err(anyhow!("failed to share remote project"))?;
2081    };
2082
2083    send_dev_server_projects_update(user_id, status, &session).await;
2084
2085    response.send(proto::ShareProjectResponse { project_id })?;
2086
2087    Ok(())
2088}
2089
2090/// Join someone elses shared project.
2091async fn join_project(
2092    request: proto::JoinProject,
2093    response: Response<proto::JoinProject>,
2094    session: UserSession,
2095) -> Result<()> {
2096    let project_id = ProjectId::from_proto(request.project_id);
2097
2098    tracing::info!(%project_id, "join project");
2099
2100    let db = session.db().await;
2101    let (project, replica_id) = &mut *db
2102        .join_project(project_id, session.connection_id, session.user_id())
2103        .await?;
2104    drop(db);
2105    tracing::info!(%project_id, "join remote project");
2106    join_project_internal(response, session, project, replica_id)
2107}
2108
2109trait JoinProjectInternalResponse {
2110    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2111}
2112impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2113    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2114        Response::<proto::JoinProject>::send(self, result)
2115    }
2116}
2117impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2118    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2119        Response::<proto::JoinHostedProject>::send(self, result)
2120    }
2121}
2122
2123fn join_project_internal(
2124    response: impl JoinProjectInternalResponse,
2125    session: UserSession,
2126    project: &mut Project,
2127    replica_id: &ReplicaId,
2128) -> Result<()> {
2129    let collaborators = project
2130        .collaborators
2131        .iter()
2132        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2133        .map(|collaborator| collaborator.to_proto())
2134        .collect::<Vec<_>>();
2135    let project_id = project.id;
2136    let guest_user_id = session.user_id();
2137
2138    let worktrees = project
2139        .worktrees
2140        .iter()
2141        .map(|(id, worktree)| proto::WorktreeMetadata {
2142            id: *id,
2143            root_name: worktree.root_name.clone(),
2144            visible: worktree.visible,
2145            abs_path: worktree.abs_path.clone(),
2146        })
2147        .collect::<Vec<_>>();
2148
2149    let add_project_collaborator = proto::AddProjectCollaborator {
2150        project_id: project_id.to_proto(),
2151        collaborator: Some(proto::Collaborator {
2152            peer_id: Some(session.connection_id.into()),
2153            replica_id: replica_id.0 as u32,
2154            user_id: guest_user_id.to_proto(),
2155        }),
2156    };
2157
2158    for collaborator in &collaborators {
2159        session
2160            .peer
2161            .send(
2162                collaborator.peer_id.unwrap().into(),
2163                add_project_collaborator.clone(),
2164            )
2165            .trace_err();
2166    }
2167
2168    // First, we send the metadata associated with each worktree.
2169    response.send(proto::JoinProjectResponse {
2170        project_id: project.id.0 as u64,
2171        worktrees: worktrees.clone(),
2172        replica_id: replica_id.0 as u32,
2173        collaborators: collaborators.clone(),
2174        language_servers: project.language_servers.clone(),
2175        role: project.role.into(),
2176        dev_server_project_id: project
2177            .dev_server_project_id
2178            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2179    })?;
2180
2181    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2182        #[cfg(any(test, feature = "test-support"))]
2183        const MAX_CHUNK_SIZE: usize = 2;
2184        #[cfg(not(any(test, feature = "test-support")))]
2185        const MAX_CHUNK_SIZE: usize = 256;
2186
2187        // Stream this worktree's entries.
2188        let message = proto::UpdateWorktree {
2189            project_id: project_id.to_proto(),
2190            worktree_id,
2191            abs_path: worktree.abs_path.clone(),
2192            root_name: worktree.root_name,
2193            updated_entries: worktree.entries,
2194            removed_entries: Default::default(),
2195            scan_id: worktree.scan_id,
2196            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2197            updated_repositories: worktree.repository_entries.into_values().collect(),
2198            removed_repositories: Default::default(),
2199        };
2200        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2201            session.peer.send(session.connection_id, update.clone())?;
2202        }
2203
2204        // Stream this worktree's diagnostics.
2205        for summary in worktree.diagnostic_summaries {
2206            session.peer.send(
2207                session.connection_id,
2208                proto::UpdateDiagnosticSummary {
2209                    project_id: project_id.to_proto(),
2210                    worktree_id: worktree.id,
2211                    summary: Some(summary),
2212                },
2213            )?;
2214        }
2215
2216        for settings_file in worktree.settings_files {
2217            session.peer.send(
2218                session.connection_id,
2219                proto::UpdateWorktreeSettings {
2220                    project_id: project_id.to_proto(),
2221                    worktree_id: worktree.id,
2222                    path: settings_file.path,
2223                    content: Some(settings_file.content),
2224                },
2225            )?;
2226        }
2227    }
2228
2229    for language_server in &project.language_servers {
2230        session.peer.send(
2231            session.connection_id,
2232            proto::UpdateLanguageServer {
2233                project_id: project_id.to_proto(),
2234                language_server_id: language_server.id,
2235                variant: Some(
2236                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2237                        proto::LspDiskBasedDiagnosticsUpdated {},
2238                    ),
2239                ),
2240            },
2241        )?;
2242    }
2243
2244    Ok(())
2245}
2246
2247/// Leave someone elses shared project.
2248async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2249    let sender_id = session.connection_id;
2250    let project_id = ProjectId::from_proto(request.project_id);
2251    let db = session.db().await;
2252    if db.is_hosted_project(project_id).await? {
2253        let project = db.leave_hosted_project(project_id, sender_id).await?;
2254        project_left(&project, &session);
2255        return Ok(());
2256    }
2257
2258    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2259    tracing::info!(
2260        %project_id,
2261        "leave project"
2262    );
2263
2264    project_left(project, &session);
2265    if let Some(room) = room {
2266        room_updated(room, &session.peer);
2267    }
2268
2269    Ok(())
2270}
2271
2272async fn join_hosted_project(
2273    request: proto::JoinHostedProject,
2274    response: Response<proto::JoinHostedProject>,
2275    session: UserSession,
2276) -> Result<()> {
2277    let (mut project, replica_id) = session
2278        .db()
2279        .await
2280        .join_hosted_project(
2281            ProjectId(request.project_id as i32),
2282            session.user_id(),
2283            session.connection_id,
2284        )
2285        .await?;
2286
2287    join_project_internal(response, session, &mut project, &replica_id)
2288}
2289
2290async fn list_remote_directory(
2291    request: proto::ListRemoteDirectory,
2292    response: Response<proto::ListRemoteDirectory>,
2293    session: UserSession,
2294) -> Result<()> {
2295    let dev_server_id = DevServerId(request.dev_server_id as i32);
2296    let dev_server_connection_id = session
2297        .connection_pool()
2298        .await
2299        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2300
2301    session
2302        .db()
2303        .await
2304        .get_dev_server_for_user(dev_server_id, session.user_id())
2305        .await?;
2306
2307    response.send(
2308        session
2309            .peer
2310            .forward_request(session.connection_id, dev_server_connection_id, request)
2311            .await?,
2312    )?;
2313    Ok(())
2314}
2315
2316async fn update_dev_server_project(
2317    request: proto::UpdateDevServerProject,
2318    response: Response<proto::UpdateDevServerProject>,
2319    session: UserSession,
2320) -> Result<()> {
2321    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2322
2323    let (dev_server_project, update) = session
2324        .db()
2325        .await
2326        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2327        .await?;
2328
2329    let projects = session
2330        .db()
2331        .await
2332        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2333        .await?;
2334
2335    let dev_server_connection_id = session
2336        .connection_pool()
2337        .await
2338        .dev_server_connection_id_supporting(
2339            dev_server_project.dev_server_id,
2340            ZedVersion::with_list_directory(),
2341        )?;
2342
2343    session.peer.send(
2344        dev_server_connection_id,
2345        proto::DevServerInstructions { projects },
2346    )?;
2347
2348    send_dev_server_projects_update(session.user_id(), update, &session).await;
2349
2350    response.send(proto::Ack {})
2351}
2352
2353async fn create_dev_server_project(
2354    request: proto::CreateDevServerProject,
2355    response: Response<proto::CreateDevServerProject>,
2356    session: UserSession,
2357) -> Result<()> {
2358    let dev_server_id = DevServerId(request.dev_server_id as i32);
2359    let dev_server_connection_id = session
2360        .connection_pool()
2361        .await
2362        .dev_server_connection_id(dev_server_id);
2363    let Some(dev_server_connection_id) = dev_server_connection_id else {
2364        Err(ErrorCode::DevServerOffline
2365            .message("Cannot create a remote project when the dev server is offline".to_string())
2366            .anyhow())?
2367    };
2368
2369    let path = request.path.clone();
2370    //Check that the path exists on the dev server
2371    session
2372        .peer
2373        .forward_request(
2374            session.connection_id,
2375            dev_server_connection_id,
2376            proto::ValidateDevServerProjectRequest { path: path.clone() },
2377        )
2378        .await?;
2379
2380    let (dev_server_project, update) = session
2381        .db()
2382        .await
2383        .create_dev_server_project(
2384            DevServerId(request.dev_server_id as i32),
2385            &request.path,
2386            session.user_id(),
2387        )
2388        .await?;
2389
2390    let projects = session
2391        .db()
2392        .await
2393        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2394        .await?;
2395
2396    session.peer.send(
2397        dev_server_connection_id,
2398        proto::DevServerInstructions { projects },
2399    )?;
2400
2401    send_dev_server_projects_update(session.user_id(), update, &session).await;
2402
2403    response.send(proto::CreateDevServerProjectResponse {
2404        dev_server_project: Some(dev_server_project.to_proto(None)),
2405    })?;
2406    Ok(())
2407}
2408
2409async fn create_dev_server(
2410    request: proto::CreateDevServer,
2411    response: Response<proto::CreateDevServer>,
2412    session: UserSession,
2413) -> Result<()> {
2414    let access_token = auth::random_token();
2415    let hashed_access_token = auth::hash_access_token(&access_token);
2416
2417    if request.name.is_empty() {
2418        return Err(proto::ErrorCode::Forbidden
2419            .message("Dev server name cannot be empty".to_string())
2420            .anyhow())?;
2421    }
2422
2423    let (dev_server, status) = session
2424        .db()
2425        .await
2426        .create_dev_server(
2427            &request.name,
2428            request.ssh_connection_string.as_deref(),
2429            &hashed_access_token,
2430            session.user_id(),
2431        )
2432        .await?;
2433
2434    send_dev_server_projects_update(session.user_id(), status, &session).await;
2435
2436    response.send(proto::CreateDevServerResponse {
2437        dev_server_id: dev_server.id.0 as u64,
2438        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2439        name: request.name,
2440    })?;
2441    Ok(())
2442}
2443
2444async fn regenerate_dev_server_token(
2445    request: proto::RegenerateDevServerToken,
2446    response: Response<proto::RegenerateDevServerToken>,
2447    session: UserSession,
2448) -> Result<()> {
2449    let dev_server_id = DevServerId(request.dev_server_id as i32);
2450    let access_token = auth::random_token();
2451    let hashed_access_token = auth::hash_access_token(&access_token);
2452
2453    let connection_id = session
2454        .connection_pool()
2455        .await
2456        .dev_server_connection_id(dev_server_id);
2457    if let Some(connection_id) = connection_id {
2458        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2459        session.peer.send(
2460            connection_id,
2461            proto::ShutdownDevServer {
2462                reason: Some("dev server token was regenerated".to_string()),
2463            },
2464        )?;
2465        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2466    }
2467
2468    let status = session
2469        .db()
2470        .await
2471        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2472        .await?;
2473
2474    send_dev_server_projects_update(session.user_id(), status, &session).await;
2475
2476    response.send(proto::RegenerateDevServerTokenResponse {
2477        dev_server_id: dev_server_id.to_proto(),
2478        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2479    })?;
2480    Ok(())
2481}
2482
2483async fn rename_dev_server(
2484    request: proto::RenameDevServer,
2485    response: Response<proto::RenameDevServer>,
2486    session: UserSession,
2487) -> Result<()> {
2488    if request.name.trim().is_empty() {
2489        return Err(proto::ErrorCode::Forbidden
2490            .message("Dev server name cannot be empty".to_string())
2491            .anyhow())?;
2492    }
2493
2494    let dev_server_id = DevServerId(request.dev_server_id as i32);
2495    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2496    if dev_server.user_id != session.user_id() {
2497        return Err(anyhow!(ErrorCode::Forbidden))?;
2498    }
2499
2500    let status = session
2501        .db()
2502        .await
2503        .rename_dev_server(
2504            dev_server_id,
2505            &request.name,
2506            request.ssh_connection_string.as_deref(),
2507            session.user_id(),
2508        )
2509        .await?;
2510
2511    send_dev_server_projects_update(session.user_id(), status, &session).await;
2512
2513    response.send(proto::Ack {})?;
2514    Ok(())
2515}
2516
2517async fn delete_dev_server(
2518    request: proto::DeleteDevServer,
2519    response: Response<proto::DeleteDevServer>,
2520    session: UserSession,
2521) -> Result<()> {
2522    let dev_server_id = DevServerId(request.dev_server_id as i32);
2523    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2524    if dev_server.user_id != session.user_id() {
2525        return Err(anyhow!(ErrorCode::Forbidden))?;
2526    }
2527
2528    let connection_id = session
2529        .connection_pool()
2530        .await
2531        .dev_server_connection_id(dev_server_id);
2532    if let Some(connection_id) = connection_id {
2533        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2534        session.peer.send(
2535            connection_id,
2536            proto::ShutdownDevServer {
2537                reason: Some("dev server was deleted".to_string()),
2538            },
2539        )?;
2540        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2541    }
2542
2543    let status = session
2544        .db()
2545        .await
2546        .delete_dev_server(dev_server_id, session.user_id())
2547        .await?;
2548
2549    send_dev_server_projects_update(session.user_id(), status, &session).await;
2550
2551    response.send(proto::Ack {})?;
2552    Ok(())
2553}
2554
2555async fn delete_dev_server_project(
2556    request: proto::DeleteDevServerProject,
2557    response: Response<proto::DeleteDevServerProject>,
2558    session: UserSession,
2559) -> Result<()> {
2560    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2561    let dev_server_project = session
2562        .db()
2563        .await
2564        .get_dev_server_project(dev_server_project_id)
2565        .await?;
2566
2567    let dev_server = session
2568        .db()
2569        .await
2570        .get_dev_server(dev_server_project.dev_server_id)
2571        .await?;
2572    if dev_server.user_id != session.user_id() {
2573        return Err(anyhow!(ErrorCode::Forbidden))?;
2574    }
2575
2576    let dev_server_connection_id = session
2577        .connection_pool()
2578        .await
2579        .dev_server_connection_id(dev_server.id);
2580
2581    if let Some(dev_server_connection_id) = dev_server_connection_id {
2582        let project = session
2583            .db()
2584            .await
2585            .find_dev_server_project(dev_server_project_id)
2586            .await;
2587        if let Ok(project) = project {
2588            unshare_project_internal(
2589                project.id,
2590                dev_server_connection_id,
2591                Some(session.user_id()),
2592                &session,
2593            )
2594            .await?;
2595        }
2596    }
2597
2598    let (projects, status) = session
2599        .db()
2600        .await
2601        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2602        .await?;
2603
2604    if let Some(dev_server_connection_id) = dev_server_connection_id {
2605        session.peer.send(
2606            dev_server_connection_id,
2607            proto::DevServerInstructions { projects },
2608        )?;
2609    }
2610
2611    send_dev_server_projects_update(session.user_id(), status, &session).await;
2612
2613    response.send(proto::Ack {})?;
2614    Ok(())
2615}
2616
2617async fn rejoin_dev_server_projects(
2618    request: proto::RejoinRemoteProjects,
2619    response: Response<proto::RejoinRemoteProjects>,
2620    session: UserSession,
2621) -> Result<()> {
2622    let mut rejoined_projects = {
2623        let db = session.db().await;
2624        db.rejoin_dev_server_projects(
2625            &request.rejoined_projects,
2626            session.user_id(),
2627            session.0.connection_id,
2628        )
2629        .await?
2630    };
2631    response.send(proto::RejoinRemoteProjectsResponse {
2632        rejoined_projects: rejoined_projects
2633            .iter()
2634            .map(|project| project.to_proto())
2635            .collect(),
2636    })?;
2637    notify_rejoined_projects(&mut rejoined_projects, &session)
2638}
2639
2640async fn reconnect_dev_server(
2641    request: proto::ReconnectDevServer,
2642    response: Response<proto::ReconnectDevServer>,
2643    session: DevServerSession,
2644) -> Result<()> {
2645    let reshared_projects = {
2646        let db = session.db().await;
2647        db.reshare_dev_server_projects(
2648            &request.reshared_projects,
2649            session.dev_server_id(),
2650            session.0.connection_id,
2651        )
2652        .await?
2653    };
2654
2655    for project in &reshared_projects {
2656        for collaborator in &project.collaborators {
2657            session
2658                .peer
2659                .send(
2660                    collaborator.connection_id,
2661                    proto::UpdateProjectCollaborator {
2662                        project_id: project.id.to_proto(),
2663                        old_peer_id: Some(project.old_connection_id.into()),
2664                        new_peer_id: Some(session.connection_id.into()),
2665                    },
2666                )
2667                .trace_err();
2668        }
2669
2670        broadcast(
2671            Some(session.connection_id),
2672            project
2673                .collaborators
2674                .iter()
2675                .map(|collaborator| collaborator.connection_id),
2676            |connection_id| {
2677                session.peer.forward_send(
2678                    session.connection_id,
2679                    connection_id,
2680                    proto::UpdateProject {
2681                        project_id: project.id.to_proto(),
2682                        worktrees: project.worktrees.clone(),
2683                    },
2684                )
2685            },
2686        );
2687    }
2688
2689    response.send(proto::ReconnectDevServerResponse {
2690        reshared_projects: reshared_projects
2691            .iter()
2692            .map(|project| proto::ResharedProject {
2693                id: project.id.to_proto(),
2694                collaborators: project
2695                    .collaborators
2696                    .iter()
2697                    .map(|collaborator| collaborator.to_proto())
2698                    .collect(),
2699            })
2700            .collect(),
2701    })?;
2702
2703    Ok(())
2704}
2705
2706async fn shutdown_dev_server(
2707    _: proto::ShutdownDevServer,
2708    response: Response<proto::ShutdownDevServer>,
2709    session: DevServerSession,
2710) -> Result<()> {
2711    response.send(proto::Ack {})?;
2712    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2713    remove_dev_server_connection(session.dev_server_id(), &session).await
2714}
2715
2716async fn shutdown_dev_server_internal(
2717    dev_server_id: DevServerId,
2718    connection_id: ConnectionId,
2719    session: &Session,
2720) -> Result<()> {
2721    let (dev_server_projects, dev_server) = {
2722        let db = session.db().await;
2723        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2724        let dev_server = db.get_dev_server(dev_server_id).await?;
2725        (dev_server_projects, dev_server)
2726    };
2727
2728    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2729        unshare_project_internal(
2730            ProjectId::from_proto(project_id),
2731            connection_id,
2732            None,
2733            session,
2734        )
2735        .await?;
2736    }
2737
2738    session
2739        .connection_pool()
2740        .await
2741        .set_dev_server_offline(dev_server_id);
2742
2743    let status = session
2744        .db()
2745        .await
2746        .dev_server_projects_update(dev_server.user_id)
2747        .await?;
2748    send_dev_server_projects_update(dev_server.user_id, status, session).await;
2749
2750    Ok(())
2751}
2752
2753async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2754    let dev_server_connection = session
2755        .connection_pool()
2756        .await
2757        .dev_server_connection_id(dev_server_id);
2758
2759    if let Some(dev_server_connection) = dev_server_connection {
2760        session
2761            .connection_pool()
2762            .await
2763            .remove_connection(dev_server_connection)?;
2764    }
2765    Ok(())
2766}
2767
2768/// Updates other participants with changes to the project
2769async fn update_project(
2770    request: proto::UpdateProject,
2771    response: Response<proto::UpdateProject>,
2772    session: Session,
2773) -> Result<()> {
2774    let project_id = ProjectId::from_proto(request.project_id);
2775    let (room, guest_connection_ids) = &*session
2776        .db()
2777        .await
2778        .update_project(project_id, session.connection_id, &request.worktrees)
2779        .await?;
2780    broadcast(
2781        Some(session.connection_id),
2782        guest_connection_ids.iter().copied(),
2783        |connection_id| {
2784            session
2785                .peer
2786                .forward_send(session.connection_id, connection_id, request.clone())
2787        },
2788    );
2789    if let Some(room) = room {
2790        room_updated(room, &session.peer);
2791    }
2792    response.send(proto::Ack {})?;
2793
2794    Ok(())
2795}
2796
2797/// Updates other participants with changes to the worktree
2798async fn update_worktree(
2799    request: proto::UpdateWorktree,
2800    response: Response<proto::UpdateWorktree>,
2801    session: Session,
2802) -> Result<()> {
2803    let guest_connection_ids = session
2804        .db()
2805        .await
2806        .update_worktree(&request, session.connection_id)
2807        .await?;
2808
2809    broadcast(
2810        Some(session.connection_id),
2811        guest_connection_ids.iter().copied(),
2812        |connection_id| {
2813            session
2814                .peer
2815                .forward_send(session.connection_id, connection_id, request.clone())
2816        },
2817    );
2818    response.send(proto::Ack {})?;
2819    Ok(())
2820}
2821
2822/// Updates other participants with changes to the diagnostics
2823async fn update_diagnostic_summary(
2824    message: proto::UpdateDiagnosticSummary,
2825    session: Session,
2826) -> Result<()> {
2827    let guest_connection_ids = session
2828        .db()
2829        .await
2830        .update_diagnostic_summary(&message, session.connection_id)
2831        .await?;
2832
2833    broadcast(
2834        Some(session.connection_id),
2835        guest_connection_ids.iter().copied(),
2836        |connection_id| {
2837            session
2838                .peer
2839                .forward_send(session.connection_id, connection_id, message.clone())
2840        },
2841    );
2842
2843    Ok(())
2844}
2845
2846/// Updates other participants with changes to the worktree settings
2847async fn update_worktree_settings(
2848    message: proto::UpdateWorktreeSettings,
2849    session: Session,
2850) -> Result<()> {
2851    let guest_connection_ids = session
2852        .db()
2853        .await
2854        .update_worktree_settings(&message, session.connection_id)
2855        .await?;
2856
2857    broadcast(
2858        Some(session.connection_id),
2859        guest_connection_ids.iter().copied(),
2860        |connection_id| {
2861            session
2862                .peer
2863                .forward_send(session.connection_id, connection_id, message.clone())
2864        },
2865    );
2866
2867    Ok(())
2868}
2869
2870/// Notify other participants that a  language server has started.
2871async fn start_language_server(
2872    request: proto::StartLanguageServer,
2873    session: Session,
2874) -> Result<()> {
2875    let guest_connection_ids = session
2876        .db()
2877        .await
2878        .start_language_server(&request, session.connection_id)
2879        .await?;
2880
2881    broadcast(
2882        Some(session.connection_id),
2883        guest_connection_ids.iter().copied(),
2884        |connection_id| {
2885            session
2886                .peer
2887                .forward_send(session.connection_id, connection_id, request.clone())
2888        },
2889    );
2890    Ok(())
2891}
2892
2893/// Notify other participants that a language server has changed.
2894async fn update_language_server(
2895    request: proto::UpdateLanguageServer,
2896    session: Session,
2897) -> Result<()> {
2898    let project_id = ProjectId::from_proto(request.project_id);
2899    let project_connection_ids = session
2900        .db()
2901        .await
2902        .project_connection_ids(project_id, session.connection_id, true)
2903        .await?;
2904    broadcast(
2905        Some(session.connection_id),
2906        project_connection_ids.iter().copied(),
2907        |connection_id| {
2908            session
2909                .peer
2910                .forward_send(session.connection_id, connection_id, request.clone())
2911        },
2912    );
2913    Ok(())
2914}
2915
2916/// forward a project request to the host. These requests should be read only
2917/// as guests are allowed to send them.
2918async fn forward_read_only_project_request<T>(
2919    request: T,
2920    response: Response<T>,
2921    session: UserSession,
2922) -> Result<()>
2923where
2924    T: EntityMessage + RequestMessage,
2925{
2926    let project_id = ProjectId::from_proto(request.remote_entity_id());
2927    let host_connection_id = session
2928        .db()
2929        .await
2930        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2931        .await?;
2932    let payload = session
2933        .peer
2934        .forward_request(session.connection_id, host_connection_id, request)
2935        .await?;
2936    response.send(payload)?;
2937    Ok(())
2938}
2939
2940async fn forward_find_search_candidates_request(
2941    request: proto::FindSearchCandidates,
2942    response: Response<proto::FindSearchCandidates>,
2943    session: UserSession,
2944) -> Result<()> {
2945    let project_id = ProjectId::from_proto(request.remote_entity_id());
2946    let host_connection_id = session
2947        .db()
2948        .await
2949        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2950        .await?;
2951
2952    let host_version = session
2953        .connection_pool()
2954        .await
2955        .connection(host_connection_id)
2956        .map(|c| c.zed_version);
2957
2958    if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2959    {
2960        let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2961        let search = proto::SearchProject {
2962            project_id: project_id.to_proto(),
2963            query: query.query,
2964            regex: query.regex,
2965            whole_word: query.whole_word,
2966            case_sensitive: query.case_sensitive,
2967            files_to_include: query.files_to_include,
2968            files_to_exclude: query.files_to_exclude,
2969            include_ignored: query.include_ignored,
2970        };
2971
2972        let payload = session
2973            .peer
2974            .forward_request(session.connection_id, host_connection_id, search)
2975            .await?;
2976        return response.send(proto::FindSearchCandidatesResponse {
2977            buffer_ids: payload
2978                .locations
2979                .into_iter()
2980                .map(|loc| loc.buffer_id)
2981                .collect(),
2982        });
2983    }
2984
2985    let payload = session
2986        .peer
2987        .forward_request(session.connection_id, host_connection_id, request)
2988        .await?;
2989    response.send(payload)?;
2990    Ok(())
2991}
2992
2993/// forward a project request to the dev server. Only allowed
2994/// if it's your dev server.
2995async fn forward_project_request_for_owner<T>(
2996    request: T,
2997    response: Response<T>,
2998    session: UserSession,
2999) -> Result<()>
3000where
3001    T: EntityMessage + RequestMessage,
3002{
3003    let project_id = ProjectId::from_proto(request.remote_entity_id());
3004
3005    let host_connection_id = session
3006        .db()
3007        .await
3008        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3009        .await?;
3010    let payload = session
3011        .peer
3012        .forward_request(session.connection_id, host_connection_id, request)
3013        .await?;
3014    response.send(payload)?;
3015    Ok(())
3016}
3017
3018/// forward a project request to the host. These requests are disallowed
3019/// for guests.
3020async fn forward_mutating_project_request<T>(
3021    request: T,
3022    response: Response<T>,
3023    session: UserSession,
3024) -> Result<()>
3025where
3026    T: EntityMessage + RequestMessage,
3027{
3028    let project_id = ProjectId::from_proto(request.remote_entity_id());
3029
3030    let host_connection_id = session
3031        .db()
3032        .await
3033        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3034        .await?;
3035    let payload = session
3036        .peer
3037        .forward_request(session.connection_id, host_connection_id, request)
3038        .await?;
3039    response.send(payload)?;
3040    Ok(())
3041}
3042
3043/// Notify other participants that a new buffer has been created
3044async fn create_buffer_for_peer(
3045    request: proto::CreateBufferForPeer,
3046    session: Session,
3047) -> Result<()> {
3048    session
3049        .db()
3050        .await
3051        .check_user_is_project_host(
3052            ProjectId::from_proto(request.project_id),
3053            session.connection_id,
3054        )
3055        .await?;
3056    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3057    session
3058        .peer
3059        .forward_send(session.connection_id, peer_id.into(), request)?;
3060    Ok(())
3061}
3062
3063/// Notify other participants that a buffer has been updated. This is
3064/// allowed for guests as long as the update is limited to selections.
3065async fn update_buffer(
3066    request: proto::UpdateBuffer,
3067    response: Response<proto::UpdateBuffer>,
3068    session: Session,
3069) -> Result<()> {
3070    let project_id = ProjectId::from_proto(request.project_id);
3071    let mut capability = Capability::ReadOnly;
3072
3073    for op in request.operations.iter() {
3074        match op.variant {
3075            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3076            Some(_) => capability = Capability::ReadWrite,
3077        }
3078    }
3079
3080    let host = {
3081        let guard = session
3082            .db()
3083            .await
3084            .connections_for_buffer_update(
3085                project_id,
3086                session.principal_id(),
3087                session.connection_id,
3088                capability,
3089            )
3090            .await?;
3091
3092        let (host, guests) = &*guard;
3093
3094        broadcast(
3095            Some(session.connection_id),
3096            guests.clone(),
3097            |connection_id| {
3098                session
3099                    .peer
3100                    .forward_send(session.connection_id, connection_id, request.clone())
3101            },
3102        );
3103
3104        *host
3105    };
3106
3107    if host != session.connection_id {
3108        session
3109            .peer
3110            .forward_request(session.connection_id, host, request.clone())
3111            .await?;
3112    }
3113
3114    response.send(proto::Ack {})?;
3115    Ok(())
3116}
3117
3118async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3119    let project_id = ProjectId::from_proto(message.project_id);
3120
3121    let operation = message.operation.as_ref().context("invalid operation")?;
3122    let capability = match operation.variant.as_ref() {
3123        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3124            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3125                match buffer_op.variant {
3126                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3127                        Capability::ReadOnly
3128                    }
3129                    _ => Capability::ReadWrite,
3130                }
3131            } else {
3132                Capability::ReadWrite
3133            }
3134        }
3135        Some(_) => Capability::ReadWrite,
3136        None => Capability::ReadOnly,
3137    };
3138
3139    let guard = session
3140        .db()
3141        .await
3142        .connections_for_buffer_update(
3143            project_id,
3144            session.principal_id(),
3145            session.connection_id,
3146            capability,
3147        )
3148        .await?;
3149
3150    let (host, guests) = &*guard;
3151
3152    broadcast(
3153        Some(session.connection_id),
3154        guests.iter().chain([host]).copied(),
3155        |connection_id| {
3156            session
3157                .peer
3158                .forward_send(session.connection_id, connection_id, message.clone())
3159        },
3160    );
3161
3162    Ok(())
3163}
3164
3165/// Notify other participants that a project has been updated.
3166async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3167    request: T,
3168    session: Session,
3169) -> Result<()> {
3170    let project_id = ProjectId::from_proto(request.remote_entity_id());
3171    let project_connection_ids = session
3172        .db()
3173        .await
3174        .project_connection_ids(project_id, session.connection_id, false)
3175        .await?;
3176
3177    broadcast(
3178        Some(session.connection_id),
3179        project_connection_ids.iter().copied(),
3180        |connection_id| {
3181            session
3182                .peer
3183                .forward_send(session.connection_id, connection_id, request.clone())
3184        },
3185    );
3186    Ok(())
3187}
3188
3189/// Start following another user in a call.
3190async fn follow(
3191    request: proto::Follow,
3192    response: Response<proto::Follow>,
3193    session: UserSession,
3194) -> Result<()> {
3195    let room_id = RoomId::from_proto(request.room_id);
3196    let project_id = request.project_id.map(ProjectId::from_proto);
3197    let leader_id = request
3198        .leader_id
3199        .ok_or_else(|| anyhow!("invalid leader id"))?
3200        .into();
3201    let follower_id = session.connection_id;
3202
3203    session
3204        .db()
3205        .await
3206        .check_room_participants(room_id, leader_id, session.connection_id)
3207        .await?;
3208
3209    let response_payload = session
3210        .peer
3211        .forward_request(session.connection_id, leader_id, request)
3212        .await?;
3213    response.send(response_payload)?;
3214
3215    if let Some(project_id) = project_id {
3216        let room = session
3217            .db()
3218            .await
3219            .follow(room_id, project_id, leader_id, follower_id)
3220            .await?;
3221        room_updated(&room, &session.peer);
3222    }
3223
3224    Ok(())
3225}
3226
3227/// Stop following another user in a call.
3228async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3229    let room_id = RoomId::from_proto(request.room_id);
3230    let project_id = request.project_id.map(ProjectId::from_proto);
3231    let leader_id = request
3232        .leader_id
3233        .ok_or_else(|| anyhow!("invalid leader id"))?
3234        .into();
3235    let follower_id = session.connection_id;
3236
3237    session
3238        .db()
3239        .await
3240        .check_room_participants(room_id, leader_id, session.connection_id)
3241        .await?;
3242
3243    session
3244        .peer
3245        .forward_send(session.connection_id, leader_id, request)?;
3246
3247    if let Some(project_id) = project_id {
3248        let room = session
3249            .db()
3250            .await
3251            .unfollow(room_id, project_id, leader_id, follower_id)
3252            .await?;
3253        room_updated(&room, &session.peer);
3254    }
3255
3256    Ok(())
3257}
3258
3259/// Notify everyone following you of your current location.
3260async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3261    let room_id = RoomId::from_proto(request.room_id);
3262    let database = session.db.lock().await;
3263
3264    let connection_ids = if let Some(project_id) = request.project_id {
3265        let project_id = ProjectId::from_proto(project_id);
3266        database
3267            .project_connection_ids(project_id, session.connection_id, true)
3268            .await?
3269    } else {
3270        database
3271            .room_connection_ids(room_id, session.connection_id)
3272            .await?
3273    };
3274
3275    // For now, don't send view update messages back to that view's current leader.
3276    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3277        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3278        _ => None,
3279    });
3280
3281    for connection_id in connection_ids.iter().cloned() {
3282        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3283            session
3284                .peer
3285                .forward_send(session.connection_id, connection_id, request.clone())?;
3286        }
3287    }
3288    Ok(())
3289}
3290
3291/// Get public data about users.
3292async fn get_users(
3293    request: proto::GetUsers,
3294    response: Response<proto::GetUsers>,
3295    session: Session,
3296) -> Result<()> {
3297    let user_ids = request
3298        .user_ids
3299        .into_iter()
3300        .map(UserId::from_proto)
3301        .collect();
3302    let users = session
3303        .db()
3304        .await
3305        .get_users_by_ids(user_ids)
3306        .await?
3307        .into_iter()
3308        .map(|user| proto::User {
3309            id: user.id.to_proto(),
3310            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3311            github_login: user.github_login,
3312        })
3313        .collect();
3314    response.send(proto::UsersResponse { users })?;
3315    Ok(())
3316}
3317
3318/// Search for users (to invite) buy Github login
3319async fn fuzzy_search_users(
3320    request: proto::FuzzySearchUsers,
3321    response: Response<proto::FuzzySearchUsers>,
3322    session: UserSession,
3323) -> Result<()> {
3324    let query = request.query;
3325    let users = match query.len() {
3326        0 => vec![],
3327        1 | 2 => session
3328            .db()
3329            .await
3330            .get_user_by_github_login(&query)
3331            .await?
3332            .into_iter()
3333            .collect(),
3334        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3335    };
3336    let users = users
3337        .into_iter()
3338        .filter(|user| user.id != session.user_id())
3339        .map(|user| proto::User {
3340            id: user.id.to_proto(),
3341            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3342            github_login: user.github_login,
3343        })
3344        .collect();
3345    response.send(proto::UsersResponse { users })?;
3346    Ok(())
3347}
3348
3349/// Send a contact request to another user.
3350async fn request_contact(
3351    request: proto::RequestContact,
3352    response: Response<proto::RequestContact>,
3353    session: UserSession,
3354) -> Result<()> {
3355    let requester_id = session.user_id();
3356    let responder_id = UserId::from_proto(request.responder_id);
3357    if requester_id == responder_id {
3358        return Err(anyhow!("cannot add yourself as a contact"))?;
3359    }
3360
3361    let notifications = session
3362        .db()
3363        .await
3364        .send_contact_request(requester_id, responder_id)
3365        .await?;
3366
3367    // Update outgoing contact requests of requester
3368    let mut update = proto::UpdateContacts::default();
3369    update.outgoing_requests.push(responder_id.to_proto());
3370    for connection_id in session
3371        .connection_pool()
3372        .await
3373        .user_connection_ids(requester_id)
3374    {
3375        session.peer.send(connection_id, update.clone())?;
3376    }
3377
3378    // Update incoming contact requests of responder
3379    let mut update = proto::UpdateContacts::default();
3380    update
3381        .incoming_requests
3382        .push(proto::IncomingContactRequest {
3383            requester_id: requester_id.to_proto(),
3384        });
3385    let connection_pool = session.connection_pool().await;
3386    for connection_id in connection_pool.user_connection_ids(responder_id) {
3387        session.peer.send(connection_id, update.clone())?;
3388    }
3389
3390    send_notifications(&connection_pool, &session.peer, notifications);
3391
3392    response.send(proto::Ack {})?;
3393    Ok(())
3394}
3395
3396/// Accept or decline a contact request
3397async fn respond_to_contact_request(
3398    request: proto::RespondToContactRequest,
3399    response: Response<proto::RespondToContactRequest>,
3400    session: UserSession,
3401) -> Result<()> {
3402    let responder_id = session.user_id();
3403    let requester_id = UserId::from_proto(request.requester_id);
3404    let db = session.db().await;
3405    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3406        db.dismiss_contact_notification(responder_id, requester_id)
3407            .await?;
3408    } else {
3409        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3410
3411        let notifications = db
3412            .respond_to_contact_request(responder_id, requester_id, accept)
3413            .await?;
3414        let requester_busy = db.is_user_busy(requester_id).await?;
3415        let responder_busy = db.is_user_busy(responder_id).await?;
3416
3417        let pool = session.connection_pool().await;
3418        // Update responder with new contact
3419        let mut update = proto::UpdateContacts::default();
3420        if accept {
3421            update
3422                .contacts
3423                .push(contact_for_user(requester_id, requester_busy, &pool));
3424        }
3425        update
3426            .remove_incoming_requests
3427            .push(requester_id.to_proto());
3428        for connection_id in pool.user_connection_ids(responder_id) {
3429            session.peer.send(connection_id, update.clone())?;
3430        }
3431
3432        // Update requester with new contact
3433        let mut update = proto::UpdateContacts::default();
3434        if accept {
3435            update
3436                .contacts
3437                .push(contact_for_user(responder_id, responder_busy, &pool));
3438        }
3439        update
3440            .remove_outgoing_requests
3441            .push(responder_id.to_proto());
3442
3443        for connection_id in pool.user_connection_ids(requester_id) {
3444            session.peer.send(connection_id, update.clone())?;
3445        }
3446
3447        send_notifications(&pool, &session.peer, notifications);
3448    }
3449
3450    response.send(proto::Ack {})?;
3451    Ok(())
3452}
3453
3454/// Remove a contact.
3455async fn remove_contact(
3456    request: proto::RemoveContact,
3457    response: Response<proto::RemoveContact>,
3458    session: UserSession,
3459) -> Result<()> {
3460    let requester_id = session.user_id();
3461    let responder_id = UserId::from_proto(request.user_id);
3462    let db = session.db().await;
3463    let (contact_accepted, deleted_notification_id) =
3464        db.remove_contact(requester_id, responder_id).await?;
3465
3466    let pool = session.connection_pool().await;
3467    // Update outgoing contact requests of requester
3468    let mut update = proto::UpdateContacts::default();
3469    if contact_accepted {
3470        update.remove_contacts.push(responder_id.to_proto());
3471    } else {
3472        update
3473            .remove_outgoing_requests
3474            .push(responder_id.to_proto());
3475    }
3476    for connection_id in pool.user_connection_ids(requester_id) {
3477        session.peer.send(connection_id, update.clone())?;
3478    }
3479
3480    // Update incoming contact requests of responder
3481    let mut update = proto::UpdateContacts::default();
3482    if contact_accepted {
3483        update.remove_contacts.push(requester_id.to_proto());
3484    } else {
3485        update
3486            .remove_incoming_requests
3487            .push(requester_id.to_proto());
3488    }
3489    for connection_id in pool.user_connection_ids(responder_id) {
3490        session.peer.send(connection_id, update.clone())?;
3491        if let Some(notification_id) = deleted_notification_id {
3492            session.peer.send(
3493                connection_id,
3494                proto::DeleteNotification {
3495                    notification_id: notification_id.to_proto(),
3496                },
3497            )?;
3498        }
3499    }
3500
3501    response.send(proto::Ack {})?;
3502    Ok(())
3503}
3504
3505fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3506    version.0.minor() < 139
3507}
3508
3509async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3510    let plan = session.current_plan(session.db().await).await?;
3511
3512    session
3513        .peer
3514        .send(
3515            session.connection_id,
3516            proto::UpdateUserPlan { plan: plan.into() },
3517        )
3518        .trace_err();
3519
3520    Ok(())
3521}
3522
3523async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3524    subscribe_user_to_channels(
3525        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3526        &session,
3527    )
3528    .await?;
3529    Ok(())
3530}
3531
3532async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3533    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3534    let mut pool = session.connection_pool().await;
3535    for membership in &channels_for_user.channel_memberships {
3536        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3537    }
3538    session.peer.send(
3539        session.connection_id,
3540        build_update_user_channels(&channels_for_user),
3541    )?;
3542    session.peer.send(
3543        session.connection_id,
3544        build_channels_update(channels_for_user),
3545    )?;
3546    Ok(())
3547}
3548
3549/// Creates a new channel.
3550async fn create_channel(
3551    request: proto::CreateChannel,
3552    response: Response<proto::CreateChannel>,
3553    session: UserSession,
3554) -> Result<()> {
3555    let db = session.db().await;
3556
3557    let parent_id = request.parent_id.map(ChannelId::from_proto);
3558    let (channel, membership) = db
3559        .create_channel(&request.name, parent_id, session.user_id())
3560        .await?;
3561
3562    let root_id = channel.root_id();
3563    let channel = Channel::from_model(channel);
3564
3565    response.send(proto::CreateChannelResponse {
3566        channel: Some(channel.to_proto()),
3567        parent_id: request.parent_id,
3568    })?;
3569
3570    let mut connection_pool = session.connection_pool().await;
3571    if let Some(membership) = membership {
3572        connection_pool.subscribe_to_channel(
3573            membership.user_id,
3574            membership.channel_id,
3575            membership.role,
3576        );
3577        let update = proto::UpdateUserChannels {
3578            channel_memberships: vec![proto::ChannelMembership {
3579                channel_id: membership.channel_id.to_proto(),
3580                role: membership.role.into(),
3581            }],
3582            ..Default::default()
3583        };
3584        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3585            session.peer.send(connection_id, update.clone())?;
3586        }
3587    }
3588
3589    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3590        if !role.can_see_channel(channel.visibility) {
3591            continue;
3592        }
3593
3594        let update = proto::UpdateChannels {
3595            channels: vec![channel.to_proto()],
3596            ..Default::default()
3597        };
3598        session.peer.send(connection_id, update.clone())?;
3599    }
3600
3601    Ok(())
3602}
3603
3604/// Delete a channel
3605async fn delete_channel(
3606    request: proto::DeleteChannel,
3607    response: Response<proto::DeleteChannel>,
3608    session: UserSession,
3609) -> Result<()> {
3610    let db = session.db().await;
3611
3612    let channel_id = request.channel_id;
3613    let (root_channel, removed_channels) = db
3614        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3615        .await?;
3616    response.send(proto::Ack {})?;
3617
3618    // Notify members of removed channels
3619    let mut update = proto::UpdateChannels::default();
3620    update
3621        .delete_channels
3622        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3623
3624    let connection_pool = session.connection_pool().await;
3625    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3626        session.peer.send(connection_id, update.clone())?;
3627    }
3628
3629    Ok(())
3630}
3631
3632/// Invite someone to join a channel.
3633async fn invite_channel_member(
3634    request: proto::InviteChannelMember,
3635    response: Response<proto::InviteChannelMember>,
3636    session: UserSession,
3637) -> Result<()> {
3638    let db = session.db().await;
3639    let channel_id = ChannelId::from_proto(request.channel_id);
3640    let invitee_id = UserId::from_proto(request.user_id);
3641    let InviteMemberResult {
3642        channel,
3643        notifications,
3644    } = db
3645        .invite_channel_member(
3646            channel_id,
3647            invitee_id,
3648            session.user_id(),
3649            request.role().into(),
3650        )
3651        .await?;
3652
3653    let update = proto::UpdateChannels {
3654        channel_invitations: vec![channel.to_proto()],
3655        ..Default::default()
3656    };
3657
3658    let connection_pool = session.connection_pool().await;
3659    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3660        session.peer.send(connection_id, update.clone())?;
3661    }
3662
3663    send_notifications(&connection_pool, &session.peer, notifications);
3664
3665    response.send(proto::Ack {})?;
3666    Ok(())
3667}
3668
3669/// remove someone from a channel
3670async fn remove_channel_member(
3671    request: proto::RemoveChannelMember,
3672    response: Response<proto::RemoveChannelMember>,
3673    session: UserSession,
3674) -> Result<()> {
3675    let db = session.db().await;
3676    let channel_id = ChannelId::from_proto(request.channel_id);
3677    let member_id = UserId::from_proto(request.user_id);
3678
3679    let RemoveChannelMemberResult {
3680        membership_update,
3681        notification_id,
3682    } = db
3683        .remove_channel_member(channel_id, member_id, session.user_id())
3684        .await?;
3685
3686    let mut connection_pool = session.connection_pool().await;
3687    notify_membership_updated(
3688        &mut connection_pool,
3689        membership_update,
3690        member_id,
3691        &session.peer,
3692    );
3693    for connection_id in connection_pool.user_connection_ids(member_id) {
3694        if let Some(notification_id) = notification_id {
3695            session
3696                .peer
3697                .send(
3698                    connection_id,
3699                    proto::DeleteNotification {
3700                        notification_id: notification_id.to_proto(),
3701                    },
3702                )
3703                .trace_err();
3704        }
3705    }
3706
3707    response.send(proto::Ack {})?;
3708    Ok(())
3709}
3710
3711/// Toggle the channel between public and private.
3712/// Care is taken to maintain the invariant that public channels only descend from public channels,
3713/// (though members-only channels can appear at any point in the hierarchy).
3714async fn set_channel_visibility(
3715    request: proto::SetChannelVisibility,
3716    response: Response<proto::SetChannelVisibility>,
3717    session: UserSession,
3718) -> Result<()> {
3719    let db = session.db().await;
3720    let channel_id = ChannelId::from_proto(request.channel_id);
3721    let visibility = request.visibility().into();
3722
3723    let channel_model = db
3724        .set_channel_visibility(channel_id, visibility, session.user_id())
3725        .await?;
3726    let root_id = channel_model.root_id();
3727    let channel = Channel::from_model(channel_model);
3728
3729    let mut connection_pool = session.connection_pool().await;
3730    for (user_id, role) in connection_pool
3731        .channel_user_ids(root_id)
3732        .collect::<Vec<_>>()
3733        .into_iter()
3734    {
3735        let update = if role.can_see_channel(channel.visibility) {
3736            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3737            proto::UpdateChannels {
3738                channels: vec![channel.to_proto()],
3739                ..Default::default()
3740            }
3741        } else {
3742            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3743            proto::UpdateChannels {
3744                delete_channels: vec![channel.id.to_proto()],
3745                ..Default::default()
3746            }
3747        };
3748
3749        for connection_id in connection_pool.user_connection_ids(user_id) {
3750            session.peer.send(connection_id, update.clone())?;
3751        }
3752    }
3753
3754    response.send(proto::Ack {})?;
3755    Ok(())
3756}
3757
3758/// Alter the role for a user in the channel.
3759async fn set_channel_member_role(
3760    request: proto::SetChannelMemberRole,
3761    response: Response<proto::SetChannelMemberRole>,
3762    session: UserSession,
3763) -> Result<()> {
3764    let db = session.db().await;
3765    let channel_id = ChannelId::from_proto(request.channel_id);
3766    let member_id = UserId::from_proto(request.user_id);
3767    let result = db
3768        .set_channel_member_role(
3769            channel_id,
3770            session.user_id(),
3771            member_id,
3772            request.role().into(),
3773        )
3774        .await?;
3775
3776    match result {
3777        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3778            let mut connection_pool = session.connection_pool().await;
3779            notify_membership_updated(
3780                &mut connection_pool,
3781                membership_update,
3782                member_id,
3783                &session.peer,
3784            )
3785        }
3786        db::SetMemberRoleResult::InviteUpdated(channel) => {
3787            let update = proto::UpdateChannels {
3788                channel_invitations: vec![channel.to_proto()],
3789                ..Default::default()
3790            };
3791
3792            for connection_id in session
3793                .connection_pool()
3794                .await
3795                .user_connection_ids(member_id)
3796            {
3797                session.peer.send(connection_id, update.clone())?;
3798            }
3799        }
3800    }
3801
3802    response.send(proto::Ack {})?;
3803    Ok(())
3804}
3805
3806/// Change the name of a channel
3807async fn rename_channel(
3808    request: proto::RenameChannel,
3809    response: Response<proto::RenameChannel>,
3810    session: UserSession,
3811) -> Result<()> {
3812    let db = session.db().await;
3813    let channel_id = ChannelId::from_proto(request.channel_id);
3814    let channel_model = db
3815        .rename_channel(channel_id, session.user_id(), &request.name)
3816        .await?;
3817    let root_id = channel_model.root_id();
3818    let channel = Channel::from_model(channel_model);
3819
3820    response.send(proto::RenameChannelResponse {
3821        channel: Some(channel.to_proto()),
3822    })?;
3823
3824    let connection_pool = session.connection_pool().await;
3825    let update = proto::UpdateChannels {
3826        channels: vec![channel.to_proto()],
3827        ..Default::default()
3828    };
3829    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3830        if role.can_see_channel(channel.visibility) {
3831            session.peer.send(connection_id, update.clone())?;
3832        }
3833    }
3834
3835    Ok(())
3836}
3837
3838/// Move a channel to a new parent.
3839async fn move_channel(
3840    request: proto::MoveChannel,
3841    response: Response<proto::MoveChannel>,
3842    session: UserSession,
3843) -> Result<()> {
3844    let channel_id = ChannelId::from_proto(request.channel_id);
3845    let to = ChannelId::from_proto(request.to);
3846
3847    let (root_id, channels) = session
3848        .db()
3849        .await
3850        .move_channel(channel_id, to, session.user_id())
3851        .await?;
3852
3853    let connection_pool = session.connection_pool().await;
3854    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3855        let channels = channels
3856            .iter()
3857            .filter_map(|channel| {
3858                if role.can_see_channel(channel.visibility) {
3859                    Some(channel.to_proto())
3860                } else {
3861                    None
3862                }
3863            })
3864            .collect::<Vec<_>>();
3865        if channels.is_empty() {
3866            continue;
3867        }
3868
3869        let update = proto::UpdateChannels {
3870            channels,
3871            ..Default::default()
3872        };
3873
3874        session.peer.send(connection_id, update.clone())?;
3875    }
3876
3877    response.send(Ack {})?;
3878    Ok(())
3879}
3880
3881/// Get the list of channel members
3882async fn get_channel_members(
3883    request: proto::GetChannelMembers,
3884    response: Response<proto::GetChannelMembers>,
3885    session: UserSession,
3886) -> Result<()> {
3887    let db = session.db().await;
3888    let channel_id = ChannelId::from_proto(request.channel_id);
3889    let limit = if request.limit == 0 {
3890        u16::MAX as u64
3891    } else {
3892        request.limit
3893    };
3894    let (members, users) = db
3895        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3896        .await?;
3897    response.send(proto::GetChannelMembersResponse { members, users })?;
3898    Ok(())
3899}
3900
3901/// Accept or decline a channel invitation.
3902async fn respond_to_channel_invite(
3903    request: proto::RespondToChannelInvite,
3904    response: Response<proto::RespondToChannelInvite>,
3905    session: UserSession,
3906) -> Result<()> {
3907    let db = session.db().await;
3908    let channel_id = ChannelId::from_proto(request.channel_id);
3909    let RespondToChannelInvite {
3910        membership_update,
3911        notifications,
3912    } = db
3913        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3914        .await?;
3915
3916    let mut connection_pool = session.connection_pool().await;
3917    if let Some(membership_update) = membership_update {
3918        notify_membership_updated(
3919            &mut connection_pool,
3920            membership_update,
3921            session.user_id(),
3922            &session.peer,
3923        );
3924    } else {
3925        let update = proto::UpdateChannels {
3926            remove_channel_invitations: vec![channel_id.to_proto()],
3927            ..Default::default()
3928        };
3929
3930        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3931            session.peer.send(connection_id, update.clone())?;
3932        }
3933    };
3934
3935    send_notifications(&connection_pool, &session.peer, notifications);
3936
3937    response.send(proto::Ack {})?;
3938
3939    Ok(())
3940}
3941
3942/// Join the channels' room
3943async fn join_channel(
3944    request: proto::JoinChannel,
3945    response: Response<proto::JoinChannel>,
3946    session: UserSession,
3947) -> Result<()> {
3948    let channel_id = ChannelId::from_proto(request.channel_id);
3949    join_channel_internal(channel_id, Box::new(response), session).await
3950}
3951
3952trait JoinChannelInternalResponse {
3953    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3954}
3955impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3956    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3957        Response::<proto::JoinChannel>::send(self, result)
3958    }
3959}
3960impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3961    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3962        Response::<proto::JoinRoom>::send(self, result)
3963    }
3964}
3965
3966async fn join_channel_internal(
3967    channel_id: ChannelId,
3968    response: Box<impl JoinChannelInternalResponse>,
3969    session: UserSession,
3970) -> Result<()> {
3971    let joined_room = {
3972        let mut db = session.db().await;
3973        // If zed quits without leaving the room, and the user re-opens zed before the
3974        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3975        // room they were in.
3976        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3977            tracing::info!(
3978                stale_connection_id = %connection,
3979                "cleaning up stale connection",
3980            );
3981            drop(db);
3982            leave_room_for_session(&session, connection).await?;
3983            db = session.db().await;
3984        }
3985
3986        let (joined_room, membership_updated, role) = db
3987            .join_channel(channel_id, session.user_id(), session.connection_id)
3988            .await?;
3989
3990        let live_kit_connection_info =
3991            session
3992                .app_state
3993                .live_kit_client
3994                .as_ref()
3995                .and_then(|live_kit| {
3996                    let (can_publish, token) = if role == ChannelRole::Guest {
3997                        (
3998                            false,
3999                            live_kit
4000                                .guest_token(
4001                                    &joined_room.room.live_kit_room,
4002                                    &session.user_id().to_string(),
4003                                )
4004                                .trace_err()?,
4005                        )
4006                    } else {
4007                        (
4008                            true,
4009                            live_kit
4010                                .room_token(
4011                                    &joined_room.room.live_kit_room,
4012                                    &session.user_id().to_string(),
4013                                )
4014                                .trace_err()?,
4015                        )
4016                    };
4017
4018                    Some(LiveKitConnectionInfo {
4019                        server_url: live_kit.url().into(),
4020                        token,
4021                        can_publish,
4022                    })
4023                });
4024
4025        response.send(proto::JoinRoomResponse {
4026            room: Some(joined_room.room.clone()),
4027            channel_id: joined_room
4028                .channel
4029                .as_ref()
4030                .map(|channel| channel.id.to_proto()),
4031            live_kit_connection_info,
4032        })?;
4033
4034        let mut connection_pool = session.connection_pool().await;
4035        if let Some(membership_updated) = membership_updated {
4036            notify_membership_updated(
4037                &mut connection_pool,
4038                membership_updated,
4039                session.user_id(),
4040                &session.peer,
4041            );
4042        }
4043
4044        room_updated(&joined_room.room, &session.peer);
4045
4046        joined_room
4047    };
4048
4049    channel_updated(
4050        &joined_room
4051            .channel
4052            .ok_or_else(|| anyhow!("channel not returned"))?,
4053        &joined_room.room,
4054        &session.peer,
4055        &*session.connection_pool().await,
4056    );
4057
4058    update_user_contacts(session.user_id(), &session).await?;
4059    Ok(())
4060}
4061
4062/// Start editing the channel notes
4063async fn join_channel_buffer(
4064    request: proto::JoinChannelBuffer,
4065    response: Response<proto::JoinChannelBuffer>,
4066    session: UserSession,
4067) -> Result<()> {
4068    let db = session.db().await;
4069    let channel_id = ChannelId::from_proto(request.channel_id);
4070
4071    let open_response = db
4072        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4073        .await?;
4074
4075    let collaborators = open_response.collaborators.clone();
4076    response.send(open_response)?;
4077
4078    let update = UpdateChannelBufferCollaborators {
4079        channel_id: channel_id.to_proto(),
4080        collaborators: collaborators.clone(),
4081    };
4082    channel_buffer_updated(
4083        session.connection_id,
4084        collaborators
4085            .iter()
4086            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4087        &update,
4088        &session.peer,
4089    );
4090
4091    Ok(())
4092}
4093
4094/// Edit the channel notes
4095async fn update_channel_buffer(
4096    request: proto::UpdateChannelBuffer,
4097    session: UserSession,
4098) -> Result<()> {
4099    let db = session.db().await;
4100    let channel_id = ChannelId::from_proto(request.channel_id);
4101
4102    let (collaborators, epoch, version) = db
4103        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4104        .await?;
4105
4106    channel_buffer_updated(
4107        session.connection_id,
4108        collaborators.clone(),
4109        &proto::UpdateChannelBuffer {
4110            channel_id: channel_id.to_proto(),
4111            operations: request.operations,
4112        },
4113        &session.peer,
4114    );
4115
4116    let pool = &*session.connection_pool().await;
4117
4118    let non_collaborators =
4119        pool.channel_connection_ids(channel_id)
4120            .filter_map(|(connection_id, _)| {
4121                if collaborators.contains(&connection_id) {
4122                    None
4123                } else {
4124                    Some(connection_id)
4125                }
4126            });
4127
4128    broadcast(None, non_collaborators, |peer_id| {
4129        session.peer.send(
4130            peer_id,
4131            proto::UpdateChannels {
4132                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4133                    channel_id: channel_id.to_proto(),
4134                    epoch: epoch as u64,
4135                    version: version.clone(),
4136                }],
4137                ..Default::default()
4138            },
4139        )
4140    });
4141
4142    Ok(())
4143}
4144
4145/// Rejoin the channel notes after a connection blip
4146async fn rejoin_channel_buffers(
4147    request: proto::RejoinChannelBuffers,
4148    response: Response<proto::RejoinChannelBuffers>,
4149    session: UserSession,
4150) -> Result<()> {
4151    let db = session.db().await;
4152    let buffers = db
4153        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4154        .await?;
4155
4156    for rejoined_buffer in &buffers {
4157        let collaborators_to_notify = rejoined_buffer
4158            .buffer
4159            .collaborators
4160            .iter()
4161            .filter_map(|c| Some(c.peer_id?.into()));
4162        channel_buffer_updated(
4163            session.connection_id,
4164            collaborators_to_notify,
4165            &proto::UpdateChannelBufferCollaborators {
4166                channel_id: rejoined_buffer.buffer.channel_id,
4167                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4168            },
4169            &session.peer,
4170        );
4171    }
4172
4173    response.send(proto::RejoinChannelBuffersResponse {
4174        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4175    })?;
4176
4177    Ok(())
4178}
4179
4180/// Stop editing the channel notes
4181async fn leave_channel_buffer(
4182    request: proto::LeaveChannelBuffer,
4183    response: Response<proto::LeaveChannelBuffer>,
4184    session: UserSession,
4185) -> Result<()> {
4186    let db = session.db().await;
4187    let channel_id = ChannelId::from_proto(request.channel_id);
4188
4189    let left_buffer = db
4190        .leave_channel_buffer(channel_id, session.connection_id)
4191        .await?;
4192
4193    response.send(Ack {})?;
4194
4195    channel_buffer_updated(
4196        session.connection_id,
4197        left_buffer.connections,
4198        &proto::UpdateChannelBufferCollaborators {
4199            channel_id: channel_id.to_proto(),
4200            collaborators: left_buffer.collaborators,
4201        },
4202        &session.peer,
4203    );
4204
4205    Ok(())
4206}
4207
4208fn channel_buffer_updated<T: EnvelopedMessage>(
4209    sender_id: ConnectionId,
4210    collaborators: impl IntoIterator<Item = ConnectionId>,
4211    message: &T,
4212    peer: &Peer,
4213) {
4214    broadcast(Some(sender_id), collaborators, |peer_id| {
4215        peer.send(peer_id, message.clone())
4216    });
4217}
4218
4219fn send_notifications(
4220    connection_pool: &ConnectionPool,
4221    peer: &Peer,
4222    notifications: db::NotificationBatch,
4223) {
4224    for (user_id, notification) in notifications {
4225        for connection_id in connection_pool.user_connection_ids(user_id) {
4226            if let Err(error) = peer.send(
4227                connection_id,
4228                proto::AddNotification {
4229                    notification: Some(notification.clone()),
4230                },
4231            ) {
4232                tracing::error!(
4233                    "failed to send notification to {:?} {}",
4234                    connection_id,
4235                    error
4236                );
4237            }
4238        }
4239    }
4240}
4241
4242/// Send a message to the channel
4243async fn send_channel_message(
4244    request: proto::SendChannelMessage,
4245    response: Response<proto::SendChannelMessage>,
4246    session: UserSession,
4247) -> Result<()> {
4248    // Validate the message body.
4249    let body = request.body.trim().to_string();
4250    if body.len() > MAX_MESSAGE_LEN {
4251        return Err(anyhow!("message is too long"))?;
4252    }
4253    if body.is_empty() {
4254        return Err(anyhow!("message can't be blank"))?;
4255    }
4256
4257    // TODO: adjust mentions if body is trimmed
4258
4259    let timestamp = OffsetDateTime::now_utc();
4260    let nonce = request
4261        .nonce
4262        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4263
4264    let channel_id = ChannelId::from_proto(request.channel_id);
4265    let CreatedChannelMessage {
4266        message_id,
4267        participant_connection_ids,
4268        notifications,
4269    } = session
4270        .db()
4271        .await
4272        .create_channel_message(
4273            channel_id,
4274            session.user_id(),
4275            &body,
4276            &request.mentions,
4277            timestamp,
4278            nonce.clone().into(),
4279            request.reply_to_message_id.map(MessageId::from_proto),
4280        )
4281        .await?;
4282
4283    let message = proto::ChannelMessage {
4284        sender_id: session.user_id().to_proto(),
4285        id: message_id.to_proto(),
4286        body,
4287        mentions: request.mentions,
4288        timestamp: timestamp.unix_timestamp() as u64,
4289        nonce: Some(nonce),
4290        reply_to_message_id: request.reply_to_message_id,
4291        edited_at: None,
4292    };
4293    broadcast(
4294        Some(session.connection_id),
4295        participant_connection_ids.clone(),
4296        |connection| {
4297            session.peer.send(
4298                connection,
4299                proto::ChannelMessageSent {
4300                    channel_id: channel_id.to_proto(),
4301                    message: Some(message.clone()),
4302                },
4303            )
4304        },
4305    );
4306    response.send(proto::SendChannelMessageResponse {
4307        message: Some(message),
4308    })?;
4309
4310    let pool = &*session.connection_pool().await;
4311    let non_participants =
4312        pool.channel_connection_ids(channel_id)
4313            .filter_map(|(connection_id, _)| {
4314                if participant_connection_ids.contains(&connection_id) {
4315                    None
4316                } else {
4317                    Some(connection_id)
4318                }
4319            });
4320    broadcast(None, non_participants, |peer_id| {
4321        session.peer.send(
4322            peer_id,
4323            proto::UpdateChannels {
4324                latest_channel_message_ids: vec![proto::ChannelMessageId {
4325                    channel_id: channel_id.to_proto(),
4326                    message_id: message_id.to_proto(),
4327                }],
4328                ..Default::default()
4329            },
4330        )
4331    });
4332    send_notifications(pool, &session.peer, notifications);
4333
4334    Ok(())
4335}
4336
4337/// Delete a channel message
4338async fn remove_channel_message(
4339    request: proto::RemoveChannelMessage,
4340    response: Response<proto::RemoveChannelMessage>,
4341    session: UserSession,
4342) -> Result<()> {
4343    let channel_id = ChannelId::from_proto(request.channel_id);
4344    let message_id = MessageId::from_proto(request.message_id);
4345    let (connection_ids, existing_notification_ids) = session
4346        .db()
4347        .await
4348        .remove_channel_message(channel_id, message_id, session.user_id())
4349        .await?;
4350
4351    broadcast(
4352        Some(session.connection_id),
4353        connection_ids,
4354        move |connection| {
4355            session.peer.send(connection, request.clone())?;
4356
4357            for notification_id in &existing_notification_ids {
4358                session.peer.send(
4359                    connection,
4360                    proto::DeleteNotification {
4361                        notification_id: (*notification_id).to_proto(),
4362                    },
4363                )?;
4364            }
4365
4366            Ok(())
4367        },
4368    );
4369    response.send(proto::Ack {})?;
4370    Ok(())
4371}
4372
4373async fn update_channel_message(
4374    request: proto::UpdateChannelMessage,
4375    response: Response<proto::UpdateChannelMessage>,
4376    session: UserSession,
4377) -> Result<()> {
4378    let channel_id = ChannelId::from_proto(request.channel_id);
4379    let message_id = MessageId::from_proto(request.message_id);
4380    let updated_at = OffsetDateTime::now_utc();
4381    let UpdatedChannelMessage {
4382        message_id,
4383        participant_connection_ids,
4384        notifications,
4385        reply_to_message_id,
4386        timestamp,
4387        deleted_mention_notification_ids,
4388        updated_mention_notifications,
4389    } = session
4390        .db()
4391        .await
4392        .update_channel_message(
4393            channel_id,
4394            message_id,
4395            session.user_id(),
4396            request.body.as_str(),
4397            &request.mentions,
4398            updated_at,
4399        )
4400        .await?;
4401
4402    let nonce = request
4403        .nonce
4404        .clone()
4405        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4406
4407    let message = proto::ChannelMessage {
4408        sender_id: session.user_id().to_proto(),
4409        id: message_id.to_proto(),
4410        body: request.body.clone(),
4411        mentions: request.mentions.clone(),
4412        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4413        nonce: Some(nonce),
4414        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4415        edited_at: Some(updated_at.unix_timestamp() as u64),
4416    };
4417
4418    response.send(proto::Ack {})?;
4419
4420    let pool = &*session.connection_pool().await;
4421    broadcast(
4422        Some(session.connection_id),
4423        participant_connection_ids,
4424        |connection| {
4425            session.peer.send(
4426                connection,
4427                proto::ChannelMessageUpdate {
4428                    channel_id: channel_id.to_proto(),
4429                    message: Some(message.clone()),
4430                },
4431            )?;
4432
4433            for notification_id in &deleted_mention_notification_ids {
4434                session.peer.send(
4435                    connection,
4436                    proto::DeleteNotification {
4437                        notification_id: (*notification_id).to_proto(),
4438                    },
4439                )?;
4440            }
4441
4442            for notification in &updated_mention_notifications {
4443                session.peer.send(
4444                    connection,
4445                    proto::UpdateNotification {
4446                        notification: Some(notification.clone()),
4447                    },
4448                )?;
4449            }
4450
4451            Ok(())
4452        },
4453    );
4454
4455    send_notifications(pool, &session.peer, notifications);
4456
4457    Ok(())
4458}
4459
4460/// Mark a channel message as read
4461async fn acknowledge_channel_message(
4462    request: proto::AckChannelMessage,
4463    session: UserSession,
4464) -> Result<()> {
4465    let channel_id = ChannelId::from_proto(request.channel_id);
4466    let message_id = MessageId::from_proto(request.message_id);
4467    let notifications = session
4468        .db()
4469        .await
4470        .observe_channel_message(channel_id, session.user_id(), message_id)
4471        .await?;
4472    send_notifications(
4473        &*session.connection_pool().await,
4474        &session.peer,
4475        notifications,
4476    );
4477    Ok(())
4478}
4479
4480/// Mark a buffer version as synced
4481async fn acknowledge_buffer_version(
4482    request: proto::AckBufferOperation,
4483    session: UserSession,
4484) -> Result<()> {
4485    let buffer_id = BufferId::from_proto(request.buffer_id);
4486    session
4487        .db()
4488        .await
4489        .observe_buffer_version(
4490            buffer_id,
4491            session.user_id(),
4492            request.epoch as i32,
4493            &request.version,
4494        )
4495        .await?;
4496    Ok(())
4497}
4498
4499async fn count_language_model_tokens(
4500    request: proto::CountLanguageModelTokens,
4501    response: Response<proto::CountLanguageModelTokens>,
4502    session: Session,
4503    config: &Config,
4504) -> Result<()> {
4505    let Some(session) = session.for_user() else {
4506        return Err(anyhow!("user not found"))?;
4507    };
4508    authorize_access_to_legacy_llm_endpoints(&session).await?;
4509
4510    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4511        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4512        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4513    };
4514
4515    session
4516        .app_state
4517        .rate_limiter
4518        .check(&*rate_limit, session.user_id())
4519        .await?;
4520
4521    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4522        Some(proto::LanguageModelProvider::Google) => {
4523            let api_key = config
4524                .google_ai_api_key
4525                .as_ref()
4526                .context("no Google AI API key configured on the server")?;
4527            google_ai::count_tokens(
4528                session.http_client.as_ref(),
4529                google_ai::API_URL,
4530                api_key,
4531                serde_json::from_str(&request.request)?,
4532                None,
4533            )
4534            .await?
4535        }
4536        _ => return Err(anyhow!("unsupported provider"))?,
4537    };
4538
4539    response.send(proto::CountLanguageModelTokensResponse {
4540        token_count: result.total_tokens as u32,
4541    })?;
4542
4543    Ok(())
4544}
4545
4546struct ZedProCountLanguageModelTokensRateLimit;
4547
4548impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4549    fn capacity(&self) -> usize {
4550        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4551            .ok()
4552            .and_then(|v| v.parse().ok())
4553            .unwrap_or(600) // Picked arbitrarily
4554    }
4555
4556    fn refill_duration(&self) -> chrono::Duration {
4557        chrono::Duration::hours(1)
4558    }
4559
4560    fn db_name(&self) -> &'static str {
4561        "zed-pro:count-language-model-tokens"
4562    }
4563}
4564
4565struct FreeCountLanguageModelTokensRateLimit;
4566
4567impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4568    fn capacity(&self) -> usize {
4569        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4570            .ok()
4571            .and_then(|v| v.parse().ok())
4572            .unwrap_or(600 / 10) // Picked arbitrarily
4573    }
4574
4575    fn refill_duration(&self) -> chrono::Duration {
4576        chrono::Duration::hours(1)
4577    }
4578
4579    fn db_name(&self) -> &'static str {
4580        "free:count-language-model-tokens"
4581    }
4582}
4583
4584struct ZedProComputeEmbeddingsRateLimit;
4585
4586impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4587    fn capacity(&self) -> usize {
4588        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4589            .ok()
4590            .and_then(|v| v.parse().ok())
4591            .unwrap_or(5000) // Picked arbitrarily
4592    }
4593
4594    fn refill_duration(&self) -> chrono::Duration {
4595        chrono::Duration::hours(1)
4596    }
4597
4598    fn db_name(&self) -> &'static str {
4599        "zed-pro:compute-embeddings"
4600    }
4601}
4602
4603struct FreeComputeEmbeddingsRateLimit;
4604
4605impl RateLimit for FreeComputeEmbeddingsRateLimit {
4606    fn capacity(&self) -> usize {
4607        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4608            .ok()
4609            .and_then(|v| v.parse().ok())
4610            .unwrap_or(5000 / 10) // Picked arbitrarily
4611    }
4612
4613    fn refill_duration(&self) -> chrono::Duration {
4614        chrono::Duration::hours(1)
4615    }
4616
4617    fn db_name(&self) -> &'static str {
4618        "free:compute-embeddings"
4619    }
4620}
4621
4622async fn compute_embeddings(
4623    request: proto::ComputeEmbeddings,
4624    response: Response<proto::ComputeEmbeddings>,
4625    session: UserSession,
4626    api_key: Option<Arc<str>>,
4627) -> Result<()> {
4628    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4629    authorize_access_to_legacy_llm_endpoints(&session).await?;
4630
4631    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4632        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4633        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4634    };
4635
4636    session
4637        .app_state
4638        .rate_limiter
4639        .check(&*rate_limit, session.user_id())
4640        .await?;
4641
4642    let embeddings = match request.model.as_str() {
4643        "openai/text-embedding-3-small" => {
4644            open_ai::embed(
4645                session.http_client.as_ref(),
4646                OPEN_AI_API_URL,
4647                &api_key,
4648                OpenAiEmbeddingModel::TextEmbedding3Small,
4649                request.texts.iter().map(|text| text.as_str()),
4650            )
4651            .await?
4652        }
4653        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4654    };
4655
4656    let embeddings = request
4657        .texts
4658        .iter()
4659        .map(|text| {
4660            let mut hasher = sha2::Sha256::new();
4661            hasher.update(text.as_bytes());
4662            let result = hasher.finalize();
4663            result.to_vec()
4664        })
4665        .zip(
4666            embeddings
4667                .data
4668                .into_iter()
4669                .map(|embedding| embedding.embedding),
4670        )
4671        .collect::<HashMap<_, _>>();
4672
4673    let db = session.db().await;
4674    db.save_embeddings(&request.model, &embeddings)
4675        .await
4676        .context("failed to save embeddings")
4677        .trace_err();
4678
4679    response.send(proto::ComputeEmbeddingsResponse {
4680        embeddings: embeddings
4681            .into_iter()
4682            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4683            .collect(),
4684    })?;
4685    Ok(())
4686}
4687
4688async fn get_cached_embeddings(
4689    request: proto::GetCachedEmbeddings,
4690    response: Response<proto::GetCachedEmbeddings>,
4691    session: UserSession,
4692) -> Result<()> {
4693    authorize_access_to_legacy_llm_endpoints(&session).await?;
4694
4695    let db = session.db().await;
4696    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4697
4698    response.send(proto::GetCachedEmbeddingsResponse {
4699        embeddings: embeddings
4700            .into_iter()
4701            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4702            .collect(),
4703    })?;
4704    Ok(())
4705}
4706
4707/// This is leftover from before the LLM service.
4708///
4709/// The endpoints protected by this check will be moved there eventually.
4710async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4711    if session.is_staff() {
4712        Ok(())
4713    } else {
4714        Err(anyhow!("permission denied"))?
4715    }
4716}
4717
4718/// Get a Supermaven API key for the user
4719async fn get_supermaven_api_key(
4720    _request: proto::GetSupermavenApiKey,
4721    response: Response<proto::GetSupermavenApiKey>,
4722    session: UserSession,
4723) -> Result<()> {
4724    let user_id: String = session.user_id().to_string();
4725    if !session.is_staff() {
4726        return Err(anyhow!("supermaven not enabled for this account"))?;
4727    }
4728
4729    let email = session
4730        .email()
4731        .ok_or_else(|| anyhow!("user must have an email"))?;
4732
4733    let supermaven_admin_api = session
4734        .supermaven_client
4735        .as_ref()
4736        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4737
4738    let result = supermaven_admin_api
4739        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4740        .await?;
4741
4742    response.send(proto::GetSupermavenApiKeyResponse {
4743        api_key: result.api_key,
4744    })?;
4745
4746    Ok(())
4747}
4748
4749/// Start receiving chat updates for a channel
4750async fn join_channel_chat(
4751    request: proto::JoinChannelChat,
4752    response: Response<proto::JoinChannelChat>,
4753    session: UserSession,
4754) -> Result<()> {
4755    let channel_id = ChannelId::from_proto(request.channel_id);
4756
4757    let db = session.db().await;
4758    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4759        .await?;
4760    let messages = db
4761        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4762        .await?;
4763    response.send(proto::JoinChannelChatResponse {
4764        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4765        messages,
4766    })?;
4767    Ok(())
4768}
4769
4770/// Stop receiving chat updates for a channel
4771async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4772    let channel_id = ChannelId::from_proto(request.channel_id);
4773    session
4774        .db()
4775        .await
4776        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4777        .await?;
4778    Ok(())
4779}
4780
4781/// Retrieve the chat history for a channel
4782async fn get_channel_messages(
4783    request: proto::GetChannelMessages,
4784    response: Response<proto::GetChannelMessages>,
4785    session: UserSession,
4786) -> Result<()> {
4787    let channel_id = ChannelId::from_proto(request.channel_id);
4788    let messages = session
4789        .db()
4790        .await
4791        .get_channel_messages(
4792            channel_id,
4793            session.user_id(),
4794            MESSAGE_COUNT_PER_PAGE,
4795            Some(MessageId::from_proto(request.before_message_id)),
4796        )
4797        .await?;
4798    response.send(proto::GetChannelMessagesResponse {
4799        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4800        messages,
4801    })?;
4802    Ok(())
4803}
4804
4805/// Retrieve specific chat messages
4806async fn get_channel_messages_by_id(
4807    request: proto::GetChannelMessagesById,
4808    response: Response<proto::GetChannelMessagesById>,
4809    session: UserSession,
4810) -> Result<()> {
4811    let message_ids = request
4812        .message_ids
4813        .iter()
4814        .map(|id| MessageId::from_proto(*id))
4815        .collect::<Vec<_>>();
4816    let messages = session
4817        .db()
4818        .await
4819        .get_channel_messages_by_id(session.user_id(), &message_ids)
4820        .await?;
4821    response.send(proto::GetChannelMessagesResponse {
4822        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4823        messages,
4824    })?;
4825    Ok(())
4826}
4827
4828/// Retrieve the current users notifications
4829async fn get_notifications(
4830    request: proto::GetNotifications,
4831    response: Response<proto::GetNotifications>,
4832    session: UserSession,
4833) -> Result<()> {
4834    let notifications = session
4835        .db()
4836        .await
4837        .get_notifications(
4838            session.user_id(),
4839            NOTIFICATION_COUNT_PER_PAGE,
4840            request.before_id.map(db::NotificationId::from_proto),
4841        )
4842        .await?;
4843    response.send(proto::GetNotificationsResponse {
4844        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4845        notifications,
4846    })?;
4847    Ok(())
4848}
4849
4850/// Mark notifications as read
4851async fn mark_notification_as_read(
4852    request: proto::MarkNotificationRead,
4853    response: Response<proto::MarkNotificationRead>,
4854    session: UserSession,
4855) -> Result<()> {
4856    let database = &session.db().await;
4857    let notifications = database
4858        .mark_notification_as_read_by_id(
4859            session.user_id(),
4860            NotificationId::from_proto(request.notification_id),
4861        )
4862        .await?;
4863    send_notifications(
4864        &*session.connection_pool().await,
4865        &session.peer,
4866        notifications,
4867    );
4868    response.send(proto::Ack {})?;
4869    Ok(())
4870}
4871
4872/// Get the current users information
4873async fn get_private_user_info(
4874    _request: proto::GetPrivateUserInfo,
4875    response: Response<proto::GetPrivateUserInfo>,
4876    session: UserSession,
4877) -> Result<()> {
4878    let db = session.db().await;
4879
4880    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4881    let user = db
4882        .get_user_by_id(session.user_id())
4883        .await?
4884        .ok_or_else(|| anyhow!("user not found"))?;
4885    let flags = db.get_user_flags(session.user_id()).await?;
4886
4887    response.send(proto::GetPrivateUserInfoResponse {
4888        metrics_id,
4889        staff: user.admin,
4890        flags,
4891        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4892    })?;
4893    Ok(())
4894}
4895
4896/// Accept the terms of service (tos) on behalf of the current user
4897async fn accept_terms_of_service(
4898    _request: proto::AcceptTermsOfService,
4899    response: Response<proto::AcceptTermsOfService>,
4900    session: UserSession,
4901) -> Result<()> {
4902    let db = session.db().await;
4903
4904    let accepted_tos_at = Utc::now();
4905    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4906        .await?;
4907
4908    response.send(proto::AcceptTermsOfServiceResponse {
4909        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4910    })?;
4911    Ok(())
4912}
4913
4914/// The minimum account age an account must have in order to use the LLM service.
4915const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4916
4917async fn get_llm_api_token(
4918    _request: proto::GetLlmToken,
4919    response: Response<proto::GetLlmToken>,
4920    session: UserSession,
4921) -> Result<()> {
4922    let db = session.db().await;
4923
4924    let flags = db.get_user_flags(session.user_id()).await?;
4925    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4926    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4927
4928    if !session.is_staff() && !has_language_models_feature_flag {
4929        Err(anyhow!("permission denied"))?
4930    }
4931
4932    let user_id = session.user_id();
4933    let user = db
4934        .get_user_by_id(user_id)
4935        .await?
4936        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4937
4938    if user.accepted_tos_at.is_none() {
4939        Err(anyhow!("terms of service not accepted"))?
4940    }
4941
4942    let mut account_created_at = user.created_at;
4943    if let Some(github_created_at) = user.github_user_created_at {
4944        account_created_at = account_created_at.min(github_created_at);
4945    }
4946    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4947        Err(anyhow!("account too young"))?
4948    }
4949    let token = LlmTokenClaims::create(
4950        user.id,
4951        user.github_login.clone(),
4952        session.is_staff(),
4953        has_llm_closed_beta_feature_flag,
4954        session.current_plan(db).await?,
4955        &session.app_state.config,
4956    )?;
4957    response.send(proto::GetLlmTokenResponse { token })?;
4958    Ok(())
4959}
4960
4961fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4962    let message = match message {
4963        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4964        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4965        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4966        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4967        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4968            code: frame.code.into(),
4969            reason: frame.reason,
4970        })),
4971        // We should never receive a frame while reading the message, according
4972        // to the `tungstenite` maintainers:
4973        //
4974        // > It cannot occur when you read messages from the WebSocket, but it
4975        // > can be used when you want to send the raw frames (e.g. you want to
4976        // > send the frames to the WebSocket without composing the full message first).
4977        // >
4978        // > — https://github.com/snapview/tungstenite-rs/issues/268
4979        TungsteniteMessage::Frame(_) => {
4980            bail!("received an unexpected frame while reading the message")
4981        }
4982    };
4983
4984    Ok(message)
4985}
4986
4987fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4988    match message {
4989        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4990        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4991        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4992        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4993        AxumMessage::Close(frame) => {
4994            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4995                code: frame.code.into(),
4996                reason: frame.reason,
4997            }))
4998        }
4999    }
5000}
5001
5002fn notify_membership_updated(
5003    connection_pool: &mut ConnectionPool,
5004    result: MembershipUpdated,
5005    user_id: UserId,
5006    peer: &Peer,
5007) {
5008    for membership in &result.new_channels.channel_memberships {
5009        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5010    }
5011    for channel_id in &result.removed_channels {
5012        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5013    }
5014
5015    let user_channels_update = proto::UpdateUserChannels {
5016        channel_memberships: result
5017            .new_channels
5018            .channel_memberships
5019            .iter()
5020            .map(|cm| proto::ChannelMembership {
5021                channel_id: cm.channel_id.to_proto(),
5022                role: cm.role.into(),
5023            })
5024            .collect(),
5025        ..Default::default()
5026    };
5027
5028    let mut update = build_channels_update(result.new_channels);
5029    update.delete_channels = result
5030        .removed_channels
5031        .into_iter()
5032        .map(|id| id.to_proto())
5033        .collect();
5034    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5035
5036    for connection_id in connection_pool.user_connection_ids(user_id) {
5037        peer.send(connection_id, user_channels_update.clone())
5038            .trace_err();
5039        peer.send(connection_id, update.clone()).trace_err();
5040    }
5041}
5042
5043fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5044    proto::UpdateUserChannels {
5045        channel_memberships: channels
5046            .channel_memberships
5047            .iter()
5048            .map(|m| proto::ChannelMembership {
5049                channel_id: m.channel_id.to_proto(),
5050                role: m.role.into(),
5051            })
5052            .collect(),
5053        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5054        observed_channel_message_id: channels.observed_channel_messages.clone(),
5055    }
5056}
5057
5058fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5059    let mut update = proto::UpdateChannels::default();
5060
5061    for channel in channels.channels {
5062        update.channels.push(channel.to_proto());
5063    }
5064
5065    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5066    update.latest_channel_message_ids = channels.latest_channel_messages;
5067
5068    for (channel_id, participants) in channels.channel_participants {
5069        update
5070            .channel_participants
5071            .push(proto::ChannelParticipants {
5072                channel_id: channel_id.to_proto(),
5073                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5074            });
5075    }
5076
5077    for channel in channels.invited_channels {
5078        update.channel_invitations.push(channel.to_proto());
5079    }
5080
5081    update.hosted_projects = channels.hosted_projects;
5082    update
5083}
5084
5085fn build_initial_contacts_update(
5086    contacts: Vec<db::Contact>,
5087    pool: &ConnectionPool,
5088) -> proto::UpdateContacts {
5089    let mut update = proto::UpdateContacts::default();
5090
5091    for contact in contacts {
5092        match contact {
5093            db::Contact::Accepted { user_id, busy } => {
5094                update.contacts.push(contact_for_user(user_id, busy, pool));
5095            }
5096            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5097            db::Contact::Incoming { user_id } => {
5098                update
5099                    .incoming_requests
5100                    .push(proto::IncomingContactRequest {
5101                        requester_id: user_id.to_proto(),
5102                    })
5103            }
5104        }
5105    }
5106
5107    update
5108}
5109
5110fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5111    proto::Contact {
5112        user_id: user_id.to_proto(),
5113        online: pool.is_user_online(user_id),
5114        busy,
5115    }
5116}
5117
5118fn room_updated(room: &proto::Room, peer: &Peer) {
5119    broadcast(
5120        None,
5121        room.participants
5122            .iter()
5123            .filter_map(|participant| Some(participant.peer_id?.into())),
5124        |peer_id| {
5125            peer.send(
5126                peer_id,
5127                proto::RoomUpdated {
5128                    room: Some(room.clone()),
5129                },
5130            )
5131        },
5132    );
5133}
5134
5135fn channel_updated(
5136    channel: &db::channel::Model,
5137    room: &proto::Room,
5138    peer: &Peer,
5139    pool: &ConnectionPool,
5140) {
5141    let participants = room
5142        .participants
5143        .iter()
5144        .map(|p| p.user_id)
5145        .collect::<Vec<_>>();
5146
5147    broadcast(
5148        None,
5149        pool.channel_connection_ids(channel.root_id())
5150            .filter_map(|(channel_id, role)| {
5151                role.can_see_channel(channel.visibility)
5152                    .then_some(channel_id)
5153            }),
5154        |peer_id| {
5155            peer.send(
5156                peer_id,
5157                proto::UpdateChannels {
5158                    channel_participants: vec![proto::ChannelParticipants {
5159                        channel_id: channel.id.to_proto(),
5160                        participant_user_ids: participants.clone(),
5161                    }],
5162                    ..Default::default()
5163                },
5164            )
5165        },
5166    );
5167}
5168
5169async fn send_dev_server_projects_update(
5170    user_id: UserId,
5171    mut status: proto::DevServerProjectsUpdate,
5172    session: &Session,
5173) {
5174    let pool = session.connection_pool().await;
5175    for dev_server in &mut status.dev_servers {
5176        dev_server.status =
5177            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5178    }
5179    let connections = pool.user_connection_ids(user_id);
5180    for connection_id in connections {
5181        session.peer.send(connection_id, status.clone()).trace_err();
5182    }
5183}
5184
5185async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5186    let db = session.db().await;
5187
5188    let contacts = db.get_contacts(user_id).await?;
5189    let busy = db.is_user_busy(user_id).await?;
5190
5191    let pool = session.connection_pool().await;
5192    let updated_contact = contact_for_user(user_id, busy, &pool);
5193    for contact in contacts {
5194        if let db::Contact::Accepted {
5195            user_id: contact_user_id,
5196            ..
5197        } = contact
5198        {
5199            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5200                session
5201                    .peer
5202                    .send(
5203                        contact_conn_id,
5204                        proto::UpdateContacts {
5205                            contacts: vec![updated_contact.clone()],
5206                            remove_contacts: Default::default(),
5207                            incoming_requests: Default::default(),
5208                            remove_incoming_requests: Default::default(),
5209                            outgoing_requests: Default::default(),
5210                            remove_outgoing_requests: Default::default(),
5211                        },
5212                    )
5213                    .trace_err();
5214            }
5215        }
5216    }
5217    Ok(())
5218}
5219
5220async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5221    log::info!("lost dev server connection, unsharing projects");
5222    let project_ids = session
5223        .db()
5224        .await
5225        .get_stale_dev_server_projects(session.connection_id)
5226        .await?;
5227
5228    for project_id in project_ids {
5229        // not unshare re-checks the connection ids match, so we get away with no transaction
5230        unshare_project_internal(project_id, session.connection_id, None, session).await?;
5231    }
5232
5233    let user_id = session.dev_server().user_id;
5234    let update = session
5235        .db()
5236        .await
5237        .dev_server_projects_update(user_id)
5238        .await?;
5239
5240    send_dev_server_projects_update(user_id, update, session).await;
5241
5242    Ok(())
5243}
5244
5245async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5246    let mut contacts_to_update = HashSet::default();
5247
5248    let room_id;
5249    let canceled_calls_to_user_ids;
5250    let live_kit_room;
5251    let delete_live_kit_room;
5252    let room;
5253    let channel;
5254
5255    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5256        contacts_to_update.insert(session.user_id());
5257
5258        for project in left_room.left_projects.values() {
5259            project_left(project, session);
5260        }
5261
5262        room_id = RoomId::from_proto(left_room.room.id);
5263        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5264        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5265        delete_live_kit_room = left_room.deleted;
5266        room = mem::take(&mut left_room.room);
5267        channel = mem::take(&mut left_room.channel);
5268
5269        room_updated(&room, &session.peer);
5270    } else {
5271        return Ok(());
5272    }
5273
5274    if let Some(channel) = channel {
5275        channel_updated(
5276            &channel,
5277            &room,
5278            &session.peer,
5279            &*session.connection_pool().await,
5280        );
5281    }
5282
5283    {
5284        let pool = session.connection_pool().await;
5285        for canceled_user_id in canceled_calls_to_user_ids {
5286            for connection_id in pool.user_connection_ids(canceled_user_id) {
5287                session
5288                    .peer
5289                    .send(
5290                        connection_id,
5291                        proto::CallCanceled {
5292                            room_id: room_id.to_proto(),
5293                        },
5294                    )
5295                    .trace_err();
5296            }
5297            contacts_to_update.insert(canceled_user_id);
5298        }
5299    }
5300
5301    for contact_user_id in contacts_to_update {
5302        update_user_contacts(contact_user_id, session).await?;
5303    }
5304
5305    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5306        live_kit
5307            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5308            .await
5309            .trace_err();
5310
5311        if delete_live_kit_room {
5312            live_kit.delete_room(live_kit_room).await.trace_err();
5313        }
5314    }
5315
5316    Ok(())
5317}
5318
5319async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5320    let left_channel_buffers = session
5321        .db()
5322        .await
5323        .leave_channel_buffers(session.connection_id)
5324        .await?;
5325
5326    for left_buffer in left_channel_buffers {
5327        channel_buffer_updated(
5328            session.connection_id,
5329            left_buffer.connections,
5330            &proto::UpdateChannelBufferCollaborators {
5331                channel_id: left_buffer.channel_id.to_proto(),
5332                collaborators: left_buffer.collaborators,
5333            },
5334            &session.peer,
5335        );
5336    }
5337
5338    Ok(())
5339}
5340
5341fn project_left(project: &db::LeftProject, session: &UserSession) {
5342    for connection_id in &project.connection_ids {
5343        if project.should_unshare {
5344            session
5345                .peer
5346                .send(
5347                    *connection_id,
5348                    proto::UnshareProject {
5349                        project_id: project.id.to_proto(),
5350                    },
5351                )
5352                .trace_err();
5353        } else {
5354            session
5355                .peer
5356                .send(
5357                    *connection_id,
5358                    proto::RemoveProjectCollaborator {
5359                        project_id: project.id.to_proto(),
5360                        peer_id: Some(session.connection_id.into()),
5361                    },
5362                )
5363                .trace_err();
5364        }
5365    }
5366}
5367
5368pub trait ResultExt {
5369    type Ok;
5370
5371    fn trace_err(self) -> Option<Self::Ok>;
5372}
5373
5374impl<T, E> ResultExt for Result<T, E>
5375where
5376    E: std::fmt::Debug,
5377{
5378    type Ok = T;
5379
5380    #[track_caller]
5381    fn trace_err(self) -> Option<T> {
5382        match self {
5383            Ok(value) => Some(value),
5384            Err(error) => {
5385                tracing::error!("{:?}", error);
5386                None
5387            }
5388        }
5389    }
5390}