rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use sha2::Digest;
  40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  41
  42use futures::{
  43    channel::oneshot,
  44    future::{self, BoxFuture},
  45    stream::FuturesUnordered,
  46    FutureExt, SinkExt, StreamExt, TryStreamExt,
  47};
  48use http_client::IsahcHttpClient;
  49use prometheus::{register_int_gauge, IntGauge};
  50use rpc::{
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  56};
  57use semantic_version::SemanticVersion;
  58use serde::{Serialize, Serializer};
  59use std::{
  60    any::TypeId,
  61    future::Future,
  62    marker::PhantomData,
  63    mem,
  64    net::SocketAddr,
  65    ops::{Deref, DerefMut},
  66    rc::Rc,
  67    sync::{
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69        Arc, OnceLock,
  70    },
  71    time::{Duration, Instant},
  72};
  73use time::OffsetDateTime;
  74use tokio::sync::{watch, MutexGuard, Semaphore};
  75use tower::ServiceBuilder;
  76use tracing::{
  77    field::{self},
  78    info_span, instrument, Instrument,
  79};
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107#[derive(Clone, Debug)]
 108pub enum Principal {
 109    User(User),
 110    Impersonated { user: User, admin: User },
 111    DevServer(dev_server::Model),
 112}
 113
 114impl Principal {
 115    fn update_span(&self, span: &tracing::Span) {
 116        match &self {
 117            Principal::User(user) => {
 118                span.record("user_id", &user.id.0);
 119                span.record("login", &user.github_login);
 120            }
 121            Principal::Impersonated { user, admin } => {
 122                span.record("user_id", &user.id.0);
 123                span.record("login", &user.github_login);
 124                span.record("impersonator", &admin.github_login);
 125            }
 126            Principal::DevServer(dev_server) => {
 127                span.record("dev_server_id", &dev_server.id.0);
 128            }
 129        }
 130    }
 131}
 132
 133#[derive(Clone)]
 134struct Session {
 135    principal: Principal,
 136    connection_id: ConnectionId,
 137    db: Arc<tokio::sync::Mutex<DbHandle>>,
 138    peer: Arc<Peer>,
 139    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 140    app_state: Arc<AppState>,
 141    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 142    http_client: Arc<IsahcHttpClient>,
 143    /// The GeoIP country code for the user.
 144    #[allow(unused)]
 145    geoip_country_code: Option<String>,
 146    _executor: Executor,
 147}
 148
 149impl Session {
 150    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 151        #[cfg(test)]
 152        tokio::task::yield_now().await;
 153        let guard = self.db.lock().await;
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        guard
 157    }
 158
 159    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.connection_pool.lock();
 163        ConnectionPoolGuard {
 164            guard,
 165            _not_send: PhantomData,
 166        }
 167    }
 168
 169    fn for_user(self) -> Option<UserSession> {
 170        UserSession::new(self)
 171    }
 172
 173    fn for_dev_server(self) -> Option<DevServerSession> {
 174        DevServerSession::new(self)
 175    }
 176
 177    fn user_id(&self) -> Option<UserId> {
 178        match &self.principal {
 179            Principal::User(user) => Some(user.id),
 180            Principal::Impersonated { user, .. } => Some(user.id),
 181            Principal::DevServer(_) => None,
 182        }
 183    }
 184
 185    fn is_staff(&self) -> bool {
 186        match &self.principal {
 187            Principal::User(user) => user.admin,
 188            Principal::Impersonated { .. } => true,
 189            Principal::DevServer(_) => false,
 190        }
 191    }
 192
 193    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 194        if self.is_staff() {
 195            return Ok(proto::Plan::ZedPro);
 196        }
 197
 198        let Some(user_id) = self.user_id() else {
 199            return Ok(proto::Plan::Free);
 200        };
 201
 202        if db.has_active_billing_subscription(user_id).await? {
 203            Ok(proto::Plan::ZedPro)
 204        } else {
 205            Ok(proto::Plan::Free)
 206        }
 207    }
 208
 209    fn dev_server_id(&self) -> Option<DevServerId> {
 210        match &self.principal {
 211            Principal::User(_) | Principal::Impersonated { .. } => None,
 212            Principal::DevServer(dev_server) => Some(dev_server.id),
 213        }
 214    }
 215
 216    fn principal_id(&self) -> PrincipalId {
 217        match &self.principal {
 218            Principal::User(user) => PrincipalId::UserId(user.id),
 219            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 220            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 221        }
 222    }
 223}
 224
 225impl Debug for Session {
 226    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 227        let mut result = f.debug_struct("Session");
 228        match &self.principal {
 229            Principal::User(user) => {
 230                result.field("user", &user.github_login);
 231            }
 232            Principal::Impersonated { user, admin } => {
 233                result.field("user", &user.github_login);
 234                result.field("impersonator", &admin.github_login);
 235            }
 236            Principal::DevServer(dev_server) => {
 237                result.field("dev_server", &dev_server.id);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct UserSession(Session);
 245
 246impl UserSession {
 247    pub fn new(s: Session) -> Option<Self> {
 248        s.user_id().map(|_| UserSession(s))
 249    }
 250    pub fn user_id(&self) -> UserId {
 251        self.0.user_id().unwrap()
 252    }
 253
 254    pub fn email(&self) -> Option<String> {
 255        match &self.0.principal {
 256            Principal::User(user) => user.email_address.clone(),
 257            Principal::Impersonated { user, .. } => user.email_address.clone(),
 258            Principal::DevServer(..) => None,
 259        }
 260    }
 261}
 262
 263impl Deref for UserSession {
 264    type Target = Session;
 265
 266    fn deref(&self) -> &Self::Target {
 267        &self.0
 268    }
 269}
 270impl DerefMut for UserSession {
 271    fn deref_mut(&mut self) -> &mut Self::Target {
 272        &mut self.0
 273    }
 274}
 275
 276struct DevServerSession(Session);
 277
 278impl DevServerSession {
 279    pub fn new(s: Session) -> Option<Self> {
 280        s.dev_server_id().map(|_| DevServerSession(s))
 281    }
 282    pub fn dev_server_id(&self) -> DevServerId {
 283        self.0.dev_server_id().unwrap()
 284    }
 285
 286    fn dev_server(&self) -> &dev_server::Model {
 287        match &self.0.principal {
 288            Principal::DevServer(dev_server) => dev_server,
 289            _ => unreachable!(),
 290        }
 291    }
 292}
 293
 294impl Deref for DevServerSession {
 295    type Target = Session;
 296
 297    fn deref(&self) -> &Self::Target {
 298        &self.0
 299    }
 300}
 301impl DerefMut for DevServerSession {
 302    fn deref_mut(&mut self) -> &mut Self::Target {
 303        &mut self.0
 304    }
 305}
 306
 307fn user_handler<M: RequestMessage, Fut>(
 308    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 309) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 310where
 311    Fut: Send + Future<Output = Result<()>>,
 312{
 313    let handler = Arc::new(handler);
 314    move |message, response, session| {
 315        let handler = handler.clone();
 316        Box::pin(async move {
 317            if let Some(user_session) = session.for_user() {
 318                Ok(handler(message, response, user_session).await?)
 319            } else {
 320                Err(Error::Internal(anyhow!(
 321                    "must be a user to call {}",
 322                    M::NAME
 323                )))
 324            }
 325        })
 326    }
 327}
 328
 329fn dev_server_handler<M: RequestMessage, Fut>(
 330    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 331) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 332where
 333    Fut: Send + Future<Output = Result<()>>,
 334{
 335    let handler = Arc::new(handler);
 336    move |message, response, session| {
 337        let handler = handler.clone();
 338        Box::pin(async move {
 339            if let Some(dev_server_session) = session.for_dev_server() {
 340                Ok(handler(message, response, dev_server_session).await?)
 341            } else {
 342                Err(Error::Internal(anyhow!(
 343                    "must be a dev server to call {}",
 344                    M::NAME
 345                )))
 346            }
 347        })
 348    }
 349}
 350
 351fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 352    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 353) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 354where
 355    InnertRetFut: Send + Future<Output = Result<()>>,
 356{
 357    let handler = Arc::new(handler);
 358    move |message, session| {
 359        let handler = handler.clone();
 360        Box::pin(async move {
 361            if let Some(user_session) = session.for_user() {
 362                Ok(handler(message, user_session).await?)
 363            } else {
 364                Err(Error::Internal(anyhow!(
 365                    "must be a user to call {}",
 366                    M::NAME
 367                )))
 368            }
 369        })
 370    }
 371}
 372
 373struct DbHandle(Arc<Database>);
 374
 375impl Deref for DbHandle {
 376    type Target = Database;
 377
 378    fn deref(&self) -> &Self::Target {
 379        self.0.as_ref()
 380    }
 381}
 382
 383pub struct Server {
 384    id: parking_lot::Mutex<ServerId>,
 385    peer: Arc<Peer>,
 386    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 387    app_state: Arc<AppState>,
 388    handlers: HashMap<TypeId, MessageHandler>,
 389    teardown: watch::Sender<bool>,
 390}
 391
 392pub(crate) struct ConnectionPoolGuard<'a> {
 393    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 394    _not_send: PhantomData<Rc<()>>,
 395}
 396
 397#[derive(Serialize)]
 398pub struct ServerSnapshot<'a> {
 399    peer: &'a Peer,
 400    #[serde(serialize_with = "serialize_deref")]
 401    connection_pool: ConnectionPoolGuard<'a>,
 402}
 403
 404pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 405where
 406    S: Serializer,
 407    T: Deref<Target = U>,
 408    U: Serialize,
 409{
 410    Serialize::serialize(value.deref(), serializer)
 411}
 412
 413impl Server {
 414    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 415        let mut server = Self {
 416            id: parking_lot::Mutex::new(id),
 417            peer: Peer::new(id.0 as u32),
 418            app_state: app_state.clone(),
 419            connection_pool: Default::default(),
 420            handlers: Default::default(),
 421            teardown: watch::channel(false).0,
 422        };
 423
 424        server
 425            .add_request_handler(ping)
 426            .add_request_handler(user_handler(create_room))
 427            .add_request_handler(user_handler(join_room))
 428            .add_request_handler(user_handler(rejoin_room))
 429            .add_request_handler(user_handler(leave_room))
 430            .add_request_handler(user_handler(set_room_participant_role))
 431            .add_request_handler(user_handler(call))
 432            .add_request_handler(user_handler(cancel_call))
 433            .add_message_handler(user_message_handler(decline_call))
 434            .add_request_handler(user_handler(update_participant_location))
 435            .add_request_handler(user_handler(share_project))
 436            .add_message_handler(unshare_project)
 437            .add_request_handler(user_handler(join_project))
 438            .add_request_handler(user_handler(join_hosted_project))
 439            .add_request_handler(user_handler(rejoin_dev_server_projects))
 440            .add_request_handler(user_handler(create_dev_server_project))
 441            .add_request_handler(user_handler(update_dev_server_project))
 442            .add_request_handler(user_handler(delete_dev_server_project))
 443            .add_request_handler(user_handler(create_dev_server))
 444            .add_request_handler(user_handler(regenerate_dev_server_token))
 445            .add_request_handler(user_handler(rename_dev_server))
 446            .add_request_handler(user_handler(delete_dev_server))
 447            .add_request_handler(user_handler(list_remote_directory))
 448            .add_request_handler(dev_server_handler(share_dev_server_project))
 449            .add_request_handler(dev_server_handler(shutdown_dev_server))
 450            .add_request_handler(dev_server_handler(reconnect_dev_server))
 451            .add_message_handler(user_message_handler(leave_project))
 452            .add_request_handler(update_project)
 453            .add_request_handler(update_worktree)
 454            .add_message_handler(start_language_server)
 455            .add_message_handler(update_language_server)
 456            .add_message_handler(update_diagnostic_summary)
 457            .add_message_handler(update_worktree_settings)
 458            .add_request_handler(user_handler(
 459                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_project_request_for_owner::<proto::TaskTemplates>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetHover>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::GetDefinition>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetTypeDefinition>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetReferences>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::SearchProject>,
 478            ))
 479            .add_request_handler(user_handler(forward_find_search_candidates_request))
 480            .add_request_handler(user_handler(
 481                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 482            ))
 483            .add_request_handler(user_handler(
 484                forward_read_only_project_request::<proto::GetProjectSymbols>,
 485            ))
 486            .add_request_handler(user_handler(
 487                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 488            ))
 489            .add_request_handler(user_handler(
 490                forward_read_only_project_request::<proto::OpenBufferById>,
 491            ))
 492            .add_request_handler(user_handler(
 493                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 494            ))
 495            .add_request_handler(user_handler(
 496                forward_read_only_project_request::<proto::InlayHints>,
 497            ))
 498            .add_request_handler(user_handler(
 499                forward_read_only_project_request::<proto::ResolveInlayHint>,
 500            ))
 501            .add_request_handler(user_handler(
 502                forward_read_only_project_request::<proto::OpenBufferByPath>,
 503            ))
 504            .add_request_handler(user_handler(
 505                forward_mutating_project_request::<proto::GetCompletions>,
 506            ))
 507            .add_request_handler(user_handler(
 508                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 509            ))
 510            .add_request_handler(user_handler(
 511                forward_mutating_project_request::<proto::OpenNewBuffer>,
 512            ))
 513            .add_request_handler(user_handler(
 514                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 515            ))
 516            .add_request_handler(user_handler(
 517                forward_mutating_project_request::<proto::GetCodeActions>,
 518            ))
 519            .add_request_handler(user_handler(
 520                forward_mutating_project_request::<proto::ApplyCodeAction>,
 521            ))
 522            .add_request_handler(user_handler(
 523                forward_mutating_project_request::<proto::PrepareRename>,
 524            ))
 525            .add_request_handler(user_handler(
 526                forward_mutating_project_request::<proto::PerformRename>,
 527            ))
 528            .add_request_handler(user_handler(
 529                forward_mutating_project_request::<proto::ReloadBuffers>,
 530            ))
 531            .add_request_handler(user_handler(
 532                forward_mutating_project_request::<proto::FormatBuffers>,
 533            ))
 534            .add_request_handler(user_handler(
 535                forward_mutating_project_request::<proto::CreateProjectEntry>,
 536            ))
 537            .add_request_handler(user_handler(
 538                forward_mutating_project_request::<proto::RenameProjectEntry>,
 539            ))
 540            .add_request_handler(user_handler(
 541                forward_mutating_project_request::<proto::CopyProjectEntry>,
 542            ))
 543            .add_request_handler(user_handler(
 544                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 545            ))
 546            .add_request_handler(user_handler(
 547                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 548            ))
 549            .add_request_handler(user_handler(
 550                forward_mutating_project_request::<proto::OnTypeFormatting>,
 551            ))
 552            .add_request_handler(user_handler(
 553                forward_mutating_project_request::<proto::SaveBuffer>,
 554            ))
 555            .add_request_handler(user_handler(
 556                forward_mutating_project_request::<proto::BlameBuffer>,
 557            ))
 558            .add_request_handler(user_handler(
 559                forward_mutating_project_request::<proto::MultiLspQuery>,
 560            ))
 561            .add_request_handler(user_handler(
 562                forward_mutating_project_request::<proto::RestartLanguageServers>,
 563            ))
 564            .add_request_handler(user_handler(
 565                forward_mutating_project_request::<proto::LinkedEditingRange>,
 566            ))
 567            .add_message_handler(create_buffer_for_peer)
 568            .add_request_handler(update_buffer)
 569            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 571            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 572            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 573            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 574            .add_request_handler(get_users)
 575            .add_request_handler(user_handler(fuzzy_search_users))
 576            .add_request_handler(user_handler(request_contact))
 577            .add_request_handler(user_handler(remove_contact))
 578            .add_request_handler(user_handler(respond_to_contact_request))
 579            .add_message_handler(subscribe_to_channels)
 580            .add_request_handler(user_handler(create_channel))
 581            .add_request_handler(user_handler(delete_channel))
 582            .add_request_handler(user_handler(invite_channel_member))
 583            .add_request_handler(user_handler(remove_channel_member))
 584            .add_request_handler(user_handler(set_channel_member_role))
 585            .add_request_handler(user_handler(set_channel_visibility))
 586            .add_request_handler(user_handler(rename_channel))
 587            .add_request_handler(user_handler(join_channel_buffer))
 588            .add_request_handler(user_handler(leave_channel_buffer))
 589            .add_message_handler(user_message_handler(update_channel_buffer))
 590            .add_request_handler(user_handler(rejoin_channel_buffers))
 591            .add_request_handler(user_handler(get_channel_members))
 592            .add_request_handler(user_handler(respond_to_channel_invite))
 593            .add_request_handler(user_handler(join_channel))
 594            .add_request_handler(user_handler(join_channel_chat))
 595            .add_message_handler(user_message_handler(leave_channel_chat))
 596            .add_request_handler(user_handler(send_channel_message))
 597            .add_request_handler(user_handler(remove_channel_message))
 598            .add_request_handler(user_handler(update_channel_message))
 599            .add_request_handler(user_handler(get_channel_messages))
 600            .add_request_handler(user_handler(get_channel_messages_by_id))
 601            .add_request_handler(user_handler(get_notifications))
 602            .add_request_handler(user_handler(mark_notification_as_read))
 603            .add_request_handler(user_handler(move_channel))
 604            .add_request_handler(user_handler(follow))
 605            .add_message_handler(user_message_handler(unfollow))
 606            .add_message_handler(user_message_handler(update_followers))
 607            .add_request_handler(user_handler(get_private_user_info))
 608            .add_request_handler(user_handler(get_llm_api_token))
 609            .add_request_handler(user_handler(accept_terms_of_service))
 610            .add_message_handler(user_message_handler(acknowledge_channel_message))
 611            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 612            .add_request_handler(user_handler(get_supermaven_api_key))
 613            .add_request_handler(user_handler(
 614                forward_mutating_project_request::<proto::OpenContext>,
 615            ))
 616            .add_request_handler(user_handler(
 617                forward_mutating_project_request::<proto::CreateContext>,
 618            ))
 619            .add_request_handler(user_handler(
 620                forward_mutating_project_request::<proto::SynchronizeContexts>,
 621            ))
 622            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 623            .add_message_handler(update_context)
 624            .add_request_handler({
 625                let app_state = app_state.clone();
 626                move |request, response, session| {
 627                    let app_state = app_state.clone();
 628                    async move {
 629                        count_language_model_tokens(request, response, session, &app_state.config)
 630                            .await
 631                    }
 632                }
 633            })
 634            .add_request_handler({
 635                user_handler(move |request, response, session| {
 636                    get_cached_embeddings(request, response, session)
 637                })
 638            })
 639            .add_request_handler({
 640                let app_state = app_state.clone();
 641                user_handler(move |request, response, session| {
 642                    compute_embeddings(
 643                        request,
 644                        response,
 645                        session,
 646                        app_state.config.openai_api_key.clone(),
 647                    )
 648                })
 649            });
 650
 651        Arc::new(server)
 652    }
 653
 654    pub async fn start(&self) -> Result<()> {
 655        let server_id = *self.id.lock();
 656        let app_state = self.app_state.clone();
 657        let peer = self.peer.clone();
 658        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 659        let pool = self.connection_pool.clone();
 660        let live_kit_client = self.app_state.live_kit_client.clone();
 661
 662        let span = info_span!("start server");
 663        self.app_state.executor.spawn_detached(
 664            async move {
 665                tracing::info!("waiting for cleanup timeout");
 666                timeout.await;
 667                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 668                if let Some((room_ids, channel_ids)) = app_state
 669                    .db
 670                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 671                    .await
 672                    .trace_err()
 673                {
 674                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 675                    tracing::info!(
 676                        stale_channel_buffer_count = channel_ids.len(),
 677                        "retrieved stale channel buffers"
 678                    );
 679
 680                    for channel_id in channel_ids {
 681                        if let Some(refreshed_channel_buffer) = app_state
 682                            .db
 683                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 684                            .await
 685                            .trace_err()
 686                        {
 687                            for connection_id in refreshed_channel_buffer.connection_ids {
 688                                peer.send(
 689                                    connection_id,
 690                                    proto::UpdateChannelBufferCollaborators {
 691                                        channel_id: channel_id.to_proto(),
 692                                        collaborators: refreshed_channel_buffer
 693                                            .collaborators
 694                                            .clone(),
 695                                    },
 696                                )
 697                                .trace_err();
 698                            }
 699                        }
 700                    }
 701
 702                    for room_id in room_ids {
 703                        let mut contacts_to_update = HashSet::default();
 704                        let mut canceled_calls_to_user_ids = Vec::new();
 705                        let mut live_kit_room = String::new();
 706                        let mut delete_live_kit_room = false;
 707
 708                        if let Some(mut refreshed_room) = app_state
 709                            .db
 710                            .clear_stale_room_participants(room_id, server_id)
 711                            .await
 712                            .trace_err()
 713                        {
 714                            tracing::info!(
 715                                room_id = room_id.0,
 716                                new_participant_count = refreshed_room.room.participants.len(),
 717                                "refreshed room"
 718                            );
 719                            room_updated(&refreshed_room.room, &peer);
 720                            if let Some(channel) = refreshed_room.channel.as_ref() {
 721                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 722                            }
 723                            contacts_to_update
 724                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 725                            contacts_to_update
 726                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 727                            canceled_calls_to_user_ids =
 728                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 729                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 730                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 731                        }
 732
 733                        {
 734                            let pool = pool.lock();
 735                            for canceled_user_id in canceled_calls_to_user_ids {
 736                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 737                                    peer.send(
 738                                        connection_id,
 739                                        proto::CallCanceled {
 740                                            room_id: room_id.to_proto(),
 741                                        },
 742                                    )
 743                                    .trace_err();
 744                                }
 745                            }
 746                        }
 747
 748                        for user_id in contacts_to_update {
 749                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 750                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 751                            if let Some((busy, contacts)) = busy.zip(contacts) {
 752                                let pool = pool.lock();
 753                                let updated_contact = contact_for_user(user_id, busy, &pool);
 754                                for contact in contacts {
 755                                    if let db::Contact::Accepted {
 756                                        user_id: contact_user_id,
 757                                        ..
 758                                    } = contact
 759                                    {
 760                                        for contact_conn_id in
 761                                            pool.user_connection_ids(contact_user_id)
 762                                        {
 763                                            peer.send(
 764                                                contact_conn_id,
 765                                                proto::UpdateContacts {
 766                                                    contacts: vec![updated_contact.clone()],
 767                                                    remove_contacts: Default::default(),
 768                                                    incoming_requests: Default::default(),
 769                                                    remove_incoming_requests: Default::default(),
 770                                                    outgoing_requests: Default::default(),
 771                                                    remove_outgoing_requests: Default::default(),
 772                                                },
 773                                            )
 774                                            .trace_err();
 775                                        }
 776                                    }
 777                                }
 778                            }
 779                        }
 780
 781                        if let Some(live_kit) = live_kit_client.as_ref() {
 782                            if delete_live_kit_room {
 783                                live_kit.delete_room(live_kit_room).await.trace_err();
 784                            }
 785                        }
 786                    }
 787                }
 788
 789                app_state
 790                    .db
 791                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 792                    .await
 793                    .trace_err();
 794            }
 795            .instrument(span),
 796        );
 797        Ok(())
 798    }
 799
 800    pub fn teardown(&self) {
 801        self.peer.teardown();
 802        self.connection_pool.lock().reset();
 803        let _ = self.teardown.send(true);
 804    }
 805
 806    #[cfg(test)]
 807    pub fn reset(&self, id: ServerId) {
 808        self.teardown();
 809        *self.id.lock() = id;
 810        self.peer.reset(id.0 as u32);
 811        let _ = self.teardown.send(false);
 812    }
 813
 814    #[cfg(test)]
 815    pub fn id(&self) -> ServerId {
 816        *self.id.lock()
 817    }
 818
 819    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 820    where
 821        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 822        Fut: 'static + Send + Future<Output = Result<()>>,
 823        M: EnvelopedMessage,
 824    {
 825        let prev_handler = self.handlers.insert(
 826            TypeId::of::<M>(),
 827            Box::new(move |envelope, session| {
 828                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 829                let received_at = envelope.received_at;
 830                tracing::info!("message received");
 831                let start_time = Instant::now();
 832                let future = (handler)(*envelope, session);
 833                async move {
 834                    let result = future.await;
 835                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 836                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 837                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 838                    let payload_type = M::NAME;
 839
 840                    match result {
 841                        Err(error) => {
 842                            tracing::error!(
 843                                ?error,
 844                                total_duration_ms,
 845                                processing_duration_ms,
 846                                queue_duration_ms,
 847                                payload_type,
 848                                "error handling message"
 849                            )
 850                        }
 851                        Ok(()) => tracing::info!(
 852                            total_duration_ms,
 853                            processing_duration_ms,
 854                            queue_duration_ms,
 855                            "finished handling message"
 856                        ),
 857                    }
 858                }
 859                .boxed()
 860            }),
 861        );
 862        if prev_handler.is_some() {
 863            panic!("registered a handler for the same message twice");
 864        }
 865        self
 866    }
 867
 868    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 869    where
 870        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 871        Fut: 'static + Send + Future<Output = Result<()>>,
 872        M: EnvelopedMessage,
 873    {
 874        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 875        self
 876    }
 877
 878    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 879    where
 880        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 881        Fut: Send + Future<Output = Result<()>>,
 882        M: RequestMessage,
 883    {
 884        let handler = Arc::new(handler);
 885        self.add_handler(move |envelope, session| {
 886            let receipt = envelope.receipt();
 887            let handler = handler.clone();
 888            async move {
 889                let peer = session.peer.clone();
 890                let responded = Arc::new(AtomicBool::default());
 891                let response = Response {
 892                    peer: peer.clone(),
 893                    responded: responded.clone(),
 894                    receipt,
 895                };
 896                match (handler)(envelope.payload, response, session).await {
 897                    Ok(()) => {
 898                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 899                            Ok(())
 900                        } else {
 901                            Err(anyhow!("handler did not send a response"))?
 902                        }
 903                    }
 904                    Err(error) => {
 905                        let proto_err = match &error {
 906                            Error::Internal(err) => err.to_proto(),
 907                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 908                        };
 909                        peer.respond_with_error(receipt, proto_err)?;
 910                        Err(error)
 911                    }
 912                }
 913            }
 914        })
 915    }
 916
 917    #[allow(clippy::too_many_arguments)]
 918    pub fn handle_connection(
 919        self: &Arc<Self>,
 920        connection: Connection,
 921        address: String,
 922        principal: Principal,
 923        zed_version: ZedVersion,
 924        geoip_country_code: Option<String>,
 925        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 926        executor: Executor,
 927    ) -> impl Future<Output = ()> {
 928        let this = self.clone();
 929        let span = info_span!("handle connection", %address,
 930            connection_id=field::Empty,
 931            user_id=field::Empty,
 932            login=field::Empty,
 933            impersonator=field::Empty,
 934            dev_server_id=field::Empty,
 935            geoip_country_code=field::Empty
 936        );
 937        principal.update_span(&span);
 938        if let Some(country_code) = geoip_country_code.as_ref() {
 939            span.record("geoip_country_code", country_code);
 940        }
 941
 942        let mut teardown = self.teardown.subscribe();
 943        async move {
 944            if *teardown.borrow() {
 945                tracing::error!("server is tearing down");
 946                return
 947            }
 948            let (connection_id, handle_io, mut incoming_rx) = this
 949                .peer
 950                .add_connection(connection, {
 951                    let executor = executor.clone();
 952                    move |duration| executor.sleep(duration)
 953                });
 954            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 955
 956            tracing::info!("connection opened");
 957
 958            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 959            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 960                Ok(http_client) => Arc::new(http_client),
 961                Err(error) => {
 962                    tracing::error!(?error, "failed to create HTTP client");
 963                    return;
 964                }
 965            };
 966
 967            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 968                Some(Arc::new(SupermavenAdminApi::new(
 969                    supermaven_admin_api_key.to_string(),
 970                    http_client.clone(),
 971                )))
 972            } else {
 973                None
 974            };
 975
 976            let session = Session {
 977                principal: principal.clone(),
 978                connection_id,
 979                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 980                peer: this.peer.clone(),
 981                connection_pool: this.connection_pool.clone(),
 982                app_state: this.app_state.clone(),
 983                http_client,
 984                geoip_country_code,
 985                _executor: executor.clone(),
 986                supermaven_client,
 987            };
 988
 989            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 990                tracing::error!(?error, "failed to send initial client update");
 991                return;
 992            }
 993
 994            let handle_io = handle_io.fuse();
 995            futures::pin_mut!(handle_io);
 996
 997            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 998            // This prevents deadlocks when e.g., client A performs a request to client B and
 999            // client B performs a request to client A. If both clients stop processing further
