rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use sha2::Digest;
  40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  41
  42use futures::{
  43    channel::oneshot,
  44    future::{self, BoxFuture},
  45    stream::FuturesUnordered,
  46    FutureExt, SinkExt, StreamExt, TryStreamExt,
  47};
  48use http_client::IsahcHttpClient;
  49use prometheus::{register_int_gauge, IntGauge};
  50use rpc::{
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  56};
  57use semantic_version::SemanticVersion;
  58use serde::{Serialize, Serializer};
  59use std::{
  60    any::TypeId,
  61    future::Future,
  62    marker::PhantomData,
  63    mem,
  64    net::SocketAddr,
  65    ops::{Deref, DerefMut},
  66    rc::Rc,
  67    sync::{
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69        Arc, OnceLock,
  70    },
  71    time::{Duration, Instant},
  72};
  73use time::OffsetDateTime;
  74use tokio::sync::{watch, MutexGuard, Semaphore};
  75use tower::ServiceBuilder;
  76use tracing::{
  77    field::{self},
  78    info_span, instrument, Instrument,
  79};
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107#[derive(Clone, Debug)]
 108pub enum Principal {
 109    User(User),
 110    Impersonated { user: User, admin: User },
 111    DevServer(dev_server::Model),
 112}
 113
 114impl Principal {
 115    fn update_span(&self, span: &tracing::Span) {
 116        match &self {
 117            Principal::User(user) => {
 118                span.record("user_id", user.id.0);
 119                span.record("login", &user.github_login);
 120            }
 121            Principal::Impersonated { user, admin } => {
 122                span.record("user_id", user.id.0);
 123                span.record("login", &user.github_login);
 124                span.record("impersonator", &admin.github_login);
 125            }
 126            Principal::DevServer(dev_server) => {
 127                span.record("dev_server_id", dev_server.id.0);
 128            }
 129        }
 130    }
 131}
 132
 133#[derive(Clone)]
 134struct Session {
 135    principal: Principal,
 136    connection_id: ConnectionId,
 137    db: Arc<tokio::sync::Mutex<DbHandle>>,
 138    peer: Arc<Peer>,
 139    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 140    app_state: Arc<AppState>,
 141    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 142    http_client: Arc<IsahcHttpClient>,
 143    /// The GeoIP country code for the user.
 144    #[allow(unused)]
 145    geoip_country_code: Option<String>,
 146    _executor: Executor,
 147}
 148
 149impl Session {
 150    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 151        #[cfg(test)]
 152        tokio::task::yield_now().await;
 153        let guard = self.db.lock().await;
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        guard
 157    }
 158
 159    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.connection_pool.lock();
 163        ConnectionPoolGuard {
 164            guard,
 165            _not_send: PhantomData,
 166        }
 167    }
 168
 169    fn for_user(self) -> Option<UserSession> {
 170        UserSession::new(self)
 171    }
 172
 173    fn for_dev_server(self) -> Option<DevServerSession> {
 174        DevServerSession::new(self)
 175    }
 176
 177    fn user_id(&self) -> Option<UserId> {
 178        match &self.principal {
 179            Principal::User(user) => Some(user.id),
 180            Principal::Impersonated { user, .. } => Some(user.id),
 181            Principal::DevServer(_) => None,
 182        }
 183    }
 184
 185    fn is_staff(&self) -> bool {
 186        match &self.principal {
 187            Principal::User(user) => user.admin,
 188            Principal::Impersonated { .. } => true,
 189            Principal::DevServer(_) => false,
 190        }
 191    }
 192
 193    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 194        if self.is_staff() {
 195            return Ok(proto::Plan::ZedPro);
 196        }
 197
 198        let Some(user_id) = self.user_id() else {
 199            return Ok(proto::Plan::Free);
 200        };
 201
 202        if db.has_active_billing_subscription(user_id).await? {
 203            Ok(proto::Plan::ZedPro)
 204        } else {
 205            Ok(proto::Plan::Free)
 206        }
 207    }
 208
 209    fn dev_server_id(&self) -> Option<DevServerId> {
 210        match &self.principal {
 211            Principal::User(_) | Principal::Impersonated { .. } => None,
 212            Principal::DevServer(dev_server) => Some(dev_server.id),
 213        }
 214    }
 215
 216    fn principal_id(&self) -> PrincipalId {
 217        match &self.principal {
 218            Principal::User(user) => PrincipalId::UserId(user.id),
 219            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 220            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 221        }
 222    }
 223}
 224
 225impl Debug for Session {
 226    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 227        let mut result = f.debug_struct("Session");
 228        match &self.principal {
 229            Principal::User(user) => {
 230                result.field("user", &user.github_login);
 231            }
 232            Principal::Impersonated { user, admin } => {
 233                result.field("user", &user.github_login);
 234                result.field("impersonator", &admin.github_login);
 235            }
 236            Principal::DevServer(dev_server) => {
 237                result.field("dev_server", &dev_server.id);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct UserSession(Session);
 245
 246impl UserSession {
 247    pub fn new(s: Session) -> Option<Self> {
 248        s.user_id().map(|_| UserSession(s))
 249    }
 250    pub fn user_id(&self) -> UserId {
 251        self.0.user_id().unwrap()
 252    }
 253
 254    pub fn email(&self) -> Option<String> {
 255        match &self.0.principal {
 256            Principal::User(user) => user.email_address.clone(),
 257            Principal::Impersonated { user, .. } => user.email_address.clone(),
 258            Principal::DevServer(..) => None,
 259        }
 260    }
 261}
 262
 263impl Deref for UserSession {
 264    type Target = Session;
 265
 266    fn deref(&self) -> &Self::Target {
 267        &self.0
 268    }
 269}
 270impl DerefMut for UserSession {
 271    fn deref_mut(&mut self) -> &mut Self::Target {
 272        &mut self.0
 273    }
 274}
 275
 276struct DevServerSession(Session);
 277
 278impl DevServerSession {
 279    pub fn new(s: Session) -> Option<Self> {
 280        s.dev_server_id().map(|_| DevServerSession(s))
 281    }
 282    pub fn dev_server_id(&self) -> DevServerId {
 283        self.0.dev_server_id().unwrap()
 284    }
 285
 286    fn dev_server(&self) -> &dev_server::Model {
 287        match &self.0.principal {
 288            Principal::DevServer(dev_server) => dev_server,
 289            _ => unreachable!(),
 290        }
 291    }
 292}
 293
 294impl Deref for DevServerSession {
 295    type Target = Session;
 296
 297    fn deref(&self) -> &Self::Target {
 298        &self.0
 299    }
 300}
 301impl DerefMut for DevServerSession {
 302    fn deref_mut(&mut self) -> &mut Self::Target {
 303        &mut self.0
 304    }
 305}
 306
 307fn user_handler<M: RequestMessage, Fut>(
 308    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 309) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 310where
 311    Fut: Send + Future<Output = Result<()>>,
 312{
 313    let handler = Arc::new(handler);
 314    move |message, response, session| {
 315        let handler = handler.clone();
 316        Box::pin(async move {
 317            if let Some(user_session) = session.for_user() {
 318                Ok(handler(message, response, user_session).await?)
 319            } else {
 320                Err(Error::Internal(anyhow!(
 321                    "must be a user to call {}",
 322                    M::NAME
 323                )))
 324            }
 325        })
 326    }
 327}
 328
 329fn dev_server_handler<M: RequestMessage, Fut>(
 330    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 331) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 332where
 333    Fut: Send + Future<Output = Result<()>>,
 334{
 335    let handler = Arc::new(handler);
 336    move |message, response, session| {
 337        let handler = handler.clone();
 338        Box::pin(async move {
 339            if let Some(dev_server_session) = session.for_dev_server() {
 340                Ok(handler(message, response, dev_server_session).await?)
 341            } else {
 342                Err(Error::Internal(anyhow!(
 343                    "must be a dev server to call {}",
 344                    M::NAME
 345                )))
 346            }
 347        })
 348    }
 349}
 350
 351fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 352    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 353) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 354where
 355    InnertRetFut: Send + Future<Output = Result<()>>,
 356{
 357    let handler = Arc::new(handler);
 358    move |message, session| {
 359        let handler = handler.clone();
 360        Box::pin(async move {
 361            if let Some(user_session) = session.for_user() {
 362                Ok(handler(message, user_session).await?)
 363            } else {
 364                Err(Error::Internal(anyhow!(
 365                    "must be a user to call {}",
 366                    M::NAME
 367                )))
 368            }
 369        })
 370    }
 371}
 372
 373struct DbHandle(Arc<Database>);
 374
 375impl Deref for DbHandle {
 376    type Target = Database;
 377
 378    fn deref(&self) -> &Self::Target {
 379        self.0.as_ref()
 380    }
 381}
 382
 383pub struct Server {
 384    id: parking_lot::Mutex<ServerId>,
 385    peer: Arc<Peer>,
 386    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 387    app_state: Arc<AppState>,
 388    handlers: HashMap<TypeId, MessageHandler>,
 389    teardown: watch::Sender<bool>,
 390}
 391
 392pub(crate) struct ConnectionPoolGuard<'a> {
 393    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 394    _not_send: PhantomData<Rc<()>>,
 395}
 396
 397#[derive(Serialize)]
 398pub struct ServerSnapshot<'a> {
 399    peer: &'a Peer,
 400    #[serde(serialize_with = "serialize_deref")]
 401    connection_pool: ConnectionPoolGuard<'a>,
 402}
 403
 404pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 405where
 406    S: Serializer,
 407    T: Deref<Target = U>,
 408    U: Serialize,
 409{
 410    Serialize::serialize(value.deref(), serializer)
 411}
 412
 413impl Server {
 414    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 415        let mut server = Self {
 416            id: parking_lot::Mutex::new(id),
 417            peer: Peer::new(id.0 as u32),
 418            app_state: app_state.clone(),
 419            connection_pool: Default::default(),
 420            handlers: Default::default(),
 421            teardown: watch::channel(false).0,
 422        };
 423
 424        server
 425            .add_request_handler(ping)
 426            .add_request_handler(user_handler(create_room))
 427            .add_request_handler(user_handler(join_room))
 428            .add_request_handler(user_handler(rejoin_room))
 429            .add_request_handler(user_handler(leave_room))
 430            .add_request_handler(user_handler(set_room_participant_role))
 431            .add_request_handler(user_handler(call))
 432            .add_request_handler(user_handler(cancel_call))
 433            .add_message_handler(user_message_handler(decline_call))
 434            .add_request_handler(user_handler(update_participant_location))
 435            .add_request_handler(user_handler(share_project))
 436            .add_message_handler(unshare_project)
 437            .add_request_handler(user_handler(join_project))
 438            .add_request_handler(user_handler(join_hosted_project))
 439            .add_request_handler(user_handler(rejoin_dev_server_projects))
 440            .add_request_handler(user_handler(create_dev_server_project))
 441            .add_request_handler(user_handler(update_dev_server_project))
 442            .add_request_handler(user_handler(delete_dev_server_project))
 443            .add_request_handler(user_handler(create_dev_server))
 444            .add_request_handler(user_handler(regenerate_dev_server_token))
 445            .add_request_handler(user_handler(rename_dev_server))
 446            .add_request_handler(user_handler(delete_dev_server))
 447            .add_request_handler(user_handler(list_remote_directory))
 448            .add_request_handler(dev_server_handler(share_dev_server_project))
 449            .add_request_handler(dev_server_handler(shutdown_dev_server))
 450            .add_request_handler(dev_server_handler(reconnect_dev_server))
 451            .add_message_handler(user_message_handler(leave_project))
 452            .add_request_handler(update_project)
 453            .add_request_handler(update_worktree)
 454            .add_message_handler(start_language_server)
 455            .add_message_handler(update_language_server)
 456            .add_message_handler(update_diagnostic_summary)
 457            .add_message_handler(update_worktree_settings)
 458            .add_request_handler(user_handler(
 459                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_project_request_for_owner::<proto::TaskTemplates>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetHover>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::GetDefinition>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetTypeDefinition>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetReferences>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::SearchProject>,
 478            ))
 479            .add_request_handler(user_handler(forward_find_search_candidates_request))
 480            .add_request_handler(user_handler(
 481                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 482            ))
 483            .add_request_handler(user_handler(
 484                forward_read_only_project_request::<proto::GetProjectSymbols>,
 485            ))
 486            .add_request_handler(user_handler(
 487                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 488            ))
 489            .add_request_handler(user_handler(
 490                forward_read_only_project_request::<proto::OpenBufferById>,
 491            ))
 492            .add_request_handler(user_handler(
 493                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 494            ))
 495            .add_request_handler(user_handler(
 496                forward_read_only_project_request::<proto::InlayHints>,
 497            ))
 498            .add_request_handler(user_handler(
 499                forward_read_only_project_request::<proto::ResolveInlayHint>,
 500            ))
 501            .add_request_handler(user_handler(
 502                forward_read_only_project_request::<proto::OpenBufferByPath>,
 503            ))
 504            .add_request_handler(user_handler(
 505                forward_mutating_project_request::<proto::GetCompletions>,
 506            ))
 507            .add_request_handler(user_handler(
 508                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 509            ))
 510            .add_request_handler(user_handler(
 511                forward_mutating_project_request::<proto::OpenNewBuffer>,
 512            ))
 513            .add_request_handler(user_handler(
 514                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 515            ))
 516            .add_request_handler(user_handler(
 517                forward_mutating_project_request::<proto::GetCodeActions>,
 518            ))
 519            .add_request_handler(user_handler(
 520                forward_mutating_project_request::<proto::ApplyCodeAction>,
 521            ))
 522            .add_request_handler(user_handler(
 523                forward_mutating_project_request::<proto::PrepareRename>,
 524            ))
 525            .add_request_handler(user_handler(
 526                forward_mutating_project_request::<proto::PerformRename>,
 527            ))
 528            .add_request_handler(user_handler(
 529                forward_mutating_project_request::<proto::ReloadBuffers>,
 530            ))
 531            .add_request_handler(user_handler(
 532                forward_mutating_project_request::<proto::FormatBuffers>,
 533            ))
 534            .add_request_handler(user_handler(
 535                forward_mutating_project_request::<proto::CreateProjectEntry>,
 536            ))
 537            .add_request_handler(user_handler(
 538                forward_mutating_project_request::<proto::RenameProjectEntry>,
 539            ))
 540            .add_request_handler(user_handler(
 541                forward_mutating_project_request::<proto::CopyProjectEntry>,
 542            ))
 543            .add_request_handler(user_handler(
 544                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 545            ))
 546            .add_request_handler(user_handler(
 547                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 548            ))
 549            .add_request_handler(user_handler(
 550                forward_mutating_project_request::<proto::OnTypeFormatting>,
 551            ))
 552            .add_request_handler(user_handler(
 553                forward_mutating_project_request::<proto::SaveBuffer>,
 554            ))
 555            .add_request_handler(user_handler(
 556                forward_mutating_project_request::<proto::BlameBuffer>,
 557            ))
 558            .add_request_handler(user_handler(
 559                forward_mutating_project_request::<proto::MultiLspQuery>,
 560            ))
 561            .add_request_handler(user_handler(
 562                forward_mutating_project_request::<proto::RestartLanguageServers>,
 563            ))
 564            .add_request_handler(user_handler(
 565                forward_mutating_project_request::<proto::LinkedEditingRange>,
 566            ))
 567            .add_message_handler(create_buffer_for_peer)
 568            .add_request_handler(update_buffer)
 569            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 571            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 572            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 573            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 574            .add_request_handler(get_users)
 575            .add_request_handler(user_handler(fuzzy_search_users))
 576            .add_request_handler(user_handler(request_contact))
 577            .add_request_handler(user_handler(remove_contact))
 578            .add_request_handler(user_handler(respond_to_contact_request))
 579            .add_message_handler(subscribe_to_channels)
 580            .add_request_handler(user_handler(create_channel))
 581            .add_request_handler(user_handler(delete_channel))
 582            .add_request_handler(user_handler(invite_channel_member))
 583            .add_request_handler(user_handler(remove_channel_member))
 584            .add_request_handler(user_handler(set_channel_member_role))
 585            .add_request_handler(user_handler(set_channel_visibility))
 586            .add_request_handler(user_handler(rename_channel))
 587            .add_request_handler(user_handler(join_channel_buffer))
 588            .add_request_handler(user_handler(leave_channel_buffer))
 589            .add_message_handler(user_message_handler(update_channel_buffer))
 590            .add_request_handler(user_handler(rejoin_channel_buffers))
 591            .add_request_handler(user_handler(get_channel_members))
 592            .add_request_handler(user_handler(respond_to_channel_invite))
 593            .add_request_handler(user_handler(join_channel))
 594            .add_request_handler(user_handler(join_channel_chat))
 595            .add_message_handler(user_message_handler(leave_channel_chat))
 596            .add_request_handler(user_handler(send_channel_message))
 597            .add_request_handler(user_handler(remove_channel_message))
 598            .add_request_handler(user_handler(update_channel_message))
 599            .add_request_handler(user_handler(get_channel_messages))
 600            .add_request_handler(user_handler(get_channel_messages_by_id))
 601            .add_request_handler(user_handler(get_notifications))
 602            .add_request_handler(user_handler(mark_notification_as_read))
 603            .add_request_handler(user_handler(move_channel))
 604            .add_request_handler(user_handler(follow))
 605            .add_message_handler(user_message_handler(unfollow))
 606            .add_message_handler(user_message_handler(update_followers))
 607            .add_request_handler(user_handler(get_private_user_info))
 608            .add_request_handler(user_handler(get_llm_api_token))
 609            .add_request_handler(user_handler(accept_terms_of_service))
 610            .add_message_handler(user_message_handler(acknowledge_channel_message))
 611            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 612            .add_request_handler(user_handler(get_supermaven_api_key))
 613            .add_request_handler(user_handler(
 614                forward_mutating_project_request::<proto::OpenContext>,
 615            ))
 616            .add_request_handler(user_handler(
 617                forward_mutating_project_request::<proto::CreateContext>,
 618            ))
 619            .add_request_handler(user_handler(
 620                forward_mutating_project_request::<proto::SynchronizeContexts>,
 621            ))
 622            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 623            .add_message_handler(update_context)
 624            .add_request_handler({
 625                let app_state = app_state.clone();
 626                move |request, response, session| {
 627                    let app_state = app_state.clone();
 628                    async move {
 629                        count_language_model_tokens(request, response, session, &app_state.config)
 630                            .await
 631                    }
 632                }
 633            })
 634            .add_request_handler({
 635                user_handler(move |request, response, session| {
 636                    get_cached_embeddings(request, response, session)
 637                })
 638            })
 639            .add_request_handler({
 640                let app_state = app_state.clone();
 641                user_handler(move |request, response, session| {
 642                    compute_embeddings(
 643                        request,
 644                        response,
 645                        session,
 646                        app_state.config.openai_api_key.clone(),
 647                    )
 648                })
 649            });
 650
 651        Arc::new(server)
 652    }
 653
 654    pub async fn start(&self) -> Result<()> {
 655        let server_id = *self.id.lock();
 656        let app_state = self.app_state.clone();
 657        let peer = self.peer.clone();
 658        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 659        let pool = self.connection_pool.clone();
 660        let live_kit_client = self.app_state.live_kit_client.clone();
 661
 662        let span = info_span!("start server");
 663        self.app_state.executor.spawn_detached(
 664            async move {
 665                tracing::info!("waiting for cleanup timeout");
 666                timeout.await;
 667                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 668                if let Some((room_ids, channel_ids)) = app_state
 669                    .db
 670                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 671                    .await
 672                    .trace_err()
 673                {
 674                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 675                    tracing::info!(
 676                        stale_channel_buffer_count = channel_ids.len(),
 677                        "retrieved stale channel buffers"
 678                    );
 679
 680                    for channel_id in channel_ids {
 681                        if let Some(refreshed_channel_buffer) = app_state
 682                            .db
 683                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 684                            .await
 685                            .trace_err()
 686                        {
 687                            for connection_id in refreshed_channel_buffer.connection_ids {
 688                                peer.send(
 689                                    connection_id,
 690                                    proto::UpdateChannelBufferCollaborators {
 691                                        channel_id: channel_id.to_proto(),
 692                                        collaborators: refreshed_channel_buffer
 693                                            .collaborators
 694                                            .clone(),
 695                                    },
 696                                )
 697                                .trace_err();
 698                            }
 699                        }
 700                    }
 701
 702                    for room_id in room_ids {
 703                        let mut contacts_to_update = HashSet::default();
 704                        let mut canceled_calls_to_user_ids = Vec::new();
 705                        let mut live_kit_room = String::new();
 706                        let mut delete_live_kit_room = false;
 707
 708                        if let Some(mut refreshed_room) = app_state
 709                            .db
 710                            .clear_stale_room_participants(room_id, server_id)
 711                            .await
 712                            .trace_err()
 713                        {
 714                            tracing::info!(
 715                                room_id = room_id.0,
 716                                new_participant_count = refreshed_room.room.participants.len(),
 717                                "refreshed room"
 718                            );
 719                            room_updated(&refreshed_room.room, &peer);
 720                            if let Some(channel) = refreshed_room.channel.as_ref() {
 721                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 722                            }
 723                            contacts_to_update
 724                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 725                            contacts_to_update
 726                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 727                            canceled_calls_to_user_ids =
 728                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 729                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 730                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 731                        }
 732
 733                        {
 734                            let pool = pool.lock();
 735                            for canceled_user_id in canceled_calls_to_user_ids {
 736                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 737                                    peer.send(
 738                                        connection_id,
 739                                        proto::CallCanceled {
 740                                            room_id: room_id.to_proto(),
 741                                        },
 742                                    )
 743                                    .trace_err();
 744                                }
 745                            }
 746                        }
 747
 748                        for user_id in contacts_to_update {
 749                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 750                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 751                            if let Some((busy, contacts)) = busy.zip(contacts) {
 752                                let pool = pool.lock();
 753                                let updated_contact = contact_for_user(user_id, busy, &pool);
 754                                for contact in contacts {
 755                                    if let db::Contact::Accepted {
 756                                        user_id: contact_user_id,
 757                                        ..
 758                                    } = contact
 759                                    {
 760                                        for contact_conn_id in
 761                                            pool.user_connection_ids(contact_user_id)
 762                                        {
 763                                            peer.send(
 764                                                contact_conn_id,
 765                                                proto::UpdateContacts {
 766                                                    contacts: vec![updated_contact.clone()],
 767                                                    remove_contacts: Default::default(),
 768                                                    incoming_requests: Default::default(),
 769                                                    remove_incoming_requests: Default::default(),
 770                                                    outgoing_requests: Default::default(),
 771                                                    remove_outgoing_requests: Default::default(),
 772                                                },
 773                                            )
 774                                            .trace_err();
 775                                        }
 776                                    }
 777                                }
 778                            }
 779                        }
 780
 781                        if let Some(live_kit) = live_kit_client.as_ref() {
 782                            if delete_live_kit_room {
 783                                live_kit.delete_room(live_kit_room).await.trace_err();
 784                            }
 785                        }
 786                    }
 787                }
 788
 789                app_state
 790                    .db
 791                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 792                    .await
 793                    .trace_err();
 794            }
 795            .instrument(span),
 796        );
 797        Ok(())
 798    }
 799
 800    pub fn teardown(&self) {
 801        self.peer.teardown();
 802        self.connection_pool.lock().reset();
 803        let _ = self.teardown.send(true);
 804    }
 805
 806    #[cfg(test)]
 807    pub fn reset(&self, id: ServerId) {
 808        self.teardown();
 809        *self.id.lock() = id;
 810        self.peer.reset(id.0 as u32);
 811        let _ = self.teardown.send(false);
 812    }
 813
 814    #[cfg(test)]
 815    pub fn id(&self) -> ServerId {
 816        *self.id.lock()
 817    }
 818
 819    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 820    where
 821        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 822        Fut: 'static + Send + Future<Output = Result<()>>,
 823        M: EnvelopedMessage,
 824    {
 825        let prev_handler = self.handlers.insert(
 826            TypeId::of::<M>(),
 827            Box::new(move |envelope, session| {
 828                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 829                let received_at = envelope.received_at;
 830                tracing::info!("message received");
 831                let start_time = Instant::now();
 832                let future = (handler)(*envelope, session);
 833                async move {
 834                    let result = future.await;
 835                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 836                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 837                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 838                    let payload_type = M::NAME;
 839
 840                    match result {
 841                        Err(error) => {
 842                            tracing::error!(
 843                                ?error,
 844                                total_duration_ms,
 845                                processing_duration_ms,
 846                                queue_duration_ms,
 847                                payload_type,
 848                                "error handling message"
 849                            )
 850                        }
 851                        Ok(()) => tracing::info!(
 852                            total_duration_ms,
 853                            processing_duration_ms,
 854                            queue_duration_ms,
 855                            "finished handling message"
 856                        ),
 857                    }
 858                }
 859                .boxed()
 860            }),
 861        );
 862        if prev_handler.is_some() {
 863            panic!("registered a handler for the same message twice");
 864        }
 865        self
 866    }
 867
 868    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 869    where
 870        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 871        Fut: 'static + Send + Future<Output = Result<()>>,
 872        M: EnvelopedMessage,
 873    {
 874        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 875        self
 876    }
 877
 878    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 879    where
 880        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 881        Fut: Send + Future<Output = Result<()>>,
 882        M: RequestMessage,
 883    {
 884        let handler = Arc::new(handler);
 885        self.add_handler(move |envelope, session| {
 886            let receipt = envelope.receipt();
 887            let handler = handler.clone();
 888            async move {
 889                let peer = session.peer.clone();
 890                let responded = Arc::new(AtomicBool::default());
 891                let response = Response {
 892                    peer: peer.clone(),
 893                    responded: responded.clone(),
 894                    receipt,
 895                };
 896                match (handler)(envelope.payload, response, session).await {
 897                    Ok(()) => {
 898                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 899                            Ok(())
 900                        } else {
 901                            Err(anyhow!("handler did not send a response"))?
 902                        }
 903                    }
 904                    Err(error) => {
 905                        let proto_err = match &error {
 906                            Error::Internal(err) => err.to_proto(),
 907                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 908                        };
 909                        peer.respond_with_error(receipt, proto_err)?;
 910                        Err(error)
 911                    }
 912                }
 913            }
 914        })
 915    }
 916
 917    #[allow(clippy::too_many_arguments)]
 918    pub fn handle_connection(
 919        self: &Arc<Self>,
 920        connection: Connection,
 921        address: String,
 922        principal: Principal,
 923        zed_version: ZedVersion,
 924        geoip_country_code: Option<String>,
 925        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 926        executor: Executor,
 927    ) -> impl Future<Output = ()> {
 928        let this = self.clone();
 929        let span = info_span!("handle connection", %address,
 930            connection_id=field::Empty,
 931            user_id=field::Empty,
 932            login=field::Empty,
 933            impersonator=field::Empty,
 934            dev_server_id=field::Empty,
 935            geoip_country_code=field::Empty
 936        );
 937        principal.update_span(&span);
 938        if let Some(country_code) = geoip_country_code.as_ref() {
 939            span.record("geoip_country_code", country_code);
 940        }
 941
 942        let mut teardown = self.teardown.subscribe();
 943        async move {
 944            if *teardown.borrow() {
 945                tracing::error!("server is tearing down");
 946                return
 947            }
 948            let (connection_id, handle_io, mut incoming_rx) = this
 949                .peer
 950                .add_connection(connection, {
 951                    let executor = executor.clone();
 952                    move |duration| executor.sleep(duration)
 953                });
 954            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 955
 956            tracing::info!("connection opened");
 957
 958            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 959            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 960                Ok(http_client) => Arc::new(http_client),
 961                Err(error) => {
 962                    tracing::error!(?error, "failed to create HTTP client");
 963                    return;
 964                }
 965            };
 966
 967            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 968                    supermaven_admin_api_key.to_string(),
 969                    http_client.clone(),
 970                )));
 971
 972            let session = Session {
 973                principal: principal.clone(),
 974                connection_id,
 975                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 976                peer: this.peer.clone(),
 977                connection_pool: this.connection_pool.clone(),
 978                app_state: this.app_state.clone(),
 979                http_client,
 980                geoip_country_code,
 981                _executor: executor.clone(),
 982                supermaven_client,
 983            };
 984
 985            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 986                tracing::error!(?error, "failed to send initial client update");
 987                return;
 988            }
 989
 990            let handle_io = handle_io.fuse();
 991            futures::pin_mut!(handle_io);
 992
 993            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 994            // This prevents deadlocks when e.g., client A performs a request to client B and
 995            // client B performs a request to client A. If both clients stop processing further
 996            // messages until their respective request completes, they won't have a chance to
 997            // respond to the other client's request and cause a deadlock.
 998            //
 999            // This arrangement ensures we will attempt to process earlier messages first, but fall