1000            // messages until their respective request completes, they won't have a chance to
1001            // respond to the other client's request and cause a deadlock.
1002            //
1003            // This arrangement ensures we will attempt to process earlier messages first, but fall
1004            // back to processing messages arrived later in the spirit of making progress.
1005            let mut foreground_message_handlers = FuturesUnordered::new();
1006            let concurrent_handlers = Arc::new(Semaphore::new(256));
1007            loop {
1008                let next_message = async {
1009                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1010                    let message = incoming_rx.next().await;
1011                    (permit, message)
1012                }.fuse();
1013                futures::pin_mut!(next_message);
1014                futures::select_biased! {
1015                    _ = teardown.changed().fuse() => return,
1016                    result = handle_io => {
1017                        if let Err(error) = result {
1018                            tracing::error!(?error, "error handling I/O");
1019                        }
1020                        break;
1021                    }
1022                    _ = foreground_message_handlers.next() => {}
1023                    next_message = next_message => {
1024                        let (permit, message) = next_message;
1025                        if let Some(message) = message {
1026                            let type_name = message.payload_type_name();
1027                            // note: we copy all the fields from the parent span so we can query them in the logs.
1028                            // (https://github.com/tokio-rs/tracing/issues/2670).
1029                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1030                                user_id=field::Empty,
1031                                login=field::Empty,
1032                                impersonator=field::Empty,
1033                                dev_server_id=field::Empty
1034                            );
1035                            principal.update_span(&span);
1036                            let span_enter = span.enter();
1037                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1038                                let is_background = message.is_background();
1039                                let handle_message = (handler)(message, session.clone());
1040                                drop(span_enter);
1041
1042                                let handle_message = async move {
1043                                    handle_message.await;
1044                                    drop(permit);
1045                                }.instrument(span);
1046                                if is_background {
1047                                    executor.spawn_detached(handle_message);
1048                                } else {
1049                                    foreground_message_handlers.push(handle_message);
1050                                }
1051                            } else {
1052                                tracing::error!("no message handler");
1053                            }
1054                        } else {
1055                            tracing::info!("connection closed");
1056                            break;
1057                        }
1058                    }
1059                }
1060            }
1061
1062            drop(foreground_message_handlers);
1063            tracing::info!("signing out");
1064            if let Err(error) = connection_lost(session, teardown, executor).await {
1065                tracing::error!(?error, "error signing out");
1066            }
1067
1068        }.instrument(span)
1069    }
1070
1071    async fn send_initial_client_update(
1072        &self,
1073        connection_id: ConnectionId,
1074        principal: &Principal,
1075        zed_version: ZedVersion,
1076        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1077        session: &Session,
1078    ) -> Result<()> {
1079        self.peer.send(
1080            connection_id,
1081            proto::Hello {
1082                peer_id: Some(connection_id.into()),
1083            },
1084        )?;
1085        tracing::info!("sent hello message");
1086        if let Some(send_connection_id) = send_connection_id.take() {
1087            let _ = send_connection_id.send(connection_id);
1088        }
1089
1090        match principal {
1091            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1092                if !user.connected_once {
1093                    self.peer.send(connection_id, proto::ShowContacts {})?;
1094                    self.app_state
1095                        .db
1096                        .set_user_connected_once(user.id, true)
1097                        .await?;
1098                }
1099
1100                update_user_plan(user.id, session).await?;
1101
1102                let (contacts, dev_server_projects) = future::try_join(
1103                    self.app_state.db.get_contacts(user.id),
1104                    self.app_state.db.dev_server_projects_update(user.id),
1105                )
1106                .await?;
1107
1108                {
1109                    let mut pool = self.connection_pool.lock();
1110                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1111                    self.peer.send(
1112                        connection_id,
1113                        build_initial_contacts_update(contacts, &pool),
1114                    )?;
1115                }
1116
1117                if should_auto_subscribe_to_channels(zed_version) {
1118                    subscribe_user_to_channels(user.id, session).await?;
1119                }
1120
1121                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1122
1123                if let Some(incoming_call) =
1124                    self.app_state.db.incoming_call_for_user(user.id).await?
1125                {
1126                    self.peer.send(connection_id, incoming_call)?;
1127                }
1128
1129                update_user_contacts(user.id, &session).await?;
1130            }
1131            Principal::DevServer(dev_server) => {
1132                {
1133                    let mut pool = self.connection_pool.lock();
1134                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1135                    {
1136                        self.peer.send(
1137                            stale_connection_id,
1138                            proto::ShutdownDevServer {
1139                                reason: Some(
1140                                    "another dev server connected with the same token".to_string(),
1141                                ),
1142                            },
1143                        )?;
1144                        pool.remove_connection(stale_connection_id)?;
1145                    };
1146                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1147                }
1148
1149                let projects = self
1150                    .app_state
1151                    .db
1152                    .get_projects_for_dev_server(dev_server.id)
1153                    .await?;
1154                self.peer
1155                    .send(connection_id, proto::DevServerInstructions { projects })?;
1156
1157                let status = self
1158                    .app_state
1159                    .db
1160                    .dev_server_projects_update(dev_server.user_id)
1161                    .await?;
1162                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1163            }
1164        }
1165
1166        Ok(())
1167    }
1168
1169    pub async fn invite_code_redeemed(
1170        self: &Arc<Self>,
1171        inviter_id: UserId,
1172        invitee_id: UserId,
1173    ) -> Result<()> {
1174        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1175            if let Some(code) = &user.invite_code {
1176                let pool = self.connection_pool.lock();
1177                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1178                for connection_id in pool.user_connection_ids(inviter_id) {
1179                    self.peer.send(
1180                        connection_id,
1181                        proto::UpdateContacts {
1182                            contacts: vec![invitee_contact.clone()],
1183                            ..Default::default()
1184                        },
1185                    )?;
1186                    self.peer.send(
1187                        connection_id,
1188                        proto::UpdateInviteInfo {
1189                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1190                            count: user.invite_count as u32,
1191                        },
1192                    )?;
1193                }
1194            }
1195        }
1196        Ok(())
1197    }
1198
1199    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1200        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1201            if let Some(invite_code) = &user.invite_code {
1202                let pool = self.connection_pool.lock();
1203                for connection_id in pool.user_connection_ids(user_id) {
1204                    self.peer.send(
1205                        connection_id,
1206                        proto::UpdateInviteInfo {
1207                            url: format!(
1208                                "{}{}",
1209                                self.app_state.config.invite_link_prefix, invite_code
1210                            ),
1211                            count: user.invite_count as u32,
1212                        },
1213                    )?;
1214                }
1215            }
1216        }
1217        Ok(())
1218    }
1219
1220    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1221        ServerSnapshot {
1222            connection_pool: ConnectionPoolGuard {
1223                guard: self.connection_pool.lock(),
1224                _not_send: PhantomData,
1225            },
1226            peer: &self.peer,
1227        }
1228    }
1229}
1230
1231impl<'a> Deref for ConnectionPoolGuard<'a> {
1232    type Target = ConnectionPool;
1233
1234    fn deref(&self) -> &Self::Target {
1235        &self.guard
1236    }
1237}
1238
1239impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1240    fn deref_mut(&mut self) -> &mut Self::Target {
1241        &mut self.guard
1242    }
1243}
1244
1245impl<'a> Drop for ConnectionPoolGuard<'a> {
1246    fn drop(&mut self) {
1247        #[cfg(test)]
1248        self.check_invariants();
1249    }
1250}
1251
1252fn broadcast<F>(
1253    sender_id: Option<ConnectionId>,
1254    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1255    mut f: F,
1256) where
1257    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1258{
1259    for receiver_id in receiver_ids {
1260        if Some(receiver_id) != sender_id {
1261            if let Err(error) = f(receiver_id) {
1262                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1263            }
1264        }
1265    }
1266}
1267
1268pub struct ProtocolVersion(u32);
1269
1270impl Header for ProtocolVersion {
1271    fn name() -> &'static HeaderName {
1272        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1273        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1274    }
1275
1276    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1277    where
1278        Self: Sized,
1279        I: Iterator<Item = &'i axum::http::HeaderValue>,
1280    {
1281        let version = values
1282            .next()
1283            .ok_or_else(axum::headers::Error::invalid)?
1284            .to_str()
1285            .map_err(|_| axum::headers::Error::invalid())?
1286            .parse()
1287            .map_err(|_| axum::headers::Error::invalid())?;
1288        Ok(Self(version))
1289    }
1290
1291    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1292        values.extend([self.0.to_string().parse().unwrap()]);
1293    }
1294}
1295
1296pub struct AppVersionHeader(SemanticVersion);
1297impl Header for AppVersionHeader {
1298    fn name() -> &'static HeaderName {
1299        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1300        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1301    }
1302
1303    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1304    where
1305        Self: Sized,
1306        I: Iterator<Item = &'i axum::http::HeaderValue>,
1307    {
1308        let version = values
1309            .next()
1310            .ok_or_else(axum::headers::Error::invalid)?
1311            .to_str()
1312            .map_err(|_| axum::headers::Error::invalid())?
1313            .parse()
1314            .map_err(|_| axum::headers::Error::invalid())?;
1315        Ok(Self(version))
1316    }
1317
1318    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1319        values.extend([self.0.to_string().parse().unwrap()]);
1320    }
1321}
1322
1323pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1324    Router::new()
1325        .route("/rpc", get(handle_websocket_request))
1326        .layer(
1327            ServiceBuilder::new()
1328                .layer(Extension(server.app_state.clone()))
1329                .layer(middleware::from_fn(auth::validate_header)),
1330        )
1331        .route("/metrics", get(handle_metrics))
1332        .layer(Extension(server))
1333}
1334
1335pub async fn handle_websocket_request(
1336    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1337    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1338    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1339    Extension(server): Extension<Arc<Server>>,
1340    Extension(principal): Extension<Principal>,
1341    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1342    ws: WebSocketUpgrade,
1343) -> axum::response::Response {
1344    if protocol_version != rpc::PROTOCOL_VERSION {
1345        return (
1346            StatusCode::UPGRADE_REQUIRED,
1347            "client must be upgraded".to_string(),
1348        )
1349            .into_response();
1350    }
1351
1352    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1353        return (
1354            StatusCode::UPGRADE_REQUIRED,
1355            "no version header found".to_string(),
1356        )
1357            .into_response();
1358    };
1359
1360    if !version.can_collaborate() {
1361        return (
1362            StatusCode::UPGRADE_REQUIRED,
1363            "client must be upgraded".to_string(),
1364        )
1365            .into_response();
1366    }
1367
1368    let socket_address = socket_address.to_string();
1369    ws.on_upgrade(move |socket| {
1370        let socket = socket
1371            .map_ok(to_tungstenite_message)
1372            .err_into()
1373            .with(|message| async move { to_axum_message(message) });
1374        let connection = Connection::new(Box::pin(socket));
1375        async move {
1376            server
1377                .handle_connection(
1378                    connection,
1379                    socket_address,
1380                    principal,
1381                    version,
1382                    country_code_header.map(|header| header.to_string()),
1383                    None,
1384                    Executor::Production,
1385                )
1386                .await;
1387        }
1388    })
1389}
1390
1391pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1392    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1393    let connections_metric = CONNECTIONS_METRIC
1394        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1395
1396    let connections = server
1397        .connection_pool
1398        .lock()
1399        .connections()
1400        .filter(|connection| !connection.admin)
1401        .count();
1402    connections_metric.set(connections as _);
1403
1404    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1405    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1406        register_int_gauge!(
1407            "shared_projects",
1408            "number of open projects with one or more guests"
1409        )
1410        .unwrap()
1411    });
1412
1413    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1414    shared_projects_metric.set(shared_projects as _);
1415
1416    let encoder = prometheus::TextEncoder::new();
1417    let metric_families = prometheus::gather();
1418    let encoded_metrics = encoder
1419        .encode_to_string(&metric_families)
1420        .map_err(|err| anyhow!("{}", err))?;
1421    Ok(encoded_metrics)
1422}
1423
1424#[instrument(err, skip(executor))]
1425async fn connection_lost(
1426    session: Session,
1427    mut teardown: watch::Receiver<bool>,
1428    executor: Executor,
1429) -> Result<()> {
1430    session.peer.disconnect(session.connection_id);
1431    session
1432        .connection_pool()
1433        .await
1434        .remove_connection(session.connection_id)?;
1435
1436    session
1437        .db()
1438        .await
1439        .connection_lost(session.connection_id)
1440        .await
1441        .trace_err();
1442
1443    futures::select_biased! {
1444        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1445            match &session.principal {
1446                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1447                    let session = session.for_user().unwrap();
1448
1449                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1450                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1451                    leave_channel_buffers_for_session(&session)
1452                        .await
1453                        .trace_err();
1454
1455                    if !session
1456                        .connection_pool()
1457                        .await
1458                        .is_user_online(session.user_id())
1459                    {
1460                        let db = session.db().await;
1461                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1462                            room_updated(&room, &session.peer);
1463                        }
1464                    }
1465
1466                    update_user_contacts(session.user_id(), &session).await?;
1467                },
1468            Principal::DevServer(_) => {
1469                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1470            },
1471        }
1472        },
1473        _ = teardown.changed().fuse() => {}
1474    }
1475
1476    Ok(())
1477}
1478
1479/// Acknowledges a ping from a client, used to keep the connection alive.
1480async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1481    response.send(proto::Ack {})?;
1482    Ok(())
1483}
1484
1485/// Creates a new room for calling (outside of channels)
1486async fn create_room(
1487    _request: proto::CreateRoom,
1488    response: Response<proto::CreateRoom>,
1489    session: UserSession,
1490) -> Result<()> {
1491    let live_kit_room = nanoid::nanoid!(30);
1492
1493    let live_kit_connection_info = util::maybe!(async {
1494        let live_kit = session.app_state.live_kit_client.as_ref();
1495        let live_kit = live_kit?;
1496        let user_id = session.user_id().to_string();
1497
1498        let token = live_kit
1499            .room_token(&live_kit_room, &user_id.to_string())
1500            .trace_err()?;
1501
1502        Some(proto::LiveKitConnectionInfo {
1503            server_url: live_kit.url().into(),
1504            token,
1505            can_publish: true,
1506        })
1507    })
1508    .await;
1509
1510    let room = session
1511        .db()
1512        .await
1513        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1514        .await?;
1515
1516    response.send(proto::CreateRoomResponse {
1517        room: Some(room.clone()),
1518        live_kit_connection_info,
1519    })?;
1520
1521    update_user_contacts(session.user_id(), &session).await?;
1522    Ok(())
1523}
1524
1525/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1526async fn join_room(
1527    request: proto::JoinRoom,
1528    response: Response<proto::JoinRoom>,
1529    session: UserSession,
1530) -> Result<()> {
1531    let room_id = RoomId::from_proto(request.id);
1532
1533    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1534
1535    if let Some(channel_id) = channel_id {
1536        return join_channel_internal(channel_id, Box::new(response), session).await;
1537    }
1538
1539    let joined_room = {
1540        let room = session
1541            .db()
1542            .await
1543            .join_room(room_id, session.user_id(), session.connection_id)
1544            .await?;
1545        room_updated(&room.room, &session.peer);
1546        room.into_inner()
1547    };
1548
1549    for connection_id in session
1550        .connection_pool()
1551        .await
1552        .user_connection_ids(session.user_id())
1553    {
1554        session
1555            .peer
1556            .send(
1557                connection_id,
1558                proto::CallCanceled {
1559                    room_id: room_id.to_proto(),
1560                },
1561            )
1562            .trace_err();
1563    }
1564
1565    let live_kit_connection_info =
1566        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1567            if let Some(token) = live_kit
1568                .room_token(
1569                    &joined_room.room.live_kit_room,
1570                    &session.user_id().to_string(),
1571                )
1572                .trace_err()
1573            {
1574                Some(proto::LiveKitConnectionInfo {
1575                    server_url: live_kit.url().into(),
1576                    token,
1577                    can_publish: true,
1578                })
1579            } else {
1580                None
1581            }
1582        } else {
1583            None
1584        };
1585
1586    response.send(proto::JoinRoomResponse {
1587        room: Some(joined_room.room),
1588        channel_id: None,
1589        live_kit_connection_info,
1590    })?;
1591
1592    update_user_contacts(session.user_id(), &session).await?;
1593    Ok(())
1594}
1595
1596/// Rejoin room is used to reconnect to a room after connection errors.
1597async fn rejoin_room(
1598    request: proto::RejoinRoom,
1599    response: Response<proto::RejoinRoom>,
1600    session: UserSession,
1601) -> Result<()> {
1602    let room;
1603    let channel;
1604    {
1605        let mut rejoined_room = session
1606            .db()
1607            .await
1608            .rejoin_room(request, session.user_id(), session.connection_id)
1609            .await?;
1610
1611        response.send(proto::RejoinRoomResponse {
1612            room: Some(rejoined_room.room.clone()),
1613            reshared_projects: rejoined_room
1614                .reshared_projects
1615                .iter()
1616                .map(|project| proto::ResharedProject {
1617                    id: project.id.to_proto(),
1618                    collaborators: project
1619                        .collaborators
1620                        .iter()
1621                        .map(|collaborator| collaborator.to_proto())
1622                        .collect(),
1623                })
1624                .collect(),
1625            rejoined_projects: rejoined_room
1626                .rejoined_projects
1627                .iter()
1628                .map(|rejoined_project| rejoined_project.to_proto())
1629                .collect(),
1630        })?;
1631        room_updated(&rejoined_room.room, &session.peer);
1632
1633        for project in &rejoined_room.reshared_projects {
1634            for collaborator in &project.collaborators {
1635                session
1636                    .peer
1637                    .send(
1638                        collaborator.connection_id,
1639                        proto::UpdateProjectCollaborator {
1640                            project_id: project.id.to_proto(),
1641                            old_peer_id: Some(project.old_connection_id.into()),
1642                            new_peer_id: Some(session.connection_id.into()),
1643                        },
1644                    )
1645                    .trace_err();
1646            }
1647
1648            broadcast(
1649                Some(session.connection_id),
1650                project
1651                    .collaborators
1652                    .iter()
1653                    .map(|collaborator| collaborator.connection_id),
1654                |connection_id| {
1655                    session.peer.forward_send(
1656                        session.connection_id,
1657                        connection_id,
1658                        proto::UpdateProject {
1659                            project_id: project.id.to_proto(),
1660                            worktrees: project.worktrees.clone(),
1661                        },
1662                    )
1663                },
1664            );
1665        }
1666
1667        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1668
1669        let rejoined_room = rejoined_room.into_inner();
1670
1671        room = rejoined_room.room;
1672        channel = rejoined_room.channel;
1673    }
1674
1675    if let Some(channel) = channel {
1676        channel_updated(
1677            &channel,
1678            &room,
1679            &session.peer,
1680            &*session.connection_pool().await,
1681        );
1682    }
1683
1684    update_user_contacts(session.user_id(), &session).await?;
1685    Ok(())
1686}
1687
1688fn notify_rejoined_projects(
1689    rejoined_projects: &mut Vec<RejoinedProject>,
1690    session: &UserSession,
1691) -> Result<()> {
1692    for project in rejoined_projects.iter() {
1693        for collaborator in &project.collaborators {
1694            session
1695                .peer
1696                .send(
1697                    collaborator.connection_id,
1698                    proto::UpdateProjectCollaborator {
1699                        project_id: project.id.to_proto(),
1700                        old_peer_id: Some(project.old_connection_id.into()),
1701                        new_peer_id: Some(session.connection_id.into()),
1702                    },
1703                )
1704                .trace_err();
1705        }
1706    }
1707
1708    for project in rejoined_projects {
1709        for worktree in mem::take(&mut project.worktrees) {
1710            #[cfg(any(test, feature = "test-support"))]
1711            const MAX_CHUNK_SIZE: usize = 2;
1712            #[cfg(not(any(test, feature = "test-support")))]
1713            const MAX_CHUNK_SIZE: usize = 256;
1714
1715            // Stream this worktree's entries.
1716            let message = proto::UpdateWorktree {
1717                project_id: project.id.to_proto(),
1718                worktree_id: worktree.id,
1719                abs_path: worktree.abs_path.clone(),
1720                root_name: worktree.root_name,
1721                updated_entries: worktree.updated_entries,
1722                removed_entries: worktree.removed_entries,
1723                scan_id: worktree.scan_id,
1724                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1725                updated_repositories: worktree.updated_repositories,
1726                removed_repositories: worktree.removed_repositories,
1727            };
1728            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1729                session.peer.send(session.connection_id, update.clone())?;
1730            }
1731
1732            // Stream this worktree's diagnostics.
1733            for summary in worktree.diagnostic_summaries {
1734                session.peer.send(
1735                    session.connection_id,
1736                    proto::UpdateDiagnosticSummary {
1737                        project_id: project.id.to_proto(),
1738                        worktree_id: worktree.id,
1739                        summary: Some(summary),
1740                    },
1741                )?;
1742            }
1743
1744            for settings_file in worktree.settings_files {
1745                session.peer.send(
1746                    session.connection_id,
1747                    proto::UpdateWorktreeSettings {
1748                        project_id: project.id.to_proto(),
1749                        worktree_id: worktree.id,
1750                        path: settings_file.path,
1751                        content: Some(settings_file.content),
1752                    },
1753                )?;
1754            }
1755        }
1756
1757        for language_server in &project.language_servers {
1758            session.peer.send(
1759                session.connection_id,
1760                proto::UpdateLanguageServer {
1761                    project_id: project.id.to_proto(),
1762                    language_server_id: language_server.id,
1763                    variant: Some(
1764                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1765                            proto::LspDiskBasedDiagnosticsUpdated {},
1766                        ),
1767                    ),
1768                },
1769            )?;
1770        }
1771    }
1772    Ok(())
1773}
1774
1775/// leave room disconnects from the room.
1776async fn leave_room(
1777    _: proto::LeaveRoom,
1778    response: Response<proto::LeaveRoom>,
1779    session: UserSession,
1780) -> Result<()> {
1781    leave_room_for_session(&session, session.connection_id).await?;
1782    response.send(proto::Ack {})?;
1783    Ok(())
1784}
1785
1786/// Updates the permissions of someone else in the room.
1787async fn set_room_participant_role(
1788    request: proto::SetRoomParticipantRole,
1789    response: Response<proto::SetRoomParticipantRole>,
1790    session: UserSession,
1791) -> Result<()> {
1792    let user_id = UserId::from_proto(request.user_id);
1793    let role = ChannelRole::from(request.role());
1794
1795    let (live_kit_room, can_publish) = {
1796        let room = session
1797            .db()
1798            .await
1799            .set_room_participant_role(
1800                session.user_id(),
1801                RoomId::from_proto(request.room_id),
1802                user_id,
1803                role,
1804            )
1805            .await?;
1806
1807        let live_kit_room = room.live_kit_room.clone();
1808        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1809        room_updated(&room, &session.peer);
1810        (live_kit_room, can_publish)
1811    };
1812
1813    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1814        live_kit
1815            .update_participant(
1816                live_kit_room.clone(),
1817                request.user_id.to_string(),
1818                live_kit_server::proto::ParticipantPermission {
1819                    can_subscribe: true,
1820                    can_publish,
1821                    can_publish_data: can_publish,
1822                    hidden: false,
1823                    recorder: false,
1824                },
1825            )
1826            .await
1827            .trace_err();
1828    }
1829
1830    response.send(proto::Ack {})?;
1831    Ok(())
1832}
1833
1834/// Call someone else into the current room
1835async fn call(
1836    request: proto::Call,
1837    response: Response<proto::Call>,
1838    session: UserSession,
1839) -> Result<()> {
1840    let room_id = RoomId::from_proto(request.room_id);
1841    let calling_user_id = session.user_id();
1842    let calling_connection_id = session.connection_id;
1843    let called_user_id = UserId::from_proto(request.called_user_id);
1844    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1845    if !session
1846        .db()
1847        .await
1848        .has_contact(calling_user_id, called_user_id)
1849        .await?
1850    {
1851        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1852    }
1853
1854    let incoming_call = {
1855        let (room, incoming_call) = &mut *session
1856            .db()
1857            .await
1858            .call(
1859                room_id,
1860                calling_user_id,
1861                calling_connection_id,
1862                called_user_id,
1863                initial_project_id,
1864            )
1865            .await?;
1866        room_updated(&room, &session.peer);
1867        mem::take(incoming_call)
1868    };
1869    update_user_contacts(called_user_id, &session).await?;
1870
1871    let mut calls = session
1872        .connection_pool()
1873        .await
1874        .user_connection_ids(called_user_id)
1875        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1876        .collect::<FuturesUnordered<_>>();
1877
1878    while let Some(call_response) = calls.next().await {
1879        match call_response.as_ref() {
1880            Ok(_) => {
1881                response.send(proto::Ack {})?;
1882                return Ok(());
1883            }
1884            Err(_) => {
1885                call_response.trace_err();
1886            }
1887        }
1888    }
1889
1890    {
1891        let room = session
1892            .db()
1893            .await
1894            .call_failed(room_id, called_user_id)
1895            .await?;
1896        room_updated(&room, &session.peer);
1897    }
1898    update_user_contacts(called_user_id, &session).await?;
1899
1900    Err(anyhow!("failed to ring user"))?
1901}
1902
1903/// Cancel an outgoing call.
1904async fn cancel_call(
1905    request: proto::CancelCall,
1906    response: Response<proto::CancelCall>,
1907    session: UserSession,
1908) -> Result<()> {
1909    let called_user_id = UserId::from_proto(request.called_user_id);
1910    let room_id = RoomId::from_proto(request.room_id);
1911    {
1912        let room = session
1913            .db()
1914            .await
1915            .cancel_call(room_id, session.connection_id, called_user_id)
1916            .await?;
1917        room_updated(&room, &session.peer);
1918    }
1919
1920    for connection_id in session
1921        .connection_pool()
1922        .await
1923        .user_connection_ids(called_user_id)
1924    {
1925        session
1926            .peer
1927            .send(
1928                connection_id,
1929                proto::CallCanceled {
1930                    room_id: room_id.to_proto(),
1931                },
1932            )
1933            .trace_err();
1934    }
1935    response.send(proto::Ack {})?;
1936
1937    update_user_contacts(called_user_id, &session).await?;
1938    Ok(())
1939}
1940
1941/// Decline an incoming call.
1942async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1943    let room_id = RoomId::from_proto(message.room_id);
1944    {
1945        let room = session
1946            .db()
1947            .await
1948            .decline_call(Some(room_id), session.user_id())
1949            .await?
1950            .ok_or_else(|| anyhow!("failed to decline call"))?;
1951        room_updated(&room, &session.peer);
1952    }
1953
1954    for connection_id in session
1955        .connection_pool()
1956        .await
1957        .user_connection_ids(session.user_id())
1958    {
1959        session
1960            .peer
1961            .send(
1962                connection_id,
1963                proto::CallCanceled {
1964                    room_id: room_id.to_proto(),
1965                },
1966            )
1967            .trace_err();
1968    }
1969    update_user_contacts(session.user_id(), &session).await?;
1970    Ok(())
1971}
1972
1973/// Updates other participants in the room with your current location.
1974async fn update_participant_location(
1975    request: proto::UpdateParticipantLocation,
1976    response: Response<proto::UpdateParticipantLocation>,
1977    session: UserSession,
1978) -> Result<()> {
1979    let room_id = RoomId::from_proto(request.room_id);
1980    let location = request
1981        .location
1982        .ok_or_else(|| anyhow!("invalid location"))?;
1983
1984    let db = session.db().await;
1985    let room = db
1986        .update_room_participant_location(room_id, session.connection_id, location)
1987        .await?;
1988
1989    room_updated(&room, &session.peer);
1990    response.send(proto::Ack {})?;
1991    Ok(())
1992}
1993
1994/// Share a project into the room.
1995async fn share_project(
1996    request: proto::ShareProject,
1997    response: Response<proto::ShareProject>,
1998    session: UserSession,
1999) -> Result<()> {
2000    let (project_id, room) = &*session
2001        .db()
2002        .await
2003        .share_project(
2004            RoomId::from_proto(request.room_id),
2005            session.connection_id,
2006            &request.worktrees,
2007            request
2008                .dev_server_project_id
2009                .map(|id| DevServerProjectId::from_proto(id)),
2010        )
2011        .await?;
2012    response.send(proto::ShareProjectResponse {
2013        project_id: project_id.to_proto(),
2014    })?;
2015    room_updated(&room, &session.peer);
2016
2017    Ok(())
2018}
2019
2020/// Unshare a project from the room.
2021async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2022    let project_id = ProjectId::from_proto(message.project_id);
2023    unshare_project_internal(
2024        project_id,
2025        session.connection_id,
2026        session.user_id(),
2027        &session,
2028    )
2029    .await
2030}
2031
2032async fn unshare_project_internal(
2033    project_id: ProjectId,
2034    connection_id: ConnectionId,
2035    user_id: Option<UserId>,
2036    session: &Session,
2037) -> Result<()> {
2038    let delete = {
2039        let room_guard = session
2040            .db()
2041            .await
2042            .unshare_project(project_id, connection_id, user_id)
2043            .await?;
2044
2045        let (delete, room, guest_connection_ids) = &*room_guard;
2046
2047        let message = proto::UnshareProject {
2048            project_id: project_id.to_proto(),
2049        };
2050
2051        broadcast(
2052            Some(connection_id),
2053            guest_connection_ids.iter().copied(),
2054            |conn_id| session.peer.send(conn_id, message.clone()),
2055        );
2056        if let Some(room) = room {
2057            room_updated(room, &session.peer);
2058        }
2059
2060        *delete
2061    };
2062
2063    if delete {
2064        let db = session.db().await;
2065        db.delete_project(project_id).await?;
2066    }
2067
2068    Ok(())
2069}
2070
2071/// DevServer makes a project available online
2072async fn share_dev_server_project(
2073    request: proto::ShareDevServerProject,
2074    response: Response<proto::ShareDevServerProject>,
2075    session: DevServerSession,
2076) -> Result<()> {
2077    let (dev_server_project, user_id, status) = session
2078        .db()
2079        .await
2080        .share_dev_server_project(
2081            DevServerProjectId::from_proto(request.dev_server_project_id),
2082            session.dev_server_id(),
2083            session.connection_id,
2084            &request.worktrees,
2085        )
2086        .await?;
2087    let Some(project_id) = dev_server_project.project_id else {
2088        return Err(anyhow!("failed to share remote project"))?;
2089    };
2090
2091    send_dev_server_projects_update(user_id, status, &session).await;
2092
2093    response.send(proto::ShareProjectResponse { project_id })?;
2094
2095    Ok(())
2096}
2097
2098/// Join someone elses shared project.
2099async fn join_project(
2100    request: proto::JoinProject,
2101    response: Response<proto::JoinProject>,
2102    session: UserSession,
2103) -> Result<()> {
2104    let project_id = ProjectId::from_proto(request.project_id);
2105
2106    tracing::info!(%project_id, "join project");
2107
2108    let db = session.db().await;
2109    let (project, replica_id) = &mut *db
2110        .join_project(project_id, session.connection_id, session.user_id())
2111        .await?;
2112    drop(db);
2113    tracing::info!(%project_id, "join remote project");
2114    join_project_internal(response, session, project, replica_id)
2115}
2116
2117trait JoinProjectInternalResponse {
2118    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2119}
2120impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2121    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2122        Response::<proto::JoinProject>::send(self, result)
2123    }
2124}
2125impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2126    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2127        Response::<proto::JoinHostedProject>::send(self, result)
2128    }
2129}
2130
2131fn join_project_internal(
2132    response: impl JoinProjectInternalResponse,
2133    session: UserSession,
2134    project: &mut Project,
2135    replica_id: &ReplicaId,
2136) -> Result<()> {
2137    let collaborators = project
2138        .collaborators
2139        .iter()
2140        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2141        .map(|collaborator| collaborator.to_proto())
2142        .collect::<Vec<_>>();
2143    let project_id = project.id;
2144    let guest_user_id = session.user_id();
2145
2146    let worktrees = project
2147        .worktrees
2148        .iter()
2149        .map(|(id, worktree)| proto::WorktreeMetadata {
2150            id: *id,
2151            root_name: worktree.root_name.clone(),
2152            visible: worktree.visible,
2153            abs_path: worktree.abs_path.clone(),
2154        })
2155        .collect::<Vec<_>>();
2156
2157    let add_project_collaborator = proto::AddProjectCollaborator {
2158        project_id: project_id.to_proto(),
2159        collaborator: Some(proto::Collaborator {
2160            peer_id: Some(session.connection_id.into()),
2161            replica_id: replica_id.0 as u32,
2162            user_id: guest_user_id.to_proto(),
2163        }),
2164    };
2165
2166    for collaborator in &collaborators {
2167        session
2168            .peer
2169            .send(
2170                collaborator.peer_id.unwrap().into(),
2171                add_project_collaborator.clone(),
2172            )
2173            .trace_err();
2174    }
2175
2176    // First, we send the metadata associated with each worktree.
2177    response.send(proto::JoinProjectResponse {
2178        project_id: project.id.0 as u64,
2179        worktrees: worktrees.clone(),
2180        replica_id: replica_id.0 as u32,
2181        collaborators: collaborators.clone(),
2182        language_servers: project.language_servers.clone(),
2183        role: project.role.into(),
2184        dev_server_project_id: project
2185            .dev_server_project_id
2186            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2187    })?;
2188
2189    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2190        #[cfg(any(test, feature = "test-support"))]
2191        const MAX_CHUNK_SIZE: usize = 2;
2192        #[cfg(not(any(test, feature = "test-support")))]
2193        const MAX_CHUNK_SIZE: usize = 256;
2194
2195        // Stream this worktree's entries.
2196        let message = proto::UpdateWorktree {
2197            project_id: project_id.to_proto(),
2198            worktree_id,
2199            abs_path: worktree.abs_path.clone(),
2200            root_name: worktree.root_name,
2201            updated_entries: worktree.entries,
2202            removed_entries: Default::default(),
2203            scan_id: worktree.scan_id,
2204            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2205            updated_repositories: worktree.repository_entries.into_values().collect(),
2206            removed_repositories: Default::default(),
2207        };
2208        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2209            session.peer.send(session.connection_id, update.clone())?;
2210        }
2211
2212        // Stream this worktree's diagnostics.
2213        for summary in worktree.diagnostic_summaries {
2214            session.peer.send(
2215                session.connection_id,
2216                proto::UpdateDiagnosticSummary {
2217                    project_id: project_id.to_proto(),
2218                    worktree_id: worktree.id,
2219                    summary: Some(summary),
2220                },
2221            )?;
2222        }
2223
2224        for settings_file in worktree.settings_files {
2225            session.peer.send(
2226                session.connection_id,
2227                proto::UpdateWorktreeSettings {
2228                    project_id: project_id.to_proto(),
2229                    worktree_id: worktree.id,
2230                    path: settings_file.path,
2231                    content: Some(settings_file.content),
2232                },
2233            )?;
2234        }
2235    }
2236
2237    for language_server in &project.language_servers {
2238        session.peer.send(
2239            session.connection_id,
2240            proto::UpdateLanguageServer {
2241                project_id: project_id.to_proto(),
2242                language_server_id: language_server.id,
2243                variant: Some(
2244                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2245                        proto::LspDiskBasedDiagnosticsUpdated {},
2246                    ),
2247                ),
2248            },
2249        )?;
2250    }
2251
2252    Ok(())
2253}
2254
2255/// Leave someone elses shared project.
2256async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2257    let sender_id = session.connection_id;
2258    let project_id = ProjectId::from_proto(request.project_id);
2259    let db = session.db().await;
2260    if db.is_hosted_project(project_id).await? {
2261        let project = db.leave_hosted_project(project_id, sender_id).await?;
2262        project_left(&project, &session);
2263        return Ok(());
2264    }
2265
2266    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2267    tracing::info!(
2268        %project_id,
2269        "leave project"
2270    );
2271
2272    project_left(&project, &session);
2273    if let Some(room) = room {
2274        room_updated(&room, &session.peer);
2275    }
2276
2277    Ok(())
2278}
2279
2280async fn join_hosted_project(
2281    request: proto::JoinHostedProject,
2282    response: Response<proto::JoinHostedProject>,
2283    session: UserSession,
2284) -> Result<()> {
2285    let (mut project, replica_id) = session
2286        .db()
2287        .await
2288        .join_hosted_project(
2289            ProjectId(request.project_id as i32),
2290            session.user_id(),
2291            session.connection_id,
2292        )
2293        .await?;
2294
2295    join_project_internal(response, session, &mut project, &replica_id)
2296}
2297
2298async fn list_remote_directory(
2299    request: proto::ListRemoteDirectory,
2300    response: Response<proto::ListRemoteDirectory>,
2301    session: UserSession,
2302) -> Result<()> {
2303    let dev_server_id = DevServerId(request.dev_server_id as i32);
2304    let dev_server_connection_id = session
2305        .connection_pool()
2306        .await
2307        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2308
2309    session
2310        .db()
2311        .await
2312        .get_dev_server_for_user(dev_server_id, session.user_id())
2313        .await?;
2314
2315    response.send(
2316        session
2317            .peer
2318            .forward_request(session.connection_id, dev_server_connection_id, request)
2319            .await?,
2320    )?;
2321    Ok(())
2322}
2323
2324async fn update_dev_server_project(
2325    request: proto::UpdateDevServerProject,
2326    response: Response<proto::UpdateDevServerProject>,
2327    session: UserSession,
2328) -> Result<()> {
2329    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2330
2331    let (dev_server_project, update) = session
2332        .db()
2333        .await
2334        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2335        .await?;
2336
2337    let projects = session
2338        .db()
2339        .await
2340        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2341        .await?;
2342
2343    let dev_server_connection_id = session
2344        .connection_pool()
2345        .await
2346        .dev_server_connection_id_supporting(
2347            dev_server_project.dev_server_id,
2348            ZedVersion::with_list_directory(),
2349        )?;
2350
2351    session.peer.send(
2352        dev_server_connection_id,
2353        proto::DevServerInstructions { projects },
2354    )?;
2355
2356    send_dev_server_projects_update(session.user_id(), update, &session).await;
2357
2358    response.send(proto::Ack {})
2359}
2360
2361async fn create_dev_server_project(
2362    request: proto::CreateDevServerProject,
2363    response: Response<proto::CreateDevServerProject>,
2364    session: UserSession,
2365) -> Result<()> {
2366    let dev_server_id = DevServerId(request.dev_server_id as i32);
2367    let dev_server_connection_id = session
2368        .connection_pool()
2369        .await
2370        .dev_server_connection_id(dev_server_id);
2371    let Some(dev_server_connection_id) = dev_server_connection_id else {
2372        Err(ErrorCode::DevServerOffline
2373            .message("Cannot create a remote project when the dev server is offline".to_string())
2374            .anyhow())?
2375    };
2376
2377    let path = request.path.clone();
2378    //Check that the path exists on the dev server
2379    session
2380        .peer
2381        .forward_request(
2382            session.connection_id,
2383            dev_server_connection_id,
2384            proto::ValidateDevServerProjectRequest { path: path.clone() },
2385        )
2386        .await?;
2387
2388    let (dev_server_project, update) = session
2389        .db()
2390        .await
2391        .create_dev_server_project(
2392            DevServerId(request.dev_server_id as i32),
2393            &request.path,
2394            session.user_id(),
2395        )
2396        .await?;
2397
2398    let projects = session
2399        .db()
2400        .await
2401        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2402        .await?;
2403
2404    session.peer.send(
2405        dev_server_connection_id,
2406        proto::DevServerInstructions { projects },
2407    )?;
2408
2409    send_dev_server_projects_update(session.user_id(), update, &session).await;
2410
2411    response.send(proto::CreateDevServerProjectResponse {
2412        dev_server_project: Some(dev_server_project.to_proto(None)),
2413    })?;
2414    Ok(())
2415}
2416
2417async fn create_dev_server(
2418    request: proto::CreateDevServer,
2419    response: Response<proto::CreateDevServer>,
2420    session: UserSession,
2421) -> Result<()> {
2422    let access_token = auth::random_token();
2423    let hashed_access_token = auth::hash_access_token(&access_token);
2424
2425    if request.name.is_empty() {
2426        return Err(proto::ErrorCode::Forbidden
2427            .message("Dev server name cannot be empty".to_string())
2428            .anyhow())?;
2429    }
2430
2431    let (dev_server, status) = session
2432        .db()
2433        .await
2434        .create_dev_server(
2435            &request.name,
2436            request.ssh_connection_string.as_deref(),
2437            &hashed_access_token,
2438            session.user_id(),
2439        )
2440        .await?;
2441
2442    send_dev_server_projects_update(session.user_id(), status, &session).await;
2443
2444    response.send(proto::CreateDevServerResponse {
2445        dev_server_id: dev_server.id.0 as u64,
2446        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2447        name: request.name,
2448    })?;
2449    Ok(())
2450}
2451
2452async fn regenerate_dev_server_token(
2453    request: proto::RegenerateDevServerToken,
2454    response: Response<proto::RegenerateDevServerToken>,
2455    session: UserSession,
2456) -> Result<()> {
2457    let dev_server_id = DevServerId(request.dev_server_id as i32);
2458    let access_token = auth::random_token();
2459    let hashed_access_token = auth::hash_access_token(&access_token);
2460
2461    let connection_id = session
2462        .connection_pool()
2463        .await
2464        .dev_server_connection_id(dev_server_id);
2465    if let Some(connection_id) = connection_id {
2466        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2467        session.peer.send(
2468            connection_id,
2469            proto::ShutdownDevServer {
2470                reason: Some("dev server token was regenerated".to_string()),
2471            },
2472        )?;
2473        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2474    }
2475
2476    let status = session
2477        .db()
2478        .await
2479        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2480        .await?;
2481
2482    send_dev_server_projects_update(session.user_id(), status, &session).await;
2483
2484    response.send(proto::RegenerateDevServerTokenResponse {
2485        dev_server_id: dev_server_id.to_proto(),
2486        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2487    })?;
2488    Ok(())
2489}
2490
2491async fn rename_dev_server(
2492    request: proto::RenameDevServer,
2493    response: Response<proto::RenameDevServer>,
2494    session: UserSession,
2495) -> Result<()> {
2496    if request.name.trim().is_empty() {
2497        return Err(proto::ErrorCode::Forbidden
2498            .message("Dev server name cannot be empty".to_string())
2499            .anyhow())?;
2500    }
2501
2502    let dev_server_id = DevServerId(request.dev_server_id as i32);
2503    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2504    if dev_server.user_id != session.user_id() {
2505        return Err(anyhow!(ErrorCode::Forbidden))?;
2506    }
2507
2508    let status = session
2509        .db()
2510        .await
2511        .rename_dev_server(
2512            dev_server_id,
2513            &request.name,
2514            request.ssh_connection_string.as_deref(),
2515            session.user_id(),
2516        )
2517        .await?;
2518
2519    send_dev_server_projects_update(session.user_id(), status, &session).await;
2520
2521    response.send(proto::Ack {})?;
2522    Ok(())
2523}
2524
2525async fn delete_dev_server(
2526    request: proto::DeleteDevServer,
2527    response: Response<proto::DeleteDevServer>,
2528    session: UserSession,
2529) -> Result<()> {
2530    let dev_server_id = DevServerId(request.dev_server_id as i32);
2531    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2532    if dev_server.user_id != session.user_id() {
2533        return Err(anyhow!(ErrorCode::Forbidden))?;
2534    }
2535
2536    let connection_id = session
2537        .connection_pool()
2538        .await
2539        .dev_server_connection_id(dev_server_id);
2540    if let Some(connection_id) = connection_id {
2541        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2542        session.peer.send(
2543            connection_id,
2544            proto::ShutdownDevServer {
2545                reason: Some("dev server was deleted".to_string()),
2546            },
2547        )?;
2548        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2549    }
2550
2551    let status = session
2552        .db()
2553        .await
2554        .delete_dev_server(dev_server_id, session.user_id())
2555        .await?;
2556
2557    send_dev_server_projects_update(session.user_id(), status, &session).await;
2558
2559    response.send(proto::Ack {})?;
2560    Ok(())
2561}
2562
2563async fn delete_dev_server_project(
2564    request: proto::DeleteDevServerProject,
2565    response: Response<proto::DeleteDevServerProject>,
2566    session: UserSession,
2567) -> Result<()> {
2568    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2569    let dev_server_project = session
2570        .db()
2571        .await
2572        .get_dev_server_project(dev_server_project_id)
2573        .await?;
2574
2575    let dev_server = session
2576        .db()
2577        .await
2578        .get_dev_server(dev_server_project.dev_server_id)
2579        .await?;
2580    if dev_server.user_id != session.user_id() {
2581        return Err(anyhow!(ErrorCode::Forbidden))?;
2582    }
2583
2584    let dev_server_connection_id = session
2585        .connection_pool()
2586        .await
2587        .dev_server_connection_id(dev_server.id);
2588
2589    if let Some(dev_server_connection_id) = dev_server_connection_id {
2590        let project = session
2591            .db()
2592            .await
2593            .find_dev_server_project(dev_server_project_id)
2594            .await;
2595        if let Ok(project) = project {
2596            unshare_project_internal(
2597                project.id,
2598                dev_server_connection_id,
2599                Some(session.user_id()),
2600                &session,
2601            )
2602            .await?;
2603        }
2604    }
2605
2606    let (projects, status) = session
2607        .db()
2608        .await
2609        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2610        .await?;
2611
2612    if let Some(dev_server_connection_id) = dev_server_connection_id {
2613        session.peer.send(
2614            dev_server_connection_id,
2615            proto::DevServerInstructions { projects },
2616        )?;
2617    }
2618
2619    send_dev_server_projects_update(session.user_id(), status, &session).await;
2620
2621    response.send(proto::Ack {})?;
2622    Ok(())
2623}
2624
2625async fn rejoin_dev_server_projects(
2626    request: proto::RejoinRemoteProjects,
2627    response: Response<proto::RejoinRemoteProjects>,
2628    session: UserSession,
2629) -> Result<()> {
2630    let mut rejoined_projects = {
2631        let db = session.db().await;
2632        db.rejoin_dev_server_projects(
2633            &request.rejoined_projects,
2634            session.user_id(),
2635            session.0.connection_id,
2636        )
2637        .await?
2638    };
2639    response.send(proto::RejoinRemoteProjectsResponse {
2640        rejoined_projects: rejoined_projects
2641            .iter()
2642            .map(|project| project.to_proto())
2643            .collect(),
2644    })?;
2645    notify_rejoined_projects(&mut rejoined_projects, &session)
2646}
2647
2648async fn reconnect_dev_server(
2649    request: proto::ReconnectDevServer,
2650    response: Response<proto::ReconnectDevServer>,
2651    session: DevServerSession,
2652) -> Result<()> {
2653    let reshared_projects = {
2654        let db = session.db().await;
2655        db.reshare_dev_server_projects(
2656            &request.reshared_projects,
2657            session.dev_server_id(),
2658            session.0.connection_id,
2659        )
2660        .await?
2661    };
2662
2663    for project in &reshared_projects {
2664        for collaborator in &project.collaborators {
2665            session
2666                .peer
2667                .send(
2668                    collaborator.connection_id,
2669                    proto::UpdateProjectCollaborator {
2670                        project_id: project.id.to_proto(),
2671                        old_peer_id: Some(project.old_connection_id.into()),
2672                        new_peer_id: Some(session.connection_id.into()),
2673                    },
2674                )
2675                .trace_err();
2676        }
2677
2678        broadcast(
2679            Some(session.connection_id),
2680            project
2681                .collaborators
2682                .iter()
2683                .map(|collaborator| collaborator.connection_id),
2684            |connection_id| {
2685                session.peer.forward_send(
2686                    session.connection_id,
2687                    connection_id,
2688                    proto::UpdateProject {
2689                        project_id: project.id.to_proto(),
2690                        worktrees: project.worktrees.clone(),
2691                    },
2692                )
2693            },
2694        );
2695    }
2696
2697    response.send(proto::ReconnectDevServerResponse {
2698        reshared_projects: reshared_projects
2699            .iter()
2700            .map(|project| proto::ResharedProject {
2701                id: project.id.to_proto(),
2702                collaborators: project
2703                    .collaborators
2704                    .iter()
2705                    .map(|collaborator| collaborator.to_proto())
2706                    .collect(),
2707            })
2708            .collect(),
2709    })?;
2710
2711    Ok(())
2712}
2713
2714async fn shutdown_dev_server(
2715    _: proto::ShutdownDevServer,
2716    response: Response<proto::ShutdownDevServer>,
2717    session: DevServerSession,
2718) -> Result<()> {
2719    response.send(proto::Ack {})?;
2720    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2721    remove_dev_server_connection(session.dev_server_id(), &session).await
2722}
2723
2724async fn shutdown_dev_server_internal(
2725    dev_server_id: DevServerId,
2726    connection_id: ConnectionId,
2727    session: &Session,
2728) -> Result<()> {
2729    let (dev_server_projects, dev_server) = {
2730        let db = session.db().await;
2731        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2732        let dev_server = db.get_dev_server(dev_server_id).await?;
2733        (dev_server_projects, dev_server)
2734    };
2735
2736    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2737        unshare_project_internal(
2738            ProjectId::from_proto(project_id),
2739            connection_id,
2740            None,
2741            session,
2742        )
2743        .await?;
2744    }
2745
2746    session
2747        .connection_pool()
2748        .await
2749        .set_dev_server_offline(dev_server_id);
2750
2751    let status = session
2752        .db()
2753        .await
2754        .dev_server_projects_update(dev_server.user_id)
2755        .await?;
2756    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2757
2758    Ok(())
2759}
2760
2761async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2762    let dev_server_connection = session
2763        .connection_pool()
2764        .await
2765        .dev_server_connection_id(dev_server_id);
2766
2767    if let Some(dev_server_connection) = dev_server_connection {
2768        session
2769            .connection_pool()
2770            .await
2771            .remove_connection(dev_server_connection)?;
2772    }
2773    Ok(())
2774}
2775
2776/// Updates other participants with changes to the project
2777async fn update_project(
2778    request: proto::UpdateProject,
2779    response: Response<proto::UpdateProject>,
2780    session: Session,
2781) -> Result<()> {
2782    let project_id = ProjectId::from_proto(request.project_id);
2783    let (room, guest_connection_ids) = &*session
2784        .db()
2785        .await
2786        .update_project(project_id, session.connection_id, &request.worktrees)
2787        .await?;
2788    broadcast(
2789        Some(session.connection_id),
2790        guest_connection_ids.iter().copied(),
2791        |connection_id| {
2792            session
2793                .peer
2794                .forward_send(session.connection_id, connection_id, request.clone())
2795        },
2796    );
2797    if let Some(room) = room {
2798        room_updated(&room, &session.peer);
2799    }
2800    response.send(proto::Ack {})?;
2801
2802    Ok(())
2803}
2804
2805/// Updates other participants with changes to the worktree
2806async fn update_worktree(
2807    request: proto::UpdateWorktree,
2808    response: Response<proto::UpdateWorktree>,
2809    session: Session,
2810) -> Result<()> {
2811    let guest_connection_ids = session
2812        .db()
2813        .await
2814        .update_worktree(&request, session.connection_id)
2815        .await?;
2816
2817    broadcast(
2818        Some(session.connection_id),
2819        guest_connection_ids.iter().copied(),
2820        |connection_id| {
2821            session
2822                .peer
2823                .forward_send(session.connection_id, connection_id, request.clone())
2824        },
2825    );
2826    response.send(proto::Ack {})?;
2827    Ok(())
2828}
2829
2830/// Updates other participants with changes to the diagnostics
2831async fn update_diagnostic_summary(
2832    message: proto::UpdateDiagnosticSummary,
2833    session: Session,
2834) -> Result<()> {
2835    let guest_connection_ids = session
2836        .db()
2837        .await
2838        .update_diagnostic_summary(&message, session.connection_id)
2839        .await?;
2840
2841    broadcast(
2842        Some(session.connection_id),
2843        guest_connection_ids.iter().copied(),
2844        |connection_id| {
2845            session
2846                .peer
2847                .forward_send(session.connection_id, connection_id, message.clone())
2848        },
2849    );
2850
2851    Ok(())
2852}
2853
2854/// Updates other participants with changes to the worktree settings
2855async fn update_worktree_settings(
2856    message: proto::UpdateWorktreeSettings,
2857    session: Session,
2858) -> Result<()> {
2859    let guest_connection_ids = session
2860        .db()
2861        .await
2862        .update_worktree_settings(&message, session.connection_id)
2863        .await?;
2864
2865    broadcast(
2866        Some(session.connection_id),
2867        guest_connection_ids.iter().copied(),
2868        |connection_id| {
2869            session
2870                .peer
2871                .forward_send(session.connection_id, connection_id, message.clone())
2872        },
2873    );
2874
2875    Ok(())
2876}
2877
2878/// Notify other participants that a  language server has started.
2879async fn start_language_server(
2880    request: proto::StartLanguageServer,
2881    session: Session,
2882) -> Result<()> {
2883    let guest_connection_ids = session
2884        .db()
2885        .await
2886        .start_language_server(&request, session.connection_id)
2887        .await?;
2888
2889    broadcast(
2890        Some(session.connection_id),
2891        guest_connection_ids.iter().copied(),
2892        |connection_id| {
2893            session
2894                .peer
2895                .forward_send(session.connection_id, connection_id, request.clone())
2896        },
2897    );
2898    Ok(())
2899}
2900
2901/// Notify other participants that a language server has changed.
2902async fn update_language_server(
2903    request: proto::UpdateLanguageServer,
2904    session: Session,
2905) -> Result<()> {
2906    let project_id = ProjectId::from_proto(request.project_id);
2907    let project_connection_ids = session
2908        .db()
2909        .await
2910        .project_connection_ids(project_id, session.connection_id, true)
2911        .await?;
2912    broadcast(
2913        Some(session.connection_id),
2914        project_connection_ids.iter().copied(),
2915        |connection_id| {
2916            session
2917                .peer
2918                .forward_send(session.connection_id, connection_id, request.clone())
2919        },
2920    );
2921    Ok(())
2922}
2923
2924/// forward a project request to the host. These requests should be read only
2925/// as guests are allowed to send them.
2926async fn forward_read_only_project_request<T>(
2927    request: T,
2928    response: Response<T>,
2929    session: UserSession,
2930) -> Result<()>
2931where
2932    T: EntityMessage + RequestMessage,
2933{
2934    let project_id = ProjectId::from_proto(request.remote_entity_id());
2935    let host_connection_id = session
2936        .db()
2937        .await
2938        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2939        .await?;
2940    let payload = session
2941        .peer
2942        .forward_request(session.connection_id, host_connection_id, request)
2943        .await?;
2944    response.send(payload)?;
2945    Ok(())
2946}
2947
2948async fn forward_find_search_candidates_request(
2949    request: proto::FindSearchCandidates,
2950    response: Response<proto::FindSearchCandidates>,
2951    session: UserSession,
2952) -> Result<()> {
2953    let project_id = ProjectId::from_proto(request.remote_entity_id());
2954    let host_connection_id = session
2955        .db()
2956        .await
2957        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2958        .await?;
2959
2960    let host_version = session
2961        .connection_pool()
2962        .await
2963        .connection(host_connection_id)
2964        .map(|c| c.zed_version);
2965
2966    if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2967    {
2968        let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2969        let search = proto::SearchProject {
2970            project_id: project_id.to_proto(),
2971            query: query.query,
2972            regex: query.regex,
2973            whole_word: query.whole_word,
2974            case_sensitive: query.case_sensitive,
2975            files_to_include: query.files_to_include,
2976            files_to_exclude: query.files_to_exclude,
2977            include_ignored: query.include_ignored,
2978        };
2979
2980        let payload = session
2981            .peer
2982            .forward_request(session.connection_id, host_connection_id, search)
2983            .await?;
2984        return response.send(proto::FindSearchCandidatesResponse {
2985            buffer_ids: payload
2986                .locations
2987                .into_iter()
2988                .map(|loc| loc.buffer_id)
2989                .collect(),
2990        });
2991    }
2992
2993    let payload = session
2994        .peer
2995        .forward_request(session.connection_id, host_connection_id, request)
2996        .await?;
2997    response.send(payload)?;
2998    Ok(())
2999}
3000
3001/// forward a project request to the dev server. Only allowed
3002/// if it's your dev server.
3003async fn forward_project_request_for_owner<T>(
3004    request: T,
3005    response: Response<T>,
3006    session: UserSession,
3007) -> Result<()>
3008where
3009    T: EntityMessage + RequestMessage,
3010{
3011    let project_id = ProjectId::from_proto(request.remote_entity_id());
3012
3013    let host_connection_id = session
3014        .db()
3015        .await
3016        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3017        .await?;
3018    let payload = session
3019        .peer
3020        .forward_request(session.connection_id, host_connection_id, request)
3021        .await?;
3022    response.send(payload)?;
3023    Ok(())
3024}
3025
3026/// forward a project request to the host. These requests are disallowed
3027/// for guests.
3028async fn forward_mutating_project_request<T>(
3029    request: T,
3030    response: Response<T>,
3031    session: UserSession,
3032) -> Result<()>
3033where
3034    T: EntityMessage + RequestMessage,
3035{
3036    let project_id = ProjectId::from_proto(request.remote_entity_id());
3037
3038    let host_connection_id = session
3039        .db()
3040        .await
3041        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3042        .await?;
3043    let payload = session
3044        .peer
3045        .forward_request(session.connection_id, host_connection_id, request)
3046        .await?;
3047    response.send(payload)?;
3048    Ok(())
3049}
3050
3051/// Notify other participants that a new buffer has been created
3052async fn create_buffer_for_peer(
3053    request: proto::CreateBufferForPeer,
3054    session: Session,
3055) -> Result<()> {
3056    session
3057        .db()
3058        .await
3059        .check_user_is_project_host(
3060            ProjectId::from_proto(request.project_id),
3061            session.connection_id,
3062        )
3063        .await?;
3064    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3065    session
3066        .peer
3067        .forward_send(session.connection_id, peer_id.into(), request)?;
3068    Ok(())
3069}
3070
3071/// Notify other participants that a buffer has been updated. This is
3072/// allowed for guests as long as the update is limited to selections.
3073async fn update_buffer(
3074    request: proto::UpdateBuffer,
3075    response: Response<proto::UpdateBuffer>,
3076    session: Session,
3077) -> Result<()> {
3078    let project_id = ProjectId::from_proto(request.project_id);
3079    let mut capability = Capability::ReadOnly;
3080
3081    for op in request.operations.iter() {
3082        match op.variant {
3083            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3084            Some(_) => capability = Capability::ReadWrite,
3085        }
3086    }
3087
3088    let host = {
3089        let guard = session
3090            .db()
3091            .await
3092            .connections_for_buffer_update(
3093                project_id,
3094                session.principal_id(),
3095                session.connection_id,
3096                capability,
3097            )
3098            .await?;
3099
3100        let (host, guests) = &*guard;
3101
3102        broadcast(
3103            Some(session.connection_id),
3104            guests.clone(),
3105            |connection_id| {
3106                session
3107                    .peer
3108                    .forward_send(session.connection_id, connection_id, request.clone())
3109            },
3110        );
3111
3112        *host
3113    };
3114
3115    if host != session.connection_id {
3116        session
3117            .peer
3118            .forward_request(session.connection_id, host, request.clone())
3119            .await?;
3120    }
3121
3122    response.send(proto::Ack {})?;
3123    Ok(())
3124}
3125
3126async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3127    let project_id = ProjectId::from_proto(message.project_id);
3128
3129    let operation = message.operation.as_ref().context("invalid operation")?;
3130    let capability = match operation.variant.as_ref() {
3131        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3132            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3133                match buffer_op.variant {
3134                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3135                        Capability::ReadOnly
3136                    }
3137                    _ => Capability::ReadWrite,
3138                }
3139            } else {
3140                Capability::ReadWrite
3141            }
3142        }
3143        Some(_) => Capability::ReadWrite,
3144        None => Capability::ReadOnly,
3145    };
3146
3147    let guard = session
3148        .db()
3149        .await
3150        .connections_for_buffer_update(
3151            project_id,
3152            session.principal_id(),
3153            session.connection_id,
3154            capability,
3155        )
3156        .await?;
3157
3158    let (host, guests) = &*guard;
3159
3160    broadcast(
3161        Some(session.connection_id),
3162        guests.iter().chain([host]).copied(),
3163        |connection_id| {
3164            session
3165                .peer
3166                .forward_send(session.connection_id, connection_id, message.clone())
3167        },
3168    );
3169
3170    Ok(())
3171}
3172
3173/// Notify other participants that a project has been updated.
3174async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3175    request: T,
3176    session: Session,
3177) -> Result<()> {
3178    let project_id = ProjectId::from_proto(request.remote_entity_id());
3179    let project_connection_ids = session
3180        .db()
3181        .await
3182        .project_connection_ids(project_id, session.connection_id, false)
3183        .await?;
3184
3185    broadcast(
3186        Some(session.connection_id),
3187        project_connection_ids.iter().copied(),
3188        |connection_id| {
3189            session
3190                .peer
3191                .forward_send(session.connection_id, connection_id, request.clone())
3192        },
3193    );
3194    Ok(())
3195}
3196
3197/// Start following another user in a call.
3198async fn follow(
3199    request: proto::Follow,
3200    response: Response<proto::Follow>,
3201    session: UserSession,
3202) -> Result<()> {
3203    let room_id = RoomId::from_proto(request.room_id);
3204    let project_id = request.project_id.map(ProjectId::from_proto);
3205    let leader_id = request
3206        .leader_id
3207        .ok_or_else(|| anyhow!("invalid leader id"))?
3208        .into();
3209    let follower_id = session.connection_id;
3210
3211    session
3212        .db()
3213        .await
3214        .check_room_participants(room_id, leader_id, session.connection_id)
3215        .await?;
3216
3217    let response_payload = session
3218        .peer
3219        .forward_request(session.connection_id, leader_id, request)
3220        .await?;
3221    response.send(response_payload)?;
3222
3223    if let Some(project_id) = project_id {
3224        let room = session
3225            .db()
3226            .await
3227            .follow(room_id, project_id, leader_id, follower_id)
3228            .await?;
3229        room_updated(&room, &session.peer);
3230    }
3231
3232    Ok(())
3233}
3234
3235/// Stop following another user in a call.
3236async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3237    let room_id = RoomId::from_proto(request.room_id);
3238    let project_id = request.project_id.map(ProjectId::from_proto);
3239    let leader_id = request
3240        .leader_id
3241        .ok_or_else(|| anyhow!("invalid leader id"))?
3242        .into();
3243    let follower_id = session.connection_id;
3244
3245    session
3246        .db()
3247        .await
3248        .check_room_participants(room_id, leader_id, session.connection_id)
3249        .await?;
3250
3251    session
3252        .peer
3253        .forward_send(session.connection_id, leader_id, request)?;
3254
3255    if let Some(project_id) = project_id {
3256        let room = session
3257            .db()
3258            .await
3259            .unfollow(room_id, project_id, leader_id, follower_id)
3260            .await?;
3261        room_updated(&room, &session.peer);
3262    }
3263
3264    Ok(())
3265}
3266
3267/// Notify everyone following you of your current location.
3268async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3269    let room_id = RoomId::from_proto(request.room_id);
3270    let database = session.db.lock().await;
3271
3272    let connection_ids = if let Some(project_id) = request.project_id {
3273        let project_id = ProjectId::from_proto(project_id);
3274        database
3275            .project_connection_ids(project_id, session.connection_id, true)
3276            .await?
3277    } else {
3278        database
3279            .room_connection_ids(room_id, session.connection_id)
3280            .await?
3281    };
3282
3283    // For now, don't send view update messages back to that view's current leader.
3284    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3285        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3286        _ => None,
3287    });
3288
3289    for connection_id in connection_ids.iter().cloned() {
3290        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3291            session
3292                .peer
3293                .forward_send(session.connection_id, connection_id, request.clone())?;
3294        }
3295    }
3296    Ok(())
3297}
3298
3299/// Get public data about users.
3300async fn get_users(
3301    request: proto::GetUsers,
3302    response: Response<proto::GetUsers>,
3303    session: Session,
3304) -> Result<()> {
3305    let user_ids = request
3306        .user_ids
3307        .into_iter()
3308        .map(UserId::from_proto)
3309        .collect();
3310    let users = session
3311        .db()
3312        .await
3313        .get_users_by_ids(user_ids)
3314        .await?
3315        .into_iter()
3316        .map(|user| proto::User {
3317            id: user.id.to_proto(),
3318            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3319            github_login: user.github_login,
3320        })
3321        .collect();
3322    response.send(proto::UsersResponse { users })?;
3323    Ok(())
3324}
3325
3326/// Search for users (to invite) buy Github login
3327async fn fuzzy_search_users(
3328    request: proto::FuzzySearchUsers,
3329    response: Response<proto::FuzzySearchUsers>,
3330    session: UserSession,
3331) -> Result<()> {
3332    let query = request.query;
3333    let users = match query.len() {
3334        0 => vec![],
3335        1 | 2 => session
3336            .db()
3337            .await
3338            .get_user_by_github_login(&query)
3339            .await?
3340            .into_iter()
3341            .collect(),
3342        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3343    };
3344    let users = users
3345        .into_iter()
3346        .filter(|user| user.id != session.user_id())
3347        .map(|user| proto::User {
3348            id: user.id.to_proto(),
3349            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3350            github_login: user.github_login,
3351        })
3352        .collect();
3353    response.send(proto::UsersResponse { users })?;
3354    Ok(())
3355}
3356
3357/// Send a contact request to another user.
3358async fn request_contact(
3359    request: proto::RequestContact,
3360    response: Response<proto::RequestContact>,
3361    session: UserSession,
3362) -> Result<()> {
3363    let requester_id = session.user_id();
3364    let responder_id = UserId::from_proto(request.responder_id);
3365    if requester_id == responder_id {
3366        return Err(anyhow!("cannot add yourself as a contact"))?;
3367    }
3368
3369    let notifications = session
3370        .db()
3371        .await
3372        .send_contact_request(requester_id, responder_id)
3373        .await?;
3374
3375    // Update outgoing contact requests of requester
3376    let mut update = proto::UpdateContacts::default();
3377    update.outgoing_requests.push(responder_id.to_proto());
3378    for connection_id in session
3379        .connection_pool()
3380        .await
3381        .user_connection_ids(requester_id)
3382    {
3383        session.peer.send(connection_id, update.clone())?;
3384    }
3385
3386    // Update incoming contact requests of responder
3387    let mut update = proto::UpdateContacts::default();
3388    update
3389        .incoming_requests
3390        .push(proto::IncomingContactRequest {
3391            requester_id: requester_id.to_proto(),
3392        });
3393    let connection_pool = session.connection_pool().await;
3394    for connection_id in connection_pool.user_connection_ids(responder_id) {
3395        session.peer.send(connection_id, update.clone())?;
3396    }
3397
3398    send_notifications(&connection_pool, &session.peer, notifications);
3399
3400    response.send(proto::Ack {})?;
3401    Ok(())
3402}
3403
3404/// Accept or decline a contact request
3405async fn respond_to_contact_request(
3406    request: proto::RespondToContactRequest,
3407    response: Response<proto::RespondToContactRequest>,
3408    session: UserSession,
3409) -> Result<()> {
3410    let responder_id = session.user_id();
3411    let requester_id = UserId::from_proto(request.requester_id);
3412    let db = session.db().await;
3413    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3414        db.dismiss_contact_notification(responder_id, requester_id)
3415            .await?;
3416    } else {
3417        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3418
3419        let notifications = db
3420            .respond_to_contact_request(responder_id, requester_id, accept)
3421            .await?;
3422        let requester_busy = db.is_user_busy(requester_id).await?;
3423        let responder_busy = db.is_user_busy(responder_id).await?;
3424
3425        let pool = session.connection_pool().await;
3426        // Update responder with new contact
3427        let mut update = proto::UpdateContacts::default();
3428        if accept {
3429            update
3430                .contacts
3431                .push(contact_for_user(requester_id, requester_busy, &pool));
3432        }
3433        update
3434            .remove_incoming_requests
3435            .push(requester_id.to_proto());
3436        for connection_id in pool.user_connection_ids(responder_id) {
3437            session.peer.send(connection_id, update.clone())?;
3438        }
3439
3440        // Update requester with new contact
3441        let mut update = proto::UpdateContacts::default();
3442        if accept {
3443            update
3444                .contacts
3445                .push(contact_for_user(responder_id, responder_busy, &pool));
3446        }
3447        update
3448            .remove_outgoing_requests
3449            .push(responder_id.to_proto());
3450
3451        for connection_id in pool.user_connection_ids(requester_id) {
3452            session.peer.send(connection_id, update.clone())?;
3453        }
3454
3455        send_notifications(&pool, &session.peer, notifications);
3456    }
3457
3458    response.send(proto::Ack {})?;
3459    Ok(())
3460}
3461
3462/// Remove a contact.
3463async fn remove_contact(
3464    request: proto::RemoveContact,
3465    response: Response<proto::RemoveContact>,
3466    session: UserSession,
3467) -> Result<()> {
3468    let requester_id = session.user_id();
3469    let responder_id = UserId::from_proto(request.user_id);
3470    let db = session.db().await;
3471    let (contact_accepted, deleted_notification_id) =
3472        db.remove_contact(requester_id, responder_id).await?;
3473
3474    let pool = session.connection_pool().await;
3475    // Update outgoing contact requests of requester
3476    let mut update = proto::UpdateContacts::default();
3477    if contact_accepted {
3478        update.remove_contacts.push(responder_id.to_proto());
3479    } else {
3480        update
3481            .remove_outgoing_requests
3482            .push(responder_id.to_proto());
3483    }
3484    for connection_id in pool.user_connection_ids(requester_id) {
3485        session.peer.send(connection_id, update.clone())?;
3486    }
3487
3488    // Update incoming contact requests of responder
3489    let mut update = proto::UpdateContacts::default();
3490    if contact_accepted {
3491        update.remove_contacts.push(requester_id.to_proto());
3492    } else {
3493        update
3494            .remove_incoming_requests
3495            .push(requester_id.to_proto());
3496    }
3497    for connection_id in pool.user_connection_ids(responder_id) {
3498        session.peer.send(connection_id, update.clone())?;
3499        if let Some(notification_id) = deleted_notification_id {
3500            session.peer.send(
3501                connection_id,
3502                proto::DeleteNotification {
3503                    notification_id: notification_id.to_proto(),
3504                },
3505            )?;
3506        }
3507    }
3508
3509    response.send(proto::Ack {})?;
3510    Ok(())
3511}
3512
3513fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3514    version.0.minor() < 139
3515}
3516
3517async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3518    let plan = session.current_plan(session.db().await).await?;
3519
3520    session
3521        .peer
3522        .send(
3523            session.connection_id,
3524            proto::UpdateUserPlan { plan: plan.into() },
3525        )
3526        .trace_err();
3527
3528    Ok(())
3529}
3530
3531async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3532    subscribe_user_to_channels(
3533        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3534        &session,
3535    )
3536    .await?;
3537    Ok(())
3538}
3539
3540async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3541    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3542    let mut pool = session.connection_pool().await;
3543    for membership in &channels_for_user.channel_memberships {
3544        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3545    }
3546    session.peer.send(
3547        session.connection_id,
3548        build_update_user_channels(&channels_for_user),
3549    )?;
3550    session.peer.send(
3551        session.connection_id,
3552        build_channels_update(channels_for_user),
3553    )?;
3554    Ok(())
3555}
3556
3557/// Creates a new channel.
3558async fn create_channel(
3559    request: proto::CreateChannel,
3560    response: Response<proto::CreateChannel>,
3561    session: UserSession,
3562) -> Result<()> {
3563    let db = session.db().await;
3564
3565    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3566    let (channel, membership) = db
3567        .create_channel(&request.name, parent_id, session.user_id())
3568        .await?;
3569
3570    let root_id = channel.root_id();
3571    let channel = Channel::from_model(channel);
3572
3573    response.send(proto::CreateChannelResponse {
3574        channel: Some(channel.to_proto()),
3575        parent_id: request.parent_id,
3576    })?;
3577
3578    let mut connection_pool = session.connection_pool().await;
3579    if let Some(membership) = membership {
3580        connection_pool.subscribe_to_channel(
3581            membership.user_id,
3582            membership.channel_id,
3583            membership.role,
3584        );
3585        let update = proto::UpdateUserChannels {
3586            channel_memberships: vec![proto::ChannelMembership {
3587                channel_id: membership.channel_id.to_proto(),
3588                role: membership.role.into(),
3589            }],
3590            ..Default::default()
3591        };
3592        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3593            session.peer.send(connection_id, update.clone())?;
3594        }
3595    }
3596
3597    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3598        if !role.can_see_channel(channel.visibility) {
3599            continue;
3600        }
3601
3602        let update = proto::UpdateChannels {
3603            channels: vec![channel.to_proto()],
3604            ..Default::default()
3605        };
3606        session.peer.send(connection_id, update.clone())?;
3607    }
3608
3609    Ok(())
3610}
3611
3612/// Delete a channel
3613async fn delete_channel(
3614    request: proto::DeleteChannel,
3615    response: Response<proto::DeleteChannel>,
3616    session: UserSession,
3617) -> Result<()> {
3618    let db = session.db().await;
3619
3620    let channel_id = request.channel_id;
3621    let (root_channel, removed_channels) = db
3622        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3623        .await?;
3624    response.send(proto::Ack {})?;
3625
3626    // Notify members of removed channels
3627    let mut update = proto::UpdateChannels::default();
3628    update
3629        .delete_channels
3630        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3631
3632    let connection_pool = session.connection_pool().await;
3633    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3634        session.peer.send(connection_id, update.clone())?;
3635    }
3636
3637    Ok(())
3638}
3639
3640/// Invite someone to join a channel.
3641async fn invite_channel_member(
3642    request: proto::InviteChannelMember,
3643    response: Response<proto::InviteChannelMember>,
3644    session: UserSession,
3645) -> Result<()> {
3646    let db = session.db().await;
3647    let channel_id = ChannelId::from_proto(request.channel_id);
3648    let invitee_id = UserId::from_proto(request.user_id);
3649    let InviteMemberResult {
3650        channel,
3651        notifications,
3652    } = db
3653        .invite_channel_member(
3654            channel_id,
3655            invitee_id,
3656            session.user_id(),
3657            request.role().into(),
3658        )
3659        .await?;
3660
3661    let update = proto::UpdateChannels {
3662        channel_invitations: vec![channel.to_proto()],
3663        ..Default::default()
3664    };
3665
3666    let connection_pool = session.connection_pool().await;
3667    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3668        session.peer.send(connection_id, update.clone())?;
3669    }
3670
3671    send_notifications(&connection_pool, &session.peer, notifications);
3672
3673    response.send(proto::Ack {})?;
3674    Ok(())
3675}
3676
3677/// remove someone from a channel
3678async fn remove_channel_member(
3679    request: proto::RemoveChannelMember,
3680    response: Response<proto::RemoveChannelMember>,
3681    session: UserSession,
3682) -> Result<()> {
3683    let db = session.db().await;
3684    let channel_id = ChannelId::from_proto(request.channel_id);
3685    let member_id = UserId::from_proto(request.user_id);
3686
3687    let RemoveChannelMemberResult {
3688        membership_update,
3689        notification_id,
3690    } = db
3691        .remove_channel_member(channel_id, member_id, session.user_id())
3692        .await?;
3693
3694    let mut connection_pool = session.connection_pool().await;
3695    notify_membership_updated(
3696        &mut connection_pool,
3697        membership_update,
3698        member_id,
3699        &session.peer,
3700    );
3701    for connection_id in connection_pool.user_connection_ids(member_id) {
3702        if let Some(notification_id) = notification_id {
3703            session
3704                .peer
3705                .send(
3706                    connection_id,
3707                    proto::DeleteNotification {
3708                        notification_id: notification_id.to_proto(),
3709                    },
3710                )
3711                .trace_err();
3712        }
3713    }
3714
3715    response.send(proto::Ack {})?;
3716    Ok(())
3717}
3718
3719/// Toggle the channel between public and private.
3720/// Care is taken to maintain the invariant that public channels only descend from public channels,
3721/// (though members-only channels can appear at any point in the hierarchy).
3722async fn set_channel_visibility(
3723    request: proto::SetChannelVisibility,
3724    response: Response<proto::SetChannelVisibility>,
3725    session: UserSession,
3726) -> Result<()> {
3727    let db = session.db().await;
3728    let channel_id = ChannelId::from_proto(request.channel_id);
3729    let visibility = request.visibility().into();
3730
3731    let channel_model = db
3732        .set_channel_visibility(channel_id, visibility, session.user_id())
3733        .await?;
3734    let root_id = channel_model.root_id();
3735    let channel = Channel::from_model(channel_model);
3736
3737    let mut connection_pool = session.connection_pool().await;
3738    for (user_id, role) in connection_pool
3739        .channel_user_ids(root_id)
3740        .collect::<Vec<_>>()
3741        .into_iter()
3742    {
3743        let update = if role.can_see_channel(channel.visibility) {
3744            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3745            proto::UpdateChannels {
3746                channels: vec![channel.to_proto()],
3747                ..Default::default()
3748            }
3749        } else {
3750            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3751            proto::UpdateChannels {
3752                delete_channels: vec![channel.id.to_proto()],
3753                ..Default::default()
3754            }
3755        };
3756
3757        for connection_id in connection_pool.user_connection_ids(user_id) {
3758            session.peer.send(connection_id, update.clone())?;
3759        }
3760    }
3761
3762    response.send(proto::Ack {})?;
3763    Ok(())
3764}
3765
3766/// Alter the role for a user in the channel.
3767async fn set_channel_member_role(
3768    request: proto::SetChannelMemberRole,
3769    response: Response<proto::SetChannelMemberRole>,
3770    session: UserSession,
3771) -> Result<()> {
3772    let db = session.db().await;
3773    let channel_id = ChannelId::from_proto(request.channel_id);
3774    let member_id = UserId::from_proto(request.user_id);
3775    let result = db
3776        .set_channel_member_role(
3777            channel_id,
3778            session.user_id(),
3779            member_id,
3780            request.role().into(),
3781        )
3782        .await?;
3783
3784    match result {
3785        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3786            let mut connection_pool = session.connection_pool().await;
3787            notify_membership_updated(
3788                &mut connection_pool,
3789                membership_update,
3790                member_id,
3791                &session.peer,
3792            )
3793        }
3794        db::SetMemberRoleResult::InviteUpdated(channel) => {
3795            let update = proto::UpdateChannels {
3796                channel_invitations: vec![channel.to_proto()],
3797                ..Default::default()
3798            };
3799
3800            for connection_id in session
3801                .connection_pool()
3802                .await
3803                .user_connection_ids(member_id)
3804            {
3805                session.peer.send(connection_id, update.clone())?;
3806            }
3807        }
3808    }
3809
3810    response.send(proto::Ack {})?;
3811    Ok(())
3812}
3813
3814/// Change the name of a channel
3815async fn rename_channel(
3816    request: proto::RenameChannel,
3817    response: Response<proto::RenameChannel>,
3818    session: UserSession,
3819) -> Result<()> {
3820    let db = session.db().await;
3821    let channel_id = ChannelId::from_proto(request.channel_id);
3822    let channel_model = db
3823        .rename_channel(channel_id, session.user_id(), &request.name)
3824        .await?;
3825    let root_id = channel_model.root_id();
3826    let channel = Channel::from_model(channel_model);
3827
3828    response.send(proto::RenameChannelResponse {
3829        channel: Some(channel.to_proto()),
3830    })?;
3831
3832    let connection_pool = session.connection_pool().await;
3833    let update = proto::UpdateChannels {
3834        channels: vec![channel.to_proto()],
3835        ..Default::default()
3836    };
3837    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3838        if role.can_see_channel(channel.visibility) {
3839            session.peer.send(connection_id, update.clone())?;
3840        }
3841    }
3842
3843    Ok(())
3844}
3845
3846/// Move a channel to a new parent.
3847async fn move_channel(
3848    request: proto::MoveChannel,
3849    response: Response<proto::MoveChannel>,
3850    session: UserSession,
3851) -> Result<()> {
3852    let channel_id = ChannelId::from_proto(request.channel_id);
3853    let to = ChannelId::from_proto(request.to);
3854
3855    let (root_id, channels) = session
3856        .db()
3857        .await
3858        .move_channel(channel_id, to, session.user_id())
3859        .await?;
3860
3861    let connection_pool = session.connection_pool().await;
3862    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3863        let channels = channels
3864            .iter()
3865            .filter_map(|channel| {
3866                if role.can_see_channel(channel.visibility) {
3867                    Some(channel.to_proto())
3868                } else {
3869                    None
3870                }
3871            })
3872            .collect::<Vec<_>>();
3873        if channels.is_empty() {
3874            continue;
3875        }
3876
3877        let update = proto::UpdateChannels {
3878            channels,
3879            ..Default::default()
3880        };
3881
3882        session.peer.send(connection_id, update.clone())?;
3883    }
3884
3885    response.send(Ack {})?;
3886    Ok(())
3887}
3888
3889/// Get the list of channel members
3890async fn get_channel_members(
3891    request: proto::GetChannelMembers,
3892    response: Response<proto::GetChannelMembers>,
3893    session: UserSession,
3894) -> Result<()> {
3895    let db = session.db().await;
3896    let channel_id = ChannelId::from_proto(request.channel_id);
3897    let limit = if request.limit == 0 {
3898        u16::MAX as u64
3899    } else {
3900        request.limit
3901    };
3902    let (members, users) = db
3903        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3904        .await?;
3905    response.send(proto::GetChannelMembersResponse { members, users })?;
3906    Ok(())
3907}
3908
3909/// Accept or decline a channel invitation.
3910async fn respond_to_channel_invite(
3911    request: proto::RespondToChannelInvite,
3912    response: Response<proto::RespondToChannelInvite>,
3913    session: UserSession,
3914) -> Result<()> {
3915    let db = session.db().await;
3916    let channel_id = ChannelId::from_proto(request.channel_id);
3917    let RespondToChannelInvite {
3918        membership_update,
3919        notifications,
3920    } = db
3921        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3922        .await?;
3923
3924    let mut connection_pool = session.connection_pool().await;
3925    if let Some(membership_update) = membership_update {
3926        notify_membership_updated(
3927            &mut connection_pool,
3928            membership_update,
3929            session.user_id(),
3930            &session.peer,
3931        );
3932    } else {
3933        let update = proto::UpdateChannels {
3934            remove_channel_invitations: vec![channel_id.to_proto()],
3935            ..Default::default()
3936        };
3937
3938        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3939            session.peer.send(connection_id, update.clone())?;
3940        }
3941    };
3942
3943    send_notifications(&connection_pool, &session.peer, notifications);
3944
3945    response.send(proto::Ack {})?;
3946
3947    Ok(())
3948}
3949
3950/// Join the channels' room
3951async fn join_channel(
3952    request: proto::JoinChannel,
3953    response: Response<proto::JoinChannel>,
3954    session: UserSession,
3955) -> Result<()> {
3956    let channel_id = ChannelId::from_proto(request.channel_id);
3957    join_channel_internal(channel_id, Box::new(response), session).await
3958}
3959
3960trait JoinChannelInternalResponse {
3961    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3962}
3963impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3964    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3965        Response::<proto::JoinChannel>::send(self, result)
3966    }
3967}
3968impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3969    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3970        Response::<proto::JoinRoom>::send(self, result)
3971    }
3972}
3973
3974async fn join_channel_internal(
3975    channel_id: ChannelId,
3976    response: Box<impl JoinChannelInternalResponse>,
3977    session: UserSession,
3978) -> Result<()> {
3979    let joined_room = {
3980        let mut db = session.db().await;
3981        // If zed quits without leaving the room, and the user re-opens zed before the
3982        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3983        // room they were in.
3984        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3985            tracing::info!(
3986                stale_connection_id = %connection,
3987                "cleaning up stale connection",
3988            );
3989            drop(db);
3990            leave_room_for_session(&session, connection).await?;
3991            db = session.db().await;
3992        }
3993
3994        let (joined_room, membership_updated, role) = db
3995            .join_channel(channel_id, session.user_id(), session.connection_id)
3996            .await?;
3997
3998        let live_kit_connection_info =
3999            session
4000                .app_state
4001                .live_kit_client
4002                .as_ref()
4003                .and_then(|live_kit| {
4004                    let (can_publish, token) = if role == ChannelRole::Guest {
4005                        (
4006                            false,
4007                            live_kit
4008                                .guest_token(
4009                                    &joined_room.room.live_kit_room,
4010                                    &session.user_id().to_string(),
4011                                )
4012                                .trace_err()?,
4013                        )
4014                    } else {
4015                        (
4016                            true,
4017                            live_kit
4018                                .room_token(
4019                                    &joined_room.room.live_kit_room,
4020                                    &session.user_id().to_string(),
4021                                )
4022                                .trace_err()?,
4023                        )
4024                    };
4025
4026                    Some(LiveKitConnectionInfo {
4027                        server_url: live_kit.url().into(),
4028                        token,
4029                        can_publish,
4030                    })
4031                });
4032
4033        response.send(proto::JoinRoomResponse {
4034            room: Some(joined_room.room.clone()),
4035            channel_id: joined_room
4036                .channel
4037                .as_ref()
4038                .map(|channel| channel.id.to_proto()),
4039            live_kit_connection_info,
4040        })?;
4041
4042        let mut connection_pool = session.connection_pool().await;
4043        if let Some(membership_updated) = membership_updated {
4044            notify_membership_updated(
4045                &mut connection_pool,
4046                membership_updated,
4047                session.user_id(),
4048                &session.peer,
4049            );
4050        }
4051
4052        room_updated(&joined_room.room, &session.peer);
4053
4054        joined_room
4055    };
4056
4057    channel_updated(
4058        &joined_room
4059            .channel
4060            .ok_or_else(|| anyhow!("channel not returned"))?,
4061        &joined_room.room,
4062        &session.peer,
4063        &*session.connection_pool().await,
4064    );
4065
4066    update_user_contacts(session.user_id(), &session).await?;
4067    Ok(())
4068}
4069
4070/// Start editing the channel notes
4071async fn join_channel_buffer(
4072    request: proto::JoinChannelBuffer,
4073    response: Response<proto::JoinChannelBuffer>,
4074    session: UserSession,
4075) -> Result<()> {
4076    let db = session.db().await;
4077    let channel_id = ChannelId::from_proto(request.channel_id);
4078
4079    let open_response = db
4080        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4081        .await?;
4082
4083    let collaborators = open_response.collaborators.clone();
4084    response.send(open_response)?;
4085
4086    let update = UpdateChannelBufferCollaborators {
4087        channel_id: channel_id.to_proto(),
4088        collaborators: collaborators.clone(),
4089    };
4090    channel_buffer_updated(
4091        session.connection_id,
4092        collaborators
4093            .iter()
4094            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4095        &update,
4096        &session.peer,
4097    );
4098
4099    Ok(())
4100}
4101
4102/// Edit the channel notes
4103async fn update_channel_buffer(
4104    request: proto::UpdateChannelBuffer,
4105    session: UserSession,
4106) -> Result<()> {
4107    let db = session.db().await;
4108    let channel_id = ChannelId::from_proto(request.channel_id);
4109
4110    let (collaborators, epoch, version) = db
4111        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4112        .await?;
4113
4114    channel_buffer_updated(
4115        session.connection_id,
4116        collaborators.clone(),
4117        &proto::UpdateChannelBuffer {
4118            channel_id: channel_id.to_proto(),
4119            operations: request.operations,
4120        },
4121        &session.peer,
4122    );
4123
4124    let pool = &*session.connection_pool().await;
4125
4126    let non_collaborators =
4127        pool.channel_connection_ids(channel_id)
4128            .filter_map(|(connection_id, _)| {
4129                if collaborators.contains(&connection_id) {
4130                    None
4131                } else {
4132                    Some(connection_id)
4133                }
4134            });
4135
4136    broadcast(None, non_collaborators, |peer_id| {
4137        session.peer.send(
4138            peer_id,
4139            proto::UpdateChannels {
4140                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4141                    channel_id: channel_id.to_proto(),
4142                    epoch: epoch as u64,
4143                    version: version.clone(),
4144                }],
4145                ..Default::default()
4146            },
4147        )
4148    });
4149
4150    Ok(())
4151}
4152
4153/// Rejoin the channel notes after a connection blip
4154async fn rejoin_channel_buffers(
4155    request: proto::RejoinChannelBuffers,
4156    response: Response<proto::RejoinChannelBuffers>,
4157    session: UserSession,
4158) -> Result<()> {
4159    let db = session.db().await;
4160    let buffers = db
4161        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4162        .await?;
4163
4164    for rejoined_buffer in &buffers {
4165        let collaborators_to_notify = rejoined_buffer
4166            .buffer
4167            .collaborators
4168            .iter()
4169            .filter_map(|c| Some(c.peer_id?.into()));
4170        channel_buffer_updated(
4171            session.connection_id,
4172            collaborators_to_notify,
4173            &proto::UpdateChannelBufferCollaborators {
4174                channel_id: rejoined_buffer.buffer.channel_id,
4175                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4176            },
4177            &session.peer,
4178        );
4179    }
4180
4181    response.send(proto::RejoinChannelBuffersResponse {
4182        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4183    })?;
4184
4185    Ok(())
4186}
4187
4188/// Stop editing the channel notes
4189async fn leave_channel_buffer(
4190    request: proto::LeaveChannelBuffer,
4191    response: Response<proto::LeaveChannelBuffer>,
4192    session: UserSession,
4193) -> Result<()> {
4194    let db = session.db().await;
4195    let channel_id = ChannelId::from_proto(request.channel_id);
4196
4197    let left_buffer = db
4198        .leave_channel_buffer(channel_id, session.connection_id)
4199        .await?;
4200
4201    response.send(Ack {})?;
4202
4203    channel_buffer_updated(
4204        session.connection_id,
4205        left_buffer.connections,
4206        &proto::UpdateChannelBufferCollaborators {
4207            channel_id: channel_id.to_proto(),
4208            collaborators: left_buffer.collaborators,
4209        },
4210        &session.peer,
4211    );
4212
4213    Ok(())
4214}
4215
4216fn channel_buffer_updated<T: EnvelopedMessage>(
4217    sender_id: ConnectionId,
4218    collaborators: impl IntoIterator<Item = ConnectionId>,
4219    message: &T,
4220    peer: &Peer,
4221) {
4222    broadcast(Some(sender_id), collaborators, |peer_id| {
4223        peer.send(peer_id, message.clone())
4224    });
4225}
4226
4227fn send_notifications(
4228    connection_pool: &ConnectionPool,
4229    peer: &Peer,
4230    notifications: db::NotificationBatch,
4231) {
4232    for (user_id, notification) in notifications {
4233        for connection_id in connection_pool.user_connection_ids(user_id) {
4234            if let Err(error) = peer.send(
4235                connection_id,
4236                proto::AddNotification {
4237                    notification: Some(notification.clone()),
4238                },
4239            ) {
4240                tracing::error!(
4241                    "failed to send notification to {:?} {}",
4242                    connection_id,
4243                    error
4244                );
4245            }
4246        }
4247    }
4248}
4249
4250/// Send a message to the channel
4251async fn send_channel_message(
4252    request: proto::SendChannelMessage,
4253    response: Response<proto::SendChannelMessage>,
4254    session: UserSession,
4255) -> Result<()> {
4256    // Validate the message body.
4257    let body = request.body.trim().to_string();
4258    if body.len() > MAX_MESSAGE_LEN {
4259        return Err(anyhow!("message is too long"))?;
4260    }
4261    if body.is_empty() {
4262        return Err(anyhow!("message can't be blank"))?;
4263    }
4264
4265    // TODO: adjust mentions if body is trimmed
4266
4267    let timestamp = OffsetDateTime::now_utc();
4268    let nonce = request
4269        .nonce
4270        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4271
4272    let channel_id = ChannelId::from_proto(request.channel_id);
4273    let CreatedChannelMessage {
4274        message_id,
4275        participant_connection_ids,
4276        notifications,
4277    } = session
4278        .db()
4279        .await
4280        .create_channel_message(
4281            channel_id,
4282            session.user_id(),
4283            &body,
4284            &request.mentions,
4285            timestamp,
4286            nonce.clone().into(),
4287            match request.reply_to_message_id {
4288                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4289                None => None,
4290            },
4291        )
4292        .await?;
4293
4294    let message = proto::ChannelMessage {
4295        sender_id: session.user_id().to_proto(),
4296        id: message_id.to_proto(),
4297        body,
4298        mentions: request.mentions,
4299        timestamp: timestamp.unix_timestamp() as u64,
4300        nonce: Some(nonce),
4301        reply_to_message_id: request.reply_to_message_id,
4302        edited_at: None,
4303    };
4304    broadcast(
4305        Some(session.connection_id),
4306        participant_connection_ids.clone(),
4307        |connection| {
4308            session.peer.send(
4309                connection,
4310                proto::ChannelMessageSent {
4311                    channel_id: channel_id.to_proto(),
4312                    message: Some(message.clone()),
4313                },
4314            )
4315        },
4316    );
4317    response.send(proto::SendChannelMessageResponse {
4318        message: Some(message),
4319    })?;
4320
4321    let pool = &*session.connection_pool().await;
4322    let non_participants =
4323        pool.channel_connection_ids(channel_id)
4324            .filter_map(|(connection_id, _)| {
4325                if participant_connection_ids.contains(&connection_id) {
4326                    None
4327                } else {
4328                    Some(connection_id)
4329                }
4330            });
4331    broadcast(None, non_participants, |peer_id| {
4332        session.peer.send(
4333            peer_id,
4334            proto::UpdateChannels {
4335                latest_channel_message_ids: vec![proto::ChannelMessageId {
4336                    channel_id: channel_id.to_proto(),
4337                    message_id: message_id.to_proto(),
4338                }],
4339                ..Default::default()
4340            },
4341        )
4342    });
4343    send_notifications(pool, &session.peer, notifications);
4344
4345    Ok(())
4346}
4347
4348/// Delete a channel message
4349async fn remove_channel_message(
4350    request: proto::RemoveChannelMessage,
4351    response: Response<proto::RemoveChannelMessage>,
4352    session: UserSession,
4353) -> Result<()> {
4354    let channel_id = ChannelId::from_proto(request.channel_id);
4355    let message_id = MessageId::from_proto(request.message_id);
4356    let (connection_ids, existing_notification_ids) = session
4357        .db()
4358        .await
4359        .remove_channel_message(channel_id, message_id, session.user_id())
4360        .await?;
4361
4362    broadcast(
4363        Some(session.connection_id),
4364        connection_ids,
4365        move |connection| {
4366            session.peer.send(connection, request.clone())?;
4367
4368            for notification_id in &existing_notification_ids {
4369                session.peer.send(
4370                    connection,
4371                    proto::DeleteNotification {
4372                        notification_id: (*notification_id).to_proto(),
4373                    },
4374                )?;
4375            }
4376
4377            Ok(())
4378        },
4379    );
4380    response.send(proto::Ack {})?;
4381    Ok(())
4382}
4383
4384async fn update_channel_message(
4385    request: proto::UpdateChannelMessage,
4386    response: Response<proto::UpdateChannelMessage>,
4387    session: UserSession,
4388) -> Result<()> {
4389    let channel_id = ChannelId::from_proto(request.channel_id);
4390    let message_id = MessageId::from_proto(request.message_id);
4391    let updated_at = OffsetDateTime::now_utc();
4392    let UpdatedChannelMessage {
4393        message_id,
4394        participant_connection_ids,
4395        notifications,
4396        reply_to_message_id,
4397        timestamp,
4398        deleted_mention_notification_ids,
4399        updated_mention_notifications,
4400    } = session
4401        .db()
4402        .await
4403        .update_channel_message(
4404            channel_id,
4405            message_id,
4406            session.user_id(),
4407            request.body.as_str(),
4408            &request.mentions,
4409            updated_at,
4410        )
4411        .await?;
4412
4413    let nonce = request
4414        .nonce
4415        .clone()
4416        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4417
4418    let message = proto::ChannelMessage {
4419        sender_id: session.user_id().to_proto(),
4420        id: message_id.to_proto(),
4421        body: request.body.clone(),
4422        mentions: request.mentions.clone(),
4423        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4424        nonce: Some(nonce),
4425        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4426        edited_at: Some(updated_at.unix_timestamp() as u64),
4427    };
4428
4429    response.send(proto::Ack {})?;
4430
4431    let pool = &*session.connection_pool().await;
4432    broadcast(
4433        Some(session.connection_id),
4434        participant_connection_ids,
4435        |connection| {
4436            session.peer.send(
4437                connection,
4438                proto::ChannelMessageUpdate {
4439                    channel_id: channel_id.to_proto(),
4440                    message: Some(message.clone()),
4441                },
4442            )?;
4443
4444            for notification_id in &deleted_mention_notification_ids {
4445                session.peer.send(
4446                    connection,
4447                    proto::DeleteNotification {
4448                        notification_id: (*notification_id).to_proto(),
4449                    },
4450                )?;
4451            }
4452
4453            for notification in &updated_mention_notifications {
4454                session.peer.send(
4455                    connection,
4456                    proto::UpdateNotification {
4457                        notification: Some(notification.clone()),
4458                    },
4459                )?;
4460            }
4461
4462            Ok(())
4463        },
4464    );
4465
4466    send_notifications(pool, &session.peer, notifications);
4467
4468    Ok(())
4469}
4470
4471/// Mark a channel message as read
4472async fn acknowledge_channel_message(
4473    request: proto::AckChannelMessage,
4474    session: UserSession,
4475) -> Result<()> {
4476    let channel_id = ChannelId::from_proto(request.channel_id);
4477    let message_id = MessageId::from_proto(request.message_id);
4478    let notifications = session
4479        .db()
4480        .await
4481        .observe_channel_message(channel_id, session.user_id(), message_id)
4482        .await?;
4483    send_notifications(
4484        &*session.connection_pool().await,
4485        &session.peer,
4486        notifications,
4487    );
4488    Ok(())
4489}
4490
4491/// Mark a buffer version as synced
4492async fn acknowledge_buffer_version(
4493    request: proto::AckBufferOperation,
4494    session: UserSession,
4495) -> Result<()> {
4496    let buffer_id = BufferId::from_proto(request.buffer_id);
4497    session
4498        .db()
4499        .await
4500        .observe_buffer_version(
4501            buffer_id,
4502            session.user_id(),
4503            request.epoch as i32,
4504            &request.version,
4505        )
4506        .await?;
4507    Ok(())
4508}
4509
4510async fn count_language_model_tokens(
4511    request: proto::CountLanguageModelTokens,
4512    response: Response<proto::CountLanguageModelTokens>,
4513    session: Session,
4514    config: &Config,
4515) -> Result<()> {
4516    let Some(session) = session.for_user() else {
4517        return Err(anyhow!("user not found"))?;
4518    };
4519    authorize_access_to_legacy_llm_endpoints(&session).await?;
4520
4521    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4522        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4523        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4524    };
4525
4526    session
4527        .app_state
4528        .rate_limiter
4529        .check(&*rate_limit, session.user_id())
4530        .await?;
4531
4532    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4533        Some(proto::LanguageModelProvider::Google) => {
4534            let api_key = config
4535                .google_ai_api_key
4536                .as_ref()
4537                .context("no Google AI API key configured on the server")?;
4538            google_ai::count_tokens(
4539                session.http_client.as_ref(),
4540                google_ai::API_URL,
4541                api_key,
4542                serde_json::from_str(&request.request)?,
4543            )
4544            .await?
4545        }
4546        _ => return Err(anyhow!("unsupported provider"))?,
4547    };
4548
4549    response.send(proto::CountLanguageModelTokensResponse {
4550        token_count: result.total_tokens as u32,
4551    })?;
4552
4553    Ok(())
4554}
4555
4556struct ZedProCountLanguageModelTokensRateLimit;
4557
4558impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4559    fn capacity(&self) -> usize {
4560        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4561            .ok()
4562            .and_then(|v| v.parse().ok())
4563            .unwrap_or(600) // Picked arbitrarily
4564    }
4565
4566    fn refill_duration(&self) -> chrono::Duration {
4567        chrono::Duration::hours(1)
4568    }
4569
4570    fn db_name(&self) -> &'static str {
4571        "zed-pro:count-language-model-tokens"
4572    }
4573}
4574
4575struct FreeCountLanguageModelTokensRateLimit;
4576
4577impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4578    fn capacity(&self) -> usize {
4579        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4580            .ok()
4581            .and_then(|v| v.parse().ok())
4582            .unwrap_or(600 / 10) // Picked arbitrarily
4583    }
4584
4585    fn refill_duration(&self) -> chrono::Duration {
4586        chrono::Duration::hours(1)
4587    }
4588
4589    fn db_name(&self) -> &'static str {
4590        "free:count-language-model-tokens"
4591    }
4592}
4593
4594struct ZedProComputeEmbeddingsRateLimit;
4595
4596impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4597    fn capacity(&self) -> usize {
4598        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4599            .ok()
4600            .and_then(|v| v.parse().ok())
4601            .unwrap_or(5000) // Picked arbitrarily
4602    }
4603
4604    fn refill_duration(&self) -> chrono::Duration {
4605        chrono::Duration::hours(1)
4606    }
4607
4608    fn db_name(&self) -> &'static str {
4609        "zed-pro:compute-embeddings"
4610    }
4611}
4612
4613struct FreeComputeEmbeddingsRateLimit;
4614
4615impl RateLimit for FreeComputeEmbeddingsRateLimit {
4616    fn capacity(&self) -> usize {
4617        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4618            .ok()
4619            .and_then(|v| v.parse().ok())
4620            .unwrap_or(5000 / 10) // Picked arbitrarily
4621    }
4622
4623    fn refill_duration(&self) -> chrono::Duration {
4624        chrono::Duration::hours(1)
4625    }
4626
4627    fn db_name(&self) -> &'static str {
4628        "free:compute-embeddings"
4629    }
4630}
4631
4632async fn compute_embeddings(
4633    request: proto::ComputeEmbeddings,
4634    response: Response<proto::ComputeEmbeddings>,
4635    session: UserSession,
4636    api_key: Option<Arc<str>>,
4637) -> Result<()> {
4638    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4639    authorize_access_to_legacy_llm_endpoints(&session).await?;
4640
4641    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4642        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4643        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4644    };
4645
4646    session
4647        .app_state
4648        .rate_limiter
4649        .check(&*rate_limit, session.user_id())
4650        .await?;
4651
4652    let embeddings = match request.model.as_str() {
4653        "openai/text-embedding-3-small" => {
4654            open_ai::embed(
4655                session.http_client.as_ref(),
4656                OPEN_AI_API_URL,
4657                &api_key,
4658                OpenAiEmbeddingModel::TextEmbedding3Small,
4659                request.texts.iter().map(|text| text.as_str()),
4660            )
4661            .await?
4662        }
4663        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4664    };
4665
4666    let embeddings = request
4667        .texts
4668        .iter()
4669        .map(|text| {
4670            let mut hasher = sha2::Sha256::new();
4671            hasher.update(text.as_bytes());
4672            let result = hasher.finalize();
4673            result.to_vec()
4674        })
4675        .zip(
4676            embeddings
4677                .data
4678                .into_iter()
4679                .map(|embedding| embedding.embedding),
4680        )
4681        .collect::<HashMap<_, _>>();
4682
4683    let db = session.db().await;
4684    db.save_embeddings(&request.model, &embeddings)
4685        .await
4686        .context("failed to save embeddings")
4687        .trace_err();
4688
4689    response.send(proto::ComputeEmbeddingsResponse {
4690        embeddings: embeddings
4691            .into_iter()
4692            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4693            .collect(),
4694    })?;
4695    Ok(())
4696}
4697
4698async fn get_cached_embeddings(
4699    request: proto::GetCachedEmbeddings,
4700    response: Response<proto::GetCachedEmbeddings>,
4701    session: UserSession,
4702) -> Result<()> {
4703    authorize_access_to_legacy_llm_endpoints(&session).await?;
4704
4705    let db = session.db().await;
4706    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4707
4708    response.send(proto::GetCachedEmbeddingsResponse {
4709        embeddings: embeddings
4710            .into_iter()
4711            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4712            .collect(),
4713    })?;
4714    Ok(())
4715}
4716
4717/// This is leftover from before the LLM service.
4718///
4719/// The endpoints protected by this check will be moved there eventually.
4720async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4721    if session.is_staff() {
4722        Ok(())
4723    } else {
4724        Err(anyhow!("permission denied"))?
4725    }
4726}
4727
4728/// Get a Supermaven API key for the user
4729async fn get_supermaven_api_key(
4730    _request: proto::GetSupermavenApiKey,
4731    response: Response<proto::GetSupermavenApiKey>,
4732    session: UserSession,
4733) -> Result<()> {
4734    let user_id: String = session.user_id().to_string();
4735    if !session.is_staff() {
4736        return Err(anyhow!("supermaven not enabled for this account"))?;
4737    }
4738
4739    let email = session
4740        .email()
4741        .ok_or_else(|| anyhow!("user must have an email"))?;
4742
4743    let supermaven_admin_api = session
4744        .supermaven_client
4745        .as_ref()
4746        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4747
4748    let result = supermaven_admin_api
4749        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4750        .await?;
4751
4752    response.send(proto::GetSupermavenApiKeyResponse {
4753        api_key: result.api_key,
4754    })?;
4755
4756    Ok(())
4757}
4758
4759/// Start receiving chat updates for a channel
4760async fn join_channel_chat(
4761    request: proto::JoinChannelChat,
4762    response: Response<proto::JoinChannelChat>,
4763    session: UserSession,
4764) -> Result<()> {
4765    let channel_id = ChannelId::from_proto(request.channel_id);
4766
4767    let db = session.db().await;
4768    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4769        .await?;
4770    let messages = db
4771        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4772        .await?;
4773    response.send(proto::JoinChannelChatResponse {
4774        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4775        messages,
4776    })?;
4777    Ok(())
4778}
4779
4780/// Stop receiving chat updates for a channel
4781async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4782    let channel_id = ChannelId::from_proto(request.channel_id);
4783    session
4784        .db()
4785        .await
4786        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4787        .await?;
4788    Ok(())
4789}
4790
4791/// Retrieve the chat history for a channel
4792async fn get_channel_messages(
4793    request: proto::GetChannelMessages,
4794    response: Response<proto::GetChannelMessages>,
4795    session: UserSession,
4796) -> Result<()> {
4797    let channel_id = ChannelId::from_proto(request.channel_id);
4798    let messages = session
4799        .db()
4800        .await
4801        .get_channel_messages(
4802            channel_id,
4803            session.user_id(),
4804            MESSAGE_COUNT_PER_PAGE,
4805            Some(MessageId::from_proto(request.before_message_id)),
4806        )
4807        .await?;
4808    response.send(proto::GetChannelMessagesResponse {
4809        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4810        messages,
4811    })?;
4812    Ok(())
4813}
4814
4815/// Retrieve specific chat messages
4816async fn get_channel_messages_by_id(
4817    request: proto::GetChannelMessagesById,
4818    response: Response<proto::GetChannelMessagesById>,
4819    session: UserSession,
4820) -> Result<()> {
4821    let message_ids = request
4822        .message_ids
4823        .iter()
4824        .map(|id| MessageId::from_proto(*id))
4825        .collect::<Vec<_>>();
4826    let messages = session
4827        .db()
4828        .await
4829        .get_channel_messages_by_id(session.user_id(), &message_ids)
4830        .await?;
4831    response.send(proto::GetChannelMessagesResponse {
4832        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4833        messages,
4834    })?;
4835    Ok(())
4836}
4837
4838/// Retrieve the current users notifications
4839async fn get_notifications(
4840    request: proto::GetNotifications,
4841    response: Response<proto::GetNotifications>,
4842    session: UserSession,
4843) -> Result<()> {
4844    let notifications = session
4845        .db()
4846        .await
4847        .get_notifications(
4848            session.user_id(),
4849            NOTIFICATION_COUNT_PER_PAGE,
4850            request
4851                .before_id
4852                .map(|id| db::NotificationId::from_proto(id)),
4853        )
4854        .await?;
4855    response.send(proto::GetNotificationsResponse {
4856        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4857        notifications,
4858    })?;
4859    Ok(())
4860}
4861
4862/// Mark notifications as read
4863async fn mark_notification_as_read(
4864    request: proto::MarkNotificationRead,
4865    response: Response<proto::MarkNotificationRead>,
4866    session: UserSession,
4867) -> Result<()> {
4868    let database = &session.db().await;
4869    let notifications = database
4870        .mark_notification_as_read_by_id(
4871            session.user_id(),
4872            NotificationId::from_proto(request.notification_id),
4873        )
4874        .await?;
4875    send_notifications(
4876        &*session.connection_pool().await,
4877        &session.peer,
4878        notifications,
4879    );
4880    response.send(proto::Ack {})?;
4881    Ok(())
4882}
4883
4884/// Get the current users information
4885async fn get_private_user_info(
4886    _request: proto::GetPrivateUserInfo,
4887    response: Response<proto::GetPrivateUserInfo>,
4888    session: UserSession,
4889) -> Result<()> {
4890    let db = session.db().await;
4891
4892    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4893    let user = db
4894        .get_user_by_id(session.user_id())
4895        .await?
4896        .ok_or_else(|| anyhow!("user not found"))?;
4897    let flags = db.get_user_flags(session.user_id()).await?;
4898
4899    response.send(proto::GetPrivateUserInfoResponse {
4900        metrics_id,
4901        staff: user.admin,
4902        flags,
4903        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4904    })?;
4905    Ok(())
4906}
4907
4908/// Accept the terms of service (tos) on behalf of the current user
4909async fn accept_terms_of_service(
4910    _request: proto::AcceptTermsOfService,
4911    response: Response<proto::AcceptTermsOfService>,
4912    session: UserSession,
4913) -> Result<()> {
4914    let db = session.db().await;
4915
4916    let accepted_tos_at = Utc::now();
4917    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4918        .await?;
4919
4920    response.send(proto::AcceptTermsOfServiceResponse {
4921        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4922    })?;
4923    Ok(())
4924}
4925
4926/// The minimum account age an account must have in order to use the LLM service.
4927const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4928
4929async fn get_llm_api_token(
4930    _request: proto::GetLlmToken,
4931    response: Response<proto::GetLlmToken>,
4932    session: UserSession,
4933) -> Result<()> {
4934    let db = session.db().await;
4935
4936    let flags = db.get_user_flags(session.user_id()).await?;
4937    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4938    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4939
4940    if !session.is_staff() && !has_language_models_feature_flag {
4941        Err(anyhow!("permission denied"))?
4942    }
4943
4944    let user_id = session.user_id();
4945    let user = db
4946        .get_user_by_id(user_id)
4947        .await?
4948        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4949
4950    if user.accepted_tos_at.is_none() {
4951        Err(anyhow!("terms of service not accepted"))?
4952    }
4953
4954    let mut account_created_at = user.created_at;
4955    if let Some(github_created_at) = user.github_user_created_at {
4956        account_created_at = account_created_at.min(github_created_at);
4957    }
4958    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4959        Err(anyhow!("account too young"))?
4960    }
4961    let token = LlmTokenClaims::create(
4962        user.id,
4963        user.github_login.clone(),
4964        session.is_staff(),
4965        has_llm_closed_beta_feature_flag,
4966        session.current_plan(db).await?,
4967        &session.app_state.config,
4968    )?;
4969    response.send(proto::GetLlmTokenResponse { token })?;
4970    Ok(())
4971}
4972
4973fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4974    let message = match message {
4975        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4976        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4977        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4978        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4979        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4980            code: frame.code.into(),
4981            reason: frame.reason,
4982        })),
4983        // We should never receive a frame while reading the message, according
4984        // to the `tungstenite` maintainers:
4985        //
4986        // > It cannot occur when you read messages from the WebSocket, but it
4987        // > can be used when you want to send the raw frames (e.g. you want to
4988        // > send the frames to the WebSocket without composing the full message first).
4989        // >
4990        // > — https://github.com/snapview/tungstenite-rs/issues/268
4991        TungsteniteMessage::Frame(_) => {
4992            bail!("received an unexpected frame while reading the message")
4993        }
4994    };
4995
4996    Ok(message)
4997}
4998
4999fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5000    match message {
5001        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5002        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5003        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5004        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5005        AxumMessage::Close(frame) => {
5006            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5007                code: frame.code.into(),
5008                reason: frame.reason,
5009            }))
5010        }
5011    }
5012}
5013
5014fn notify_membership_updated(
5015    connection_pool: &mut ConnectionPool,
5016    result: MembershipUpdated,
5017    user_id: UserId,
5018    peer: &Peer,
5019) {
5020    for membership in &result.new_channels.channel_memberships {
5021        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5022    }
5023    for channel_id in &result.removed_channels {
5024        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5025    }
5026
5027    let user_channels_update = proto::UpdateUserChannels {
5028        channel_memberships: result
5029            .new_channels
5030            .channel_memberships
5031            .iter()
5032            .map(|cm| proto::ChannelMembership {
5033                channel_id: cm.channel_id.to_proto(),
5034                role: cm.role.into(),
5035            })
5036            .collect(),
5037        ..Default::default()
5038    };
5039
5040    let mut update = build_channels_update(result.new_channels);
5041    update.delete_channels = result
5042        .removed_channels
5043        .into_iter()
5044        .map(|id| id.to_proto())
5045        .collect();
5046    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5047
5048    for connection_id in connection_pool.user_connection_ids(user_id) {
5049        peer.send(connection_id, user_channels_update.clone())
5050            .trace_err();
5051        peer.send(connection_id, update.clone()).trace_err();
5052    }
5053}
5054
5055fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5056    proto::UpdateUserChannels {
5057        channel_memberships: channels
5058            .channel_memberships
5059            .iter()
5060            .map(|m| proto::ChannelMembership {
5061                channel_id: m.channel_id.to_proto(),
5062                role: m.role.into(),
5063            })
5064            .collect(),
5065        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5066        observed_channel_message_id: channels.observed_channel_messages.clone(),
5067    }
5068}
5069
5070fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5071    let mut update = proto::UpdateChannels::default();
5072
5073    for channel in channels.channels {
5074        update.channels.push(channel.to_proto());
5075    }
5076
5077    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5078    update.latest_channel_message_ids = channels.latest_channel_messages;
5079
5080    for (channel_id, participants) in channels.channel_participants {
5081        update
5082            .channel_participants
5083            .push(proto::ChannelParticipants {
5084                channel_id: channel_id.to_proto(),
5085                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5086            });
5087    }
5088
5089    for channel in channels.invited_channels {
5090        update.channel_invitations.push(channel.to_proto());
5091    }
5092
5093    update.hosted_projects = channels.hosted_projects;
5094    update
5095}
5096
5097fn build_initial_contacts_update(
5098    contacts: Vec<db::Contact>,
5099    pool: &ConnectionPool,
5100) -> proto::UpdateContacts {
5101    let mut update = proto::UpdateContacts::default();
5102
5103    for contact in contacts {
5104        match contact {
5105            db::Contact::Accepted { user_id, busy } => {
5106                update.contacts.push(contact_for_user(user_id, busy, &pool));
5107            }
5108            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5109            db::Contact::Incoming { user_id } => {
5110                update
5111                    .incoming_requests
5112                    .push(proto::IncomingContactRequest {
5113                        requester_id: user_id.to_proto(),
5114                    })
5115            }
5116        }
5117    }
5118
5119    update
5120}
5121
5122fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5123    proto::Contact {
5124        user_id: user_id.to_proto(),
5125        online: pool.is_user_online(user_id),
5126        busy,
5127    }
5128}
5129
5130fn room_updated(room: &proto::Room, peer: &Peer) {
5131    broadcast(
5132        None,
5133        room.participants
5134            .iter()
5135            .filter_map(|participant| Some(participant.peer_id?.into())),
5136        |peer_id| {
5137            peer.send(
5138                peer_id,
5139                proto::RoomUpdated {
5140                    room: Some(room.clone()),
5141                },
5142            )
5143        },
5144    );
5145}
5146
5147fn channel_updated(
5148    channel: &db::channel::Model,
5149    room: &proto::Room,
5150    peer: &Peer,
5151    pool: &ConnectionPool,
5152) {
5153    let participants = room
5154        .participants
5155        .iter()
5156        .map(|p| p.user_id)
5157        .collect::<Vec<_>>();
5158
5159    broadcast(
5160        None,
5161        pool.channel_connection_ids(channel.root_id())
5162            .filter_map(|(channel_id, role)| {
5163                role.can_see_channel(channel.visibility).then(|| channel_id)
5164            }),
5165        |peer_id| {
5166            peer.send(
5167                peer_id,
5168                proto::UpdateChannels {
5169                    channel_participants: vec![proto::ChannelParticipants {
5170                        channel_id: channel.id.to_proto(),
5171                        participant_user_ids: participants.clone(),
5172                    }],
5173                    ..Default::default()
5174                },
5175            )
5176        },
5177    );
5178}
5179
5180async fn send_dev_server_projects_update(
5181    user_id: UserId,
5182    mut status: proto::DevServerProjectsUpdate,
5183    session: &Session,
5184) {
5185    let pool = session.connection_pool().await;
5186    for dev_server in &mut status.dev_servers {
5187        dev_server.status =
5188            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5189    }
5190    let connections = pool.user_connection_ids(user_id);
5191    for connection_id in connections {
5192        session.peer.send(connection_id, status.clone()).trace_err();
5193    }
5194}
5195
5196async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5197    let db = session.db().await;
5198
5199    let contacts = db.get_contacts(user_id).await?;
5200    let busy = db.is_user_busy(user_id).await?;
5201
5202    let pool = session.connection_pool().await;
5203    let updated_contact = contact_for_user(user_id, busy, &pool);
5204    for contact in contacts {
5205        if let db::Contact::Accepted {
5206            user_id: contact_user_id,
5207            ..
5208        } = contact
5209        {
5210            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5211                session
5212                    .peer
5213                    .send(
5214                        contact_conn_id,
5215                        proto::UpdateContacts {
5216                            contacts: vec![updated_contact.clone()],
5217                            remove_contacts: Default::default(),
5218                            incoming_requests: Default::default(),
5219                            remove_incoming_requests: Default::default(),
5220                            outgoing_requests: Default::default(),
5221                            remove_outgoing_requests: Default::default(),
5222                        },
5223                    )
5224                    .trace_err();
5225            }
5226        }
5227    }
5228    Ok(())
5229}
5230
5231async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5232    log::info!("lost dev server connection, unsharing projects");
5233    let project_ids = session
5234        .db()
5235        .await
5236        .get_stale_dev_server_projects(session.connection_id)
5237        .await?;
5238
5239    for project_id in project_ids {
5240        // not unshare re-checks the connection ids match, so we get away with no transaction
5241        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5242    }
5243
5244    let user_id = session.dev_server().user_id;
5245    let update = session
5246        .db()
5247        .await
5248        .dev_server_projects_update(user_id)
5249        .await?;
5250
5251    send_dev_server_projects_update(user_id, update, session).await;
5252
5253    Ok(())
5254}
5255
5256async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5257    let mut contacts_to_update = HashSet::default();
5258
5259    let room_id;
5260    let canceled_calls_to_user_ids;
5261    let live_kit_room;
5262    let delete_live_kit_room;
5263    let room;
5264    let channel;
5265
5266    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5267        contacts_to_update.insert(session.user_id());
5268
5269        for project in left_room.left_projects.values() {
5270            project_left(project, session);
5271        }
5272
5273        room_id = RoomId::from_proto(left_room.room.id);
5274        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5275        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5276        delete_live_kit_room = left_room.deleted;
5277        room = mem::take(&mut left_room.room);
5278        channel = mem::take(&mut left_room.channel);
5279
5280        room_updated(&room, &session.peer);
5281    } else {
5282        return Ok(());
5283    }
5284
5285    if let Some(channel) = channel {
5286        channel_updated(
5287            &channel,
5288            &room,
5289            &session.peer,
5290            &*session.connection_pool().await,
5291        );
5292    }
5293
5294    {
5295        let pool = session.connection_pool().await;
5296        for canceled_user_id in canceled_calls_to_user_ids {
5297            for connection_id in pool.user_connection_ids(canceled_user_id) {
5298                session
5299                    .peer
5300                    .send(
5301                        connection_id,
5302                        proto::CallCanceled {
5303                            room_id: room_id.to_proto(),
5304                        },
5305                    )
5306                    .trace_err();
5307            }
5308            contacts_to_update.insert(canceled_user_id);
5309        }
5310    }
5311
5312    for contact_user_id in contacts_to_update {
5313        update_user_contacts(contact_user_id, &session).await?;
5314    }
5315
5316    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5317        live_kit
5318            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5319            .await
5320            .trace_err();
5321
5322        if delete_live_kit_room {
5323            live_kit.delete_room(live_kit_room).await.trace_err();
5324        }
5325    }
5326
5327    Ok(())
5328}
5329
5330async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5331    let left_channel_buffers = session
5332        .db()
5333        .await
5334        .leave_channel_buffers(session.connection_id)
5335        .await?;
5336
5337    for left_buffer in left_channel_buffers {
5338        channel_buffer_updated(
5339            session.connection_id,
5340            left_buffer.connections,
5341            &proto::UpdateChannelBufferCollaborators {
5342                channel_id: left_buffer.channel_id.to_proto(),
5343                collaborators: left_buffer.collaborators,
5344            },
5345            &session.peer,
5346        );
5347    }
5348
5349    Ok(())
5350}
5351
5352fn project_left(project: &db::LeftProject, session: &UserSession) {
5353    for connection_id in &project.connection_ids {
5354        if project.should_unshare {
5355            session
5356                .peer
5357                .send(
5358                    *connection_id,
5359                    proto::UnshareProject {
5360                        project_id: project.id.to_proto(),
5361                    },
5362                )
5363                .trace_err();
5364        } else {
5365            session
5366                .peer
5367                .send(
5368                    *connection_id,
5369                    proto::RemoveProjectCollaborator {
5370                        project_id: project.id.to_proto(),
5371                        peer_id: Some(session.connection_id.into()),
5372                    },
5373                )
5374                .trace_err();
5375        }
5376    }
5377}
5378
5379pub trait ResultExt {
5380    type Ok;
5381
5382    fn trace_err(self) -> Option<Self::Ok>;
5383}
5384
5385impl<T, E> ResultExt for Result<T, E>
5386where
5387    E: std::fmt::Debug,
5388{
5389    type Ok = T;
5390
5391    #[track_caller]
5392    fn trace_err(self) -> Option<T> {
5393        match self {
5394            Ok(value) => Some(value),
5395            Err(error) => {
5396                tracing::error!("{:?}", error);
5397                None
5398            }
5399        }
5400    }
5401}