1000            // back to processing messages arrived later in the spirit of making progress.
1001            let mut foreground_message_handlers = FuturesUnordered::new();
1002            let concurrent_handlers = Arc::new(Semaphore::new(256));
1003            loop {
1004                let next_message = async {
1005                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1006                    let message = incoming_rx.next().await;
1007                    (permit, message)
1008                }.fuse();
1009                futures::pin_mut!(next_message);
1010                futures::select_biased! {
1011                    _ = teardown.changed().fuse() => return,
1012                    result = handle_io => {
1013                        if let Err(error) = result {
1014                            tracing::error!(?error, "error handling I/O");
1015                        }
1016                        break;
1017                    }
1018                    _ = foreground_message_handlers.next() => {}
1019                    next_message = next_message => {
1020                        let (permit, message) = next_message;
1021                        if let Some(message) = message {
1022                            let type_name = message.payload_type_name();
1023                            // note: we copy all the fields from the parent span so we can query them in the logs.
1024                            // (https://github.com/tokio-rs/tracing/issues/2670).
1025                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1026                                user_id=field::Empty,
1027                                login=field::Empty,
1028                                impersonator=field::Empty,
1029                                dev_server_id=field::Empty
1030                            );
1031                            principal.update_span(&span);
1032                            let span_enter = span.enter();
1033                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1034                                let is_background = message.is_background();
1035                                let handle_message = (handler)(message, session.clone());
1036                                drop(span_enter);
1037
1038                                let handle_message = async move {
1039                                    handle_message.await;
1040                                    drop(permit);
1041                                }.instrument(span);
1042                                if is_background {
1043                                    executor.spawn_detached(handle_message);
1044                                } else {
1045                                    foreground_message_handlers.push(handle_message);
1046                                }
1047                            } else {
1048                                tracing::error!("no message handler");
1049                            }
1050                        } else {
1051                            tracing::info!("connection closed");
1052                            break;
1053                        }
1054                    }
1055                }
1056            }
1057
1058            drop(foreground_message_handlers);
1059            tracing::info!("signing out");
1060            if let Err(error) = connection_lost(session, teardown, executor).await {
1061                tracing::error!(?error, "error signing out");
1062            }
1063
1064        }.instrument(span)
1065    }
1066
1067    async fn send_initial_client_update(
1068        &self,
1069        connection_id: ConnectionId,
1070        principal: &Principal,
1071        zed_version: ZedVersion,
1072        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1073        session: &Session,
1074    ) -> Result<()> {
1075        self.peer.send(
1076            connection_id,
1077            proto::Hello {
1078                peer_id: Some(connection_id.into()),
1079            },
1080        )?;
1081        tracing::info!("sent hello message");
1082        if let Some(send_connection_id) = send_connection_id.take() {
1083            let _ = send_connection_id.send(connection_id);
1084        }
1085
1086        match principal {
1087            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1088                if !user.connected_once {
1089                    self.peer.send(connection_id, proto::ShowContacts {})?;
1090                    self.app_state
1091                        .db
1092                        .set_user_connected_once(user.id, true)
1093                        .await?;
1094                }
1095
1096                update_user_plan(user.id, session).await?;
1097
1098                let (contacts, dev_server_projects) = future::try_join(
1099                    self.app_state.db.get_contacts(user.id),
1100                    self.app_state.db.dev_server_projects_update(user.id),
1101                )
1102                .await?;
1103
1104                {
1105                    let mut pool = self.connection_pool.lock();
1106                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1107                    self.peer.send(
1108                        connection_id,
1109                        build_initial_contacts_update(contacts, &pool),
1110                    )?;
1111                }
1112
1113                if should_auto_subscribe_to_channels(zed_version) {
1114                    subscribe_user_to_channels(user.id, session).await?;
1115                }
1116
1117                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1118
1119                if let Some(incoming_call) =
1120                    self.app_state.db.incoming_call_for_user(user.id).await?
1121                {
1122                    self.peer.send(connection_id, incoming_call)?;
1123                }
1124
1125                update_user_contacts(user.id, session).await?;
1126            }
1127            Principal::DevServer(dev_server) => {
1128                {
1129                    let mut pool = self.connection_pool.lock();
1130                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1131                    {
1132                        self.peer.send(
1133                            stale_connection_id,
1134                            proto::ShutdownDevServer {
1135                                reason: Some(
1136                                    "another dev server connected with the same token".to_string(),
1137                                ),
1138                            },
1139                        )?;
1140                        pool.remove_connection(stale_connection_id)?;
1141                    };
1142                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1143                }
1144
1145                let projects = self
1146                    .app_state
1147                    .db
1148                    .get_projects_for_dev_server(dev_server.id)
1149                    .await?;
1150                self.peer
1151                    .send(connection_id, proto::DevServerInstructions { projects })?;
1152
1153                let status = self
1154                    .app_state
1155                    .db
1156                    .dev_server_projects_update(dev_server.user_id)
1157                    .await?;
1158                send_dev_server_projects_update(dev_server.user_id, status, session).await;
1159            }
1160        }
1161
1162        Ok(())
1163    }
1164
1165    pub async fn invite_code_redeemed(
1166        self: &Arc<Self>,
1167        inviter_id: UserId,
1168        invitee_id: UserId,
1169    ) -> Result<()> {
1170        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1171            if let Some(code) = &user.invite_code {
1172                let pool = self.connection_pool.lock();
1173                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1174                for connection_id in pool.user_connection_ids(inviter_id) {
1175                    self.peer.send(
1176                        connection_id,
1177                        proto::UpdateContacts {
1178                            contacts: vec![invitee_contact.clone()],
1179                            ..Default::default()
1180                        },
1181                    )?;
1182                    self.peer.send(
1183                        connection_id,
1184                        proto::UpdateInviteInfo {
1185                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1186                            count: user.invite_count as u32,
1187                        },
1188                    )?;
1189                }
1190            }
1191        }
1192        Ok(())
1193    }
1194
1195    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1196        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1197            if let Some(invite_code) = &user.invite_code {
1198                let pool = self.connection_pool.lock();
1199                for connection_id in pool.user_connection_ids(user_id) {
1200                    self.peer.send(
1201                        connection_id,
1202                        proto::UpdateInviteInfo {
1203                            url: format!(
1204                                "{}{}",
1205                                self.app_state.config.invite_link_prefix, invite_code
1206                            ),
1207                            count: user.invite_count as u32,
1208                        },
1209                    )?;
1210                }
1211            }
1212        }
1213        Ok(())
1214    }
1215
1216    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1217        ServerSnapshot {
1218            connection_pool: ConnectionPoolGuard {
1219                guard: self.connection_pool.lock(),
1220                _not_send: PhantomData,
1221            },
1222            peer: &self.peer,
1223        }
1224    }
1225}
1226
1227impl<'a> Deref for ConnectionPoolGuard<'a> {
1228    type Target = ConnectionPool;
1229
1230    fn deref(&self) -> &Self::Target {
1231        &self.guard
1232    }
1233}
1234
1235impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1236    fn deref_mut(&mut self) -> &mut Self::Target {
1237        &mut self.guard
1238    }
1239}
1240
1241impl<'a> Drop for ConnectionPoolGuard<'a> {
1242    fn drop(&mut self) {
1243        #[cfg(test)]
1244        self.check_invariants();
1245    }
1246}
1247
1248fn broadcast<F>(
1249    sender_id: Option<ConnectionId>,
1250    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1251    mut f: F,
1252) where
1253    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1254{
1255    for receiver_id in receiver_ids {
1256        if Some(receiver_id) != sender_id {
1257            if let Err(error) = f(receiver_id) {
1258                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1259            }
1260        }
1261    }
1262}
1263
1264pub struct ProtocolVersion(u32);
1265
1266impl Header for ProtocolVersion {
1267    fn name() -> &'static HeaderName {
1268        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1269        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1270    }
1271
1272    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1273    where
1274        Self: Sized,
1275        I: Iterator<Item = &'i axum::http::HeaderValue>,
1276    {
1277        let version = values
1278            .next()
1279            .ok_or_else(axum::headers::Error::invalid)?
1280            .to_str()
1281            .map_err(|_| axum::headers::Error::invalid())?
1282            .parse()
1283            .map_err(|_| axum::headers::Error::invalid())?;
1284        Ok(Self(version))
1285    }
1286
1287    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1288        values.extend([self.0.to_string().parse().unwrap()]);
1289    }
1290}
1291
1292pub struct AppVersionHeader(SemanticVersion);
1293impl Header for AppVersionHeader {
1294    fn name() -> &'static HeaderName {
1295        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1296        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1297    }
1298
1299    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1300    where
1301        Self: Sized,
1302        I: Iterator<Item = &'i axum::http::HeaderValue>,
1303    {
1304        let version = values
1305            .next()
1306            .ok_or_else(axum::headers::Error::invalid)?
1307            .to_str()
1308            .map_err(|_| axum::headers::Error::invalid())?
1309            .parse()
1310            .map_err(|_| axum::headers::Error::invalid())?;
1311        Ok(Self(version))
1312    }
1313
1314    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1315        values.extend([self.0.to_string().parse().unwrap()]);
1316    }
1317}
1318
1319pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1320    Router::new()
1321        .route("/rpc", get(handle_websocket_request))
1322        .layer(
1323            ServiceBuilder::new()
1324                .layer(Extension(server.app_state.clone()))
1325                .layer(middleware::from_fn(auth::validate_header)),
1326        )
1327        .route("/metrics", get(handle_metrics))
1328        .layer(Extension(server))
1329}
1330
1331pub async fn handle_websocket_request(
1332    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1333    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1334    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1335    Extension(server): Extension<Arc<Server>>,
1336    Extension(principal): Extension<Principal>,
1337    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1338    ws: WebSocketUpgrade,
1339) -> axum::response::Response {
1340    if protocol_version != rpc::PROTOCOL_VERSION {
1341        return (
1342            StatusCode::UPGRADE_REQUIRED,
1343            "client must be upgraded".to_string(),
1344        )
1345            .into_response();
1346    }
1347
1348    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1349        return (
1350            StatusCode::UPGRADE_REQUIRED,
1351            "no version header found".to_string(),
1352        )
1353            .into_response();
1354    };
1355
1356    if !version.can_collaborate() {
1357        return (
1358            StatusCode::UPGRADE_REQUIRED,
1359            "client must be upgraded".to_string(),
1360        )
1361            .into_response();
1362    }
1363
1364    let socket_address = socket_address.to_string();
1365    ws.on_upgrade(move |socket| {
1366        let socket = socket
1367            .map_ok(to_tungstenite_message)
1368            .err_into()
1369            .with(|message| async move { to_axum_message(message) });
1370        let connection = Connection::new(Box::pin(socket));
1371        async move {
1372            server
1373                .handle_connection(
1374                    connection,
1375                    socket_address,
1376                    principal,
1377                    version,
1378                    country_code_header.map(|header| header.to_string()),
1379                    None,
1380                    Executor::Production,
1381                )
1382                .await;
1383        }
1384    })
1385}
1386
1387pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1388    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1389    let connections_metric = CONNECTIONS_METRIC
1390        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1391
1392    let connections = server
1393        .connection_pool
1394        .lock()
1395        .connections()
1396        .filter(|connection| !connection.admin)
1397        .count();
1398    connections_metric.set(connections as _);
1399
1400    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1401    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1402        register_int_gauge!(
1403            "shared_projects",
1404            "number of open projects with one or more guests"
1405        )
1406        .unwrap()
1407    });
1408
1409    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1410    shared_projects_metric.set(shared_projects as _);
1411
1412    let encoder = prometheus::TextEncoder::new();
1413    let metric_families = prometheus::gather();
1414    let encoded_metrics = encoder
1415        .encode_to_string(&metric_families)
1416        .map_err(|err| anyhow!("{}", err))?;
1417    Ok(encoded_metrics)
1418}
1419
1420#[instrument(err, skip(executor))]
1421async fn connection_lost(
1422    session: Session,
1423    mut teardown: watch::Receiver<bool>,
1424    executor: Executor,
1425) -> Result<()> {
1426    session.peer.disconnect(session.connection_id);
1427    session
1428        .connection_pool()
1429        .await
1430        .remove_connection(session.connection_id)?;
1431
1432    session
1433        .db()
1434        .await
1435        .connection_lost(session.connection_id)
1436        .await
1437        .trace_err();
1438
1439    futures::select_biased! {
1440        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1441            match &session.principal {
1442                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1443                    let session = session.for_user().unwrap();
1444
1445                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1446                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1447                    leave_channel_buffers_for_session(&session)
1448                        .await
1449                        .trace_err();
1450
1451                    if !session
1452                        .connection_pool()
1453                        .await
1454                        .is_user_online(session.user_id())
1455                    {
1456                        let db = session.db().await;
1457                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1458                            room_updated(&room, &session.peer);
1459                        }
1460                    }
1461
1462                    update_user_contacts(session.user_id(), &session).await?;
1463                },
1464            Principal::DevServer(_) => {
1465                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1466            },
1467        }
1468        },
1469        _ = teardown.changed().fuse() => {}
1470    }
1471
1472    Ok(())
1473}
1474
1475/// Acknowledges a ping from a client, used to keep the connection alive.
1476async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1477    response.send(proto::Ack {})?;
1478    Ok(())
1479}
1480
1481/// Creates a new room for calling (outside of channels)
1482async fn create_room(
1483    _request: proto::CreateRoom,
1484    response: Response<proto::CreateRoom>,
1485    session: UserSession,
1486) -> Result<()> {
1487    let live_kit_room = nanoid::nanoid!(30);
1488
1489    let live_kit_connection_info = util::maybe!(async {
1490        let live_kit = session.app_state.live_kit_client.as_ref();
1491        let live_kit = live_kit?;
1492        let user_id = session.user_id().to_string();
1493
1494        let token = live_kit
1495            .room_token(&live_kit_room, &user_id.to_string())
1496            .trace_err()?;
1497
1498        Some(proto::LiveKitConnectionInfo {
1499            server_url: live_kit.url().into(),
1500            token,
1501            can_publish: true,
1502        })
1503    })
1504    .await;
1505
1506    let room = session
1507        .db()
1508        .await
1509        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1510        .await?;
1511
1512    response.send(proto::CreateRoomResponse {
1513        room: Some(room.clone()),
1514        live_kit_connection_info,
1515    })?;
1516
1517    update_user_contacts(session.user_id(), &session).await?;
1518    Ok(())
1519}
1520
1521/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1522async fn join_room(
1523    request: proto::JoinRoom,
1524    response: Response<proto::JoinRoom>,
1525    session: UserSession,
1526) -> Result<()> {
1527    let room_id = RoomId::from_proto(request.id);
1528
1529    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1530
1531    if let Some(channel_id) = channel_id {
1532        return join_channel_internal(channel_id, Box::new(response), session).await;
1533    }
1534
1535    let joined_room = {
1536        let room = session
1537            .db()
1538            .await
1539            .join_room(room_id, session.user_id(), session.connection_id)
1540            .await?;
1541        room_updated(&room.room, &session.peer);
1542        room.into_inner()
1543    };
1544
1545    for connection_id in session
1546        .connection_pool()
1547        .await
1548        .user_connection_ids(session.user_id())
1549    {
1550        session
1551            .peer
1552            .send(
1553                connection_id,
1554                proto::CallCanceled {
1555                    room_id: room_id.to_proto(),
1556                },
1557            )
1558            .trace_err();
1559    }
1560
1561    let live_kit_connection_info =
1562        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1563            live_kit
1564                .room_token(
1565                    &joined_room.room.live_kit_room,
1566                    &session.user_id().to_string(),
1567                )
1568                .trace_err()
1569                .map(|token| proto::LiveKitConnectionInfo {
1570                    server_url: live_kit.url().into(),
1571                    token,
1572                    can_publish: true,
1573                })
1574        } else {
1575            None
1576        };
1577
1578    response.send(proto::JoinRoomResponse {
1579        room: Some(joined_room.room),
1580        channel_id: None,
1581        live_kit_connection_info,
1582    })?;
1583
1584    update_user_contacts(session.user_id(), &session).await?;
1585    Ok(())
1586}
1587
1588/// Rejoin room is used to reconnect to a room after connection errors.
1589async fn rejoin_room(
1590    request: proto::RejoinRoom,
1591    response: Response<proto::RejoinRoom>,
1592    session: UserSession,
1593) -> Result<()> {
1594    let room;
1595    let channel;
1596    {
1597        let mut rejoined_room = session
1598            .db()
1599            .await
1600            .rejoin_room(request, session.user_id(), session.connection_id)
1601            .await?;
1602
1603        response.send(proto::RejoinRoomResponse {
1604            room: Some(rejoined_room.room.clone()),
1605            reshared_projects: rejoined_room
1606                .reshared_projects
1607                .iter()
1608                .map(|project| proto::ResharedProject {
1609                    id: project.id.to_proto(),
1610                    collaborators: project
1611                        .collaborators
1612                        .iter()
1613                        .map(|collaborator| collaborator.to_proto())
1614                        .collect(),
1615                })
1616                .collect(),
1617            rejoined_projects: rejoined_room
1618                .rejoined_projects
1619                .iter()
1620                .map(|rejoined_project| rejoined_project.to_proto())
1621                .collect(),
1622        })?;
1623        room_updated(&rejoined_room.room, &session.peer);
1624
1625        for project in &rejoined_room.reshared_projects {
1626            for collaborator in &project.collaborators {
1627                session
1628                    .peer
1629                    .send(
1630                        collaborator.connection_id,
1631                        proto::UpdateProjectCollaborator {
1632                            project_id: project.id.to_proto(),
1633                            old_peer_id: Some(project.old_connection_id.into()),
1634                            new_peer_id: Some(session.connection_id.into()),
1635                        },
1636                    )
1637                    .trace_err();
1638            }
1639
1640            broadcast(
1641                Some(session.connection_id),
1642                project
1643                    .collaborators
1644                    .iter()
1645                    .map(|collaborator| collaborator.connection_id),
1646                |connection_id| {
1647                    session.peer.forward_send(
1648                        session.connection_id,
1649                        connection_id,
1650                        proto::UpdateProject {
1651                            project_id: project.id.to_proto(),
1652                            worktrees: project.worktrees.clone(),
1653                        },
1654                    )
1655                },
1656            );
1657        }
1658
1659        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1660
1661        let rejoined_room = rejoined_room.into_inner();
1662
1663        room = rejoined_room.room;
1664        channel = rejoined_room.channel;
1665    }
1666
1667    if let Some(channel) = channel {
1668        channel_updated(
1669            &channel,
1670            &room,
1671            &session.peer,
1672            &*session.connection_pool().await,
1673        );
1674    }
1675
1676    update_user_contacts(session.user_id(), &session).await?;
1677    Ok(())
1678}
1679
1680fn notify_rejoined_projects(
1681    rejoined_projects: &mut Vec<RejoinedProject>,
1682    session: &UserSession,
1683) -> Result<()> {
1684    for project in rejoined_projects.iter() {
1685        for collaborator in &project.collaborators {
1686            session
1687                .peer
1688                .send(
1689                    collaborator.connection_id,
1690                    proto::UpdateProjectCollaborator {
1691                        project_id: project.id.to_proto(),
1692                        old_peer_id: Some(project.old_connection_id.into()),
1693                        new_peer_id: Some(session.connection_id.into()),
1694                    },
1695                )
1696                .trace_err();
1697        }
1698    }
1699
1700    for project in rejoined_projects {
1701        for worktree in mem::take(&mut project.worktrees) {
1702            #[cfg(any(test, feature = "test-support"))]
1703            const MAX_CHUNK_SIZE: usize = 2;
1704            #[cfg(not(any(test, feature = "test-support")))]
1705            const MAX_CHUNK_SIZE: usize = 256;
1706
1707            // Stream this worktree's entries.
1708            let message = proto::UpdateWorktree {
1709                project_id: project.id.to_proto(),
1710                worktree_id: worktree.id,
1711                abs_path: worktree.abs_path.clone(),
1712                root_name: worktree.root_name,
1713                updated_entries: worktree.updated_entries,
1714                removed_entries: worktree.removed_entries,
1715                scan_id: worktree.scan_id,
1716                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1717                updated_repositories: worktree.updated_repositories,
1718                removed_repositories: worktree.removed_repositories,
1719            };
1720            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1721                session.peer.send(session.connection_id, update.clone())?;
1722            }
1723
1724            // Stream this worktree's diagnostics.
1725            for summary in worktree.diagnostic_summaries {
1726                session.peer.send(
1727                    session.connection_id,
1728                    proto::UpdateDiagnosticSummary {
1729                        project_id: project.id.to_proto(),
1730                        worktree_id: worktree.id,
1731                        summary: Some(summary),
1732                    },
1733                )?;
1734            }
1735
1736            for settings_file in worktree.settings_files {
1737                session.peer.send(
1738                    session.connection_id,
1739                    proto::UpdateWorktreeSettings {
1740                        project_id: project.id.to_proto(),
1741                        worktree_id: worktree.id,
1742                        path: settings_file.path,
1743                        content: Some(settings_file.content),
1744                    },
1745                )?;
1746            }
1747        }
1748
1749        for language_server in &project.language_servers {
1750            session.peer.send(
1751                session.connection_id,
1752                proto::UpdateLanguageServer {
1753                    project_id: project.id.to_proto(),
1754                    language_server_id: language_server.id,
1755                    variant: Some(
1756                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1757                            proto::LspDiskBasedDiagnosticsUpdated {},
1758                        ),
1759                    ),
1760                },
1761            )?;
1762        }
1763    }
1764    Ok(())
1765}
1766
1767/// leave room disconnects from the room.
1768async fn leave_room(
1769    _: proto::LeaveRoom,
1770    response: Response<proto::LeaveRoom>,
1771    session: UserSession,
1772) -> Result<()> {
1773    leave_room_for_session(&session, session.connection_id).await?;
1774    response.send(proto::Ack {})?;
1775    Ok(())
1776}
1777
1778/// Updates the permissions of someone else in the room.
1779async fn set_room_participant_role(
1780    request: proto::SetRoomParticipantRole,
1781    response: Response<proto::SetRoomParticipantRole>,
1782    session: UserSession,
1783) -> Result<()> {
1784    let user_id = UserId::from_proto(request.user_id);
1785    let role = ChannelRole::from(request.role());
1786
1787    let (live_kit_room, can_publish) = {
1788        let room = session
1789            .db()
1790            .await
1791            .set_room_participant_role(
1792                session.user_id(),
1793                RoomId::from_proto(request.room_id),
1794                user_id,
1795                role,
1796            )
1797            .await?;
1798
1799        let live_kit_room = room.live_kit_room.clone();
1800        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1801        room_updated(&room, &session.peer);
1802        (live_kit_room, can_publish)
1803    };
1804
1805    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1806        live_kit
1807            .update_participant(
1808                live_kit_room.clone(),
1809                request.user_id.to_string(),
1810                live_kit_server::proto::ParticipantPermission {
1811                    can_subscribe: true,
1812                    can_publish,
1813                    can_publish_data: can_publish,
1814                    hidden: false,
1815                    recorder: false,
1816                },
1817            )
1818            .await
1819            .trace_err();
1820    }
1821
1822    response.send(proto::Ack {})?;
1823    Ok(())
1824}
1825
1826/// Call someone else into the current room
1827async fn call(
1828    request: proto::Call,
1829    response: Response<proto::Call>,
1830    session: UserSession,
1831) -> Result<()> {
1832    let room_id = RoomId::from_proto(request.room_id);
1833    let calling_user_id = session.user_id();
1834    let calling_connection_id = session.connection_id;
1835    let called_user_id = UserId::from_proto(request.called_user_id);
1836    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1837    if !session
1838        .db()
1839        .await
1840        .has_contact(calling_user_id, called_user_id)
1841        .await?
1842    {
1843        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1844    }
1845
1846    let incoming_call = {
1847        let (room, incoming_call) = &mut *session
1848            .db()
1849            .await
1850            .call(
1851                room_id,
1852                calling_user_id,
1853                calling_connection_id,
1854                called_user_id,
1855                initial_project_id,
1856            )
1857            .await?;
1858        room_updated(room, &session.peer);
1859        mem::take(incoming_call)
1860    };
1861    update_user_contacts(called_user_id, &session).await?;
1862
1863    let mut calls = session
1864        .connection_pool()
1865        .await
1866        .user_connection_ids(called_user_id)
1867        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1868        .collect::<FuturesUnordered<_>>();
1869
1870    while let Some(call_response) = calls.next().await {
1871        match call_response.as_ref() {
1872            Ok(_) => {
1873                response.send(proto::Ack {})?;
1874                return Ok(());
1875            }
1876            Err(_) => {
1877                call_response.trace_err();
1878            }
1879        }
1880    }
1881
1882    {
1883        let room = session
1884            .db()
1885            .await
1886            .call_failed(room_id, called_user_id)
1887            .await?;
1888        room_updated(&room, &session.peer);
1889    }
1890    update_user_contacts(called_user_id, &session).await?;
1891
1892    Err(anyhow!("failed to ring user"))?
1893}
1894
1895/// Cancel an outgoing call.
1896async fn cancel_call(
1897    request: proto::CancelCall,
1898    response: Response<proto::CancelCall>,
1899    session: UserSession,
1900) -> Result<()> {
1901    let called_user_id = UserId::from_proto(request.called_user_id);
1902    let room_id = RoomId::from_proto(request.room_id);
1903    {
1904        let room = session
1905            .db()
1906            .await
1907            .cancel_call(room_id, session.connection_id, called_user_id)
1908            .await?;
1909        room_updated(&room, &session.peer);
1910    }
1911
1912    for connection_id in session
1913        .connection_pool()
1914        .await
1915        .user_connection_ids(called_user_id)
1916    {
1917        session
1918            .peer
1919            .send(
1920                connection_id,
1921                proto::CallCanceled {
1922                    room_id: room_id.to_proto(),
1923                },
1924            )
1925            .trace_err();
1926    }
1927    response.send(proto::Ack {})?;
1928
1929    update_user_contacts(called_user_id, &session).await?;
1930    Ok(())
1931}
1932
1933/// Decline an incoming call.
1934async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1935    let room_id = RoomId::from_proto(message.room_id);
1936    {
1937        let room = session
1938            .db()
1939            .await
1940            .decline_call(Some(room_id), session.user_id())
1941            .await?
1942            .ok_or_else(|| anyhow!("failed to decline call"))?;
1943        room_updated(&room, &session.peer);
1944    }
1945
1946    for connection_id in session
1947        .connection_pool()
1948        .await
1949        .user_connection_ids(session.user_id())
1950    {
1951        session
1952            .peer
1953            .send(
1954                connection_id,
1955                proto::CallCanceled {
1956                    room_id: room_id.to_proto(),
1957                },
1958            )
1959            .trace_err();
1960    }
1961    update_user_contacts(session.user_id(), &session).await?;
1962    Ok(())
1963}
1964
1965/// Updates other participants in the room with your current location.
1966async fn update_participant_location(
1967    request: proto::UpdateParticipantLocation,
1968    response: Response<proto::UpdateParticipantLocation>,
1969    session: UserSession,
1970) -> Result<()> {
1971    let room_id = RoomId::from_proto(request.room_id);
1972    let location = request
1973        .location
1974        .ok_or_else(|| anyhow!("invalid location"))?;
1975
1976    let db = session.db().await;
1977    let room = db
1978        .update_room_participant_location(room_id, session.connection_id, location)
1979        .await?;
1980
1981    room_updated(&room, &session.peer);
1982    response.send(proto::Ack {})?;
1983    Ok(())
1984}
1985
1986/// Share a project into the room.
1987async fn share_project(
1988    request: proto::ShareProject,
1989    response: Response<proto::ShareProject>,
1990    session: UserSession,
1991) -> Result<()> {
1992    let (project_id, room) = &*session
1993        .db()
1994        .await
1995        .share_project(
1996            RoomId::from_proto(request.room_id),
1997            session.connection_id,
1998            &request.worktrees,
1999            request.is_ssh_project,
2000            request
2001                .dev_server_project_id
2002                .map(DevServerProjectId::from_proto),
2003        )
2004        .await?;
2005    response.send(proto::ShareProjectResponse {
2006        project_id: project_id.to_proto(),
2007    })?;
2008    room_updated(room, &session.peer);
2009
2010    Ok(())
2011}
2012
2013/// Unshare a project from the room.
2014async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2015    let project_id = ProjectId::from_proto(message.project_id);
2016    unshare_project_internal(
2017        project_id,
2018        session.connection_id,
2019        session.user_id(),
2020        &session,
2021    )
2022    .await
2023}
2024
2025async fn unshare_project_internal(
2026    project_id: ProjectId,
2027    connection_id: ConnectionId,
2028    user_id: Option<UserId>,
2029    session: &Session,
2030) -> Result<()> {
2031    let delete = {
2032        let room_guard = session
2033            .db()
2034            .await
2035            .unshare_project(project_id, connection_id, user_id)
2036            .await?;
2037
2038        let (delete, room, guest_connection_ids) = &*room_guard;
2039
2040        let message = proto::UnshareProject {
2041            project_id: project_id.to_proto(),
2042        };
2043
2044        broadcast(
2045            Some(connection_id),
2046            guest_connection_ids.iter().copied(),
2047            |conn_id| session.peer.send(conn_id, message.clone()),
2048        );
2049        if let Some(room) = room {
2050            room_updated(room, &session.peer);
2051        }
2052
2053        *delete
2054    };
2055
2056    if delete {
2057        let db = session.db().await;
2058        db.delete_project(project_id).await?;
2059    }
2060
2061    Ok(())
2062}
2063
2064/// DevServer makes a project available online
2065async fn share_dev_server_project(
2066    request: proto::ShareDevServerProject,
2067    response: Response<proto::ShareDevServerProject>,
2068    session: DevServerSession,
2069) -> Result<()> {
2070    let (dev_server_project, user_id, status) = session
2071        .db()
2072        .await
2073        .share_dev_server_project(
2074            DevServerProjectId::from_proto(request.dev_server_project_id),
2075            session.dev_server_id(),
2076            session.connection_id,
2077            &request.worktrees,
2078        )
2079        .await?;
2080    let Some(project_id) = dev_server_project.project_id else {
2081        return Err(anyhow!("failed to share remote project"))?;
2082    };
2083
2084    send_dev_server_projects_update(user_id, status, &session).await;
2085
2086    response.send(proto::ShareProjectResponse { project_id })?;
2087
2088    Ok(())
2089}
2090
2091/// Join someone elses shared project.
2092async fn join_project(
2093    request: proto::JoinProject,
2094    response: Response<proto::JoinProject>,
2095    session: UserSession,
2096) -> Result<()> {
2097    let project_id = ProjectId::from_proto(request.project_id);
2098
2099    tracing::info!(%project_id, "join project");
2100
2101    let db = session.db().await;
2102    let (project, replica_id) = &mut *db
2103        .join_project(project_id, session.connection_id, session.user_id())
2104        .await?;
2105    drop(db);
2106    tracing::info!(%project_id, "join remote project");
2107    join_project_internal(response, session, project, replica_id)
2108}
2109
2110trait JoinProjectInternalResponse {
2111    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2112}
2113impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2114    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2115        Response::<proto::JoinProject>::send(self, result)
2116    }
2117}
2118impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2119    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2120        Response::<proto::JoinHostedProject>::send(self, result)
2121    }
2122}
2123
2124fn join_project_internal(
2125    response: impl JoinProjectInternalResponse,
2126    session: UserSession,
2127    project: &mut Project,
2128    replica_id: &ReplicaId,
2129) -> Result<()> {
2130    let collaborators = project
2131        .collaborators
2132        .iter()
2133        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2134        .map(|collaborator| collaborator.to_proto())
2135        .collect::<Vec<_>>();
2136    let project_id = project.id;
2137    let guest_user_id = session.user_id();
2138
2139    let worktrees = project
2140        .worktrees
2141        .iter()
2142        .map(|(id, worktree)| proto::WorktreeMetadata {
2143            id: *id,
2144            root_name: worktree.root_name.clone(),
2145            visible: worktree.visible,
2146            abs_path: worktree.abs_path.clone(),
2147        })
2148        .collect::<Vec<_>>();
2149
2150    let add_project_collaborator = proto::AddProjectCollaborator {
2151        project_id: project_id.to_proto(),
2152        collaborator: Some(proto::Collaborator {
2153            peer_id: Some(session.connection_id.into()),
2154            replica_id: replica_id.0 as u32,
2155            user_id: guest_user_id.to_proto(),
2156        }),
2157    };
2158
2159    for collaborator in &collaborators {
2160        session
2161            .peer
2162            .send(
2163                collaborator.peer_id.unwrap().into(),
2164                add_project_collaborator.clone(),
2165            )
2166            .trace_err();
2167    }
2168
2169    // First, we send the metadata associated with each worktree.
2170    response.send(proto::JoinProjectResponse {
2171        project_id: project.id.0 as u64,
2172        worktrees: worktrees.clone(),
2173        replica_id: replica_id.0 as u32,
2174        collaborators: collaborators.clone(),
2175        language_servers: project.language_servers.clone(),
2176        role: project.role.into(),
2177        dev_server_project_id: project
2178            .dev_server_project_id
2179            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2180    })?;
2181
2182    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2183        #[cfg(any(test, feature = "test-support"))]
2184        const MAX_CHUNK_SIZE: usize = 2;
2185        #[cfg(not(any(test, feature = "test-support")))]
2186        const MAX_CHUNK_SIZE: usize = 256;
2187
2188        // Stream this worktree's entries.
2189        let message = proto::UpdateWorktree {
2190            project_id: project_id.to_proto(),
2191            worktree_id,
2192            abs_path: worktree.abs_path.clone(),
2193            root_name: worktree.root_name,
2194            updated_entries: worktree.entries,
2195            removed_entries: Default::default(),
2196            scan_id: worktree.scan_id,
2197            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2198            updated_repositories: worktree.repository_entries.into_values().collect(),
2199            removed_repositories: Default::default(),
2200        };
2201        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2202            session.peer.send(session.connection_id, update.clone())?;
2203        }
2204
2205        // Stream this worktree's diagnostics.
2206        for summary in worktree.diagnostic_summaries {
2207            session.peer.send(
2208                session.connection_id,
2209                proto::UpdateDiagnosticSummary {
2210                    project_id: project_id.to_proto(),
2211                    worktree_id: worktree.id,
2212                    summary: Some(summary),
2213                },
2214            )?;
2215        }
2216
2217        for settings_file in worktree.settings_files {
2218            session.peer.send(
2219                session.connection_id,
2220                proto::UpdateWorktreeSettings {
2221                    project_id: project_id.to_proto(),
2222                    worktree_id: worktree.id,
2223                    path: settings_file.path,
2224                    content: Some(settings_file.content),
2225                },
2226            )?;
2227        }
2228    }
2229
2230    for language_server in &project.language_servers {
2231        session.peer.send(
2232            session.connection_id,
2233            proto::UpdateLanguageServer {
2234                project_id: project_id.to_proto(),
2235                language_server_id: language_server.id,
2236                variant: Some(
2237                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2238                        proto::LspDiskBasedDiagnosticsUpdated {},
2239                    ),
2240                ),
2241            },
2242        )?;
2243    }
2244
2245    Ok(())
2246}
2247
2248/// Leave someone elses shared project.
2249async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2250    let sender_id = session.connection_id;
2251    let project_id = ProjectId::from_proto(request.project_id);
2252    let db = session.db().await;
2253    if db.is_hosted_project(project_id).await? {
2254        let project = db.leave_hosted_project(project_id, sender_id).await?;
2255        project_left(&project, &session);
2256        return Ok(());
2257    }
2258
2259    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2260    tracing::info!(
2261        %project_id,
2262        "leave project"
2263    );
2264
2265    project_left(project, &session);
2266    if let Some(room) = room {
2267        room_updated(room, &session.peer);
2268    }
2269
2270    Ok(())
2271}
2272
2273async fn join_hosted_project(
2274    request: proto::JoinHostedProject,
2275    response: Response<proto::JoinHostedProject>,
2276    session: UserSession,
2277) -> Result<()> {
2278    let (mut project, replica_id) = session
2279        .db()
2280        .await
2281        .join_hosted_project(
2282            ProjectId(request.project_id as i32),
2283            session.user_id(),
2284            session.connection_id,
2285        )
2286        .await?;
2287
2288    join_project_internal(response, session, &mut project, &replica_id)
2289}
2290
2291async fn list_remote_directory(
2292    request: proto::ListRemoteDirectory,
2293    response: Response<proto::ListRemoteDirectory>,
2294    session: UserSession,
2295) -> Result<()> {
2296    let dev_server_id = DevServerId(request.dev_server_id as i32);
2297    let dev_server_connection_id = session
2298        .connection_pool()
2299        .await
2300        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2301
2302    session
2303        .db()
2304        .await
2305        .get_dev_server_for_user(dev_server_id, session.user_id())
2306        .await?;
2307
2308    response.send(
2309        session
2310            .peer
2311            .forward_request(session.connection_id, dev_server_connection_id, request)
2312            .await?,
2313    )?;
2314    Ok(())
2315}
2316
2317async fn update_dev_server_project(
2318    request: proto::UpdateDevServerProject,
2319    response: Response<proto::UpdateDevServerProject>,
2320    session: UserSession,
2321) -> Result<()> {
2322    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2323
2324    let (dev_server_project, update) = session
2325        .db()
2326        .await
2327        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2328        .await?;
2329
2330    let projects = session
2331        .db()
2332        .await
2333        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2334        .await?;
2335
2336    let dev_server_connection_id = session
2337        .connection_pool()
2338        .await
2339        .dev_server_connection_id_supporting(
2340            dev_server_project.dev_server_id,
2341            ZedVersion::with_list_directory(),
2342        )?;
2343
2344    session.peer.send(
2345        dev_server_connection_id,
2346        proto::DevServerInstructions { projects },
2347    )?;
2348
2349    send_dev_server_projects_update(session.user_id(), update, &session).await;
2350
2351    response.send(proto::Ack {})
2352}
2353
2354async fn create_dev_server_project(
2355    request: proto::CreateDevServerProject,
2356    response: Response<proto::CreateDevServerProject>,
2357    session: UserSession,
2358) -> Result<()> {
2359    let dev_server_id = DevServerId(request.dev_server_id as i32);
2360    let dev_server_connection_id = session
2361        .connection_pool()
2362        .await
2363        .dev_server_connection_id(dev_server_id);
2364    let Some(dev_server_connection_id) = dev_server_connection_id else {
2365        Err(ErrorCode::DevServerOffline
2366            .message("Cannot create a remote project when the dev server is offline".to_string())
2367            .anyhow())?
2368    };
2369
2370    let path = request.path.clone();
2371    //Check that the path exists on the dev server
2372    session
2373        .peer
2374        .forward_request(
2375            session.connection_id,
2376            dev_server_connection_id,
2377            proto::ValidateDevServerProjectRequest { path: path.clone() },
2378        )
2379        .await?;
2380
2381    let (dev_server_project, update) = session
2382        .db()
2383        .await
2384        .create_dev_server_project(
2385            DevServerId(request.dev_server_id as i32),
2386            &request.path,
2387            session.user_id(),
2388        )
2389        .await?;
2390
2391    let projects = session
2392        .db()
2393        .await
2394        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2395        .await?;
2396
2397    session.peer.send(
2398        dev_server_connection_id,
2399        proto::DevServerInstructions { projects },
2400    )?;
2401
2402    send_dev_server_projects_update(session.user_id(), update, &session).await;
2403
2404    response.send(proto::CreateDevServerProjectResponse {
2405        dev_server_project: Some(dev_server_project.to_proto(None)),
2406    })?;
2407    Ok(())
2408}
2409
2410async fn create_dev_server(
2411    request: proto::CreateDevServer,
2412    response: Response<proto::CreateDevServer>,
2413    session: UserSession,
2414) -> Result<()> {
2415    let access_token = auth::random_token();
2416    let hashed_access_token = auth::hash_access_token(&access_token);
2417
2418    if request.name.is_empty() {
2419        return Err(proto::ErrorCode::Forbidden
2420            .message("Dev server name cannot be empty".to_string())
2421            .anyhow())?;
2422    }
2423
2424    let (dev_server, status) = session
2425        .db()
2426        .await
2427        .create_dev_server(
2428            &request.name,
2429            request.ssh_connection_string.as_deref(),
2430            &hashed_access_token,
2431            session.user_id(),
2432        )
2433        .await?;
2434
2435    send_dev_server_projects_update(session.user_id(), status, &session).await;
2436
2437    response.send(proto::CreateDevServerResponse {
2438        dev_server_id: dev_server.id.0 as u64,
2439        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2440        name: request.name,
2441    })?;
2442    Ok(())
2443}
2444
2445async fn regenerate_dev_server_token(
2446    request: proto::RegenerateDevServerToken,
2447    response: Response<proto::RegenerateDevServerToken>,
2448    session: UserSession,
2449) -> Result<()> {
2450    let dev_server_id = DevServerId(request.dev_server_id as i32);
2451    let access_token = auth::random_token();
2452    let hashed_access_token = auth::hash_access_token(&access_token);
2453
2454    let connection_id = session
2455        .connection_pool()
2456        .await
2457        .dev_server_connection_id(dev_server_id);
2458    if let Some(connection_id) = connection_id {
2459        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2460        session.peer.send(
2461            connection_id,
2462            proto::ShutdownDevServer {
2463                reason: Some("dev server token was regenerated".to_string()),
2464            },
2465        )?;
2466        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2467    }
2468
2469    let status = session
2470        .db()
2471        .await
2472        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2473        .await?;
2474
2475    send_dev_server_projects_update(session.user_id(), status, &session).await;
2476
2477    response.send(proto::RegenerateDevServerTokenResponse {
2478        dev_server_id: dev_server_id.to_proto(),
2479        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2480    })?;
2481    Ok(())
2482}
2483
2484async fn rename_dev_server(
2485    request: proto::RenameDevServer,
2486    response: Response<proto::RenameDevServer>,
2487    session: UserSession,
2488) -> Result<()> {
2489    if request.name.trim().is_empty() {
2490        return Err(proto::ErrorCode::Forbidden
2491            .message("Dev server name cannot be empty".to_string())
2492            .anyhow())?;
2493    }
2494
2495    let dev_server_id = DevServerId(request.dev_server_id as i32);
2496    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2497    if dev_server.user_id != session.user_id() {
2498        return Err(anyhow!(ErrorCode::Forbidden))?;
2499    }
2500
2501    let status = session
2502        .db()
2503        .await
2504        .rename_dev_server(
2505            dev_server_id,
2506            &request.name,
2507            request.ssh_connection_string.as_deref(),
2508            session.user_id(),
2509        )
2510        .await?;
2511
2512    send_dev_server_projects_update(session.user_id(), status, &session).await;
2513
2514    response.send(proto::Ack {})?;
2515    Ok(())
2516}
2517
2518async fn delete_dev_server(
2519    request: proto::DeleteDevServer,
2520    response: Response<proto::DeleteDevServer>,
2521    session: UserSession,
2522) -> Result<()> {
2523    let dev_server_id = DevServerId(request.dev_server_id as i32);
2524    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2525    if dev_server.user_id != session.user_id() {
2526        return Err(anyhow!(ErrorCode::Forbidden))?;
2527    }
2528
2529    let connection_id = session
2530        .connection_pool()
2531        .await
2532        .dev_server_connection_id(dev_server_id);
2533    if let Some(connection_id) = connection_id {
2534        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2535        session.peer.send(
2536            connection_id,
2537            proto::ShutdownDevServer {
2538                reason: Some("dev server was deleted".to_string()),
2539            },
2540        )?;
2541        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2542    }
2543
2544    let status = session
2545        .db()
2546        .await
2547        .delete_dev_server(dev_server_id, session.user_id())
2548        .await?;
2549
2550    send_dev_server_projects_update(session.user_id(), status, &session).await;
2551
2552    response.send(proto::Ack {})?;
2553    Ok(())
2554}
2555
2556async fn delete_dev_server_project(
2557    request: proto::DeleteDevServerProject,
2558    response: Response<proto::DeleteDevServerProject>,
2559    session: UserSession,
2560) -> Result<()> {
2561    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2562    let dev_server_project = session
2563        .db()
2564        .await
2565        .get_dev_server_project(dev_server_project_id)
2566        .await?;
2567
2568    let dev_server = session
2569        .db()
2570        .await
2571        .get_dev_server(dev_server_project.dev_server_id)
2572        .await?;
2573    if dev_server.user_id != session.user_id() {
2574        return Err(anyhow!(ErrorCode::Forbidden))?;
2575    }
2576
2577    let dev_server_connection_id = session
2578        .connection_pool()
2579        .await
2580        .dev_server_connection_id(dev_server.id);
2581
2582    if let Some(dev_server_connection_id) = dev_server_connection_id {
2583        let project = session
2584            .db()
2585            .await
2586            .find_dev_server_project(dev_server_project_id)
2587            .await;
2588        if let Ok(project) = project {
2589            unshare_project_internal(
2590                project.id,
2591                dev_server_connection_id,
2592                Some(session.user_id()),
2593                &session,
2594            )
2595            .await?;
2596        }
2597    }
2598
2599    let (projects, status) = session
2600        .db()
2601        .await
2602        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2603        .await?;
2604
2605    if let Some(dev_server_connection_id) = dev_server_connection_id {
2606        session.peer.send(
2607            dev_server_connection_id,
2608            proto::DevServerInstructions { projects },
2609        )?;
2610    }
2611
2612    send_dev_server_projects_update(session.user_id(), status, &session).await;
2613
2614    response.send(proto::Ack {})?;
2615    Ok(())
2616}
2617
2618async fn rejoin_dev_server_projects(
2619    request: proto::RejoinRemoteProjects,
2620    response: Response<proto::RejoinRemoteProjects>,
2621    session: UserSession,
2622) -> Result<()> {
2623    let mut rejoined_projects = {
2624        let db = session.db().await;
2625        db.rejoin_dev_server_projects(
2626            &request.rejoined_projects,
2627            session.user_id(),
2628            session.0.connection_id,
2629        )
2630        .await?
2631    };
2632    response.send(proto::RejoinRemoteProjectsResponse {
2633        rejoined_projects: rejoined_projects
2634            .iter()
2635            .map(|project| project.to_proto())
2636            .collect(),
2637    })?;
2638    notify_rejoined_projects(&mut rejoined_projects, &session)
2639}
2640
2641async fn reconnect_dev_server(
2642    request: proto::ReconnectDevServer,
2643    response: Response<proto::ReconnectDevServer>,
2644    session: DevServerSession,
2645) -> Result<()> {
2646    let reshared_projects = {
2647        let db = session.db().await;
2648        db.reshare_dev_server_projects(
2649            &request.reshared_projects,
2650            session.dev_server_id(),
2651            session.0.connection_id,
2652        )
2653        .await?
2654    };
2655
2656    for project in &reshared_projects {
2657        for collaborator in &project.collaborators {
2658            session
2659                .peer
2660                .send(
2661                    collaborator.connection_id,
2662                    proto::UpdateProjectCollaborator {
2663                        project_id: project.id.to_proto(),
2664                        old_peer_id: Some(project.old_connection_id.into()),
2665                        new_peer_id: Some(session.connection_id.into()),
2666                    },
2667                )
2668                .trace_err();
2669        }
2670
2671        broadcast(
2672            Some(session.connection_id),
2673            project
2674                .collaborators
2675                .iter()
2676                .map(|collaborator| collaborator.connection_id),
2677            |connection_id| {
2678                session.peer.forward_send(
2679                    session.connection_id,
2680                    connection_id,
2681                    proto::UpdateProject {
2682                        project_id: project.id.to_proto(),
2683                        worktrees: project.worktrees.clone(),
2684                    },
2685                )
2686            },
2687        );
2688    }
2689
2690    response.send(proto::ReconnectDevServerResponse {
2691        reshared_projects: reshared_projects
2692            .iter()
2693            .map(|project| proto::ResharedProject {
2694                id: project.id.to_proto(),
2695                collaborators: project
2696                    .collaborators
2697                    .iter()
2698                    .map(|collaborator| collaborator.to_proto())
2699                    .collect(),
2700            })
2701            .collect(),
2702    })?;
2703
2704    Ok(())
2705}
2706
2707async fn shutdown_dev_server(
2708    _: proto::ShutdownDevServer,
2709    response: Response<proto::ShutdownDevServer>,
2710    session: DevServerSession,
2711) -> Result<()> {
2712    response.send(proto::Ack {})?;
2713    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2714    remove_dev_server_connection(session.dev_server_id(), &session).await
2715}
2716
2717async fn shutdown_dev_server_internal(
2718    dev_server_id: DevServerId,
2719    connection_id: ConnectionId,
2720    session: &Session,
2721) -> Result<()> {
2722    let (dev_server_projects, dev_server) = {
2723        let db = session.db().await;
2724        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2725        let dev_server = db.get_dev_server(dev_server_id).await?;
2726        (dev_server_projects, dev_server)
2727    };
2728
2729    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2730        unshare_project_internal(
2731            ProjectId::from_proto(project_id),
2732            connection_id,
2733            None,
2734            session,
2735        )
2736        .await?;
2737    }
2738
2739    session
2740        .connection_pool()
2741        .await
2742        .set_dev_server_offline(dev_server_id);
2743
2744    let status = session
2745        .db()
2746        .await
2747        .dev_server_projects_update(dev_server.user_id)
2748        .await?;
2749    send_dev_server_projects_update(dev_server.user_id, status, session).await;
2750
2751    Ok(())
2752}
2753
2754async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2755    let dev_server_connection = session
2756        .connection_pool()
2757        .await
2758        .dev_server_connection_id(dev_server_id);
2759
2760    if let Some(dev_server_connection) = dev_server_connection {
2761        session
2762            .connection_pool()
2763            .await
2764            .remove_connection(dev_server_connection)?;
2765    }
2766    Ok(())
2767}
2768
2769/// Updates other participants with changes to the project
2770async fn update_project(
2771    request: proto::UpdateProject,
2772    response: Response<proto::UpdateProject>,
2773    session: Session,
2774) -> Result<()> {
2775    let project_id = ProjectId::from_proto(request.project_id);
2776    let (room, guest_connection_ids) = &*session
2777        .db()
2778        .await
2779        .update_project(project_id, session.connection_id, &request.worktrees)
2780        .await?;
2781    broadcast(
2782        Some(session.connection_id),
2783        guest_connection_ids.iter().copied(),
2784        |connection_id| {
2785            session
2786                .peer
2787                .forward_send(session.connection_id, connection_id, request.clone())
2788        },
2789    );
2790    if let Some(room) = room {
2791        room_updated(room, &session.peer);
2792    }
2793    response.send(proto::Ack {})?;
2794
2795    Ok(())
2796}
2797
2798/// Updates other participants with changes to the worktree
2799async fn update_worktree(
2800    request: proto::UpdateWorktree,
2801    response: Response<proto::UpdateWorktree>,
2802    session: Session,
2803) -> Result<()> {
2804    let guest_connection_ids = session
2805        .db()
2806        .await
2807        .update_worktree(&request, session.connection_id)
2808        .await?;
2809
2810    broadcast(
2811        Some(session.connection_id),
2812        guest_connection_ids.iter().copied(),
2813        |connection_id| {
2814            session
2815                .peer
2816                .forward_send(session.connection_id, connection_id, request.clone())
2817        },
2818    );
2819    response.send(proto::Ack {})?;
2820    Ok(())
2821}
2822
2823/// Updates other participants with changes to the diagnostics
2824async fn update_diagnostic_summary(
2825    message: proto::UpdateDiagnosticSummary,
2826    session: Session,
2827) -> Result<()> {
2828    let guest_connection_ids = session
2829        .db()
2830        .await
2831        .update_diagnostic_summary(&message, session.connection_id)
2832        .await?;
2833
2834    broadcast(
2835        Some(session.connection_id),
2836        guest_connection_ids.iter().copied(),
2837        |connection_id| {
2838            session
2839                .peer
2840                .forward_send(session.connection_id, connection_id, message.clone())
2841        },
2842    );
2843
2844    Ok(())
2845}
2846
2847/// Updates other participants with changes to the worktree settings
2848async fn update_worktree_settings(
2849    message: proto::UpdateWorktreeSettings,
2850    session: Session,
2851) -> Result<()> {
2852    let guest_connection_ids = session
2853        .db()
2854        .await
2855        .update_worktree_settings(&message, session.connection_id)
2856        .await?;
2857
2858    broadcast(
2859        Some(session.connection_id),
2860        guest_connection_ids.iter().copied(),
2861        |connection_id| {
2862            session
2863                .peer
2864                .forward_send(session.connection_id, connection_id, message.clone())
2865        },
2866    );
2867
2868    Ok(())
2869}
2870
2871/// Notify other participants that a  language server has started.
2872async fn start_language_server(
2873    request: proto::StartLanguageServer,
2874    session: Session,
2875) -> Result<()> {
2876    let guest_connection_ids = session
2877        .db()
2878        .await
2879        .start_language_server(&request, session.connection_id)
2880        .await?;
2881
2882    broadcast(
2883        Some(session.connection_id),
2884        guest_connection_ids.iter().copied(),
2885        |connection_id| {
2886            session
2887                .peer
2888                .forward_send(session.connection_id, connection_id, request.clone())
2889        },
2890    );
2891    Ok(())
2892}
2893
2894/// Notify other participants that a language server has changed.
2895async fn update_language_server(
2896    request: proto::UpdateLanguageServer,
2897    session: Session,
2898) -> Result<()> {
2899    let project_id = ProjectId::from_proto(request.project_id);
2900    let project_connection_ids = session
2901        .db()
2902        .await
2903        .project_connection_ids(project_id, session.connection_id, true)
2904        .await?;
2905    broadcast(
2906        Some(session.connection_id),
2907        project_connection_ids.iter().copied(),
2908        |connection_id| {
2909            session
2910                .peer
2911                .forward_send(session.connection_id, connection_id, request.clone())
2912        },
2913    );
2914    Ok(())
2915}
2916
2917/// forward a project request to the host. These requests should be read only
2918/// as guests are allowed to send them.
2919async fn forward_read_only_project_request<T>(
2920    request: T,
2921    response: Response<T>,
2922    session: UserSession,
2923) -> Result<()>
2924where
2925    T: EntityMessage + RequestMessage,
2926{
2927    let project_id = ProjectId::from_proto(request.remote_entity_id());
2928    let host_connection_id = session
2929        .db()
2930        .await
2931        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2932        .await?;
2933    let payload = session
2934        .peer
2935        .forward_request(session.connection_id, host_connection_id, request)
2936        .await?;
2937    response.send(payload)?;
2938    Ok(())
2939}
2940
2941async fn forward_find_search_candidates_request(
2942    request: proto::FindSearchCandidates,
2943    response: Response<proto::FindSearchCandidates>,
2944    session: UserSession,
2945) -> Result<()> {
2946    let project_id = ProjectId::from_proto(request.remote_entity_id());
2947    let host_connection_id = session
2948        .db()
2949        .await
2950        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2951        .await?;
2952
2953    let host_version = session
2954        .connection_pool()
2955        .await
2956        .connection(host_connection_id)
2957        .map(|c| c.zed_version);
2958
2959    if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2960    {
2961        let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2962        let search = proto::SearchProject {
2963            project_id: project_id.to_proto(),
2964            query: query.query,
2965            regex: query.regex,
2966            whole_word: query.whole_word,
2967            case_sensitive: query.case_sensitive,
2968            files_to_include: query.files_to_include,
2969            files_to_exclude: query.files_to_exclude,
2970            include_ignored: query.include_ignored,
2971        };
2972
2973        let payload = session
2974            .peer
2975            .forward_request(session.connection_id, host_connection_id, search)
2976            .await?;
2977        return response.send(proto::FindSearchCandidatesResponse {
2978            buffer_ids: payload
2979                .locations
2980                .into_iter()
2981                .map(|loc| loc.buffer_id)
2982                .collect(),
2983        });
2984    }
2985
2986    let payload = session
2987        .peer
2988        .forward_request(session.connection_id, host_connection_id, request)
2989        .await?;
2990    response.send(payload)?;
2991    Ok(())
2992}
2993
2994/// forward a project request to the dev server. Only allowed
2995/// if it's your dev server.
2996async fn forward_project_request_for_owner<T>(
2997    request: T,
2998    response: Response<T>,
2999    session: UserSession,
3000) -> Result<()>
3001where
3002    T: EntityMessage + RequestMessage,
3003{
3004    let project_id = ProjectId::from_proto(request.remote_entity_id());
3005
3006    let host_connection_id = session
3007        .db()
3008        .await
3009        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3010        .await?;
3011    let payload = session
3012        .peer
3013        .forward_request(session.connection_id, host_connection_id, request)
3014        .await?;
3015    response.send(payload)?;
3016    Ok(())
3017}
3018
3019/// forward a project request to the host. These requests are disallowed
3020/// for guests.
3021async fn forward_mutating_project_request<T>(
3022    request: T,
3023    response: Response<T>,
3024    session: UserSession,
3025) -> Result<()>
3026where
3027    T: EntityMessage + RequestMessage,
3028{
3029    let project_id = ProjectId::from_proto(request.remote_entity_id());
3030
3031    let host_connection_id = session
3032        .db()
3033        .await
3034        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3035        .await?;
3036    let payload = session
3037        .peer
3038        .forward_request(session.connection_id, host_connection_id, request)
3039        .await?;
3040    response.send(payload)?;
3041    Ok(())
3042}
3043
3044/// Notify other participants that a new buffer has been created
3045async fn create_buffer_for_peer(
3046    request: proto::CreateBufferForPeer,
3047    session: Session,
3048) -> Result<()> {
3049    session
3050        .db()
3051        .await
3052        .check_user_is_project_host(
3053            ProjectId::from_proto(request.project_id),
3054            session.connection_id,
3055        )
3056        .await?;
3057    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3058    session
3059        .peer
3060        .forward_send(session.connection_id, peer_id.into(), request)?;
3061    Ok(())
3062}
3063
3064/// Notify other participants that a buffer has been updated. This is
3065/// allowed for guests as long as the update is limited to selections.
3066async fn update_buffer(
3067    request: proto::UpdateBuffer,
3068    response: Response<proto::UpdateBuffer>,
3069    session: Session,
3070) -> Result<()> {
3071    let project_id = ProjectId::from_proto(request.project_id);
3072    let mut capability = Capability::ReadOnly;
3073
3074    for op in request.operations.iter() {
3075        match op.variant {
3076            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3077            Some(_) => capability = Capability::ReadWrite,
3078        }
3079    }
3080
3081    let host = {
3082        let guard = session
3083            .db()
3084            .await
3085            .connections_for_buffer_update(
3086                project_id,
3087                session.principal_id(),
3088                session.connection_id,
3089                capability,
3090            )
3091            .await?;
3092
3093        let (host, guests) = &*guard;
3094
3095        broadcast(
3096            Some(session.connection_id),
3097            guests.clone(),
3098            |connection_id| {
3099                session
3100                    .peer
3101                    .forward_send(session.connection_id, connection_id, request.clone())
3102            },
3103        );
3104
3105        *host
3106    };
3107
3108    if host != session.connection_id {
3109        session
3110            .peer
3111            .forward_request(session.connection_id, host, request.clone())
3112            .await?;
3113    }
3114
3115    response.send(proto::Ack {})?;
3116    Ok(())
3117}
3118
3119async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3120    let project_id = ProjectId::from_proto(message.project_id);
3121
3122    let operation = message.operation.as_ref().context("invalid operation")?;
3123    let capability = match operation.variant.as_ref() {
3124        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3125            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3126                match buffer_op.variant {
3127                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3128                        Capability::ReadOnly
3129                    }
3130                    _ => Capability::ReadWrite,
3131                }
3132            } else {
3133                Capability::ReadWrite
3134            }
3135        }
3136        Some(_) => Capability::ReadWrite,
3137        None => Capability::ReadOnly,
3138    };
3139
3140    let guard = session
3141        .db()
3142        .await
3143        .connections_for_buffer_update(
3144            project_id,
3145            session.principal_id(),
3146            session.connection_id,
3147            capability,
3148        )
3149        .await?;
3150
3151    let (host, guests) = &*guard;
3152
3153    broadcast(
3154        Some(session.connection_id),
3155        guests.iter().chain([host]).copied(),
3156        |connection_id| {
3157            session
3158                .peer
3159                .forward_send(session.connection_id, connection_id, message.clone())
3160        },
3161    );
3162
3163    Ok(())
3164}
3165
3166/// Notify other participants that a project has been updated.
3167async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3168    request: T,
3169    session: Session,
3170) -> Result<()> {
3171    let project_id = ProjectId::from_proto(request.remote_entity_id());
3172    let project_connection_ids = session
3173        .db()
3174        .await
3175        .project_connection_ids(project_id, session.connection_id, false)
3176        .await?;
3177
3178    broadcast(
3179        Some(session.connection_id),
3180        project_connection_ids.iter().copied(),
3181        |connection_id| {
3182            session
3183                .peer
3184                .forward_send(session.connection_id, connection_id, request.clone())
3185        },
3186    );
3187    Ok(())
3188}
3189
3190/// Start following another user in a call.
3191async fn follow(
3192    request: proto::Follow,
3193    response: Response<proto::Follow>,
3194    session: UserSession,
3195) -> Result<()> {
3196    let room_id = RoomId::from_proto(request.room_id);
3197    let project_id = request.project_id.map(ProjectId::from_proto);
3198    let leader_id = request
3199        .leader_id
3200        .ok_or_else(|| anyhow!("invalid leader id"))?
3201        .into();
3202    let follower_id = session.connection_id;
3203
3204    session
3205        .db()
3206        .await
3207        .check_room_participants(room_id, leader_id, session.connection_id)
3208        .await?;
3209
3210    let response_payload = session
3211        .peer
3212        .forward_request(session.connection_id, leader_id, request)
3213        .await?;
3214    response.send(response_payload)?;
3215
3216    if let Some(project_id) = project_id {
3217        let room = session
3218            .db()
3219            .await
3220            .follow(room_id, project_id, leader_id, follower_id)
3221            .await?;
3222        room_updated(&room, &session.peer);
3223    }
3224
3225    Ok(())
3226}
3227
3228/// Stop following another user in a call.
3229async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3230    let room_id = RoomId::from_proto(request.room_id);
3231    let project_id = request.project_id.map(ProjectId::from_proto);
3232    let leader_id = request
3233        .leader_id
3234        .ok_or_else(|| anyhow!("invalid leader id"))?
3235        .into();
3236    let follower_id = session.connection_id;
3237
3238    session
3239        .db()
3240        .await
3241        .check_room_participants(room_id, leader_id, session.connection_id)
3242        .await?;
3243
3244    session
3245        .peer
3246        .forward_send(session.connection_id, leader_id, request)?;
3247
3248    if let Some(project_id) = project_id {
3249        let room = session
3250            .db()
3251            .await
3252            .unfollow(room_id, project_id, leader_id, follower_id)
3253            .await?;
3254        room_updated(&room, &session.peer);
3255    }
3256
3257    Ok(())
3258}
3259
3260/// Notify everyone following you of your current location.
3261async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3262    let room_id = RoomId::from_proto(request.room_id);
3263    let database = session.db.lock().await;
3264
3265    let connection_ids = if let Some(project_id) = request.project_id {
3266        let project_id = ProjectId::from_proto(project_id);
3267        database
3268            .project_connection_ids(project_id, session.connection_id, true)
3269            .await?
3270    } else {
3271        database
3272            .room_connection_ids(room_id, session.connection_id)
3273            .await?
3274    };
3275
3276    // For now, don't send view update messages back to that view's current leader.
3277    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3278        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3279        _ => None,
3280    });
3281
3282    for connection_id in connection_ids.iter().cloned() {
3283        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3284            session
3285                .peer
3286                .forward_send(session.connection_id, connection_id, request.clone())?;
3287        }
3288    }
3289    Ok(())
3290}
3291
3292/// Get public data about users.
3293async fn get_users(
3294    request: proto::GetUsers,
3295    response: Response<proto::GetUsers>,
3296    session: Session,
3297) -> Result<()> {
3298    let user_ids = request
3299        .user_ids
3300        .into_iter()
3301        .map(UserId::from_proto)
3302        .collect();
3303    let users = session
3304        .db()
3305        .await
3306        .get_users_by_ids(user_ids)
3307        .await?
3308        .into_iter()
3309        .map(|user| proto::User {
3310            id: user.id.to_proto(),
3311            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3312            github_login: user.github_login,
3313        })
3314        .collect();
3315    response.send(proto::UsersResponse { users })?;
3316    Ok(())
3317}
3318
3319/// Search for users (to invite) buy Github login
3320async fn fuzzy_search_users(
3321    request: proto::FuzzySearchUsers,
3322    response: Response<proto::FuzzySearchUsers>,
3323    session: UserSession,
3324) -> Result<()> {
3325    let query = request.query;
3326    let users = match query.len() {
3327        0 => vec![],
3328        1 | 2 => session
3329            .db()
3330            .await
3331            .get_user_by_github_login(&query)
3332            .await?
3333            .into_iter()
3334            .collect(),
3335        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3336    };
3337    let users = users
3338        .into_iter()
3339        .filter(|user| user.id != session.user_id())
3340        .map(|user| proto::User {
3341            id: user.id.to_proto(),
3342            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3343            github_login: user.github_login,
3344        })
3345        .collect();
3346    response.send(proto::UsersResponse { users })?;
3347    Ok(())
3348}
3349
3350/// Send a contact request to another user.
3351async fn request_contact(
3352    request: proto::RequestContact,
3353    response: Response<proto::RequestContact>,
3354    session: UserSession,
3355) -> Result<()> {
3356    let requester_id = session.user_id();
3357    let responder_id = UserId::from_proto(request.responder_id);
3358    if requester_id == responder_id {
3359        return Err(anyhow!("cannot add yourself as a contact"))?;
3360    }
3361
3362    let notifications = session
3363        .db()
3364        .await
3365        .send_contact_request(requester_id, responder_id)
3366        .await?;
3367
3368    // Update outgoing contact requests of requester
3369    let mut update = proto::UpdateContacts::default();
3370    update.outgoing_requests.push(responder_id.to_proto());
3371    for connection_id in session
3372        .connection_pool()
3373        .await
3374        .user_connection_ids(requester_id)
3375    {
3376        session.peer.send(connection_id, update.clone())?;
3377    }
3378
3379    // Update incoming contact requests of responder
3380    let mut update = proto::UpdateContacts::default();
3381    update
3382        .incoming_requests
3383        .push(proto::IncomingContactRequest {
3384            requester_id: requester_id.to_proto(),
3385        });
3386    let connection_pool = session.connection_pool().await;
3387    for connection_id in connection_pool.user_connection_ids(responder_id) {
3388        session.peer.send(connection_id, update.clone())?;
3389    }
3390
3391    send_notifications(&connection_pool, &session.peer, notifications);
3392
3393    response.send(proto::Ack {})?;
3394    Ok(())
3395}
3396
3397/// Accept or decline a contact request
3398async fn respond_to_contact_request(
3399    request: proto::RespondToContactRequest,
3400    response: Response<proto::RespondToContactRequest>,
3401    session: UserSession,
3402) -> Result<()> {
3403    let responder_id = session.user_id();
3404    let requester_id = UserId::from_proto(request.requester_id);
3405    let db = session.db().await;
3406    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3407        db.dismiss_contact_notification(responder_id, requester_id)
3408            .await?;
3409    } else {
3410        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3411
3412        let notifications = db
3413            .respond_to_contact_request(responder_id, requester_id, accept)
3414            .await?;
3415        let requester_busy = db.is_user_busy(requester_id).await?;
3416        let responder_busy = db.is_user_busy(responder_id).await?;
3417
3418        let pool = session.connection_pool().await;
3419        // Update responder with new contact
3420        let mut update = proto::UpdateContacts::default();
3421        if accept {
3422            update
3423                .contacts
3424                .push(contact_for_user(requester_id, requester_busy, &pool));
3425        }
3426        update
3427            .remove_incoming_requests
3428            .push(requester_id.to_proto());
3429        for connection_id in pool.user_connection_ids(responder_id) {
3430            session.peer.send(connection_id, update.clone())?;
3431        }
3432
3433        // Update requester with new contact
3434        let mut update = proto::UpdateContacts::default();
3435        if accept {
3436            update
3437                .contacts
3438                .push(contact_for_user(responder_id, responder_busy, &pool));
3439        }
3440        update
3441            .remove_outgoing_requests
3442            .push(responder_id.to_proto());
3443
3444        for connection_id in pool.user_connection_ids(requester_id) {
3445            session.peer.send(connection_id, update.clone())?;
3446        }
3447
3448        send_notifications(&pool, &session.peer, notifications);
3449    }
3450
3451    response.send(proto::Ack {})?;
3452    Ok(())
3453}
3454
3455/// Remove a contact.
3456async fn remove_contact(
3457    request: proto::RemoveContact,
3458    response: Response<proto::RemoveContact>,
3459    session: UserSession,
3460) -> Result<()> {
3461    let requester_id = session.user_id();
3462    let responder_id = UserId::from_proto(request.user_id);
3463    let db = session.db().await;
3464    let (contact_accepted, deleted_notification_id) =
3465        db.remove_contact(requester_id, responder_id).await?;
3466
3467    let pool = session.connection_pool().await;
3468    // Update outgoing contact requests of requester
3469    let mut update = proto::UpdateContacts::default();
3470    if contact_accepted {
3471        update.remove_contacts.push(responder_id.to_proto());
3472    } else {
3473        update
3474            .remove_outgoing_requests
3475            .push(responder_id.to_proto());
3476    }
3477    for connection_id in pool.user_connection_ids(requester_id) {
3478        session.peer.send(connection_id, update.clone())?;
3479    }
3480
3481    // Update incoming contact requests of responder
3482    let mut update = proto::UpdateContacts::default();
3483    if contact_accepted {
3484        update.remove_contacts.push(requester_id.to_proto());
3485    } else {
3486        update
3487            .remove_incoming_requests
3488            .push(requester_id.to_proto());
3489    }
3490    for connection_id in pool.user_connection_ids(responder_id) {
3491        session.peer.send(connection_id, update.clone())?;
3492        if let Some(notification_id) = deleted_notification_id {
3493            session.peer.send(
3494                connection_id,
3495                proto::DeleteNotification {
3496                    notification_id: notification_id.to_proto(),
3497                },
3498            )?;
3499        }
3500    }
3501
3502    response.send(proto::Ack {})?;
3503    Ok(())
3504}
3505
3506fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3507    version.0.minor() < 139
3508}
3509
3510async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3511    let plan = session.current_plan(session.db().await).await?;
3512
3513    session
3514        .peer
3515        .send(
3516            session.connection_id,
3517            proto::UpdateUserPlan { plan: plan.into() },
3518        )
3519        .trace_err();
3520
3521    Ok(())
3522}
3523
3524async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3525    subscribe_user_to_channels(
3526        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3527        &session,
3528    )
3529    .await?;
3530    Ok(())
3531}
3532
3533async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3534    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3535    let mut pool = session.connection_pool().await;
3536    for membership in &channels_for_user.channel_memberships {
3537        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3538    }
3539    session.peer.send(
3540        session.connection_id,
3541        build_update_user_channels(&channels_for_user),
3542    )?;
3543    session.peer.send(
3544        session.connection_id,
3545        build_channels_update(channels_for_user),
3546    )?;
3547    Ok(())
3548}
3549
3550/// Creates a new channel.
3551async fn create_channel(
3552    request: proto::CreateChannel,
3553    response: Response<proto::CreateChannel>,
3554    session: UserSession,
3555) -> Result<()> {
3556    let db = session.db().await;
3557
3558    let parent_id = request.parent_id.map(ChannelId::from_proto);
3559    let (channel, membership) = db
3560        .create_channel(&request.name, parent_id, session.user_id())
3561        .await?;
3562
3563    let root_id = channel.root_id();
3564    let channel = Channel::from_model(channel);
3565
3566    response.send(proto::CreateChannelResponse {
3567        channel: Some(channel.to_proto()),
3568        parent_id: request.parent_id,
3569    })?;
3570
3571    let mut connection_pool = session.connection_pool().await;
3572    if let Some(membership) = membership {
3573        connection_pool.subscribe_to_channel(
3574            membership.user_id,
3575            membership.channel_id,
3576            membership.role,
3577        );
3578        let update = proto::UpdateUserChannels {
3579            channel_memberships: vec![proto::ChannelMembership {
3580                channel_id: membership.channel_id.to_proto(),
3581                role: membership.role.into(),
3582            }],
3583            ..Default::default()
3584        };
3585        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3586            session.peer.send(connection_id, update.clone())?;
3587        }
3588    }
3589
3590    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3591        if !role.can_see_channel(channel.visibility) {
3592            continue;
3593        }
3594
3595        let update = proto::UpdateChannels {
3596            channels: vec![channel.to_proto()],
3597            ..Default::default()
3598        };
3599        session.peer.send(connection_id, update.clone())?;
3600    }
3601
3602    Ok(())
3603}
3604
3605/// Delete a channel
3606async fn delete_channel(
3607    request: proto::DeleteChannel,
3608    response: Response<proto::DeleteChannel>,
3609    session: UserSession,
3610) -> Result<()> {
3611    let db = session.db().await;
3612
3613    let channel_id = request.channel_id;
3614    let (root_channel, removed_channels) = db
3615        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3616        .await?;
3617    response.send(proto::Ack {})?;
3618
3619    // Notify members of removed channels
3620    let mut update = proto::UpdateChannels::default();
3621    update
3622        .delete_channels
3623        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3624
3625    let connection_pool = session.connection_pool().await;
3626    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3627        session.peer.send(connection_id, update.clone())?;
3628    }
3629
3630    Ok(())
3631}
3632
3633/// Invite someone to join a channel.
3634async fn invite_channel_member(
3635    request: proto::InviteChannelMember,
3636    response: Response<proto::InviteChannelMember>,
3637    session: UserSession,
3638) -> Result<()> {
3639    let db = session.db().await;
3640    let channel_id = ChannelId::from_proto(request.channel_id);
3641    let invitee_id = UserId::from_proto(request.user_id);
3642    let InviteMemberResult {
3643        channel,
3644        notifications,
3645    } = db
3646        .invite_channel_member(
3647            channel_id,
3648            invitee_id,
3649            session.user_id(),
3650            request.role().into(),
3651        )
3652        .await?;
3653
3654    let update = proto::UpdateChannels {
3655        channel_invitations: vec![channel.to_proto()],
3656        ..Default::default()
3657    };
3658
3659    let connection_pool = session.connection_pool().await;
3660    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3661        session.peer.send(connection_id, update.clone())?;
3662    }
3663
3664    send_notifications(&connection_pool, &session.peer, notifications);
3665
3666    response.send(proto::Ack {})?;
3667    Ok(())
3668}
3669
3670/// remove someone from a channel
3671async fn remove_channel_member(
3672    request: proto::RemoveChannelMember,
3673    response: Response<proto::RemoveChannelMember>,
3674    session: UserSession,
3675) -> Result<()> {
3676    let db = session.db().await;
3677    let channel_id = ChannelId::from_proto(request.channel_id);
3678    let member_id = UserId::from_proto(request.user_id);
3679
3680    let RemoveChannelMemberResult {
3681        membership_update,
3682        notification_id,
3683    } = db
3684        .remove_channel_member(channel_id, member_id, session.user_id())
3685        .await?;
3686
3687    let mut connection_pool = session.connection_pool().await;
3688    notify_membership_updated(
3689        &mut connection_pool,
3690        membership_update,
3691        member_id,
3692        &session.peer,
3693    );
3694    for connection_id in connection_pool.user_connection_ids(member_id) {
3695        if let Some(notification_id) = notification_id {
3696            session
3697                .peer
3698                .send(
3699                    connection_id,
3700                    proto::DeleteNotification {
3701                        notification_id: notification_id.to_proto(),
3702                    },
3703                )
3704                .trace_err();
3705        }
3706    }
3707
3708    response.send(proto::Ack {})?;
3709    Ok(())
3710}
3711
3712/// Toggle the channel between public and private.
3713/// Care is taken to maintain the invariant that public channels only descend from public channels,
3714/// (though members-only channels can appear at any point in the hierarchy).
3715async fn set_channel_visibility(
3716    request: proto::SetChannelVisibility,
3717    response: Response<proto::SetChannelVisibility>,
3718    session: UserSession,
3719) -> Result<()> {
3720    let db = session.db().await;
3721    let channel_id = ChannelId::from_proto(request.channel_id);
3722    let visibility = request.visibility().into();
3723
3724    let channel_model = db
3725        .set_channel_visibility(channel_id, visibility, session.user_id())
3726        .await?;
3727    let root_id = channel_model.root_id();
3728    let channel = Channel::from_model(channel_model);
3729
3730    let mut connection_pool = session.connection_pool().await;
3731    for (user_id, role) in connection_pool
3732        .channel_user_ids(root_id)
3733        .collect::<Vec<_>>()
3734        .into_iter()
3735    {
3736        let update = if role.can_see_channel(channel.visibility) {
3737            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3738            proto::UpdateChannels {
3739                channels: vec![channel.to_proto()],
3740                ..Default::default()
3741            }
3742        } else {
3743            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3744            proto::UpdateChannels {
3745                delete_channels: vec![channel.id.to_proto()],
3746                ..Default::default()
3747            }
3748        };
3749
3750        for connection_id in connection_pool.user_connection_ids(user_id) {
3751            session.peer.send(connection_id, update.clone())?;
3752        }
3753    }
3754
3755    response.send(proto::Ack {})?;
3756    Ok(())
3757}
3758
3759/// Alter the role for a user in the channel.
3760async fn set_channel_member_role(
3761    request: proto::SetChannelMemberRole,
3762    response: Response<proto::SetChannelMemberRole>,
3763    session: UserSession,
3764) -> Result<()> {
3765    let db = session.db().await;
3766    let channel_id = ChannelId::from_proto(request.channel_id);
3767    let member_id = UserId::from_proto(request.user_id);
3768    let result = db
3769        .set_channel_member_role(
3770            channel_id,
3771            session.user_id(),
3772            member_id,
3773            request.role().into(),
3774        )
3775        .await?;
3776
3777    match result {
3778        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3779            let mut connection_pool = session.connection_pool().await;
3780            notify_membership_updated(
3781                &mut connection_pool,
3782                membership_update,
3783                member_id,
3784                &session.peer,
3785            )
3786        }
3787        db::SetMemberRoleResult::InviteUpdated(channel) => {
3788            let update = proto::UpdateChannels {
3789                channel_invitations: vec![channel.to_proto()],
3790                ..Default::default()
3791            };
3792
3793            for connection_id in session
3794                .connection_pool()
3795                .await
3796                .user_connection_ids(member_id)
3797            {
3798                session.peer.send(connection_id, update.clone())?;
3799            }
3800        }
3801    }
3802
3803    response.send(proto::Ack {})?;
3804    Ok(())
3805}
3806
3807/// Change the name of a channel
3808async fn rename_channel(
3809    request: proto::RenameChannel,
3810    response: Response<proto::RenameChannel>,
3811    session: UserSession,
3812) -> Result<()> {
3813    let db = session.db().await;
3814    let channel_id = ChannelId::from_proto(request.channel_id);
3815    let channel_model = db
3816        .rename_channel(channel_id, session.user_id(), &request.name)
3817        .await?;
3818    let root_id = channel_model.root_id();
3819    let channel = Channel::from_model(channel_model);
3820
3821    response.send(proto::RenameChannelResponse {
3822        channel: Some(channel.to_proto()),
3823    })?;
3824
3825    let connection_pool = session.connection_pool().await;
3826    let update = proto::UpdateChannels {
3827        channels: vec![channel.to_proto()],
3828        ..Default::default()
3829    };
3830    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3831        if role.can_see_channel(channel.visibility) {
3832            session.peer.send(connection_id, update.clone())?;
3833        }
3834    }
3835
3836    Ok(())
3837}
3838
3839/// Move a channel to a new parent.
3840async fn move_channel(
3841    request: proto::MoveChannel,
3842    response: Response<proto::MoveChannel>,
3843    session: UserSession,
3844) -> Result<()> {
3845    let channel_id = ChannelId::from_proto(request.channel_id);
3846    let to = ChannelId::from_proto(request.to);
3847
3848    let (root_id, channels) = session
3849        .db()
3850        .await
3851        .move_channel(channel_id, to, session.user_id())
3852        .await?;
3853
3854    let connection_pool = session.connection_pool().await;
3855    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3856        let channels = channels
3857            .iter()
3858            .filter_map(|channel| {
3859                if role.can_see_channel(channel.visibility) {
3860                    Some(channel.to_proto())
3861                } else {
3862                    None
3863                }
3864            })
3865            .collect::<Vec<_>>();
3866        if channels.is_empty() {
3867            continue;
3868        }
3869
3870        let update = proto::UpdateChannels {
3871            channels,
3872            ..Default::default()
3873        };
3874
3875        session.peer.send(connection_id, update.clone())?;
3876    }
3877
3878    response.send(Ack {})?;
3879    Ok(())
3880}
3881
3882/// Get the list of channel members
3883async fn get_channel_members(
3884    request: proto::GetChannelMembers,
3885    response: Response<proto::GetChannelMembers>,
3886    session: UserSession,
3887) -> Result<()> {
3888    let db = session.db().await;
3889    let channel_id = ChannelId::from_proto(request.channel_id);
3890    let limit = if request.limit == 0 {
3891        u16::MAX as u64
3892    } else {
3893        request.limit
3894    };
3895    let (members, users) = db
3896        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3897        .await?;
3898    response.send(proto::GetChannelMembersResponse { members, users })?;
3899    Ok(())
3900}
3901
3902/// Accept or decline a channel invitation.
3903async fn respond_to_channel_invite(
3904    request: proto::RespondToChannelInvite,
3905    response: Response<proto::RespondToChannelInvite>,
3906    session: UserSession,
3907) -> Result<()> {
3908    let db = session.db().await;
3909    let channel_id = ChannelId::from_proto(request.channel_id);
3910    let RespondToChannelInvite {
3911        membership_update,
3912        notifications,
3913    } = db
3914        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3915        .await?;
3916
3917    let mut connection_pool = session.connection_pool().await;
3918    if let Some(membership_update) = membership_update {
3919        notify_membership_updated(
3920            &mut connection_pool,
3921            membership_update,
3922            session.user_id(),
3923            &session.peer,
3924        );
3925    } else {
3926        let update = proto::UpdateChannels {
3927            remove_channel_invitations: vec![channel_id.to_proto()],
3928            ..Default::default()
3929        };
3930
3931        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3932            session.peer.send(connection_id, update.clone())?;
3933        }
3934    };
3935
3936    send_notifications(&connection_pool, &session.peer, notifications);
3937
3938    response.send(proto::Ack {})?;
3939
3940    Ok(())
3941}
3942
3943/// Join the channels' room
3944async fn join_channel(
3945    request: proto::JoinChannel,
3946    response: Response<proto::JoinChannel>,
3947    session: UserSession,
3948) -> Result<()> {
3949    let channel_id = ChannelId::from_proto(request.channel_id);
3950    join_channel_internal(channel_id, Box::new(response), session).await
3951}
3952
3953trait JoinChannelInternalResponse {
3954    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3955}
3956impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3957    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3958        Response::<proto::JoinChannel>::send(self, result)
3959    }
3960}
3961impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3962    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3963        Response::<proto::JoinRoom>::send(self, result)
3964    }
3965}
3966
3967async fn join_channel_internal(
3968    channel_id: ChannelId,
3969    response: Box<impl JoinChannelInternalResponse>,
3970    session: UserSession,
3971) -> Result<()> {
3972    let joined_room = {
3973        let mut db = session.db().await;
3974        // If zed quits without leaving the room, and the user re-opens zed before the
3975        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3976        // room they were in.
3977        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3978            tracing::info!(
3979                stale_connection_id = %connection,
3980                "cleaning up stale connection",
3981            );
3982            drop(db);
3983            leave_room_for_session(&session, connection).await?;
3984            db = session.db().await;
3985        }
3986
3987        let (joined_room, membership_updated, role) = db
3988            .join_channel(channel_id, session.user_id(), session.connection_id)
3989            .await?;
3990
3991        let live_kit_connection_info =
3992            session
3993                .app_state
3994                .live_kit_client
3995                .as_ref()
3996                .and_then(|live_kit| {
3997                    let (can_publish, token) = if role == ChannelRole::Guest {
3998                        (
3999                            false,
4000                            live_kit
4001                                .guest_token(
4002                                    &joined_room.room.live_kit_room,
4003                                    &session.user_id().to_string(),
4004                                )
4005                                .trace_err()?,
4006                        )
4007                    } else {
4008                        (
4009                            true,
4010                            live_kit
4011                                .room_token(
4012                                    &joined_room.room.live_kit_room,
4013                                    &session.user_id().to_string(),
4014                                )
4015                                .trace_err()?,
4016                        )
4017                    };
4018
4019                    Some(LiveKitConnectionInfo {
4020                        server_url: live_kit.url().into(),
4021                        token,
4022                        can_publish,
4023                    })
4024                });
4025
4026        response.send(proto::JoinRoomResponse {
4027            room: Some(joined_room.room.clone()),
4028            channel_id: joined_room
4029                .channel
4030                .as_ref()
4031                .map(|channel| channel.id.to_proto()),
4032            live_kit_connection_info,
4033        })?;
4034
4035        let mut connection_pool = session.connection_pool().await;
4036        if let Some(membership_updated) = membership_updated {
4037            notify_membership_updated(
4038                &mut connection_pool,
4039                membership_updated,
4040                session.user_id(),
4041                &session.peer,
4042            );
4043        }
4044
4045        room_updated(&joined_room.room, &session.peer);
4046
4047        joined_room
4048    };
4049
4050    channel_updated(
4051        &joined_room
4052            .channel
4053            .ok_or_else(|| anyhow!("channel not returned"))?,
4054        &joined_room.room,
4055        &session.peer,
4056        &*session.connection_pool().await,
4057    );
4058
4059    update_user_contacts(session.user_id(), &session).await?;
4060    Ok(())
4061}
4062
4063/// Start editing the channel notes
4064async fn join_channel_buffer(
4065    request: proto::JoinChannelBuffer,
4066    response: Response<proto::JoinChannelBuffer>,
4067    session: UserSession,
4068) -> Result<()> {
4069    let db = session.db().await;
4070    let channel_id = ChannelId::from_proto(request.channel_id);
4071
4072    let open_response = db
4073        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4074        .await?;
4075
4076    let collaborators = open_response.collaborators.clone();
4077    response.send(open_response)?;
4078
4079    let update = UpdateChannelBufferCollaborators {
4080        channel_id: channel_id.to_proto(),
4081        collaborators: collaborators.clone(),
4082    };
4083    channel_buffer_updated(
4084        session.connection_id,
4085        collaborators
4086            .iter()
4087            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4088        &update,
4089        &session.peer,
4090    );
4091
4092    Ok(())
4093}
4094
4095/// Edit the channel notes
4096async fn update_channel_buffer(
4097    request: proto::UpdateChannelBuffer,
4098    session: UserSession,
4099) -> Result<()> {
4100    let db = session.db().await;
4101    let channel_id = ChannelId::from_proto(request.channel_id);
4102
4103    let (collaborators, epoch, version) = db
4104        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4105        .await?;
4106
4107    channel_buffer_updated(
4108        session.connection_id,
4109        collaborators.clone(),
4110        &proto::UpdateChannelBuffer {
4111            channel_id: channel_id.to_proto(),
4112            operations: request.operations,
4113        },
4114        &session.peer,
4115    );
4116
4117    let pool = &*session.connection_pool().await;
4118
4119    let non_collaborators =
4120        pool.channel_connection_ids(channel_id)
4121            .filter_map(|(connection_id, _)| {
4122                if collaborators.contains(&connection_id) {
4123                    None
4124                } else {
4125                    Some(connection_id)
4126                }
4127            });
4128
4129    broadcast(None, non_collaborators, |peer_id| {
4130        session.peer.send(
4131            peer_id,
4132            proto::UpdateChannels {
4133                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4134                    channel_id: channel_id.to_proto(),
4135                    epoch: epoch as u64,
4136                    version: version.clone(),
4137                }],
4138                ..Default::default()
4139            },
4140        )
4141    });
4142
4143    Ok(())
4144}
4145
4146/// Rejoin the channel notes after a connection blip
4147async fn rejoin_channel_buffers(
4148    request: proto::RejoinChannelBuffers,
4149    response: Response<proto::RejoinChannelBuffers>,
4150    session: UserSession,
4151) -> Result<()> {
4152    let db = session.db().await;
4153    let buffers = db
4154        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4155        .await?;
4156
4157    for rejoined_buffer in &buffers {
4158        let collaborators_to_notify = rejoined_buffer
4159            .buffer
4160            .collaborators
4161            .iter()
4162            .filter_map(|c| Some(c.peer_id?.into()));
4163        channel_buffer_updated(
4164            session.connection_id,
4165            collaborators_to_notify,
4166            &proto::UpdateChannelBufferCollaborators {
4167                channel_id: rejoined_buffer.buffer.channel_id,
4168                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4169            },
4170            &session.peer,
4171        );
4172    }
4173
4174    response.send(proto::RejoinChannelBuffersResponse {
4175        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4176    })?;
4177
4178    Ok(())
4179}
4180
4181/// Stop editing the channel notes
4182async fn leave_channel_buffer(
4183    request: proto::LeaveChannelBuffer,
4184    response: Response<proto::LeaveChannelBuffer>,
4185    session: UserSession,
4186) -> Result<()> {
4187    let db = session.db().await;
4188    let channel_id = ChannelId::from_proto(request.channel_id);
4189
4190    let left_buffer = db
4191        .leave_channel_buffer(channel_id, session.connection_id)
4192        .await?;
4193
4194    response.send(Ack {})?;
4195
4196    channel_buffer_updated(
4197        session.connection_id,
4198        left_buffer.connections,
4199        &proto::UpdateChannelBufferCollaborators {
4200            channel_id: channel_id.to_proto(),
4201            collaborators: left_buffer.collaborators,
4202        },
4203        &session.peer,
4204    );
4205
4206    Ok(())
4207}
4208
4209fn channel_buffer_updated<T: EnvelopedMessage>(
4210    sender_id: ConnectionId,
4211    collaborators: impl IntoIterator<Item = ConnectionId>,
4212    message: &T,
4213    peer: &Peer,
4214) {
4215    broadcast(Some(sender_id), collaborators, |peer_id| {
4216        peer.send(peer_id, message.clone())
4217    });
4218}
4219
4220fn send_notifications(
4221    connection_pool: &ConnectionPool,
4222    peer: &Peer,
4223    notifications: db::NotificationBatch,
4224) {
4225    for (user_id, notification) in notifications {
4226        for connection_id in connection_pool.user_connection_ids(user_id) {
4227            if let Err(error) = peer.send(
4228                connection_id,
4229                proto::AddNotification {
4230                    notification: Some(notification.clone()),
4231                },
4232            ) {
4233                tracing::error!(
4234                    "failed to send notification to {:?} {}",
4235                    connection_id,
4236                    error
4237                );
4238            }
4239        }
4240    }
4241}
4242
4243/// Send a message to the channel
4244async fn send_channel_message(
4245    request: proto::SendChannelMessage,
4246    response: Response<proto::SendChannelMessage>,
4247    session: UserSession,
4248) -> Result<()> {
4249    // Validate the message body.
4250    let body = request.body.trim().to_string();
4251    if body.len() > MAX_MESSAGE_LEN {
4252        return Err(anyhow!("message is too long"))?;
4253    }
4254    if body.is_empty() {
4255        return Err(anyhow!("message can't be blank"))?;
4256    }
4257
4258    // TODO: adjust mentions if body is trimmed
4259
4260    let timestamp = OffsetDateTime::now_utc();
4261    let nonce = request
4262        .nonce
4263        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4264
4265    let channel_id = ChannelId::from_proto(request.channel_id);
4266    let CreatedChannelMessage {
4267        message_id,
4268        participant_connection_ids,
4269        notifications,
4270    } = session
4271        .db()
4272        .await
4273        .create_channel_message(
4274            channel_id,
4275            session.user_id(),
4276            &body,
4277            &request.mentions,
4278            timestamp,
4279            nonce.clone().into(),
4280            request.reply_to_message_id.map(MessageId::from_proto),
4281        )
4282        .await?;
4283
4284    let message = proto::ChannelMessage {
4285        sender_id: session.user_id().to_proto(),
4286        id: message_id.to_proto(),
4287        body,
4288        mentions: request.mentions,
4289        timestamp: timestamp.unix_timestamp() as u64,
4290        nonce: Some(nonce),
4291        reply_to_message_id: request.reply_to_message_id,
4292        edited_at: None,
4293    };
4294    broadcast(
4295        Some(session.connection_id),
4296        participant_connection_ids.clone(),
4297        |connection| {
4298            session.peer.send(
4299                connection,
4300                proto::ChannelMessageSent {
4301                    channel_id: channel_id.to_proto(),
4302                    message: Some(message.clone()),
4303                },
4304            )
4305        },
4306    );
4307    response.send(proto::SendChannelMessageResponse {
4308        message: Some(message),
4309    })?;
4310
4311    let pool = &*session.connection_pool().await;
4312    let non_participants =
4313        pool.channel_connection_ids(channel_id)
4314            .filter_map(|(connection_id, _)| {
4315                if participant_connection_ids.contains(&connection_id) {
4316                    None
4317                } else {
4318                    Some(connection_id)
4319                }
4320            });
4321    broadcast(None, non_participants, |peer_id| {
4322        session.peer.send(
4323            peer_id,
4324            proto::UpdateChannels {
4325                latest_channel_message_ids: vec![proto::ChannelMessageId {
4326                    channel_id: channel_id.to_proto(),
4327                    message_id: message_id.to_proto(),
4328                }],
4329                ..Default::default()
4330            },
4331        )
4332    });
4333    send_notifications(pool, &session.peer, notifications);
4334
4335    Ok(())
4336}
4337
4338/// Delete a channel message
4339async fn remove_channel_message(
4340    request: proto::RemoveChannelMessage,
4341    response: Response<proto::RemoveChannelMessage>,
4342    session: UserSession,
4343) -> Result<()> {
4344    let channel_id = ChannelId::from_proto(request.channel_id);
4345    let message_id = MessageId::from_proto(request.message_id);
4346    let (connection_ids, existing_notification_ids) = session
4347        .db()
4348        .await
4349        .remove_channel_message(channel_id, message_id, session.user_id())
4350        .await?;
4351
4352    broadcast(
4353        Some(session.connection_id),
4354        connection_ids,
4355        move |connection| {
4356            session.peer.send(connection, request.clone())?;
4357
4358            for notification_id in &existing_notification_ids {
4359                session.peer.send(
4360                    connection,
4361                    proto::DeleteNotification {
4362                        notification_id: (*notification_id).to_proto(),
4363                    },
4364                )?;
4365            }
4366
4367            Ok(())
4368        },
4369    );
4370    response.send(proto::Ack {})?;
4371    Ok(())
4372}
4373
4374async fn update_channel_message(
4375    request: proto::UpdateChannelMessage,
4376    response: Response<proto::UpdateChannelMessage>,
4377    session: UserSession,
4378) -> Result<()> {
4379    let channel_id = ChannelId::from_proto(request.channel_id);
4380    let message_id = MessageId::from_proto(request.message_id);
4381    let updated_at = OffsetDateTime::now_utc();
4382    let UpdatedChannelMessage {
4383        message_id,
4384        participant_connection_ids,
4385        notifications,
4386        reply_to_message_id,
4387        timestamp,
4388        deleted_mention_notification_ids,
4389        updated_mention_notifications,
4390    } = session
4391        .db()
4392        .await
4393        .update_channel_message(
4394            channel_id,
4395            message_id,
4396            session.user_id(),
4397            request.body.as_str(),
4398            &request.mentions,
4399            updated_at,
4400        )
4401        .await?;
4402
4403    let nonce = request
4404        .nonce
4405        .clone()
4406        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4407
4408    let message = proto::ChannelMessage {
4409        sender_id: session.user_id().to_proto(),
4410        id: message_id.to_proto(),
4411        body: request.body.clone(),
4412        mentions: request.mentions.clone(),
4413        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4414        nonce: Some(nonce),
4415        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4416        edited_at: Some(updated_at.unix_timestamp() as u64),
4417    };
4418
4419    response.send(proto::Ack {})?;
4420
4421    let pool = &*session.connection_pool().await;
4422    broadcast(
4423        Some(session.connection_id),
4424        participant_connection_ids,
4425        |connection| {
4426            session.peer.send(
4427                connection,
4428                proto::ChannelMessageUpdate {
4429                    channel_id: channel_id.to_proto(),
4430                    message: Some(message.clone()),
4431                },
4432            )?;
4433
4434            for notification_id in &deleted_mention_notification_ids {
4435                session.peer.send(
4436                    connection,
4437                    proto::DeleteNotification {
4438                        notification_id: (*notification_id).to_proto(),
4439                    },
4440                )?;
4441            }
4442
4443            for notification in &updated_mention_notifications {
4444                session.peer.send(
4445                    connection,
4446                    proto::UpdateNotification {
4447                        notification: Some(notification.clone()),
4448                    },
4449                )?;
4450            }
4451
4452            Ok(())
4453        },
4454    );
4455
4456    send_notifications(pool, &session.peer, notifications);
4457
4458    Ok(())
4459}
4460
4461/// Mark a channel message as read
4462async fn acknowledge_channel_message(
4463    request: proto::AckChannelMessage,
4464    session: UserSession,
4465) -> Result<()> {
4466    let channel_id = ChannelId::from_proto(request.channel_id);
4467    let message_id = MessageId::from_proto(request.message_id);
4468    let notifications = session
4469        .db()
4470        .await
4471        .observe_channel_message(channel_id, session.user_id(), message_id)
4472        .await?;
4473    send_notifications(
4474        &*session.connection_pool().await,
4475        &session.peer,
4476        notifications,
4477    );
4478    Ok(())
4479}
4480
4481/// Mark a buffer version as synced
4482async fn acknowledge_buffer_version(
4483    request: proto::AckBufferOperation,
4484    session: UserSession,
4485) -> Result<()> {
4486    let buffer_id = BufferId::from_proto(request.buffer_id);
4487    session
4488        .db()
4489        .await
4490        .observe_buffer_version(
4491            buffer_id,
4492            session.user_id(),
4493            request.epoch as i32,
4494            &request.version,
4495        )
4496        .await?;
4497    Ok(())
4498}
4499
4500async fn count_language_model_tokens(
4501    request: proto::CountLanguageModelTokens,
4502    response: Response<proto::CountLanguageModelTokens>,
4503    session: Session,
4504    config: &Config,
4505) -> Result<()> {
4506    let Some(session) = session.for_user() else {
4507        return Err(anyhow!("user not found"))?;
4508    };
4509    authorize_access_to_legacy_llm_endpoints(&session).await?;
4510
4511    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4512        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4513        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4514    };
4515
4516    session
4517        .app_state
4518        .rate_limiter
4519        .check(&*rate_limit, session.user_id())
4520        .await?;
4521
4522    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4523        Some(proto::LanguageModelProvider::Google) => {
4524            let api_key = config
4525                .google_ai_api_key
4526                .as_ref()
4527                .context("no Google AI API key configured on the server")?;
4528            google_ai::count_tokens(
4529                session.http_client.as_ref(),
4530                google_ai::API_URL,
4531                api_key,
4532                serde_json::from_str(&request.request)?,
4533                None,
4534            )
4535            .await?
4536        }
4537        _ => return Err(anyhow!("unsupported provider"))?,
4538    };
4539
4540    response.send(proto::CountLanguageModelTokensResponse {
4541        token_count: result.total_tokens as u32,
4542    })?;
4543
4544    Ok(())
4545}
4546
4547struct ZedProCountLanguageModelTokensRateLimit;
4548
4549impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4550    fn capacity(&self) -> usize {
4551        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4552            .ok()
4553            .and_then(|v| v.parse().ok())
4554            .unwrap_or(600) // Picked arbitrarily
4555    }
4556
4557    fn refill_duration(&self) -> chrono::Duration {
4558        chrono::Duration::hours(1)
4559    }
4560
4561    fn db_name(&self) -> &'static str {
4562        "zed-pro:count-language-model-tokens"
4563    }
4564}
4565
4566struct FreeCountLanguageModelTokensRateLimit;
4567
4568impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4569    fn capacity(&self) -> usize {
4570        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4571            .ok()
4572            .and_then(|v| v.parse().ok())
4573            .unwrap_or(600 / 10) // Picked arbitrarily
4574    }
4575
4576    fn refill_duration(&self) -> chrono::Duration {
4577        chrono::Duration::hours(1)
4578    }
4579
4580    fn db_name(&self) -> &'static str {
4581        "free:count-language-model-tokens"
4582    }
4583}
4584
4585struct ZedProComputeEmbeddingsRateLimit;
4586
4587impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4588    fn capacity(&self) -> usize {
4589        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4590            .ok()
4591            .and_then(|v| v.parse().ok())
4592            .unwrap_or(5000) // Picked arbitrarily
4593    }
4594
4595    fn refill_duration(&self) -> chrono::Duration {
4596        chrono::Duration::hours(1)
4597    }
4598
4599    fn db_name(&self) -> &'static str {
4600        "zed-pro:compute-embeddings"
4601    }
4602}
4603
4604struct FreeComputeEmbeddingsRateLimit;
4605
4606impl RateLimit for FreeComputeEmbeddingsRateLimit {
4607    fn capacity(&self) -> usize {
4608        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4609            .ok()
4610            .and_then(|v| v.parse().ok())
4611            .unwrap_or(5000 / 10) // Picked arbitrarily
4612    }
4613
4614    fn refill_duration(&self) -> chrono::Duration {
4615        chrono::Duration::hours(1)
4616    }
4617
4618    fn db_name(&self) -> &'static str {
4619        "free:compute-embeddings"
4620    }
4621}
4622
4623async fn compute_embeddings(
4624    request: proto::ComputeEmbeddings,
4625    response: Response<proto::ComputeEmbeddings>,
4626    session: UserSession,
4627    api_key: Option<Arc<str>>,
4628) -> Result<()> {
4629    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4630    authorize_access_to_legacy_llm_endpoints(&session).await?;
4631
4632    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4633        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4634        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4635    };
4636
4637    session
4638        .app_state
4639        .rate_limiter
4640        .check(&*rate_limit, session.user_id())
4641        .await?;
4642
4643    let embeddings = match request.model.as_str() {
4644        "openai/text-embedding-3-small" => {
4645            open_ai::embed(
4646                session.http_client.as_ref(),
4647                OPEN_AI_API_URL,
4648                &api_key,
4649                OpenAiEmbeddingModel::TextEmbedding3Small,
4650                request.texts.iter().map(|text| text.as_str()),
4651            )
4652            .await?
4653        }
4654        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4655    };
4656
4657    let embeddings = request
4658        .texts
4659        .iter()
4660        .map(|text| {
4661            let mut hasher = sha2::Sha256::new();
4662            hasher.update(text.as_bytes());
4663            let result = hasher.finalize();
4664            result.to_vec()
4665        })
4666        .zip(
4667            embeddings
4668                .data
4669                .into_iter()
4670                .map(|embedding| embedding.embedding),
4671        )
4672        .collect::<HashMap<_, _>>();
4673
4674    let db = session.db().await;
4675    db.save_embeddings(&request.model, &embeddings)
4676        .await
4677        .context("failed to save embeddings")
4678        .trace_err();
4679
4680    response.send(proto::ComputeEmbeddingsResponse {
4681        embeddings: embeddings
4682            .into_iter()
4683            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4684            .collect(),
4685    })?;
4686    Ok(())
4687}
4688
4689async fn get_cached_embeddings(
4690    request: proto::GetCachedEmbeddings,
4691    response: Response<proto::GetCachedEmbeddings>,
4692    session: UserSession,
4693) -> Result<()> {
4694    authorize_access_to_legacy_llm_endpoints(&session).await?;
4695
4696    let db = session.db().await;
4697    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4698
4699    response.send(proto::GetCachedEmbeddingsResponse {
4700        embeddings: embeddings
4701            .into_iter()
4702            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4703            .collect(),
4704    })?;
4705    Ok(())
4706}
4707
4708/// This is leftover from before the LLM service.
4709///
4710/// The endpoints protected by this check will be moved there eventually.
4711async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4712    if session.is_staff() {
4713        Ok(())
4714    } else {
4715        Err(anyhow!("permission denied"))?
4716    }
4717}
4718
4719/// Get a Supermaven API key for the user
4720async fn get_supermaven_api_key(
4721    _request: proto::GetSupermavenApiKey,
4722    response: Response<proto::GetSupermavenApiKey>,
4723    session: UserSession,
4724) -> Result<()> {
4725    let user_id: String = session.user_id().to_string();
4726    if !session.is_staff() {
4727        return Err(anyhow!("supermaven not enabled for this account"))?;
4728    }
4729
4730    let email = session
4731        .email()
4732        .ok_or_else(|| anyhow!("user must have an email"))?;
4733
4734    let supermaven_admin_api = session
4735        .supermaven_client
4736        .as_ref()
4737        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4738
4739    let result = supermaven_admin_api
4740        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4741        .await?;
4742
4743    response.send(proto::GetSupermavenApiKeyResponse {
4744        api_key: result.api_key,
4745    })?;
4746
4747    Ok(())
4748}
4749
4750/// Start receiving chat updates for a channel
4751async fn join_channel_chat(
4752    request: proto::JoinChannelChat,
4753    response: Response<proto::JoinChannelChat>,
4754    session: UserSession,
4755) -> Result<()> {
4756    let channel_id = ChannelId::from_proto(request.channel_id);
4757
4758    let db = session.db().await;
4759    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4760        .await?;
4761    let messages = db
4762        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4763        .await?;
4764    response.send(proto::JoinChannelChatResponse {
4765        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4766        messages,
4767    })?;
4768    Ok(())
4769}
4770
4771/// Stop receiving chat updates for a channel
4772async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4773    let channel_id = ChannelId::from_proto(request.channel_id);
4774    session
4775        .db()
4776        .await
4777        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4778        .await?;
4779    Ok(())
4780}
4781
4782/// Retrieve the chat history for a channel
4783async fn get_channel_messages(
4784    request: proto::GetChannelMessages,
4785    response: Response<proto::GetChannelMessages>,
4786    session: UserSession,
4787) -> Result<()> {
4788    let channel_id = ChannelId::from_proto(request.channel_id);
4789    let messages = session
4790        .db()
4791        .await
4792        .get_channel_messages(
4793            channel_id,
4794            session.user_id(),
4795            MESSAGE_COUNT_PER_PAGE,
4796            Some(MessageId::from_proto(request.before_message_id)),
4797        )
4798        .await?;
4799    response.send(proto::GetChannelMessagesResponse {
4800        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4801        messages,
4802    })?;
4803    Ok(())
4804}
4805
4806/// Retrieve specific chat messages
4807async fn get_channel_messages_by_id(
4808    request: proto::GetChannelMessagesById,
4809    response: Response<proto::GetChannelMessagesById>,
4810    session: UserSession,
4811) -> Result<()> {
4812    let message_ids = request
4813        .message_ids
4814        .iter()
4815        .map(|id| MessageId::from_proto(*id))
4816        .collect::<Vec<_>>();
4817    let messages = session
4818        .db()
4819        .await
4820        .get_channel_messages_by_id(session.user_id(), &message_ids)
4821        .await?;
4822    response.send(proto::GetChannelMessagesResponse {
4823        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4824        messages,
4825    })?;
4826    Ok(())
4827}
4828
4829/// Retrieve the current users notifications
4830async fn get_notifications(
4831    request: proto::GetNotifications,
4832    response: Response<proto::GetNotifications>,
4833    session: UserSession,
4834) -> Result<()> {
4835    let notifications = session
4836        .db()
4837        .await
4838        .get_notifications(
4839            session.user_id(),
4840            NOTIFICATION_COUNT_PER_PAGE,
4841            request.before_id.map(db::NotificationId::from_proto),
4842        )
4843        .await?;
4844    response.send(proto::GetNotificationsResponse {
4845        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4846        notifications,
4847    })?;
4848    Ok(())
4849}
4850
4851/// Mark notifications as read
4852async fn mark_notification_as_read(
4853    request: proto::MarkNotificationRead,
4854    response: Response<proto::MarkNotificationRead>,
4855    session: UserSession,
4856) -> Result<()> {
4857    let database = &session.db().await;
4858    let notifications = database
4859        .mark_notification_as_read_by_id(
4860            session.user_id(),
4861            NotificationId::from_proto(request.notification_id),
4862        )
4863        .await?;
4864    send_notifications(
4865        &*session.connection_pool().await,
4866        &session.peer,
4867        notifications,
4868    );
4869    response.send(proto::Ack {})?;
4870    Ok(())
4871}
4872
4873/// Get the current users information
4874async fn get_private_user_info(
4875    _request: proto::GetPrivateUserInfo,
4876    response: Response<proto::GetPrivateUserInfo>,
4877    session: UserSession,
4878) -> Result<()> {
4879    let db = session.db().await;
4880
4881    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4882    let user = db
4883        .get_user_by_id(session.user_id())
4884        .await?
4885        .ok_or_else(|| anyhow!("user not found"))?;
4886    let flags = db.get_user_flags(session.user_id()).await?;
4887
4888    response.send(proto::GetPrivateUserInfoResponse {
4889        metrics_id,
4890        staff: user.admin,
4891        flags,
4892        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4893    })?;
4894    Ok(())
4895}
4896
4897/// Accept the terms of service (tos) on behalf of the current user
4898async fn accept_terms_of_service(
4899    _request: proto::AcceptTermsOfService,
4900    response: Response<proto::AcceptTermsOfService>,
4901    session: UserSession,
4902) -> Result<()> {
4903    let db = session.db().await;
4904
4905    let accepted_tos_at = Utc::now();
4906    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4907        .await?;
4908
4909    response.send(proto::AcceptTermsOfServiceResponse {
4910        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4911    })?;
4912    Ok(())
4913}
4914
4915/// The minimum account age an account must have in order to use the LLM service.
4916const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4917
4918async fn get_llm_api_token(
4919    _request: proto::GetLlmToken,
4920    response: Response<proto::GetLlmToken>,
4921    session: UserSession,
4922) -> Result<()> {
4923    let db = session.db().await;
4924
4925    let flags = db.get_user_flags(session.user_id()).await?;
4926    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4927    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4928
4929    if !session.is_staff() && !has_language_models_feature_flag {
4930        Err(anyhow!("permission denied"))?
4931    }
4932
4933    let user_id = session.user_id();
4934    let user = db
4935        .get_user_by_id(user_id)
4936        .await?
4937        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4938
4939    if user.accepted_tos_at.is_none() {
4940        Err(anyhow!("terms of service not accepted"))?
4941    }
4942
4943    let mut account_created_at = user.created_at;
4944    if let Some(github_created_at) = user.github_user_created_at {
4945        account_created_at = account_created_at.min(github_created_at);
4946    }
4947    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4948        Err(anyhow!("account too young"))?
4949    }
4950    let token = LlmTokenClaims::create(
4951        user.id,
4952        user.github_login.clone(),
4953        session.is_staff(),
4954        has_llm_closed_beta_feature_flag,
4955        session.current_plan(db).await?,
4956        &session.app_state.config,
4957    )?;
4958    response.send(proto::GetLlmTokenResponse { token })?;
4959    Ok(())
4960}
4961
4962fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4963    let message = match message {
4964        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4965        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4966        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4967        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4968        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4969            code: frame.code.into(),
4970            reason: frame.reason,
4971        })),
4972        // We should never receive a frame while reading the message, according
4973        // to the `tungstenite` maintainers:
4974        //
4975        // > It cannot occur when you read messages from the WebSocket, but it
4976        // > can be used when you want to send the raw frames (e.g. you want to
4977        // > send the frames to the WebSocket without composing the full message first).
4978        // >
4979        // > — https://github.com/snapview/tungstenite-rs/issues/268
4980        TungsteniteMessage::Frame(_) => {
4981            bail!("received an unexpected frame while reading the message")
4982        }
4983    };
4984
4985    Ok(message)
4986}
4987
4988fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4989    match message {
4990        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4991        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4992        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4993        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4994        AxumMessage::Close(frame) => {
4995            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4996                code: frame.code.into(),
4997                reason: frame.reason,
4998            }))
4999        }
5000    }
5001}
5002
5003fn notify_membership_updated(
5004    connection_pool: &mut ConnectionPool,
5005    result: MembershipUpdated,
5006    user_id: UserId,
5007    peer: &Peer,
5008) {
5009    for membership in &result.new_channels.channel_memberships {
5010        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5011    }
5012    for channel_id in &result.removed_channels {
5013        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5014    }
5015
5016    let user_channels_update = proto::UpdateUserChannels {
5017        channel_memberships: result
5018            .new_channels
5019            .channel_memberships
5020            .iter()
5021            .map(|cm| proto::ChannelMembership {
5022                channel_id: cm.channel_id.to_proto(),
5023                role: cm.role.into(),
5024            })
5025            .collect(),
5026        ..Default::default()
5027    };
5028
5029    let mut update = build_channels_update(result.new_channels);
5030    update.delete_channels = result
5031        .removed_channels
5032        .into_iter()
5033        .map(|id| id.to_proto())
5034        .collect();
5035    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5036
5037    for connection_id in connection_pool.user_connection_ids(user_id) {
5038        peer.send(connection_id, user_channels_update.clone())
5039            .trace_err();
5040        peer.send(connection_id, update.clone()).trace_err();
5041    }
5042}
5043
5044fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5045    proto::UpdateUserChannels {
5046        channel_memberships: channels
5047            .channel_memberships
5048            .iter()
5049            .map(|m| proto::ChannelMembership {
5050                channel_id: m.channel_id.to_proto(),
5051                role: m.role.into(),
5052            })
5053            .collect(),
5054        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5055        observed_channel_message_id: channels.observed_channel_messages.clone(),
5056    }
5057}
5058
5059fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5060    let mut update = proto::UpdateChannels::default();
5061
5062    for channel in channels.channels {
5063        update.channels.push(channel.to_proto());
5064    }
5065
5066    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5067    update.latest_channel_message_ids = channels.latest_channel_messages;
5068
5069    for (channel_id, participants) in channels.channel_participants {
5070        update
5071            .channel_participants
5072            .push(proto::ChannelParticipants {
5073                channel_id: channel_id.to_proto(),
5074                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5075            });
5076    }
5077
5078    for channel in channels.invited_channels {
5079        update.channel_invitations.push(channel.to_proto());
5080    }
5081
5082    update.hosted_projects = channels.hosted_projects;
5083    update
5084}
5085
5086fn build_initial_contacts_update(
5087    contacts: Vec<db::Contact>,
5088    pool: &ConnectionPool,
5089) -> proto::UpdateContacts {
5090    let mut update = proto::UpdateContacts::default();
5091
5092    for contact in contacts {
5093        match contact {
5094            db::Contact::Accepted { user_id, busy } => {
5095                update.contacts.push(contact_for_user(user_id, busy, pool));
5096            }
5097            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5098            db::Contact::Incoming { user_id } => {
5099                update
5100                    .incoming_requests
5101                    .push(proto::IncomingContactRequest {
5102                        requester_id: user_id.to_proto(),
5103                    })
5104            }
5105        }
5106    }
5107
5108    update
5109}
5110
5111fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5112    proto::Contact {
5113        user_id: user_id.to_proto(),
5114        online: pool.is_user_online(user_id),
5115        busy,
5116    }
5117}
5118
5119fn room_updated(room: &proto::Room, peer: &Peer) {
5120    broadcast(
5121        None,
5122        room.participants
5123            .iter()
5124            .filter_map(|participant| Some(participant.peer_id?.into())),
5125        |peer_id| {
5126            peer.send(
5127                peer_id,
5128                proto::RoomUpdated {
5129                    room: Some(room.clone()),
5130                },
5131            )
5132        },
5133    );
5134}
5135
5136fn channel_updated(
5137    channel: &db::channel::Model,
5138    room: &proto::Room,
5139    peer: &Peer,
5140    pool: &ConnectionPool,
5141) {
5142    let participants = room
5143        .participants
5144        .iter()
5145        .map(|p| p.user_id)
5146        .collect::<Vec<_>>();
5147
5148    broadcast(
5149        None,
5150        pool.channel_connection_ids(channel.root_id())
5151            .filter_map(|(channel_id, role)| {
5152                role.can_see_channel(channel.visibility)
5153                    .then_some(channel_id)
5154            }),
5155        |peer_id| {
5156            peer.send(
5157                peer_id,
5158                proto::UpdateChannels {
5159                    channel_participants: vec![proto::ChannelParticipants {
5160                        channel_id: channel.id.to_proto(),
5161                        participant_user_ids: participants.clone(),
5162                    }],
5163                    ..Default::default()
5164                },
5165            )
5166        },
5167    );
5168}
5169
5170async fn send_dev_server_projects_update(
5171    user_id: UserId,
5172    mut status: proto::DevServerProjectsUpdate,
5173    session: &Session,
5174) {
5175    let pool = session.connection_pool().await;
5176    for dev_server in &mut status.dev_servers {
5177        dev_server.status =
5178            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5179    }
5180    let connections = pool.user_connection_ids(user_id);
5181    for connection_id in connections {
5182        session.peer.send(connection_id, status.clone()).trace_err();
5183    }
5184}
5185
5186async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5187    let db = session.db().await;
5188
5189    let contacts = db.get_contacts(user_id).await?;
5190    let busy = db.is_user_busy(user_id).await?;
5191
5192    let pool = session.connection_pool().await;
5193    let updated_contact = contact_for_user(user_id, busy, &pool);
5194    for contact in contacts {
5195        if let db::Contact::Accepted {
5196            user_id: contact_user_id,
5197            ..
5198        } = contact
5199        {
5200            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5201                session
5202                    .peer
5203                    .send(
5204                        contact_conn_id,
5205                        proto::UpdateContacts {
5206                            contacts: vec![updated_contact.clone()],
5207                            remove_contacts: Default::default(),
5208                            incoming_requests: Default::default(),
5209                            remove_incoming_requests: Default::default(),
5210                            outgoing_requests: Default::default(),
5211                            remove_outgoing_requests: Default::default(),
5212                        },
5213                    )
5214                    .trace_err();
5215            }
5216        }
5217    }
5218    Ok(())
5219}
5220
5221async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5222    log::info!("lost dev server connection, unsharing projects");
5223    let project_ids = session
5224        .db()
5225        .await
5226        .get_stale_dev_server_projects(session.connection_id)
5227        .await?;
5228
5229    for project_id in project_ids {
5230        // not unshare re-checks the connection ids match, so we get away with no transaction
5231        unshare_project_internal(project_id, session.connection_id, None, session).await?;
5232    }
5233
5234    let user_id = session.dev_server().user_id;
5235    let update = session
5236        .db()
5237        .await
5238        .dev_server_projects_update(user_id)
5239        .await?;
5240
5241    send_dev_server_projects_update(user_id, update, session).await;
5242
5243    Ok(())
5244}
5245
5246async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5247    let mut contacts_to_update = HashSet::default();
5248
5249    let room_id;
5250    let canceled_calls_to_user_ids;
5251    let live_kit_room;
5252    let delete_live_kit_room;
5253    let room;
5254    let channel;
5255
5256    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5257        contacts_to_update.insert(session.user_id());
5258
5259        for project in left_room.left_projects.values() {
5260            project_left(project, session);
5261        }
5262
5263        room_id = RoomId::from_proto(left_room.room.id);
5264        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5265        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5266        delete_live_kit_room = left_room.deleted;
5267        room = mem::take(&mut left_room.room);
5268        channel = mem::take(&mut left_room.channel);
5269
5270        room_updated(&room, &session.peer);
5271    } else {
5272        return Ok(());
5273    }
5274
5275    if let Some(channel) = channel {
5276        channel_updated(
5277            &channel,
5278            &room,
5279            &session.peer,
5280            &*session.connection_pool().await,
5281        );
5282    }
5283
5284    {
5285        let pool = session.connection_pool().await;
5286        for canceled_user_id in canceled_calls_to_user_ids {
5287            for connection_id in pool.user_connection_ids(canceled_user_id) {
5288                session
5289                    .peer
5290                    .send(
5291                        connection_id,
5292                        proto::CallCanceled {
5293                            room_id: room_id.to_proto(),
5294                        },
5295                    )
5296                    .trace_err();
5297            }
5298            contacts_to_update.insert(canceled_user_id);
5299        }
5300    }
5301
5302    for contact_user_id in contacts_to_update {
5303        update_user_contacts(contact_user_id, session).await?;
5304    }
5305
5306    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5307        live_kit
5308            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5309            .await
5310            .trace_err();
5311
5312        if delete_live_kit_room {
5313            live_kit.delete_room(live_kit_room).await.trace_err();
5314        }
5315    }
5316
5317    Ok(())
5318}
5319
5320async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5321    let left_channel_buffers = session
5322        .db()
5323        .await
5324        .leave_channel_buffers(session.connection_id)
5325        .await?;
5326
5327    for left_buffer in left_channel_buffers {
5328        channel_buffer_updated(
5329            session.connection_id,
5330            left_buffer.connections,
5331            &proto::UpdateChannelBufferCollaborators {
5332                channel_id: left_buffer.channel_id.to_proto(),
5333                collaborators: left_buffer.collaborators,
5334            },
5335            &session.peer,
5336        );
5337    }
5338
5339    Ok(())
5340}
5341
5342fn project_left(project: &db::LeftProject, session: &UserSession) {
5343    for connection_id in &project.connection_ids {
5344        if project.should_unshare {
5345            session
5346                .peer
5347                .send(
5348                    *connection_id,
5349                    proto::UnshareProject {
5350                        project_id: project.id.to_proto(),
5351                    },
5352                )
5353                .trace_err();
5354        } else {
5355            session
5356                .peer
5357                .send(
5358                    *connection_id,
5359                    proto::RemoveProjectCollaborator {
5360                        project_id: project.id.to_proto(),
5361                        peer_id: Some(session.connection_id.into()),
5362                    },
5363                )
5364                .trace_err();
5365        }
5366    }
5367}
5368
5369pub trait ResultExt {
5370    type Ok;
5371
5372    fn trace_err(self) -> Option<Self::Ok>;
5373}
5374
5375impl<T, E> ResultExt for Result<T, E>
5376where
5377    E: std::fmt::Debug,
5378{
5379    type Ok = T;
5380
5381    #[track_caller]
5382    fn trace_err(self) -> Option<T> {
5383        match self {
5384            Ok(value) => Some(value),
5385            Err(error) => {
5386                tracing::error!("{:?}", error);
5387                None
5388            }
5389        }
5390    }
5391}