rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use sha2::Digest;
  40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  41
  42use futures::{
  43    channel::oneshot,
  44    future::{self, BoxFuture},
  45    stream::FuturesUnordered,
  46    FutureExt, SinkExt, StreamExt, TryStreamExt,
  47};
  48use http_client::IsahcHttpClient;
  49use prometheus::{register_int_gauge, IntGauge};
  50use rpc::{
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  56};
  57use semantic_version::SemanticVersion;
  58use serde::{Serialize, Serializer};
  59use std::{
  60    any::TypeId,
  61    future::Future,
  62    marker::PhantomData,
  63    mem,
  64    net::SocketAddr,
  65    ops::{Deref, DerefMut},
  66    rc::Rc,
  67    sync::{
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69        Arc, OnceLock,
  70    },
  71    time::{Duration, Instant},
  72};
  73use time::OffsetDateTime;
  74use tokio::sync::{watch, Semaphore};
  75use tower::ServiceBuilder;
  76use tracing::{
  77    field::{self},
  78    info_span, instrument, Instrument,
  79};
  80
  81use self::connection_pool::VersionedMessage;
  82
  83pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  84
  85// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  86pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  87
  88const MESSAGE_COUNT_PER_PAGE: usize = 100;
  89const MAX_MESSAGE_LEN: usize = 1024;
  90const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  91
  92type MessageHandler =
  93    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  94
  95struct Response<R> {
  96    peer: Arc<Peer>,
  97    receipt: Receipt<R>,
  98    responded: Arc<AtomicBool>,
  99}
 100
 101impl<R: RequestMessage> Response<R> {
 102    fn send(self, payload: R::Response) -> Result<()> {
 103        self.responded.store(true, SeqCst);
 104        self.peer.respond(self.receipt, payload)?;
 105        Ok(())
 106    }
 107}
 108
 109#[derive(Clone, Debug)]
 110pub enum Principal {
 111    User(User),
 112    Impersonated { user: User, admin: User },
 113    DevServer(dev_server::Model),
 114}
 115
 116impl Principal {
 117    fn update_span(&self, span: &tracing::Span) {
 118        match &self {
 119            Principal::User(user) => {
 120                span.record("user_id", &user.id.0);
 121                span.record("login", &user.github_login);
 122            }
 123            Principal::Impersonated { user, admin } => {
 124                span.record("user_id", &user.id.0);
 125                span.record("login", &user.github_login);
 126                span.record("impersonator", &admin.github_login);
 127            }
 128            Principal::DevServer(dev_server) => {
 129                span.record("dev_server_id", &dev_server.id.0);
 130            }
 131        }
 132    }
 133}
 134
 135#[derive(Clone)]
 136struct Session {
 137    principal: Principal,
 138    connection_id: ConnectionId,
 139    db: Arc<tokio::sync::Mutex<DbHandle>>,
 140    peer: Arc<Peer>,
 141    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 142    app_state: Arc<AppState>,
 143    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 144    http_client: Arc<IsahcHttpClient>,
 145    /// The GeoIP country code for the user.
 146    #[allow(unused)]
 147    geoip_country_code: Option<String>,
 148    _executor: Executor,
 149}
 150
 151impl Session {
 152    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.db.lock().await;
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        guard
 159    }
 160
 161    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 162        #[cfg(test)]
 163        tokio::task::yield_now().await;
 164        let guard = self.connection_pool.lock();
 165        ConnectionPoolGuard {
 166            guard,
 167            _not_send: PhantomData,
 168        }
 169    }
 170
 171    fn for_user(self) -> Option<UserSession> {
 172        UserSession::new(self)
 173    }
 174
 175    fn for_dev_server(self) -> Option<DevServerSession> {
 176        DevServerSession::new(self)
 177    }
 178
 179    fn user_id(&self) -> Option<UserId> {
 180        match &self.principal {
 181            Principal::User(user) => Some(user.id),
 182            Principal::Impersonated { user, .. } => Some(user.id),
 183            Principal::DevServer(_) => None,
 184        }
 185    }
 186
 187    fn is_staff(&self) -> bool {
 188        match &self.principal {
 189            Principal::User(user) => user.admin,
 190            Principal::Impersonated { .. } => true,
 191            Principal::DevServer(_) => false,
 192        }
 193    }
 194
 195    pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
 196        if self.is_staff() {
 197            return Ok(proto::Plan::ZedPro);
 198        }
 199
 200        let Some(user_id) = self.user_id() else {
 201            return Ok(proto::Plan::Free);
 202        };
 203
 204        let db = self.db().await;
 205        if db.has_active_billing_subscription(user_id).await? {
 206            Ok(proto::Plan::ZedPro)
 207        } else {
 208            Ok(proto::Plan::Free)
 209        }
 210    }
 211
 212    fn dev_server_id(&self) -> Option<DevServerId> {
 213        match &self.principal {
 214            Principal::User(_) | Principal::Impersonated { .. } => None,
 215            Principal::DevServer(dev_server) => Some(dev_server.id),
 216        }
 217    }
 218
 219    fn principal_id(&self) -> PrincipalId {
 220        match &self.principal {
 221            Principal::User(user) => PrincipalId::UserId(user.id),
 222            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 223            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 224        }
 225    }
 226}
 227
 228impl Debug for Session {
 229    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 230        let mut result = f.debug_struct("Session");
 231        match &self.principal {
 232            Principal::User(user) => {
 233                result.field("user", &user.github_login);
 234            }
 235            Principal::Impersonated { user, admin } => {
 236                result.field("user", &user.github_login);
 237                result.field("impersonator", &admin.github_login);
 238            }
 239            Principal::DevServer(dev_server) => {
 240                result.field("dev_server", &dev_server.id);
 241            }
 242        }
 243        result.field("connection_id", &self.connection_id).finish()
 244    }
 245}
 246
 247struct UserSession(Session);
 248
 249impl UserSession {
 250    pub fn new(s: Session) -> Option<Self> {
 251        s.user_id().map(|_| UserSession(s))
 252    }
 253    pub fn user_id(&self) -> UserId {
 254        self.0.user_id().unwrap()
 255    }
 256
 257    pub fn email(&self) -> Option<String> {
 258        match &self.0.principal {
 259            Principal::User(user) => user.email_address.clone(),
 260            Principal::Impersonated { user, .. } => user.email_address.clone(),
 261            Principal::DevServer(..) => None,
 262        }
 263    }
 264}
 265
 266impl Deref for UserSession {
 267    type Target = Session;
 268
 269    fn deref(&self) -> &Self::Target {
 270        &self.0
 271    }
 272}
 273impl DerefMut for UserSession {
 274    fn deref_mut(&mut self) -> &mut Self::Target {
 275        &mut self.0
 276    }
 277}
 278
 279struct DevServerSession(Session);
 280
 281impl DevServerSession {
 282    pub fn new(s: Session) -> Option<Self> {
 283        s.dev_server_id().map(|_| DevServerSession(s))
 284    }
 285    pub fn dev_server_id(&self) -> DevServerId {
 286        self.0.dev_server_id().unwrap()
 287    }
 288
 289    fn dev_server(&self) -> &dev_server::Model {
 290        match &self.0.principal {
 291            Principal::DevServer(dev_server) => dev_server,
 292            _ => unreachable!(),
 293        }
 294    }
 295}
 296
 297impl Deref for DevServerSession {
 298    type Target = Session;
 299
 300    fn deref(&self) -> &Self::Target {
 301        &self.0
 302    }
 303}
 304impl DerefMut for DevServerSession {
 305    fn deref_mut(&mut self) -> &mut Self::Target {
 306        &mut self.0
 307    }
 308}
 309
 310fn user_handler<M: RequestMessage, Fut>(
 311    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 312) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 313where
 314    Fut: Send + Future<Output = Result<()>>,
 315{
 316    let handler = Arc::new(handler);
 317    move |message, response, session| {
 318        let handler = handler.clone();
 319        Box::pin(async move {
 320            if let Some(user_session) = session.for_user() {
 321                Ok(handler(message, response, user_session).await?)
 322            } else {
 323                Err(Error::Internal(anyhow!(
 324                    "must be a user to call {}",
 325                    M::NAME
 326                )))
 327            }
 328        })
 329    }
 330}
 331
 332fn dev_server_handler<M: RequestMessage, Fut>(
 333    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 334) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 335where
 336    Fut: Send + Future<Output = Result<()>>,
 337{
 338    let handler = Arc::new(handler);
 339    move |message, response, session| {
 340        let handler = handler.clone();
 341        Box::pin(async move {
 342            if let Some(dev_server_session) = session.for_dev_server() {
 343                Ok(handler(message, response, dev_server_session).await?)
 344            } else {
 345                Err(Error::Internal(anyhow!(
 346                    "must be a dev server to call {}",
 347                    M::NAME
 348                )))
 349            }
 350        })
 351    }
 352}
 353
 354fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 355    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 356) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 357where
 358    InnertRetFut: Send + Future<Output = Result<()>>,
 359{
 360    let handler = Arc::new(handler);
 361    move |message, session| {
 362        let handler = handler.clone();
 363        Box::pin(async move {
 364            if let Some(user_session) = session.for_user() {
 365                Ok(handler(message, user_session).await?)
 366            } else {
 367                Err(Error::Internal(anyhow!(
 368                    "must be a user to call {}",
 369                    M::NAME
 370                )))
 371            }
 372        })
 373    }
 374}
 375
 376struct DbHandle(Arc<Database>);
 377
 378impl Deref for DbHandle {
 379    type Target = Database;
 380
 381    fn deref(&self) -> &Self::Target {
 382        self.0.as_ref()
 383    }
 384}
 385
 386pub struct Server {
 387    id: parking_lot::Mutex<ServerId>,
 388    peer: Arc<Peer>,
 389    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 390    app_state: Arc<AppState>,
 391    handlers: HashMap<TypeId, MessageHandler>,
 392    teardown: watch::Sender<bool>,
 393}
 394
 395pub(crate) struct ConnectionPoolGuard<'a> {
 396    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 397    _not_send: PhantomData<Rc<()>>,
 398}
 399
 400#[derive(Serialize)]
 401pub struct ServerSnapshot<'a> {
 402    peer: &'a Peer,
 403    #[serde(serialize_with = "serialize_deref")]
 404    connection_pool: ConnectionPoolGuard<'a>,
 405}
 406
 407pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 408where
 409    S: Serializer,
 410    T: Deref<Target = U>,
 411    U: Serialize,
 412{
 413    Serialize::serialize(value.deref(), serializer)
 414}
 415
 416impl Server {
 417    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 418        let mut server = Self {
 419            id: parking_lot::Mutex::new(id),
 420            peer: Peer::new(id.0 as u32),
 421            app_state: app_state.clone(),
 422            connection_pool: Default::default(),
 423            handlers: Default::default(),
 424            teardown: watch::channel(false).0,
 425        };
 426
 427        server
 428            .add_request_handler(ping)
 429            .add_request_handler(user_handler(create_room))
 430            .add_request_handler(user_handler(join_room))
 431            .add_request_handler(user_handler(rejoin_room))
 432            .add_request_handler(user_handler(leave_room))
 433            .add_request_handler(user_handler(set_room_participant_role))
 434            .add_request_handler(user_handler(call))
 435            .add_request_handler(user_handler(cancel_call))
 436            .add_message_handler(user_message_handler(decline_call))
 437            .add_request_handler(user_handler(update_participant_location))
 438            .add_request_handler(user_handler(share_project))
 439            .add_message_handler(unshare_project)
 440            .add_request_handler(user_handler(join_project))
 441            .add_request_handler(user_handler(join_hosted_project))
 442            .add_request_handler(user_handler(rejoin_dev_server_projects))
 443            .add_request_handler(user_handler(create_dev_server_project))
 444            .add_request_handler(user_handler(update_dev_server_project))
 445            .add_request_handler(user_handler(delete_dev_server_project))
 446            .add_request_handler(user_handler(create_dev_server))
 447            .add_request_handler(user_handler(regenerate_dev_server_token))
 448            .add_request_handler(user_handler(rename_dev_server))
 449            .add_request_handler(user_handler(delete_dev_server))
 450            .add_request_handler(user_handler(list_remote_directory))
 451            .add_request_handler(dev_server_handler(share_dev_server_project))
 452            .add_request_handler(dev_server_handler(shutdown_dev_server))
 453            .add_request_handler(dev_server_handler(reconnect_dev_server))
 454            .add_message_handler(user_message_handler(leave_project))
 455            .add_request_handler(update_project)
 456            .add_request_handler(update_worktree)
 457            .add_message_handler(start_language_server)
 458            .add_message_handler(update_language_server)
 459            .add_message_handler(update_diagnostic_summary)
 460            .add_message_handler(update_worktree_settings)
 461            .add_request_handler(user_handler(
 462                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_project_request_for_owner::<proto::TaskTemplates>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::GetHover>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetDefinition>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetTypeDefinition>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::GetReferences>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_read_only_project_request::<proto::SearchProject>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::GetProjectSymbols>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_read_only_project_request::<proto::OpenBufferById>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_read_only_project_request::<proto::InlayHints>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_read_only_project_request::<proto::OpenBufferByPath>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::GetCompletions>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::GetCodeActions>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::ApplyCodeAction>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::PrepareRename>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::PerformRename>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::ReloadBuffers>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::FormatBuffers>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::CreateProjectEntry>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::RenameProjectEntry>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_mutating_project_request::<proto::CopyProjectEntry>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 547            ))
 548            .add_request_handler(user_handler(
 549                forward_mutating_project_request::<proto::OnTypeFormatting>,
 550            ))
 551            .add_request_handler(user_handler(
 552                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 553            ))
 554            .add_request_handler(user_handler(
 555                forward_mutating_project_request::<proto::BlameBuffer>,
 556            ))
 557            .add_request_handler(user_handler(
 558                forward_mutating_project_request::<proto::MultiLspQuery>,
 559            ))
 560            .add_request_handler(user_handler(
 561                forward_mutating_project_request::<proto::RestartLanguageServers>,
 562            ))
 563            .add_request_handler(user_handler(
 564                forward_mutating_project_request::<proto::LinkedEditingRange>,
 565            ))
 566            .add_message_handler(create_buffer_for_peer)
 567            .add_request_handler(update_buffer)
 568            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 569            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 571            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 572            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 573            .add_request_handler(get_users)
 574            .add_request_handler(user_handler(fuzzy_search_users))
 575            .add_request_handler(user_handler(request_contact))
 576            .add_request_handler(user_handler(remove_contact))
 577            .add_request_handler(user_handler(respond_to_contact_request))
 578            .add_message_handler(subscribe_to_channels)
 579            .add_request_handler(user_handler(create_channel))
 580            .add_request_handler(user_handler(delete_channel))
 581            .add_request_handler(user_handler(invite_channel_member))
 582            .add_request_handler(user_handler(remove_channel_member))
 583            .add_request_handler(user_handler(set_channel_member_role))
 584            .add_request_handler(user_handler(set_channel_visibility))
 585            .add_request_handler(user_handler(rename_channel))
 586            .add_request_handler(user_handler(join_channel_buffer))
 587            .add_request_handler(user_handler(leave_channel_buffer))
 588            .add_message_handler(user_message_handler(update_channel_buffer))
 589            .add_request_handler(user_handler(rejoin_channel_buffers))
 590            .add_request_handler(user_handler(get_channel_members))
 591            .add_request_handler(user_handler(respond_to_channel_invite))
 592            .add_request_handler(user_handler(join_channel))
 593            .add_request_handler(user_handler(join_channel_chat))
 594            .add_message_handler(user_message_handler(leave_channel_chat))
 595            .add_request_handler(user_handler(send_channel_message))
 596            .add_request_handler(user_handler(remove_channel_message))
 597            .add_request_handler(user_handler(update_channel_message))
 598            .add_request_handler(user_handler(get_channel_messages))
 599            .add_request_handler(user_handler(get_channel_messages_by_id))
 600            .add_request_handler(user_handler(get_notifications))
 601            .add_request_handler(user_handler(mark_notification_as_read))
 602            .add_request_handler(user_handler(move_channel))
 603            .add_request_handler(user_handler(follow))
 604            .add_message_handler(user_message_handler(unfollow))
 605            .add_message_handler(user_message_handler(update_followers))
 606            .add_request_handler(user_handler(get_private_user_info))
 607            .add_request_handler(user_handler(get_llm_api_token))
 608            .add_request_handler(user_handler(accept_terms_of_service))
 609            .add_message_handler(user_message_handler(acknowledge_channel_message))
 610            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 611            .add_request_handler(user_handler(get_supermaven_api_key))
 612            .add_request_handler(user_handler(
 613                forward_mutating_project_request::<proto::OpenContext>,
 614            ))
 615            .add_request_handler(user_handler(
 616                forward_mutating_project_request::<proto::CreateContext>,
 617            ))
 618            .add_request_handler(user_handler(
 619                forward_mutating_project_request::<proto::SynchronizeContexts>,
 620            ))
 621            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 622            .add_message_handler(update_context)
 623            .add_request_handler({
 624                let app_state = app_state.clone();
 625                move |request, response, session| {
 626                    let app_state = app_state.clone();
 627                    async move {
 628                        count_language_model_tokens(request, response, session, &app_state.config)
 629                            .await
 630                    }
 631                }
 632            })
 633            .add_request_handler({
 634                user_handler(move |request, response, session| {
 635                    get_cached_embeddings(request, response, session)
 636                })
 637            })
 638            .add_request_handler({
 639                let app_state = app_state.clone();
 640                user_handler(move |request, response, session| {
 641                    compute_embeddings(
 642                        request,
 643                        response,
 644                        session,
 645                        app_state.config.openai_api_key.clone(),
 646                    )
 647                })
 648            });
 649
 650        Arc::new(server)
 651    }
 652
 653    pub async fn start(&self) -> Result<()> {
 654        let server_id = *self.id.lock();
 655        let app_state = self.app_state.clone();
 656        let peer = self.peer.clone();
 657        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 658        let pool = self.connection_pool.clone();
 659        let live_kit_client = self.app_state.live_kit_client.clone();
 660
 661        let span = info_span!("start server");
 662        self.app_state.executor.spawn_detached(
 663            async move {
 664                tracing::info!("waiting for cleanup timeout");
 665                timeout.await;
 666                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 667                if let Some((room_ids, channel_ids)) = app_state
 668                    .db
 669                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 670                    .await
 671                    .trace_err()
 672                {
 673                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 674                    tracing::info!(
 675                        stale_channel_buffer_count = channel_ids.len(),
 676                        "retrieved stale channel buffers"
 677                    );
 678
 679                    for channel_id in channel_ids {
 680                        if let Some(refreshed_channel_buffer) = app_state
 681                            .db
 682                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 683                            .await
 684                            .trace_err()
 685                        {
 686                            for connection_id in refreshed_channel_buffer.connection_ids {
 687                                peer.send(
 688                                    connection_id,
 689                                    proto::UpdateChannelBufferCollaborators {
 690                                        channel_id: channel_id.to_proto(),
 691                                        collaborators: refreshed_channel_buffer
 692                                            .collaborators
 693                                            .clone(),
 694                                    },
 695                                )
 696                                .trace_err();
 697                            }
 698                        }
 699                    }
 700
 701                    for room_id in room_ids {
 702                        let mut contacts_to_update = HashSet::default();
 703                        let mut canceled_calls_to_user_ids = Vec::new();
 704                        let mut live_kit_room = String::new();
 705                        let mut delete_live_kit_room = false;
 706
 707                        if let Some(mut refreshed_room) = app_state
 708                            .db
 709                            .clear_stale_room_participants(room_id, server_id)
 710                            .await
 711                            .trace_err()
 712                        {
 713                            tracing::info!(
 714                                room_id = room_id.0,
 715                                new_participant_count = refreshed_room.room.participants.len(),
 716                                "refreshed room"
 717                            );
 718                            room_updated(&refreshed_room.room, &peer);
 719                            if let Some(channel) = refreshed_room.channel.as_ref() {
 720                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 721                            }
 722                            contacts_to_update
 723                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 724                            contacts_to_update
 725                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 726                            canceled_calls_to_user_ids =
 727                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 728                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 729                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 730                        }
 731
 732                        {
 733                            let pool = pool.lock();
 734                            for canceled_user_id in canceled_calls_to_user_ids {
 735                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 736                                    peer.send(
 737                                        connection_id,
 738                                        proto::CallCanceled {
 739                                            room_id: room_id.to_proto(),
 740                                        },
 741                                    )
 742                                    .trace_err();
 743                                }
 744                            }
 745                        }
 746
 747                        for user_id in contacts_to_update {
 748                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 749                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 750                            if let Some((busy, contacts)) = busy.zip(contacts) {
 751                                let pool = pool.lock();
 752                                let updated_contact = contact_for_user(user_id, busy, &pool);
 753                                for contact in contacts {
 754                                    if let db::Contact::Accepted {
 755                                        user_id: contact_user_id,
 756                                        ..
 757                                    } = contact
 758                                    {
 759                                        for contact_conn_id in
 760                                            pool.user_connection_ids(contact_user_id)
 761                                        {
 762                                            peer.send(
 763                                                contact_conn_id,
 764                                                proto::UpdateContacts {
 765                                                    contacts: vec![updated_contact.clone()],
 766                                                    remove_contacts: Default::default(),
 767                                                    incoming_requests: Default::default(),
 768                                                    remove_incoming_requests: Default::default(),
 769                                                    outgoing_requests: Default::default(),
 770                                                    remove_outgoing_requests: Default::default(),
 771                                                },
 772                                            )
 773                                            .trace_err();
 774                                        }
 775                                    }
 776                                }
 777                            }
 778                        }
 779
 780                        if let Some(live_kit) = live_kit_client.as_ref() {
 781                            if delete_live_kit_room {
 782                                live_kit.delete_room(live_kit_room).await.trace_err();
 783                            }
 784                        }
 785                    }
 786                }
 787
 788                app_state
 789                    .db
 790                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 791                    .await
 792                    .trace_err();
 793            }
 794            .instrument(span),
 795        );
 796        Ok(())
 797    }
 798
 799    pub fn teardown(&self) {
 800        self.peer.teardown();
 801        self.connection_pool.lock().reset();
 802        let _ = self.teardown.send(true);
 803    }
 804
 805    #[cfg(test)]
 806    pub fn reset(&self, id: ServerId) {
 807        self.teardown();
 808        *self.id.lock() = id;
 809        self.peer.reset(id.0 as u32);
 810        let _ = self.teardown.send(false);
 811    }
 812
 813    #[cfg(test)]
 814    pub fn id(&self) -> ServerId {
 815        *self.id.lock()
 816    }
 817
 818    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 819    where
 820        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 821        Fut: 'static + Send + Future<Output = Result<()>>,
 822        M: EnvelopedMessage,
 823    {
 824        let prev_handler = self.handlers.insert(
 825            TypeId::of::<M>(),
 826            Box::new(move |envelope, session| {
 827                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 828                let received_at = envelope.received_at;
 829                tracing::info!("message received");
 830                let start_time = Instant::now();
 831                let future = (handler)(*envelope, session);
 832                async move {
 833                    let result = future.await;
 834                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 835                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 836                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 837                    let payload_type = M::NAME;
 838
 839                    match result {
 840                        Err(error) => {
 841                            tracing::error!(
 842                                ?error,
 843                                total_duration_ms,
 844                                processing_duration_ms,
 845                                queue_duration_ms,
 846                                payload_type,
 847                                "error handling message"
 848                            )
 849                        }
 850                        Ok(()) => tracing::info!(
 851                            total_duration_ms,
 852                            processing_duration_ms,
 853                            queue_duration_ms,
 854                            "finished handling message"
 855                        ),
 856                    }
 857                }
 858                .boxed()
 859            }),
 860        );
 861        if prev_handler.is_some() {
 862            panic!("registered a handler for the same message twice");
 863        }
 864        self
 865    }
 866
 867    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 868    where
 869        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 870        Fut: 'static + Send + Future<Output = Result<()>>,
 871        M: EnvelopedMessage,
 872    {
 873        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 874        self
 875    }
 876
 877    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 878    where
 879        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 880        Fut: Send + Future<Output = Result<()>>,
 881        M: RequestMessage,
 882    {
 883        let handler = Arc::new(handler);
 884        self.add_handler(move |envelope, session| {
 885            let receipt = envelope.receipt();
 886            let handler = handler.clone();
 887            async move {
 888                let peer = session.peer.clone();
 889                let responded = Arc::new(AtomicBool::default());
 890                let response = Response {
 891                    peer: peer.clone(),
 892                    responded: responded.clone(),
 893                    receipt,
 894                };
 895                match (handler)(envelope.payload, response, session).await {
 896                    Ok(()) => {
 897                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 898                            Ok(())
 899                        } else {
 900                            Err(anyhow!("handler did not send a response"))?
 901                        }
 902                    }
 903                    Err(error) => {
 904                        let proto_err = match &error {
 905                            Error::Internal(err) => err.to_proto(),
 906                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 907                        };
 908                        peer.respond_with_error(receipt, proto_err)?;
 909                        Err(error)
 910                    }
 911                }
 912            }
 913        })
 914    }
 915
 916    #[allow(clippy::too_many_arguments)]
 917    pub fn handle_connection(
 918        self: &Arc<Self>,
 919        connection: Connection,
 920        address: String,
 921        principal: Principal,
 922        zed_version: ZedVersion,
 923        geoip_country_code: Option<String>,
 924        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 925        executor: Executor,
 926    ) -> impl Future<Output = ()> {
 927        let this = self.clone();
 928        let span = info_span!("handle connection", %address,
 929            connection_id=field::Empty,
 930            user_id=field::Empty,
 931            login=field::Empty,
 932            impersonator=field::Empty,
 933            dev_server_id=field::Empty,
 934            geoip_country_code=field::Empty
 935        );
 936        principal.update_span(&span);
 937        if let Some(country_code) = geoip_country_code.as_ref() {
 938            span.record("geoip_country_code", country_code);
 939        }
 940
 941        let mut teardown = self.teardown.subscribe();
 942        async move {
 943            if *teardown.borrow() {
 944                tracing::error!("server is tearing down");
 945                return
 946            }
 947            let (connection_id, handle_io, mut incoming_rx) = this
 948                .peer
 949                .add_connection(connection, {
 950                    let executor = executor.clone();
 951                    move |duration| executor.sleep(duration)
 952                });
 953            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 954
 955            tracing::info!("connection opened");
 956
 957            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 958            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 959                Ok(http_client) => Arc::new(http_client),
 960                Err(error) => {
 961                    tracing::error!(?error, "failed to create HTTP client");
 962                    return;
 963                }
 964            };
 965
 966            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 967                Some(Arc::new(SupermavenAdminApi::new(
 968                    supermaven_admin_api_key.to_string(),
 969                    http_client.clone(),
 970                )))
 971            } else {
 972                None
 973            };
 974
 975            let session = Session {
 976                principal: principal.clone(),
 977                connection_id,
 978                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 979                peer: this.peer.clone(),
 980                connection_pool: this.connection_pool.clone(),
 981                app_state: this.app_state.clone(),
 982                http_client,
 983                geoip_country_code,
 984                _executor: executor.clone(),
 985                supermaven_client,
 986            };
 987
 988            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 989                tracing::error!(?error, "failed to send initial client update");
 990                return;
 991            }
 992
 993            let handle_io = handle_io.fuse();
 994            futures::pin_mut!(handle_io);
 995
 996            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 997            // This prevents deadlocks when e.g., client A performs a request to client B and
 998            // client B performs a request to client A. If both clients stop processing further
 999            // messages until their respective request completes, they won't have a chance to
1000            // respond to the other client's request and cause a deadlock.
1001            //
1002            // This arrangement ensures we will attempt to process earlier messages first, but fall
1003            // back to processing messages arrived later in the spirit of making progress.
1004            let mut foreground_message_handlers = FuturesUnordered::new();
1005            let concurrent_handlers = Arc::new(Semaphore::new(256));
1006            loop {
1007                let next_message = async {
1008                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1009                    let message = incoming_rx.next().await;
1010                    (permit, message)
1011                }.fuse();
1012                futures::pin_mut!(next_message);
1013                futures::select_biased! {
1014                    _ = teardown.changed().fuse() => return,
1015                    result = handle_io => {
1016                        if let Err(error) = result {
1017                            tracing::error!(?error, "error handling I/O");
1018                        }
1019                        break;
1020                    }
1021                    _ = foreground_message_handlers.next() => {}
1022                    next_message = next_message => {
1023                        let (permit, message) = next_message;
1024                        if let Some(message) = message {
1025                            let type_name = message.payload_type_name();
1026                            // note: we copy all the fields from the parent span so we can query them in the logs.
1027                            // (https://github.com/tokio-rs/tracing/issues/2670).
1028                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1029                                user_id=field::Empty,
1030                                login=field::Empty,
1031                                impersonator=field::Empty,
1032                                dev_server_id=field::Empty
1033                            );
1034                            principal.update_span(&span);
1035                            let span_enter = span.enter();
1036                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1037                                let is_background = message.is_background();
1038                                let handle_message = (handler)(message, session.clone());
1039                                drop(span_enter);
1040
1041                                let handle_message = async move {
1042                                    handle_message.await;
1043                                    drop(permit);
1044                                }.instrument(span);
1045                                if is_background {
1046                                    executor.spawn_detached(handle_message);
1047                                } else {
1048                                    foreground_message_handlers.push(handle_message);
1049                                }
1050                            } else {
1051                                tracing::error!("no message handler");
1052                            }
1053                        } else {
1054                            tracing::info!("connection closed");
1055                            break;
1056                        }
1057                    }
1058                }
1059            }
1060
1061            drop(foreground_message_handlers);
1062            tracing::info!("signing out");
1063            if let Err(error) = connection_lost(session, teardown, executor).await {
1064                tracing::error!(?error, "error signing out");
1065            }
1066
1067        }.instrument(span)
1068    }
1069
1070    async fn send_initial_client_update(
1071        &self,
1072        connection_id: ConnectionId,
1073        principal: &Principal,
1074        zed_version: ZedVersion,
1075        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1076        session: &Session,
1077    ) -> Result<()> {
1078        self.peer.send(
1079            connection_id,
1080            proto::Hello {
1081                peer_id: Some(connection_id.into()),
1082            },
1083        )?;
1084        tracing::info!("sent hello message");
1085        if let Some(send_connection_id) = send_connection_id.take() {
1086            let _ = send_connection_id.send(connection_id);
1087        }
1088
1089        match principal {
1090            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1091                if !user.connected_once {
1092                    self.peer.send(connection_id, proto::ShowContacts {})?;
1093                    self.app_state
1094                        .db
1095                        .set_user_connected_once(user.id, true)
1096                        .await?;
1097                }
1098
1099                update_user_plan(user.id, session).await?;
1100
1101                let (contacts, dev_server_projects) = future::try_join(
1102                    self.app_state.db.get_contacts(user.id),
1103                    self.app_state.db.dev_server_projects_update(user.id),
1104                )
1105                .await?;
1106
1107                {
1108                    let mut pool = self.connection_pool.lock();
1109                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1110                    self.peer.send(
1111                        connection_id,
1112                        build_initial_contacts_update(contacts, &pool),
1113                    )?;
1114                }
1115
1116                if should_auto_subscribe_to_channels(zed_version) {
1117                    subscribe_user_to_channels(user.id, session).await?;
1118                }
1119
1120                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1121
1122                if let Some(incoming_call) =
1123                    self.app_state.db.incoming_call_for_user(user.id).await?
1124                {
1125                    self.peer.send(connection_id, incoming_call)?;
1126                }
1127
1128                update_user_contacts(user.id, &session).await?;
1129            }
1130            Principal::DevServer(dev_server) => {
1131                {
1132                    let mut pool = self.connection_pool.lock();
1133                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1134                    {
1135                        self.peer.send(
1136                            stale_connection_id,
1137                            proto::ShutdownDevServer {
1138                                reason: Some(
1139                                    "another dev server connected with the same token".to_string(),
1140                                ),
1141                            },
1142                        )?;
1143                        pool.remove_connection(stale_connection_id)?;
1144                    };
1145                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1146                }
1147
1148                let projects = self
1149                    .app_state
1150                    .db
1151                    .get_projects_for_dev_server(dev_server.id)
1152                    .await?;
1153                self.peer
1154                    .send(connection_id, proto::DevServerInstructions { projects })?;
1155
1156                let status = self
1157                    .app_state
1158                    .db
1159                    .dev_server_projects_update(dev_server.user_id)
1160                    .await?;
1161                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1162            }
1163        }
1164
1165        Ok(())
1166    }
1167
1168    pub async fn invite_code_redeemed(
1169        self: &Arc<Self>,
1170        inviter_id: UserId,
1171        invitee_id: UserId,
1172    ) -> Result<()> {
1173        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1174            if let Some(code) = &user.invite_code {
1175                let pool = self.connection_pool.lock();
1176                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1177                for connection_id in pool.user_connection_ids(inviter_id) {
1178                    self.peer.send(
1179                        connection_id,
1180                        proto::UpdateContacts {
1181                            contacts: vec![invitee_contact.clone()],
1182                            ..Default::default()
1183                        },
1184                    )?;
1185                    self.peer.send(
1186                        connection_id,
1187                        proto::UpdateInviteInfo {
1188                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1189                            count: user.invite_count as u32,
1190                        },
1191                    )?;
1192                }
1193            }
1194        }
1195        Ok(())
1196    }
1197
1198    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1199        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1200            if let Some(invite_code) = &user.invite_code {
1201                let pool = self.connection_pool.lock();
1202                for connection_id in pool.user_connection_ids(user_id) {
1203                    self.peer.send(
1204                        connection_id,
1205                        proto::UpdateInviteInfo {
1206                            url: format!(
1207                                "{}{}",
1208                                self.app_state.config.invite_link_prefix, invite_code
1209                            ),
1210                            count: user.invite_count as u32,
1211                        },
1212                    )?;
1213                }
1214            }
1215        }
1216        Ok(())
1217    }
1218
1219    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1220        ServerSnapshot {
1221            connection_pool: ConnectionPoolGuard {
1222                guard: self.connection_pool.lock(),
1223                _not_send: PhantomData,
1224            },
1225            peer: &self.peer,
1226        }
1227    }
1228}
1229
1230impl<'a> Deref for ConnectionPoolGuard<'a> {
1231    type Target = ConnectionPool;
1232
1233    fn deref(&self) -> &Self::Target {
1234        &self.guard
1235    }
1236}
1237
1238impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1239    fn deref_mut(&mut self) -> &mut Self::Target {
1240        &mut self.guard
1241    }
1242}
1243
1244impl<'a> Drop for ConnectionPoolGuard<'a> {
1245    fn drop(&mut self) {
1246        #[cfg(test)]
1247        self.check_invariants();
1248    }
1249}
1250
1251fn broadcast<F>(
1252    sender_id: Option<ConnectionId>,
1253    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1254    mut f: F,
1255) where
1256    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1257{
1258    for receiver_id in receiver_ids {
1259        if Some(receiver_id) != sender_id {
1260            if let Err(error) = f(receiver_id) {
1261                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1262            }
1263        }
1264    }
1265}
1266
1267pub struct ProtocolVersion(u32);
1268
1269impl Header for ProtocolVersion {
1270    fn name() -> &'static HeaderName {
1271        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1272        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1273    }
1274
1275    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1276    where
1277        Self: Sized,
1278        I: Iterator<Item = &'i axum::http::HeaderValue>,
1279    {
1280        let version = values
1281            .next()
1282            .ok_or_else(axum::headers::Error::invalid)?
1283            .to_str()
1284            .map_err(|_| axum::headers::Error::invalid())?
1285            .parse()
1286            .map_err(|_| axum::headers::Error::invalid())?;
1287        Ok(Self(version))
1288    }
1289
1290    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1291        values.extend([self.0.to_string().parse().unwrap()]);
1292    }
1293}
1294
1295pub struct AppVersionHeader(SemanticVersion);
1296impl Header for AppVersionHeader {
1297    fn name() -> &'static HeaderName {
1298        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1299        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1300    }
1301
1302    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1303    where
1304        Self: Sized,
1305        I: Iterator<Item = &'i axum::http::HeaderValue>,
1306    {
1307        let version = values
1308            .next()
1309            .ok_or_else(axum::headers::Error::invalid)?
1310            .to_str()
1311            .map_err(|_| axum::headers::Error::invalid())?
1312            .parse()
1313            .map_err(|_| axum::headers::Error::invalid())?;
1314        Ok(Self(version))
1315    }
1316
1317    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1318        values.extend([self.0.to_string().parse().unwrap()]);
1319    }
1320}
1321
1322pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1323    Router::new()
1324        .route("/rpc", get(handle_websocket_request))
1325        .layer(
1326            ServiceBuilder::new()
1327                .layer(Extension(server.app_state.clone()))
1328                .layer(middleware::from_fn(auth::validate_header)),
1329        )
1330        .route("/metrics", get(handle_metrics))
1331        .layer(Extension(server))
1332}
1333
1334pub async fn handle_websocket_request(
1335    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1336    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1337    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1338    Extension(server): Extension<Arc<Server>>,
1339    Extension(principal): Extension<Principal>,
1340    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1341    ws: WebSocketUpgrade,
1342) -> axum::response::Response {
1343    if protocol_version != rpc::PROTOCOL_VERSION {
1344        return (
1345            StatusCode::UPGRADE_REQUIRED,
1346            "client must be upgraded".to_string(),
1347        )
1348            .into_response();
1349    }
1350
1351    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1352        return (
1353            StatusCode::UPGRADE_REQUIRED,
1354            "no version header found".to_string(),
1355        )
1356            .into_response();
1357    };
1358
1359    if !version.can_collaborate() {
1360        return (
1361            StatusCode::UPGRADE_REQUIRED,
1362            "client must be upgraded".to_string(),
1363        )
1364            .into_response();
1365    }
1366
1367    let socket_address = socket_address.to_string();
1368    ws.on_upgrade(move |socket| {
1369        let socket = socket
1370            .map_ok(to_tungstenite_message)
1371            .err_into()
1372            .with(|message| async move { to_axum_message(message) });
1373        let connection = Connection::new(Box::pin(socket));
1374        async move {
1375            server
1376                .handle_connection(
1377                    connection,
1378                    socket_address,
1379                    principal,
1380                    version,
1381                    country_code_header.map(|header| header.to_string()),
1382                    None,
1383                    Executor::Production,
1384                )
1385                .await;
1386        }
1387    })
1388}
1389
1390pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1391    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1392    let connections_metric = CONNECTIONS_METRIC
1393        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1394
1395    let connections = server
1396        .connection_pool
1397        .lock()
1398        .connections()
1399        .filter(|connection| !connection.admin)
1400        .count();
1401    connections_metric.set(connections as _);
1402
1403    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1404    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1405        register_int_gauge!(
1406            "shared_projects",
1407            "number of open projects with one or more guests"
1408        )
1409        .unwrap()
1410    });
1411
1412    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1413    shared_projects_metric.set(shared_projects as _);
1414
1415    let encoder = prometheus::TextEncoder::new();
1416    let metric_families = prometheus::gather();
1417    let encoded_metrics = encoder
1418        .encode_to_string(&metric_families)
1419        .map_err(|err| anyhow!("{}", err))?;
1420    Ok(encoded_metrics)
1421}
1422
1423#[instrument(err, skip(executor))]
1424async fn connection_lost(
1425    session: Session,
1426    mut teardown: watch::Receiver<bool>,
1427    executor: Executor,
1428) -> Result<()> {
1429    session.peer.disconnect(session.connection_id);
1430    session
1431        .connection_pool()
1432        .await
1433        .remove_connection(session.connection_id)?;
1434
1435    session
1436        .db()
1437        .await
1438        .connection_lost(session.connection_id)
1439        .await
1440        .trace_err();
1441
1442    futures::select_biased! {
1443        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1444            match &session.principal {
1445                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1446                    let session = session.for_user().unwrap();
1447
1448                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1449                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1450                    leave_channel_buffers_for_session(&session)
1451                        .await
1452                        .trace_err();
1453
1454                    if !session
1455                        .connection_pool()
1456                        .await
1457                        .is_user_online(session.user_id())
1458                    {
1459                        let db = session.db().await;
1460                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1461                            room_updated(&room, &session.peer);
1462                        }
1463                    }
1464
1465                    update_user_contacts(session.user_id(), &session).await?;
1466                },
1467            Principal::DevServer(_) => {
1468                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1469            },
1470        }
1471        },
1472        _ = teardown.changed().fuse() => {}
1473    }
1474
1475    Ok(())
1476}
1477
1478/// Acknowledges a ping from a client, used to keep the connection alive.
1479async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1480    response.send(proto::Ack {})?;
1481    Ok(())
1482}
1483
1484/// Creates a new room for calling (outside of channels)
1485async fn create_room(
1486    _request: proto::CreateRoom,
1487    response: Response<proto::CreateRoom>,
1488    session: UserSession,
1489) -> Result<()> {
1490    let live_kit_room = nanoid::nanoid!(30);
1491
1492    let live_kit_connection_info = util::maybe!(async {
1493        let live_kit = session.app_state.live_kit_client.as_ref();
1494        let live_kit = live_kit?;
1495        let user_id = session.user_id().to_string();
1496
1497        let token = live_kit
1498            .room_token(&live_kit_room, &user_id.to_string())
1499            .trace_err()?;
1500
1501        Some(proto::LiveKitConnectionInfo {
1502            server_url: live_kit.url().into(),
1503            token,
1504            can_publish: true,
1505        })
1506    })
1507    .await;
1508
1509    let room = session
1510        .db()
1511        .await
1512        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1513        .await?;
1514
1515    response.send(proto::CreateRoomResponse {
1516        room: Some(room.clone()),
1517        live_kit_connection_info,
1518    })?;
1519
1520    update_user_contacts(session.user_id(), &session).await?;
1521    Ok(())
1522}
1523
1524/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1525async fn join_room(
1526    request: proto::JoinRoom,
1527    response: Response<proto::JoinRoom>,
1528    session: UserSession,
1529) -> Result<()> {
1530    let room_id = RoomId::from_proto(request.id);
1531
1532    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1533
1534    if let Some(channel_id) = channel_id {
1535        return join_channel_internal(channel_id, Box::new(response), session).await;
1536    }
1537
1538    let joined_room = {
1539        let room = session
1540            .db()
1541            .await
1542            .join_room(room_id, session.user_id(), session.connection_id)
1543            .await?;
1544        room_updated(&room.room, &session.peer);
1545        room.into_inner()
1546    };
1547
1548    for connection_id in session
1549        .connection_pool()
1550        .await
1551        .user_connection_ids(session.user_id())
1552    {
1553        session
1554            .peer
1555            .send(
1556                connection_id,
1557                proto::CallCanceled {
1558                    room_id: room_id.to_proto(),
1559                },
1560            )
1561            .trace_err();
1562    }
1563
1564    let live_kit_connection_info =
1565        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1566            if let Some(token) = live_kit
1567                .room_token(
1568                    &joined_room.room.live_kit_room,
1569                    &session.user_id().to_string(),
1570                )
1571                .trace_err()
1572            {
1573                Some(proto::LiveKitConnectionInfo {
1574                    server_url: live_kit.url().into(),
1575                    token,
1576                    can_publish: true,
1577                })
1578            } else {
1579                None
1580            }
1581        } else {
1582            None
1583        };
1584
1585    response.send(proto::JoinRoomResponse {
1586        room: Some(joined_room.room),
1587        channel_id: None,
1588        live_kit_connection_info,
1589    })?;
1590
1591    update_user_contacts(session.user_id(), &session).await?;
1592    Ok(())
1593}
1594
1595/// Rejoin room is used to reconnect to a room after connection errors.
1596async fn rejoin_room(
1597    request: proto::RejoinRoom,
1598    response: Response<proto::RejoinRoom>,
1599    session: UserSession,
1600) -> Result<()> {
1601    let room;
1602    let channel;
1603    {
1604        let mut rejoined_room = session
1605            .db()
1606            .await
1607            .rejoin_room(request, session.user_id(), session.connection_id)
1608            .await?;
1609
1610        response.send(proto::RejoinRoomResponse {
1611            room: Some(rejoined_room.room.clone()),
1612            reshared_projects: rejoined_room
1613                .reshared_projects
1614                .iter()
1615                .map(|project| proto::ResharedProject {
1616                    id: project.id.to_proto(),
1617                    collaborators: project
1618                        .collaborators
1619                        .iter()
1620                        .map(|collaborator| collaborator.to_proto())
1621                        .collect(),
1622                })
1623                .collect(),
1624            rejoined_projects: rejoined_room
1625                .rejoined_projects
1626                .iter()
1627                .map(|rejoined_project| rejoined_project.to_proto())
1628                .collect(),
1629        })?;
1630        room_updated(&rejoined_room.room, &session.peer);
1631
1632        for project in &rejoined_room.reshared_projects {
1633            for collaborator in &project.collaborators {
1634                session
1635                    .peer
1636                    .send(
1637                        collaborator.connection_id,
1638                        proto::UpdateProjectCollaborator {
1639                            project_id: project.id.to_proto(),
1640                            old_peer_id: Some(project.old_connection_id.into()),
1641                            new_peer_id: Some(session.connection_id.into()),
1642                        },
1643                    )
1644                    .trace_err();
1645            }
1646
1647            broadcast(
1648                Some(session.connection_id),
1649                project
1650                    .collaborators
1651                    .iter()
1652                    .map(|collaborator| collaborator.connection_id),
1653                |connection_id| {
1654                    session.peer.forward_send(
1655                        session.connection_id,
1656                        connection_id,
1657                        proto::UpdateProject {
1658                            project_id: project.id.to_proto(),
1659                            worktrees: project.worktrees.clone(),
1660                        },
1661                    )
1662                },
1663            );
1664        }
1665
1666        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1667
1668        let rejoined_room = rejoined_room.into_inner();
1669
1670        room = rejoined_room.room;
1671        channel = rejoined_room.channel;
1672    }
1673
1674    if let Some(channel) = channel {
1675        channel_updated(
1676            &channel,
1677            &room,
1678            &session.peer,
1679            &*session.connection_pool().await,
1680        );
1681    }
1682
1683    update_user_contacts(session.user_id(), &session).await?;
1684    Ok(())
1685}
1686
1687fn notify_rejoined_projects(
1688    rejoined_projects: &mut Vec<RejoinedProject>,
1689    session: &UserSession,
1690) -> Result<()> {
1691    for project in rejoined_projects.iter() {
1692        for collaborator in &project.collaborators {
1693            session
1694                .peer
1695                .send(
1696                    collaborator.connection_id,
1697                    proto::UpdateProjectCollaborator {
1698                        project_id: project.id.to_proto(),
1699                        old_peer_id: Some(project.old_connection_id.into()),
1700                        new_peer_id: Some(session.connection_id.into()),
1701                    },
1702                )
1703                .trace_err();
1704        }
1705    }
1706
1707    for project in rejoined_projects {
1708        for worktree in mem::take(&mut project.worktrees) {
1709            #[cfg(any(test, feature = "test-support"))]
1710            const MAX_CHUNK_SIZE: usize = 2;
1711            #[cfg(not(any(test, feature = "test-support")))]
1712            const MAX_CHUNK_SIZE: usize = 256;
1713
1714            // Stream this worktree's entries.
1715            let message = proto::UpdateWorktree {
1716                project_id: project.id.to_proto(),
1717                worktree_id: worktree.id,
1718                abs_path: worktree.abs_path.clone(),
1719                root_name: worktree.root_name,
1720                updated_entries: worktree.updated_entries,
1721                removed_entries: worktree.removed_entries,
1722                scan_id: worktree.scan_id,
1723                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1724                updated_repositories: worktree.updated_repositories,
1725                removed_repositories: worktree.removed_repositories,
1726            };
1727            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1728                session.peer.send(session.connection_id, update.clone())?;
1729            }
1730
1731            // Stream this worktree's diagnostics.
1732            for summary in worktree.diagnostic_summaries {
1733                session.peer.send(
1734                    session.connection_id,
1735                    proto::UpdateDiagnosticSummary {
1736                        project_id: project.id.to_proto(),
1737                        worktree_id: worktree.id,
1738                        summary: Some(summary),
1739                    },
1740                )?;
1741            }
1742
1743            for settings_file in worktree.settings_files {
1744                session.peer.send(
1745                    session.connection_id,
1746                    proto::UpdateWorktreeSettings {
1747                        project_id: project.id.to_proto(),
1748                        worktree_id: worktree.id,
1749                        path: settings_file.path,
1750                        content: Some(settings_file.content),
1751                    },
1752                )?;
1753            }
1754        }
1755
1756        for language_server in &project.language_servers {
1757            session.peer.send(
1758                session.connection_id,
1759                proto::UpdateLanguageServer {
1760                    project_id: project.id.to_proto(),
1761                    language_server_id: language_server.id,
1762                    variant: Some(
1763                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1764                            proto::LspDiskBasedDiagnosticsUpdated {},
1765                        ),
1766                    ),
1767                },
1768            )?;
1769        }
1770    }
1771    Ok(())
1772}
1773
1774/// leave room disconnects from the room.
1775async fn leave_room(
1776    _: proto::LeaveRoom,
1777    response: Response<proto::LeaveRoom>,
1778    session: UserSession,
1779) -> Result<()> {
1780    leave_room_for_session(&session, session.connection_id).await?;
1781    response.send(proto::Ack {})?;
1782    Ok(())
1783}
1784
1785/// Updates the permissions of someone else in the room.
1786async fn set_room_participant_role(
1787    request: proto::SetRoomParticipantRole,
1788    response: Response<proto::SetRoomParticipantRole>,
1789    session: UserSession,
1790) -> Result<()> {
1791    let user_id = UserId::from_proto(request.user_id);
1792    let role = ChannelRole::from(request.role());
1793
1794    let (live_kit_room, can_publish) = {
1795        let room = session
1796            .db()
1797            .await
1798            .set_room_participant_role(
1799                session.user_id(),
1800                RoomId::from_proto(request.room_id),
1801                user_id,
1802                role,
1803            )
1804            .await?;
1805
1806        let live_kit_room = room.live_kit_room.clone();
1807        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1808        room_updated(&room, &session.peer);
1809        (live_kit_room, can_publish)
1810    };
1811
1812    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1813        live_kit
1814            .update_participant(
1815                live_kit_room.clone(),
1816                request.user_id.to_string(),
1817                live_kit_server::proto::ParticipantPermission {
1818                    can_subscribe: true,
1819                    can_publish,
1820                    can_publish_data: can_publish,
1821                    hidden: false,
1822                    recorder: false,
1823                },
1824            )
1825            .await
1826            .trace_err();
1827    }
1828
1829    response.send(proto::Ack {})?;
1830    Ok(())
1831}
1832
1833/// Call someone else into the current room
1834async fn call(
1835    request: proto::Call,
1836    response: Response<proto::Call>,
1837    session: UserSession,
1838) -> Result<()> {
1839    let room_id = RoomId::from_proto(request.room_id);
1840    let calling_user_id = session.user_id();
1841    let calling_connection_id = session.connection_id;
1842    let called_user_id = UserId::from_proto(request.called_user_id);
1843    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1844    if !session
1845        .db()
1846        .await
1847        .has_contact(calling_user_id, called_user_id)
1848        .await?
1849    {
1850        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1851    }
1852
1853    let incoming_call = {
1854        let (room, incoming_call) = &mut *session
1855            .db()
1856            .await
1857            .call(
1858                room_id,
1859                calling_user_id,
1860                calling_connection_id,
1861                called_user_id,
1862                initial_project_id,
1863            )
1864            .await?;
1865        room_updated(&room, &session.peer);
1866        mem::take(incoming_call)
1867    };
1868    update_user_contacts(called_user_id, &session).await?;
1869
1870    let mut calls = session
1871        .connection_pool()
1872        .await
1873        .user_connection_ids(called_user_id)
1874        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1875        .collect::<FuturesUnordered<_>>();
1876
1877    while let Some(call_response) = calls.next().await {
1878        match call_response.as_ref() {
1879            Ok(_) => {
1880                response.send(proto::Ack {})?;
1881                return Ok(());
1882            }
1883            Err(_) => {
1884                call_response.trace_err();
1885            }
1886        }
1887    }
1888
1889    {
1890        let room = session
1891            .db()
1892            .await
1893            .call_failed(room_id, called_user_id)
1894            .await?;
1895        room_updated(&room, &session.peer);
1896    }
1897    update_user_contacts(called_user_id, &session).await?;
1898
1899    Err(anyhow!("failed to ring user"))?
1900}
1901
1902/// Cancel an outgoing call.
1903async fn cancel_call(
1904    request: proto::CancelCall,
1905    response: Response<proto::CancelCall>,
1906    session: UserSession,
1907) -> Result<()> {
1908    let called_user_id = UserId::from_proto(request.called_user_id);
1909    let room_id = RoomId::from_proto(request.room_id);
1910    {
1911        let room = session
1912            .db()
1913            .await
1914            .cancel_call(room_id, session.connection_id, called_user_id)
1915            .await?;
1916        room_updated(&room, &session.peer);
1917    }
1918
1919    for connection_id in session
1920        .connection_pool()
1921        .await
1922        .user_connection_ids(called_user_id)
1923    {
1924        session
1925            .peer
1926            .send(
1927                connection_id,
1928                proto::CallCanceled {
1929                    room_id: room_id.to_proto(),
1930                },
1931            )
1932            .trace_err();
1933    }
1934    response.send(proto::Ack {})?;
1935
1936    update_user_contacts(called_user_id, &session).await?;
1937    Ok(())
1938}
1939
1940/// Decline an incoming call.
1941async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1942    let room_id = RoomId::from_proto(message.room_id);
1943    {
1944        let room = session
1945            .db()
1946            .await
1947            .decline_call(Some(room_id), session.user_id())
1948            .await?
1949            .ok_or_else(|| anyhow!("failed to decline call"))?;
1950        room_updated(&room, &session.peer);
1951    }
1952
1953    for connection_id in session
1954        .connection_pool()
1955        .await
1956        .user_connection_ids(session.user_id())
1957    {
1958        session
1959            .peer
1960            .send(
1961                connection_id,
1962                proto::CallCanceled {
1963                    room_id: room_id.to_proto(),
1964                },
1965            )
1966            .trace_err();
1967    }
1968    update_user_contacts(session.user_id(), &session).await?;
1969    Ok(())
1970}
1971
1972/// Updates other participants in the room with your current location.
1973async fn update_participant_location(
1974    request: proto::UpdateParticipantLocation,
1975    response: Response<proto::UpdateParticipantLocation>,
1976    session: UserSession,
1977) -> Result<()> {
1978    let room_id = RoomId::from_proto(request.room_id);
1979    let location = request
1980        .location
1981        .ok_or_else(|| anyhow!("invalid location"))?;
1982
1983    let db = session.db().await;
1984    let room = db
1985        .update_room_participant_location(room_id, session.connection_id, location)
1986        .await?;
1987
1988    room_updated(&room, &session.peer);
1989    response.send(proto::Ack {})?;
1990    Ok(())
1991}
1992
1993/// Share a project into the room.
1994async fn share_project(
1995    request: proto::ShareProject,
1996    response: Response<proto::ShareProject>,
1997    session: UserSession,
1998) -> Result<()> {
1999    let (project_id, room) = &*session
2000        .db()
2001        .await
2002        .share_project(
2003            RoomId::from_proto(request.room_id),
2004            session.connection_id,
2005            &request.worktrees,
2006            request
2007                .dev_server_project_id
2008                .map(|id| DevServerProjectId::from_proto(id)),
2009        )
2010        .await?;
2011    response.send(proto::ShareProjectResponse {
2012        project_id: project_id.to_proto(),
2013    })?;
2014    room_updated(&room, &session.peer);
2015
2016    Ok(())
2017}
2018
2019/// Unshare a project from the room.
2020async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2021    let project_id = ProjectId::from_proto(message.project_id);
2022    unshare_project_internal(
2023        project_id,
2024        session.connection_id,
2025        session.user_id(),
2026        &session,
2027    )
2028    .await
2029}
2030
2031async fn unshare_project_internal(
2032    project_id: ProjectId,
2033    connection_id: ConnectionId,
2034    user_id: Option<UserId>,
2035    session: &Session,
2036) -> Result<()> {
2037    let delete = {
2038        let room_guard = session
2039            .db()
2040            .await
2041            .unshare_project(project_id, connection_id, user_id)
2042            .await?;
2043
2044        let (delete, room, guest_connection_ids) = &*room_guard;
2045
2046        let message = proto::UnshareProject {
2047            project_id: project_id.to_proto(),
2048        };
2049
2050        broadcast(
2051            Some(connection_id),
2052            guest_connection_ids.iter().copied(),
2053            |conn_id| session.peer.send(conn_id, message.clone()),
2054        );
2055        if let Some(room) = room {
2056            room_updated(room, &session.peer);
2057        }
2058
2059        *delete
2060    };
2061
2062    if delete {
2063        let db = session.db().await;
2064        db.delete_project(project_id).await?;
2065    }
2066
2067    Ok(())
2068}
2069
2070/// DevServer makes a project available online
2071async fn share_dev_server_project(
2072    request: proto::ShareDevServerProject,
2073    response: Response<proto::ShareDevServerProject>,
2074    session: DevServerSession,
2075) -> Result<()> {
2076    let (dev_server_project, user_id, status) = session
2077        .db()
2078        .await
2079        .share_dev_server_project(
2080            DevServerProjectId::from_proto(request.dev_server_project_id),
2081            session.dev_server_id(),
2082            session.connection_id,
2083            &request.worktrees,
2084        )
2085        .await?;
2086    let Some(project_id) = dev_server_project.project_id else {
2087        return Err(anyhow!("failed to share remote project"))?;
2088    };
2089
2090    send_dev_server_projects_update(user_id, status, &session).await;
2091
2092    response.send(proto::ShareProjectResponse { project_id })?;
2093
2094    Ok(())
2095}
2096
2097/// Join someone elses shared project.
2098async fn join_project(
2099    request: proto::JoinProject,
2100    response: Response<proto::JoinProject>,
2101    session: UserSession,
2102) -> Result<()> {
2103    let project_id = ProjectId::from_proto(request.project_id);
2104
2105    tracing::info!(%project_id, "join project");
2106
2107    let db = session.db().await;
2108    let (project, replica_id) = &mut *db
2109        .join_project(project_id, session.connection_id, session.user_id())
2110        .await?;
2111    drop(db);
2112    tracing::info!(%project_id, "join remote project");
2113    join_project_internal(response, session, project, replica_id)
2114}
2115
2116trait JoinProjectInternalResponse {
2117    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2118}
2119impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2120    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2121        Response::<proto::JoinProject>::send(self, result)
2122    }
2123}
2124impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2125    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2126        Response::<proto::JoinHostedProject>::send(self, result)
2127    }
2128}
2129
2130fn join_project_internal(
2131    response: impl JoinProjectInternalResponse,
2132    session: UserSession,
2133    project: &mut Project,
2134    replica_id: &ReplicaId,
2135) -> Result<()> {
2136    let collaborators = project
2137        .collaborators
2138        .iter()
2139        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2140        .map(|collaborator| collaborator.to_proto())
2141        .collect::<Vec<_>>();
2142    let project_id = project.id;
2143    let guest_user_id = session.user_id();
2144
2145    let worktrees = project
2146        .worktrees
2147        .iter()
2148        .map(|(id, worktree)| proto::WorktreeMetadata {
2149            id: *id,
2150            root_name: worktree.root_name.clone(),
2151            visible: worktree.visible,
2152            abs_path: worktree.abs_path.clone(),
2153        })
2154        .collect::<Vec<_>>();
2155
2156    let add_project_collaborator = proto::AddProjectCollaborator {
2157        project_id: project_id.to_proto(),
2158        collaborator: Some(proto::Collaborator {
2159            peer_id: Some(session.connection_id.into()),
2160            replica_id: replica_id.0 as u32,
2161            user_id: guest_user_id.to_proto(),
2162        }),
2163    };
2164
2165    for collaborator in &collaborators {
2166        session
2167            .peer
2168            .send(
2169                collaborator.peer_id.unwrap().into(),
2170                add_project_collaborator.clone(),
2171            )
2172            .trace_err();
2173    }
2174
2175    // First, we send the metadata associated with each worktree.
2176    response.send(proto::JoinProjectResponse {
2177        project_id: project.id.0 as u64,
2178        worktrees: worktrees.clone(),
2179        replica_id: replica_id.0 as u32,
2180        collaborators: collaborators.clone(),
2181        language_servers: project.language_servers.clone(),
2182        role: project.role.into(),
2183        dev_server_project_id: project
2184            .dev_server_project_id
2185            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2186    })?;
2187
2188    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2189        #[cfg(any(test, feature = "test-support"))]
2190        const MAX_CHUNK_SIZE: usize = 2;
2191        #[cfg(not(any(test, feature = "test-support")))]
2192        const MAX_CHUNK_SIZE: usize = 256;
2193
2194        // Stream this worktree's entries.
2195        let message = proto::UpdateWorktree {
2196            project_id: project_id.to_proto(),
2197            worktree_id,
2198            abs_path: worktree.abs_path.clone(),
2199            root_name: worktree.root_name,
2200            updated_entries: worktree.entries,
2201            removed_entries: Default::default(),
2202            scan_id: worktree.scan_id,
2203            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2204            updated_repositories: worktree.repository_entries.into_values().collect(),
2205            removed_repositories: Default::default(),
2206        };
2207        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2208            session.peer.send(session.connection_id, update.clone())?;
2209        }
2210
2211        // Stream this worktree's diagnostics.
2212        for summary in worktree.diagnostic_summaries {
2213            session.peer.send(
2214                session.connection_id,
2215                proto::UpdateDiagnosticSummary {
2216                    project_id: project_id.to_proto(),
2217                    worktree_id: worktree.id,
2218                    summary: Some(summary),
2219                },
2220            )?;
2221        }
2222
2223        for settings_file in worktree.settings_files {
2224            session.peer.send(
2225                session.connection_id,
2226                proto::UpdateWorktreeSettings {
2227                    project_id: project_id.to_proto(),
2228                    worktree_id: worktree.id,
2229                    path: settings_file.path,
2230                    content: Some(settings_file.content),
2231                },
2232            )?;
2233        }
2234    }
2235
2236    for language_server in &project.language_servers {
2237        session.peer.send(
2238            session.connection_id,
2239            proto::UpdateLanguageServer {
2240                project_id: project_id.to_proto(),
2241                language_server_id: language_server.id,
2242                variant: Some(
2243                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2244                        proto::LspDiskBasedDiagnosticsUpdated {},
2245                    ),
2246                ),
2247            },
2248        )?;
2249    }
2250
2251    Ok(())
2252}
2253
2254/// Leave someone elses shared project.
2255async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2256    let sender_id = session.connection_id;
2257    let project_id = ProjectId::from_proto(request.project_id);
2258    let db = session.db().await;
2259    if db.is_hosted_project(project_id).await? {
2260        let project = db.leave_hosted_project(project_id, sender_id).await?;
2261        project_left(&project, &session);
2262        return Ok(());
2263    }
2264
2265    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2266    tracing::info!(
2267        %project_id,
2268        "leave project"
2269    );
2270
2271    project_left(&project, &session);
2272    if let Some(room) = room {
2273        room_updated(&room, &session.peer);
2274    }
2275
2276    Ok(())
2277}
2278
2279async fn join_hosted_project(
2280    request: proto::JoinHostedProject,
2281    response: Response<proto::JoinHostedProject>,
2282    session: UserSession,
2283) -> Result<()> {
2284    let (mut project, replica_id) = session
2285        .db()
2286        .await
2287        .join_hosted_project(
2288            ProjectId(request.project_id as i32),
2289            session.user_id(),
2290            session.connection_id,
2291        )
2292        .await?;
2293
2294    join_project_internal(response, session, &mut project, &replica_id)
2295}
2296
2297async fn list_remote_directory(
2298    request: proto::ListRemoteDirectory,
2299    response: Response<proto::ListRemoteDirectory>,
2300    session: UserSession,
2301) -> Result<()> {
2302    let dev_server_id = DevServerId(request.dev_server_id as i32);
2303    let dev_server_connection_id = session
2304        .connection_pool()
2305        .await
2306        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2307
2308    session
2309        .db()
2310        .await
2311        .get_dev_server_for_user(dev_server_id, session.user_id())
2312        .await?;
2313
2314    response.send(
2315        session
2316            .peer
2317            .forward_request(session.connection_id, dev_server_connection_id, request)
2318            .await?,
2319    )?;
2320    Ok(())
2321}
2322
2323async fn update_dev_server_project(
2324    request: proto::UpdateDevServerProject,
2325    response: Response<proto::UpdateDevServerProject>,
2326    session: UserSession,
2327) -> Result<()> {
2328    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2329
2330    let (dev_server_project, update) = session
2331        .db()
2332        .await
2333        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2334        .await?;
2335
2336    let projects = session
2337        .db()
2338        .await
2339        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2340        .await?;
2341
2342    let dev_server_connection_id = session
2343        .connection_pool()
2344        .await
2345        .dev_server_connection_id_supporting(
2346            dev_server_project.dev_server_id,
2347            ZedVersion::with_list_directory(),
2348        )?;
2349
2350    session.peer.send(
2351        dev_server_connection_id,
2352        proto::DevServerInstructions { projects },
2353    )?;
2354
2355    send_dev_server_projects_update(session.user_id(), update, &session).await;
2356
2357    response.send(proto::Ack {})
2358}
2359
2360async fn create_dev_server_project(
2361    request: proto::CreateDevServerProject,
2362    response: Response<proto::CreateDevServerProject>,
2363    session: UserSession,
2364) -> Result<()> {
2365    let dev_server_id = DevServerId(request.dev_server_id as i32);
2366    let dev_server_connection_id = session
2367        .connection_pool()
2368        .await
2369        .dev_server_connection_id(dev_server_id);
2370    let Some(dev_server_connection_id) = dev_server_connection_id else {
2371        Err(ErrorCode::DevServerOffline
2372            .message("Cannot create a remote project when the dev server is offline".to_string())
2373            .anyhow())?
2374    };
2375
2376    let path = request.path.clone();
2377    //Check that the path exists on the dev server
2378    session
2379        .peer
2380        .forward_request(
2381            session.connection_id,
2382            dev_server_connection_id,
2383            proto::ValidateDevServerProjectRequest { path: path.clone() },
2384        )
2385        .await?;
2386
2387    let (dev_server_project, update) = session
2388        .db()
2389        .await
2390        .create_dev_server_project(
2391            DevServerId(request.dev_server_id as i32),
2392            &request.path,
2393            session.user_id(),
2394        )
2395        .await?;
2396
2397    let projects = session
2398        .db()
2399        .await
2400        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2401        .await?;
2402
2403    session.peer.send(
2404        dev_server_connection_id,
2405        proto::DevServerInstructions { projects },
2406    )?;
2407
2408    send_dev_server_projects_update(session.user_id(), update, &session).await;
2409
2410    response.send(proto::CreateDevServerProjectResponse {
2411        dev_server_project: Some(dev_server_project.to_proto(None)),
2412    })?;
2413    Ok(())
2414}
2415
2416async fn create_dev_server(
2417    request: proto::CreateDevServer,
2418    response: Response<proto::CreateDevServer>,
2419    session: UserSession,
2420) -> Result<()> {
2421    let access_token = auth::random_token();
2422    let hashed_access_token = auth::hash_access_token(&access_token);
2423
2424    if request.name.is_empty() {
2425        return Err(proto::ErrorCode::Forbidden
2426            .message("Dev server name cannot be empty".to_string())
2427            .anyhow())?;
2428    }
2429
2430    let (dev_server, status) = session
2431        .db()
2432        .await
2433        .create_dev_server(
2434            &request.name,
2435            request.ssh_connection_string.as_deref(),
2436            &hashed_access_token,
2437            session.user_id(),
2438        )
2439        .await?;
2440
2441    send_dev_server_projects_update(session.user_id(), status, &session).await;
2442
2443    response.send(proto::CreateDevServerResponse {
2444        dev_server_id: dev_server.id.0 as u64,
2445        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2446        name: request.name,
2447    })?;
2448    Ok(())
2449}
2450
2451async fn regenerate_dev_server_token(
2452    request: proto::RegenerateDevServerToken,
2453    response: Response<proto::RegenerateDevServerToken>,
2454    session: UserSession,
2455) -> Result<()> {
2456    let dev_server_id = DevServerId(request.dev_server_id as i32);
2457    let access_token = auth::random_token();
2458    let hashed_access_token = auth::hash_access_token(&access_token);
2459
2460    let connection_id = session
2461        .connection_pool()
2462        .await
2463        .dev_server_connection_id(dev_server_id);
2464    if let Some(connection_id) = connection_id {
2465        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2466        session.peer.send(
2467            connection_id,
2468            proto::ShutdownDevServer {
2469                reason: Some("dev server token was regenerated".to_string()),
2470            },
2471        )?;
2472        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2473    }
2474
2475    let status = session
2476        .db()
2477        .await
2478        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2479        .await?;
2480
2481    send_dev_server_projects_update(session.user_id(), status, &session).await;
2482
2483    response.send(proto::RegenerateDevServerTokenResponse {
2484        dev_server_id: dev_server_id.to_proto(),
2485        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2486    })?;
2487    Ok(())
2488}
2489
2490async fn rename_dev_server(
2491    request: proto::RenameDevServer,
2492    response: Response<proto::RenameDevServer>,
2493    session: UserSession,
2494) -> Result<()> {
2495    if request.name.trim().is_empty() {
2496        return Err(proto::ErrorCode::Forbidden
2497            .message("Dev server name cannot be empty".to_string())
2498            .anyhow())?;
2499    }
2500
2501    let dev_server_id = DevServerId(request.dev_server_id as i32);
2502    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2503    if dev_server.user_id != session.user_id() {
2504        return Err(anyhow!(ErrorCode::Forbidden))?;
2505    }
2506
2507    let status = session
2508        .db()
2509        .await
2510        .rename_dev_server(
2511            dev_server_id,
2512            &request.name,
2513            request.ssh_connection_string.as_deref(),
2514            session.user_id(),
2515        )
2516        .await?;
2517
2518    send_dev_server_projects_update(session.user_id(), status, &session).await;
2519
2520    response.send(proto::Ack {})?;
2521    Ok(())
2522}
2523
2524async fn delete_dev_server(
2525    request: proto::DeleteDevServer,
2526    response: Response<proto::DeleteDevServer>,
2527    session: UserSession,
2528) -> Result<()> {
2529    let dev_server_id = DevServerId(request.dev_server_id as i32);
2530    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2531    if dev_server.user_id != session.user_id() {
2532        return Err(anyhow!(ErrorCode::Forbidden))?;
2533    }
2534
2535    let connection_id = session
2536        .connection_pool()
2537        .await
2538        .dev_server_connection_id(dev_server_id);
2539    if let Some(connection_id) = connection_id {
2540        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2541        session.peer.send(
2542            connection_id,
2543            proto::ShutdownDevServer {
2544                reason: Some("dev server was deleted".to_string()),
2545            },
2546        )?;
2547        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2548    }
2549
2550    let status = session
2551        .db()
2552        .await
2553        .delete_dev_server(dev_server_id, session.user_id())
2554        .await?;
2555
2556    send_dev_server_projects_update(session.user_id(), status, &session).await;
2557
2558    response.send(proto::Ack {})?;
2559    Ok(())
2560}
2561
2562async fn delete_dev_server_project(
2563    request: proto::DeleteDevServerProject,
2564    response: Response<proto::DeleteDevServerProject>,
2565    session: UserSession,
2566) -> Result<()> {
2567    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2568    let dev_server_project = session
2569        .db()
2570        .await
2571        .get_dev_server_project(dev_server_project_id)
2572        .await?;
2573
2574    let dev_server = session
2575        .db()
2576        .await
2577        .get_dev_server(dev_server_project.dev_server_id)
2578        .await?;
2579    if dev_server.user_id != session.user_id() {
2580        return Err(anyhow!(ErrorCode::Forbidden))?;
2581    }
2582
2583    let dev_server_connection_id = session
2584        .connection_pool()
2585        .await
2586        .dev_server_connection_id(dev_server.id);
2587
2588    if let Some(dev_server_connection_id) = dev_server_connection_id {
2589        let project = session
2590            .db()
2591            .await
2592            .find_dev_server_project(dev_server_project_id)
2593            .await;
2594        if let Ok(project) = project {
2595            unshare_project_internal(
2596                project.id,
2597                dev_server_connection_id,
2598                Some(session.user_id()),
2599                &session,
2600            )
2601            .await?;
2602        }
2603    }
2604
2605    let (projects, status) = session
2606        .db()
2607        .await
2608        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2609        .await?;
2610
2611    if let Some(dev_server_connection_id) = dev_server_connection_id {
2612        session.peer.send(
2613            dev_server_connection_id,
2614            proto::DevServerInstructions { projects },
2615        )?;
2616    }
2617
2618    send_dev_server_projects_update(session.user_id(), status, &session).await;
2619
2620    response.send(proto::Ack {})?;
2621    Ok(())
2622}
2623
2624async fn rejoin_dev_server_projects(
2625    request: proto::RejoinRemoteProjects,
2626    response: Response<proto::RejoinRemoteProjects>,
2627    session: UserSession,
2628) -> Result<()> {
2629    let mut rejoined_projects = {
2630        let db = session.db().await;
2631        db.rejoin_dev_server_projects(
2632            &request.rejoined_projects,
2633            session.user_id(),
2634            session.0.connection_id,
2635        )
2636        .await?
2637    };
2638    response.send(proto::RejoinRemoteProjectsResponse {
2639        rejoined_projects: rejoined_projects
2640            .iter()
2641            .map(|project| project.to_proto())
2642            .collect(),
2643    })?;
2644    notify_rejoined_projects(&mut rejoined_projects, &session)
2645}
2646
2647async fn reconnect_dev_server(
2648    request: proto::ReconnectDevServer,
2649    response: Response<proto::ReconnectDevServer>,
2650    session: DevServerSession,
2651) -> Result<()> {
2652    let reshared_projects = {
2653        let db = session.db().await;
2654        db.reshare_dev_server_projects(
2655            &request.reshared_projects,
2656            session.dev_server_id(),
2657            session.0.connection_id,
2658        )
2659        .await?
2660    };
2661
2662    for project in &reshared_projects {
2663        for collaborator in &project.collaborators {
2664            session
2665                .peer
2666                .send(
2667                    collaborator.connection_id,
2668                    proto::UpdateProjectCollaborator {
2669                        project_id: project.id.to_proto(),
2670                        old_peer_id: Some(project.old_connection_id.into()),
2671                        new_peer_id: Some(session.connection_id.into()),
2672                    },
2673                )
2674                .trace_err();
2675        }
2676
2677        broadcast(
2678            Some(session.connection_id),
2679            project
2680                .collaborators
2681                .iter()
2682                .map(|collaborator| collaborator.connection_id),
2683            |connection_id| {
2684                session.peer.forward_send(
2685                    session.connection_id,
2686                    connection_id,
2687                    proto::UpdateProject {
2688                        project_id: project.id.to_proto(),
2689                        worktrees: project.worktrees.clone(),
2690                    },
2691                )
2692            },
2693        );
2694    }
2695
2696    response.send(proto::ReconnectDevServerResponse {
2697        reshared_projects: reshared_projects
2698            .iter()
2699            .map(|project| proto::ResharedProject {
2700                id: project.id.to_proto(),
2701                collaborators: project
2702                    .collaborators
2703                    .iter()
2704                    .map(|collaborator| collaborator.to_proto())
2705                    .collect(),
2706            })
2707            .collect(),
2708    })?;
2709
2710    Ok(())
2711}
2712
2713async fn shutdown_dev_server(
2714    _: proto::ShutdownDevServer,
2715    response: Response<proto::ShutdownDevServer>,
2716    session: DevServerSession,
2717) -> Result<()> {
2718    response.send(proto::Ack {})?;
2719    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2720    remove_dev_server_connection(session.dev_server_id(), &session).await
2721}
2722
2723async fn shutdown_dev_server_internal(
2724    dev_server_id: DevServerId,
2725    connection_id: ConnectionId,
2726    session: &Session,
2727) -> Result<()> {
2728    let (dev_server_projects, dev_server) = {
2729        let db = session.db().await;
2730        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2731        let dev_server = db.get_dev_server(dev_server_id).await?;
2732        (dev_server_projects, dev_server)
2733    };
2734
2735    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2736        unshare_project_internal(
2737            ProjectId::from_proto(project_id),
2738            connection_id,
2739            None,
2740            session,
2741        )
2742        .await?;
2743    }
2744
2745    session
2746        .connection_pool()
2747        .await
2748        .set_dev_server_offline(dev_server_id);
2749
2750    let status = session
2751        .db()
2752        .await
2753        .dev_server_projects_update(dev_server.user_id)
2754        .await?;
2755    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2756
2757    Ok(())
2758}
2759
2760async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2761    let dev_server_connection = session
2762        .connection_pool()
2763        .await
2764        .dev_server_connection_id(dev_server_id);
2765
2766    if let Some(dev_server_connection) = dev_server_connection {
2767        session
2768            .connection_pool()
2769            .await
2770            .remove_connection(dev_server_connection)?;
2771    }
2772    Ok(())
2773}
2774
2775/// Updates other participants with changes to the project
2776async fn update_project(
2777    request: proto::UpdateProject,
2778    response: Response<proto::UpdateProject>,
2779    session: Session,
2780) -> Result<()> {
2781    let project_id = ProjectId::from_proto(request.project_id);
2782    let (room, guest_connection_ids) = &*session
2783        .db()
2784        .await
2785        .update_project(project_id, session.connection_id, &request.worktrees)
2786        .await?;
2787    broadcast(
2788        Some(session.connection_id),
2789        guest_connection_ids.iter().copied(),
2790        |connection_id| {
2791            session
2792                .peer
2793                .forward_send(session.connection_id, connection_id, request.clone())
2794        },
2795    );
2796    if let Some(room) = room {
2797        room_updated(&room, &session.peer);
2798    }
2799    response.send(proto::Ack {})?;
2800
2801    Ok(())
2802}
2803
2804/// Updates other participants with changes to the worktree
2805async fn update_worktree(
2806    request: proto::UpdateWorktree,
2807    response: Response<proto::UpdateWorktree>,
2808    session: Session,
2809) -> Result<()> {
2810    let guest_connection_ids = session
2811        .db()
2812        .await
2813        .update_worktree(&request, session.connection_id)
2814        .await?;
2815
2816    broadcast(
2817        Some(session.connection_id),
2818        guest_connection_ids.iter().copied(),
2819        |connection_id| {
2820            session
2821                .peer
2822                .forward_send(session.connection_id, connection_id, request.clone())
2823        },
2824    );
2825    response.send(proto::Ack {})?;
2826    Ok(())
2827}
2828
2829/// Updates other participants with changes to the diagnostics
2830async fn update_diagnostic_summary(
2831    message: proto::UpdateDiagnosticSummary,
2832    session: Session,
2833) -> Result<()> {
2834    let guest_connection_ids = session
2835        .db()
2836        .await
2837        .update_diagnostic_summary(&message, session.connection_id)
2838        .await?;
2839
2840    broadcast(
2841        Some(session.connection_id),
2842        guest_connection_ids.iter().copied(),
2843        |connection_id| {
2844            session
2845                .peer
2846                .forward_send(session.connection_id, connection_id, message.clone())
2847        },
2848    );
2849
2850    Ok(())
2851}
2852
2853/// Updates other participants with changes to the worktree settings
2854async fn update_worktree_settings(
2855    message: proto::UpdateWorktreeSettings,
2856    session: Session,
2857) -> Result<()> {
2858    let guest_connection_ids = session
2859        .db()
2860        .await
2861        .update_worktree_settings(&message, session.connection_id)
2862        .await?;
2863
2864    broadcast(
2865        Some(session.connection_id),
2866        guest_connection_ids.iter().copied(),
2867        |connection_id| {
2868            session
2869                .peer
2870                .forward_send(session.connection_id, connection_id, message.clone())
2871        },
2872    );
2873
2874    Ok(())
2875}
2876
2877/// Notify other participants that a  language server has started.
2878async fn start_language_server(
2879    request: proto::StartLanguageServer,
2880    session: Session,
2881) -> Result<()> {
2882    let guest_connection_ids = session
2883        .db()
2884        .await
2885        .start_language_server(&request, session.connection_id)
2886        .await?;
2887
2888    broadcast(
2889        Some(session.connection_id),
2890        guest_connection_ids.iter().copied(),
2891        |connection_id| {
2892            session
2893                .peer
2894                .forward_send(session.connection_id, connection_id, request.clone())
2895        },
2896    );
2897    Ok(())
2898}
2899
2900/// Notify other participants that a language server has changed.
2901async fn update_language_server(
2902    request: proto::UpdateLanguageServer,
2903    session: Session,
2904) -> Result<()> {
2905    let project_id = ProjectId::from_proto(request.project_id);
2906    let project_connection_ids = session
2907        .db()
2908        .await
2909        .project_connection_ids(project_id, session.connection_id, true)
2910        .await?;
2911    broadcast(
2912        Some(session.connection_id),
2913        project_connection_ids.iter().copied(),
2914        |connection_id| {
2915            session
2916                .peer
2917                .forward_send(session.connection_id, connection_id, request.clone())
2918        },
2919    );
2920    Ok(())
2921}
2922
2923/// forward a project request to the host. These requests should be read only
2924/// as guests are allowed to send them.
2925async fn forward_read_only_project_request<T>(
2926    request: T,
2927    response: Response<T>,
2928    session: UserSession,
2929) -> Result<()>
2930where
2931    T: EntityMessage + RequestMessage,
2932{
2933    let project_id = ProjectId::from_proto(request.remote_entity_id());
2934    let host_connection_id = session
2935        .db()
2936        .await
2937        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2938        .await?;
2939    let payload = session
2940        .peer
2941        .forward_request(session.connection_id, host_connection_id, request)
2942        .await?;
2943    response.send(payload)?;
2944    Ok(())
2945}
2946
2947/// forward a project request to the dev server. Only allowed
2948/// if it's your dev server.
2949async fn forward_project_request_for_owner<T>(
2950    request: T,
2951    response: Response<T>,
2952    session: UserSession,
2953) -> Result<()>
2954where
2955    T: EntityMessage + RequestMessage,
2956{
2957    let project_id = ProjectId::from_proto(request.remote_entity_id());
2958
2959    let host_connection_id = session
2960        .db()
2961        .await
2962        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2963        .await?;
2964    let payload = session
2965        .peer
2966        .forward_request(session.connection_id, host_connection_id, request)
2967        .await?;
2968    response.send(payload)?;
2969    Ok(())
2970}
2971
2972/// forward a project request to the host. These requests are disallowed
2973/// for guests.
2974async fn forward_mutating_project_request<T>(
2975    request: T,
2976    response: Response<T>,
2977    session: UserSession,
2978) -> Result<()>
2979where
2980    T: EntityMessage + RequestMessage,
2981{
2982    let project_id = ProjectId::from_proto(request.remote_entity_id());
2983
2984    let host_connection_id = session
2985        .db()
2986        .await
2987        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2988        .await?;
2989    let payload = session
2990        .peer
2991        .forward_request(session.connection_id, host_connection_id, request)
2992        .await?;
2993    response.send(payload)?;
2994    Ok(())
2995}
2996
2997/// forward a project request to the host. These requests are disallowed
2998/// for guests.
2999async fn forward_versioned_mutating_project_request<T>(
3000    request: T,
3001    response: Response<T>,
3002    session: UserSession,
3003) -> Result<()>
3004where
3005    T: EntityMessage + RequestMessage + VersionedMessage,
3006{
3007    let project_id = ProjectId::from_proto(request.remote_entity_id());
3008
3009    let host_connection_id = session
3010        .db()
3011        .await
3012        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3013        .await?;
3014    if let Some(host_version) = session
3015        .connection_pool()
3016        .await
3017        .connection(host_connection_id)
3018        .map(|c| c.zed_version)
3019    {
3020        if let Some(min_required_version) = request.required_host_version() {
3021            if min_required_version > host_version {
3022                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3023                    .with_tag("required", &min_required_version.to_string())))?;
3024            }
3025        }
3026    }
3027
3028    let payload = session
3029        .peer
3030        .forward_request(session.connection_id, host_connection_id, request)
3031        .await?;
3032    response.send(payload)?;
3033    Ok(())
3034}
3035
3036/// Notify other participants that a new buffer has been created
3037async fn create_buffer_for_peer(
3038    request: proto::CreateBufferForPeer,
3039    session: Session,
3040) -> Result<()> {
3041    session
3042        .db()
3043        .await
3044        .check_user_is_project_host(
3045            ProjectId::from_proto(request.project_id),
3046            session.connection_id,
3047        )
3048        .await?;
3049    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3050    session
3051        .peer
3052        .forward_send(session.connection_id, peer_id.into(), request)?;
3053    Ok(())
3054}
3055
3056/// Notify other participants that a buffer has been updated. This is
3057/// allowed for guests as long as the update is limited to selections.
3058async fn update_buffer(
3059    request: proto::UpdateBuffer,
3060    response: Response<proto::UpdateBuffer>,
3061    session: Session,
3062) -> Result<()> {
3063    let project_id = ProjectId::from_proto(request.project_id);
3064    let mut capability = Capability::ReadOnly;
3065
3066    for op in request.operations.iter() {
3067        match op.variant {
3068            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3069            Some(_) => capability = Capability::ReadWrite,
3070        }
3071    }
3072
3073    let host = {
3074        let guard = session
3075            .db()
3076            .await
3077            .connections_for_buffer_update(
3078                project_id,
3079                session.principal_id(),
3080                session.connection_id,
3081                capability,
3082            )
3083            .await?;
3084
3085        let (host, guests) = &*guard;
3086
3087        broadcast(
3088            Some(session.connection_id),
3089            guests.clone(),
3090            |connection_id| {
3091                session
3092                    .peer
3093                    .forward_send(session.connection_id, connection_id, request.clone())
3094            },
3095        );
3096
3097        *host
3098    };
3099
3100    if host != session.connection_id {
3101        session
3102            .peer
3103            .forward_request(session.connection_id, host, request.clone())
3104            .await?;
3105    }
3106
3107    response.send(proto::Ack {})?;
3108    Ok(())
3109}
3110
3111async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3112    let project_id = ProjectId::from_proto(message.project_id);
3113
3114    let operation = message.operation.as_ref().context("invalid operation")?;
3115    let capability = match operation.variant.as_ref() {
3116        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3117            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3118                match buffer_op.variant {
3119                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3120                        Capability::ReadOnly
3121                    }
3122                    _ => Capability::ReadWrite,
3123                }
3124            } else {
3125                Capability::ReadWrite
3126            }
3127        }
3128        Some(_) => Capability::ReadWrite,
3129        None => Capability::ReadOnly,
3130    };
3131
3132    let guard = session
3133        .db()
3134        .await
3135        .connections_for_buffer_update(
3136            project_id,
3137            session.principal_id(),
3138            session.connection_id,
3139            capability,
3140        )
3141        .await?;
3142
3143    let (host, guests) = &*guard;
3144
3145    broadcast(
3146        Some(session.connection_id),
3147        guests.iter().chain([host]).copied(),
3148        |connection_id| {
3149            session
3150                .peer
3151                .forward_send(session.connection_id, connection_id, message.clone())
3152        },
3153    );
3154
3155    Ok(())
3156}
3157
3158/// Notify other participants that a project has been updated.
3159async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3160    request: T,
3161    session: Session,
3162) -> Result<()> {
3163    let project_id = ProjectId::from_proto(request.remote_entity_id());
3164    let project_connection_ids = session
3165        .db()
3166        .await
3167        .project_connection_ids(project_id, session.connection_id, false)
3168        .await?;
3169
3170    broadcast(
3171        Some(session.connection_id),
3172        project_connection_ids.iter().copied(),
3173        |connection_id| {
3174            session
3175                .peer
3176                .forward_send(session.connection_id, connection_id, request.clone())
3177        },
3178    );
3179    Ok(())
3180}
3181
3182/// Start following another user in a call.
3183async fn follow(
3184    request: proto::Follow,
3185    response: Response<proto::Follow>,
3186    session: UserSession,
3187) -> Result<()> {
3188    let room_id = RoomId::from_proto(request.room_id);
3189    let project_id = request.project_id.map(ProjectId::from_proto);
3190    let leader_id = request
3191        .leader_id
3192        .ok_or_else(|| anyhow!("invalid leader id"))?
3193        .into();
3194    let follower_id = session.connection_id;
3195
3196    session
3197        .db()
3198        .await
3199        .check_room_participants(room_id, leader_id, session.connection_id)
3200        .await?;
3201
3202    let response_payload = session
3203        .peer
3204        .forward_request(session.connection_id, leader_id, request)
3205        .await?;
3206    response.send(response_payload)?;
3207
3208    if let Some(project_id) = project_id {
3209        let room = session
3210            .db()
3211            .await
3212            .follow(room_id, project_id, leader_id, follower_id)
3213            .await?;
3214        room_updated(&room, &session.peer);
3215    }
3216
3217    Ok(())
3218}
3219
3220/// Stop following another user in a call.
3221async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3222    let room_id = RoomId::from_proto(request.room_id);
3223    let project_id = request.project_id.map(ProjectId::from_proto);
3224    let leader_id = request
3225        .leader_id
3226        .ok_or_else(|| anyhow!("invalid leader id"))?
3227        .into();
3228    let follower_id = session.connection_id;
3229
3230    session
3231        .db()
3232        .await
3233        .check_room_participants(room_id, leader_id, session.connection_id)
3234        .await?;
3235
3236    session
3237        .peer
3238        .forward_send(session.connection_id, leader_id, request)?;
3239
3240    if let Some(project_id) = project_id {
3241        let room = session
3242            .db()
3243            .await
3244            .unfollow(room_id, project_id, leader_id, follower_id)
3245            .await?;
3246        room_updated(&room, &session.peer);
3247    }
3248
3249    Ok(())
3250}
3251
3252/// Notify everyone following you of your current location.
3253async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3254    let room_id = RoomId::from_proto(request.room_id);
3255    let database = session.db.lock().await;
3256
3257    let connection_ids = if let Some(project_id) = request.project_id {
3258        let project_id = ProjectId::from_proto(project_id);
3259        database
3260            .project_connection_ids(project_id, session.connection_id, true)
3261            .await?
3262    } else {
3263        database
3264            .room_connection_ids(room_id, session.connection_id)
3265            .await?
3266    };
3267
3268    // For now, don't send view update messages back to that view's current leader.
3269    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3270        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3271        _ => None,
3272    });
3273
3274    for connection_id in connection_ids.iter().cloned() {
3275        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3276            session
3277                .peer
3278                .forward_send(session.connection_id, connection_id, request.clone())?;
3279        }
3280    }
3281    Ok(())
3282}
3283
3284/// Get public data about users.
3285async fn get_users(
3286    request: proto::GetUsers,
3287    response: Response<proto::GetUsers>,
3288    session: Session,
3289) -> Result<()> {
3290    let user_ids = request
3291        .user_ids
3292        .into_iter()
3293        .map(UserId::from_proto)
3294        .collect();
3295    let users = session
3296        .db()
3297        .await
3298        .get_users_by_ids(user_ids)
3299        .await?
3300        .into_iter()
3301        .map(|user| proto::User {
3302            id: user.id.to_proto(),
3303            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3304            github_login: user.github_login,
3305        })
3306        .collect();
3307    response.send(proto::UsersResponse { users })?;
3308    Ok(())
3309}
3310
3311/// Search for users (to invite) buy Github login
3312async fn fuzzy_search_users(
3313    request: proto::FuzzySearchUsers,
3314    response: Response<proto::FuzzySearchUsers>,
3315    session: UserSession,
3316) -> Result<()> {
3317    let query = request.query;
3318    let users = match query.len() {
3319        0 => vec![],
3320        1 | 2 => session
3321            .db()
3322            .await
3323            .get_user_by_github_login(&query)
3324            .await?
3325            .into_iter()
3326            .collect(),
3327        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3328    };
3329    let users = users
3330        .into_iter()
3331        .filter(|user| user.id != session.user_id())
3332        .map(|user| proto::User {
3333            id: user.id.to_proto(),
3334            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3335            github_login: user.github_login,
3336        })
3337        .collect();
3338    response.send(proto::UsersResponse { users })?;
3339    Ok(())
3340}
3341
3342/// Send a contact request to another user.
3343async fn request_contact(
3344    request: proto::RequestContact,
3345    response: Response<proto::RequestContact>,
3346    session: UserSession,
3347) -> Result<()> {
3348    let requester_id = session.user_id();
3349    let responder_id = UserId::from_proto(request.responder_id);
3350    if requester_id == responder_id {
3351        return Err(anyhow!("cannot add yourself as a contact"))?;
3352    }
3353
3354    let notifications = session
3355        .db()
3356        .await
3357        .send_contact_request(requester_id, responder_id)
3358        .await?;
3359
3360    // Update outgoing contact requests of requester
3361    let mut update = proto::UpdateContacts::default();
3362    update.outgoing_requests.push(responder_id.to_proto());
3363    for connection_id in session
3364        .connection_pool()
3365        .await
3366        .user_connection_ids(requester_id)
3367    {
3368        session.peer.send(connection_id, update.clone())?;
3369    }
3370
3371    // Update incoming contact requests of responder
3372    let mut update = proto::UpdateContacts::default();
3373    update
3374        .incoming_requests
3375        .push(proto::IncomingContactRequest {
3376            requester_id: requester_id.to_proto(),
3377        });
3378    let connection_pool = session.connection_pool().await;
3379    for connection_id in connection_pool.user_connection_ids(responder_id) {
3380        session.peer.send(connection_id, update.clone())?;
3381    }
3382
3383    send_notifications(&connection_pool, &session.peer, notifications);
3384
3385    response.send(proto::Ack {})?;
3386    Ok(())
3387}
3388
3389/// Accept or decline a contact request
3390async fn respond_to_contact_request(
3391    request: proto::RespondToContactRequest,
3392    response: Response<proto::RespondToContactRequest>,
3393    session: UserSession,
3394) -> Result<()> {
3395    let responder_id = session.user_id();
3396    let requester_id = UserId::from_proto(request.requester_id);
3397    let db = session.db().await;
3398    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3399        db.dismiss_contact_notification(responder_id, requester_id)
3400            .await?;
3401    } else {
3402        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3403
3404        let notifications = db
3405            .respond_to_contact_request(responder_id, requester_id, accept)
3406            .await?;
3407        let requester_busy = db.is_user_busy(requester_id).await?;
3408        let responder_busy = db.is_user_busy(responder_id).await?;
3409
3410        let pool = session.connection_pool().await;
3411        // Update responder with new contact
3412        let mut update = proto::UpdateContacts::default();
3413        if accept {
3414            update
3415                .contacts
3416                .push(contact_for_user(requester_id, requester_busy, &pool));
3417        }
3418        update
3419            .remove_incoming_requests
3420            .push(requester_id.to_proto());
3421        for connection_id in pool.user_connection_ids(responder_id) {
3422            session.peer.send(connection_id, update.clone())?;
3423        }
3424
3425        // Update requester with new contact
3426        let mut update = proto::UpdateContacts::default();
3427        if accept {
3428            update
3429                .contacts
3430                .push(contact_for_user(responder_id, responder_busy, &pool));
3431        }
3432        update
3433            .remove_outgoing_requests
3434            .push(responder_id.to_proto());
3435
3436        for connection_id in pool.user_connection_ids(requester_id) {
3437            session.peer.send(connection_id, update.clone())?;
3438        }
3439
3440        send_notifications(&pool, &session.peer, notifications);
3441    }
3442
3443    response.send(proto::Ack {})?;
3444    Ok(())
3445}
3446
3447/// Remove a contact.
3448async fn remove_contact(
3449    request: proto::RemoveContact,
3450    response: Response<proto::RemoveContact>,
3451    session: UserSession,
3452) -> Result<()> {
3453    let requester_id = session.user_id();
3454    let responder_id = UserId::from_proto(request.user_id);
3455    let db = session.db().await;
3456    let (contact_accepted, deleted_notification_id) =
3457        db.remove_contact(requester_id, responder_id).await?;
3458
3459    let pool = session.connection_pool().await;
3460    // Update outgoing contact requests of requester
3461    let mut update = proto::UpdateContacts::default();
3462    if contact_accepted {
3463        update.remove_contacts.push(responder_id.to_proto());
3464    } else {
3465        update
3466            .remove_outgoing_requests
3467            .push(responder_id.to_proto());
3468    }
3469    for connection_id in pool.user_connection_ids(requester_id) {
3470        session.peer.send(connection_id, update.clone())?;
3471    }
3472
3473    // Update incoming contact requests of responder
3474    let mut update = proto::UpdateContacts::default();
3475    if contact_accepted {
3476        update.remove_contacts.push(requester_id.to_proto());
3477    } else {
3478        update
3479            .remove_incoming_requests
3480            .push(requester_id.to_proto());
3481    }
3482    for connection_id in pool.user_connection_ids(responder_id) {
3483        session.peer.send(connection_id, update.clone())?;
3484        if let Some(notification_id) = deleted_notification_id {
3485            session.peer.send(
3486                connection_id,
3487                proto::DeleteNotification {
3488                    notification_id: notification_id.to_proto(),
3489                },
3490            )?;
3491        }
3492    }
3493
3494    response.send(proto::Ack {})?;
3495    Ok(())
3496}
3497
3498fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3499    version.0.minor() < 139
3500}
3501
3502async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3503    let plan = session.current_plan().await?;
3504
3505    session
3506        .peer
3507        .send(
3508            session.connection_id,
3509            proto::UpdateUserPlan { plan: plan.into() },
3510        )
3511        .trace_err();
3512
3513    Ok(())
3514}
3515
3516async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3517    subscribe_user_to_channels(
3518        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3519        &session,
3520    )
3521    .await?;
3522    Ok(())
3523}
3524
3525async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3526    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3527    let mut pool = session.connection_pool().await;
3528    for membership in &channels_for_user.channel_memberships {
3529        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3530    }
3531    session.peer.send(
3532        session.connection_id,
3533        build_update_user_channels(&channels_for_user),
3534    )?;
3535    session.peer.send(
3536        session.connection_id,
3537        build_channels_update(channels_for_user),
3538    )?;
3539    Ok(())
3540}
3541
3542/// Creates a new channel.
3543async fn create_channel(
3544    request: proto::CreateChannel,
3545    response: Response<proto::CreateChannel>,
3546    session: UserSession,
3547) -> Result<()> {
3548    let db = session.db().await;
3549
3550    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3551    let (channel, membership) = db
3552        .create_channel(&request.name, parent_id, session.user_id())
3553        .await?;
3554
3555    let root_id = channel.root_id();
3556    let channel = Channel::from_model(channel);
3557
3558    response.send(proto::CreateChannelResponse {
3559        channel: Some(channel.to_proto()),
3560        parent_id: request.parent_id,
3561    })?;
3562
3563    let mut connection_pool = session.connection_pool().await;
3564    if let Some(membership) = membership {
3565        connection_pool.subscribe_to_channel(
3566            membership.user_id,
3567            membership.channel_id,
3568            membership.role,
3569        );
3570        let update = proto::UpdateUserChannels {
3571            channel_memberships: vec![proto::ChannelMembership {
3572                channel_id: membership.channel_id.to_proto(),
3573                role: membership.role.into(),
3574            }],
3575            ..Default::default()
3576        };
3577        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3578            session.peer.send(connection_id, update.clone())?;
3579        }
3580    }
3581
3582    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3583        if !role.can_see_channel(channel.visibility) {
3584            continue;
3585        }
3586
3587        let update = proto::UpdateChannels {
3588            channels: vec![channel.to_proto()],
3589            ..Default::default()
3590        };
3591        session.peer.send(connection_id, update.clone())?;
3592    }
3593
3594    Ok(())
3595}
3596
3597/// Delete a channel
3598async fn delete_channel(
3599    request: proto::DeleteChannel,
3600    response: Response<proto::DeleteChannel>,
3601    session: UserSession,
3602) -> Result<()> {
3603    let db = session.db().await;
3604
3605    let channel_id = request.channel_id;
3606    let (root_channel, removed_channels) = db
3607        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3608        .await?;
3609    response.send(proto::Ack {})?;
3610
3611    // Notify members of removed channels
3612    let mut update = proto::UpdateChannels::default();
3613    update
3614        .delete_channels
3615        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3616
3617    let connection_pool = session.connection_pool().await;
3618    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3619        session.peer.send(connection_id, update.clone())?;
3620    }
3621
3622    Ok(())
3623}
3624
3625/// Invite someone to join a channel.
3626async fn invite_channel_member(
3627    request: proto::InviteChannelMember,
3628    response: Response<proto::InviteChannelMember>,
3629    session: UserSession,
3630) -> Result<()> {
3631    let db = session.db().await;
3632    let channel_id = ChannelId::from_proto(request.channel_id);
3633    let invitee_id = UserId::from_proto(request.user_id);
3634    let InviteMemberResult {
3635        channel,
3636        notifications,
3637    } = db
3638        .invite_channel_member(
3639            channel_id,
3640            invitee_id,
3641            session.user_id(),
3642            request.role().into(),
3643        )
3644        .await?;
3645
3646    let update = proto::UpdateChannels {
3647        channel_invitations: vec![channel.to_proto()],
3648        ..Default::default()
3649    };
3650
3651    let connection_pool = session.connection_pool().await;
3652    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3653        session.peer.send(connection_id, update.clone())?;
3654    }
3655
3656    send_notifications(&connection_pool, &session.peer, notifications);
3657
3658    response.send(proto::Ack {})?;
3659    Ok(())
3660}
3661
3662/// remove someone from a channel
3663async fn remove_channel_member(
3664    request: proto::RemoveChannelMember,
3665    response: Response<proto::RemoveChannelMember>,
3666    session: UserSession,
3667) -> Result<()> {
3668    let db = session.db().await;
3669    let channel_id = ChannelId::from_proto(request.channel_id);
3670    let member_id = UserId::from_proto(request.user_id);
3671
3672    let RemoveChannelMemberResult {
3673        membership_update,
3674        notification_id,
3675    } = db
3676        .remove_channel_member(channel_id, member_id, session.user_id())
3677        .await?;
3678
3679    let mut connection_pool = session.connection_pool().await;
3680    notify_membership_updated(
3681        &mut connection_pool,
3682        membership_update,
3683        member_id,
3684        &session.peer,
3685    );
3686    for connection_id in connection_pool.user_connection_ids(member_id) {
3687        if let Some(notification_id) = notification_id {
3688            session
3689                .peer
3690                .send(
3691                    connection_id,
3692                    proto::DeleteNotification {
3693                        notification_id: notification_id.to_proto(),
3694                    },
3695                )
3696                .trace_err();
3697        }
3698    }
3699
3700    response.send(proto::Ack {})?;
3701    Ok(())
3702}
3703
3704/// Toggle the channel between public and private.
3705/// Care is taken to maintain the invariant that public channels only descend from public channels,
3706/// (though members-only channels can appear at any point in the hierarchy).
3707async fn set_channel_visibility(
3708    request: proto::SetChannelVisibility,
3709    response: Response<proto::SetChannelVisibility>,
3710    session: UserSession,
3711) -> Result<()> {
3712    let db = session.db().await;
3713    let channel_id = ChannelId::from_proto(request.channel_id);
3714    let visibility = request.visibility().into();
3715
3716    let channel_model = db
3717        .set_channel_visibility(channel_id, visibility, session.user_id())
3718        .await?;
3719    let root_id = channel_model.root_id();
3720    let channel = Channel::from_model(channel_model);
3721
3722    let mut connection_pool = session.connection_pool().await;
3723    for (user_id, role) in connection_pool
3724        .channel_user_ids(root_id)
3725        .collect::<Vec<_>>()
3726        .into_iter()
3727    {
3728        let update = if role.can_see_channel(channel.visibility) {
3729            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3730            proto::UpdateChannels {
3731                channels: vec![channel.to_proto()],
3732                ..Default::default()
3733            }
3734        } else {
3735            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3736            proto::UpdateChannels {
3737                delete_channels: vec![channel.id.to_proto()],
3738                ..Default::default()
3739            }
3740        };
3741
3742        for connection_id in connection_pool.user_connection_ids(user_id) {
3743            session.peer.send(connection_id, update.clone())?;
3744        }
3745    }
3746
3747    response.send(proto::Ack {})?;
3748    Ok(())
3749}
3750
3751/// Alter the role for a user in the channel.
3752async fn set_channel_member_role(
3753    request: proto::SetChannelMemberRole,
3754    response: Response<proto::SetChannelMemberRole>,
3755    session: UserSession,
3756) -> Result<()> {
3757    let db = session.db().await;
3758    let channel_id = ChannelId::from_proto(request.channel_id);
3759    let member_id = UserId::from_proto(request.user_id);
3760    let result = db
3761        .set_channel_member_role(
3762            channel_id,
3763            session.user_id(),
3764            member_id,
3765            request.role().into(),
3766        )
3767        .await?;
3768
3769    match result {
3770        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3771            let mut connection_pool = session.connection_pool().await;
3772            notify_membership_updated(
3773                &mut connection_pool,
3774                membership_update,
3775                member_id,
3776                &session.peer,
3777            )
3778        }
3779        db::SetMemberRoleResult::InviteUpdated(channel) => {
3780            let update = proto::UpdateChannels {
3781                channel_invitations: vec![channel.to_proto()],
3782                ..Default::default()
3783            };
3784
3785            for connection_id in session
3786                .connection_pool()
3787                .await
3788                .user_connection_ids(member_id)
3789            {
3790                session.peer.send(connection_id, update.clone())?;
3791            }
3792        }
3793    }
3794
3795    response.send(proto::Ack {})?;
3796    Ok(())
3797}
3798
3799/// Change the name of a channel
3800async fn rename_channel(
3801    request: proto::RenameChannel,
3802    response: Response<proto::RenameChannel>,
3803    session: UserSession,
3804) -> Result<()> {
3805    let db = session.db().await;
3806    let channel_id = ChannelId::from_proto(request.channel_id);
3807    let channel_model = db
3808        .rename_channel(channel_id, session.user_id(), &request.name)
3809        .await?;
3810    let root_id = channel_model.root_id();
3811    let channel = Channel::from_model(channel_model);
3812
3813    response.send(proto::RenameChannelResponse {
3814        channel: Some(channel.to_proto()),
3815    })?;
3816
3817    let connection_pool = session.connection_pool().await;
3818    let update = proto::UpdateChannels {
3819        channels: vec![channel.to_proto()],
3820        ..Default::default()
3821    };
3822    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3823        if role.can_see_channel(channel.visibility) {
3824            session.peer.send(connection_id, update.clone())?;
3825        }
3826    }
3827
3828    Ok(())
3829}
3830
3831/// Move a channel to a new parent.
3832async fn move_channel(
3833    request: proto::MoveChannel,
3834    response: Response<proto::MoveChannel>,
3835    session: UserSession,
3836) -> Result<()> {
3837    let channel_id = ChannelId::from_proto(request.channel_id);
3838    let to = ChannelId::from_proto(request.to);
3839
3840    let (root_id, channels) = session
3841        .db()
3842        .await
3843        .move_channel(channel_id, to, session.user_id())
3844        .await?;
3845
3846    let connection_pool = session.connection_pool().await;
3847    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3848        let channels = channels
3849            .iter()
3850            .filter_map(|channel| {
3851                if role.can_see_channel(channel.visibility) {
3852                    Some(channel.to_proto())
3853                } else {
3854                    None
3855                }
3856            })
3857            .collect::<Vec<_>>();
3858        if channels.is_empty() {
3859            continue;
3860        }
3861
3862        let update = proto::UpdateChannels {
3863            channels,
3864            ..Default::default()
3865        };
3866
3867        session.peer.send(connection_id, update.clone())?;
3868    }
3869
3870    response.send(Ack {})?;
3871    Ok(())
3872}
3873
3874/// Get the list of channel members
3875async fn get_channel_members(
3876    request: proto::GetChannelMembers,
3877    response: Response<proto::GetChannelMembers>,
3878    session: UserSession,
3879) -> Result<()> {
3880    let db = session.db().await;
3881    let channel_id = ChannelId::from_proto(request.channel_id);
3882    let limit = if request.limit == 0 {
3883        u16::MAX as u64
3884    } else {
3885        request.limit
3886    };
3887    let (members, users) = db
3888        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3889        .await?;
3890    response.send(proto::GetChannelMembersResponse { members, users })?;
3891    Ok(())
3892}
3893
3894/// Accept or decline a channel invitation.
3895async fn respond_to_channel_invite(
3896    request: proto::RespondToChannelInvite,
3897    response: Response<proto::RespondToChannelInvite>,
3898    session: UserSession,
3899) -> Result<()> {
3900    let db = session.db().await;
3901    let channel_id = ChannelId::from_proto(request.channel_id);
3902    let RespondToChannelInvite {
3903        membership_update,
3904        notifications,
3905    } = db
3906        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3907        .await?;
3908
3909    let mut connection_pool = session.connection_pool().await;
3910    if let Some(membership_update) = membership_update {
3911        notify_membership_updated(
3912            &mut connection_pool,
3913            membership_update,
3914            session.user_id(),
3915            &session.peer,
3916        );
3917    } else {
3918        let update = proto::UpdateChannels {
3919            remove_channel_invitations: vec![channel_id.to_proto()],
3920            ..Default::default()
3921        };
3922
3923        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3924            session.peer.send(connection_id, update.clone())?;
3925        }
3926    };
3927
3928    send_notifications(&connection_pool, &session.peer, notifications);
3929
3930    response.send(proto::Ack {})?;
3931
3932    Ok(())
3933}
3934
3935/// Join the channels' room
3936async fn join_channel(
3937    request: proto::JoinChannel,
3938    response: Response<proto::JoinChannel>,
3939    session: UserSession,
3940) -> Result<()> {
3941    let channel_id = ChannelId::from_proto(request.channel_id);
3942    join_channel_internal(channel_id, Box::new(response), session).await
3943}
3944
3945trait JoinChannelInternalResponse {
3946    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3947}
3948impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3949    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3950        Response::<proto::JoinChannel>::send(self, result)
3951    }
3952}
3953impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3954    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3955        Response::<proto::JoinRoom>::send(self, result)
3956    }
3957}
3958
3959async fn join_channel_internal(
3960    channel_id: ChannelId,
3961    response: Box<impl JoinChannelInternalResponse>,
3962    session: UserSession,
3963) -> Result<()> {
3964    let joined_room = {
3965        let mut db = session.db().await;
3966        // If zed quits without leaving the room, and the user re-opens zed before the
3967        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3968        // room they were in.
3969        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3970            tracing::info!(
3971                stale_connection_id = %connection,
3972                "cleaning up stale connection",
3973            );
3974            drop(db);
3975            leave_room_for_session(&session, connection).await?;
3976            db = session.db().await;
3977        }
3978
3979        let (joined_room, membership_updated, role) = db
3980            .join_channel(channel_id, session.user_id(), session.connection_id)
3981            .await?;
3982
3983        let live_kit_connection_info =
3984            session
3985                .app_state
3986                .live_kit_client
3987                .as_ref()
3988                .and_then(|live_kit| {
3989                    let (can_publish, token) = if role == ChannelRole::Guest {
3990                        (
3991                            false,
3992                            live_kit
3993                                .guest_token(
3994                                    &joined_room.room.live_kit_room,
3995                                    &session.user_id().to_string(),
3996                                )
3997                                .trace_err()?,
3998                        )
3999                    } else {
4000                        (
4001                            true,
4002                            live_kit
4003                                .room_token(
4004                                    &joined_room.room.live_kit_room,
4005                                    &session.user_id().to_string(),
4006                                )
4007                                .trace_err()?,
4008                        )
4009                    };
4010
4011                    Some(LiveKitConnectionInfo {
4012                        server_url: live_kit.url().into(),
4013                        token,
4014                        can_publish,
4015                    })
4016                });
4017
4018        response.send(proto::JoinRoomResponse {
4019            room: Some(joined_room.room.clone()),
4020            channel_id: joined_room
4021                .channel
4022                .as_ref()
4023                .map(|channel| channel.id.to_proto()),
4024            live_kit_connection_info,
4025        })?;
4026
4027        let mut connection_pool = session.connection_pool().await;
4028        if let Some(membership_updated) = membership_updated {
4029            notify_membership_updated(
4030                &mut connection_pool,
4031                membership_updated,
4032                session.user_id(),
4033                &session.peer,
4034            );
4035        }
4036
4037        room_updated(&joined_room.room, &session.peer);
4038
4039        joined_room
4040    };
4041
4042    channel_updated(
4043        &joined_room
4044            .channel
4045            .ok_or_else(|| anyhow!("channel not returned"))?,
4046        &joined_room.room,
4047        &session.peer,
4048        &*session.connection_pool().await,
4049    );
4050
4051    update_user_contacts(session.user_id(), &session).await?;
4052    Ok(())
4053}
4054
4055/// Start editing the channel notes
4056async fn join_channel_buffer(
4057    request: proto::JoinChannelBuffer,
4058    response: Response<proto::JoinChannelBuffer>,
4059    session: UserSession,
4060) -> Result<()> {
4061    let db = session.db().await;
4062    let channel_id = ChannelId::from_proto(request.channel_id);
4063
4064    let open_response = db
4065        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4066        .await?;
4067
4068    let collaborators = open_response.collaborators.clone();
4069    response.send(open_response)?;
4070
4071    let update = UpdateChannelBufferCollaborators {
4072        channel_id: channel_id.to_proto(),
4073        collaborators: collaborators.clone(),
4074    };
4075    channel_buffer_updated(
4076        session.connection_id,
4077        collaborators
4078            .iter()
4079            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4080        &update,
4081        &session.peer,
4082    );
4083
4084    Ok(())
4085}
4086
4087/// Edit the channel notes
4088async fn update_channel_buffer(
4089    request: proto::UpdateChannelBuffer,
4090    session: UserSession,
4091) -> Result<()> {
4092    let db = session.db().await;
4093    let channel_id = ChannelId::from_proto(request.channel_id);
4094
4095    let (collaborators, epoch, version) = db
4096        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4097        .await?;
4098
4099    channel_buffer_updated(
4100        session.connection_id,
4101        collaborators.clone(),
4102        &proto::UpdateChannelBuffer {
4103            channel_id: channel_id.to_proto(),
4104            operations: request.operations,
4105        },
4106        &session.peer,
4107    );
4108
4109    let pool = &*session.connection_pool().await;
4110
4111    let non_collaborators =
4112        pool.channel_connection_ids(channel_id)
4113            .filter_map(|(connection_id, _)| {
4114                if collaborators.contains(&connection_id) {
4115                    None
4116                } else {
4117                    Some(connection_id)
4118                }
4119            });
4120
4121    broadcast(None, non_collaborators, |peer_id| {
4122        session.peer.send(
4123            peer_id,
4124            proto::UpdateChannels {
4125                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4126                    channel_id: channel_id.to_proto(),
4127                    epoch: epoch as u64,
4128                    version: version.clone(),
4129                }],
4130                ..Default::default()
4131            },
4132        )
4133    });
4134
4135    Ok(())
4136}
4137
4138/// Rejoin the channel notes after a connection blip
4139async fn rejoin_channel_buffers(
4140    request: proto::RejoinChannelBuffers,
4141    response: Response<proto::RejoinChannelBuffers>,
4142    session: UserSession,
4143) -> Result<()> {
4144    let db = session.db().await;
4145    let buffers = db
4146        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4147        .await?;
4148
4149    for rejoined_buffer in &buffers {
4150        let collaborators_to_notify = rejoined_buffer
4151            .buffer
4152            .collaborators
4153            .iter()
4154            .filter_map(|c| Some(c.peer_id?.into()));
4155        channel_buffer_updated(
4156            session.connection_id,
4157            collaborators_to_notify,
4158            &proto::UpdateChannelBufferCollaborators {
4159                channel_id: rejoined_buffer.buffer.channel_id,
4160                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4161            },
4162            &session.peer,
4163        );
4164    }
4165
4166    response.send(proto::RejoinChannelBuffersResponse {
4167        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4168    })?;
4169
4170    Ok(())
4171}
4172
4173/// Stop editing the channel notes
4174async fn leave_channel_buffer(
4175    request: proto::LeaveChannelBuffer,
4176    response: Response<proto::LeaveChannelBuffer>,
4177    session: UserSession,
4178) -> Result<()> {
4179    let db = session.db().await;
4180    let channel_id = ChannelId::from_proto(request.channel_id);
4181
4182    let left_buffer = db
4183        .leave_channel_buffer(channel_id, session.connection_id)
4184        .await?;
4185
4186    response.send(Ack {})?;
4187
4188    channel_buffer_updated(
4189        session.connection_id,
4190        left_buffer.connections,
4191        &proto::UpdateChannelBufferCollaborators {
4192            channel_id: channel_id.to_proto(),
4193            collaborators: left_buffer.collaborators,
4194        },
4195        &session.peer,
4196    );
4197
4198    Ok(())
4199}
4200
4201fn channel_buffer_updated<T: EnvelopedMessage>(
4202    sender_id: ConnectionId,
4203    collaborators: impl IntoIterator<Item = ConnectionId>,
4204    message: &T,
4205    peer: &Peer,
4206) {
4207    broadcast(Some(sender_id), collaborators, |peer_id| {
4208        peer.send(peer_id, message.clone())
4209    });
4210}
4211
4212fn send_notifications(
4213    connection_pool: &ConnectionPool,
4214    peer: &Peer,
4215    notifications: db::NotificationBatch,
4216) {
4217    for (user_id, notification) in notifications {
4218        for connection_id in connection_pool.user_connection_ids(user_id) {
4219            if let Err(error) = peer.send(
4220                connection_id,
4221                proto::AddNotification {
4222                    notification: Some(notification.clone()),
4223                },
4224            ) {
4225                tracing::error!(
4226                    "failed to send notification to {:?} {}",
4227                    connection_id,
4228                    error
4229                );
4230            }
4231        }
4232    }
4233}
4234
4235/// Send a message to the channel
4236async fn send_channel_message(
4237    request: proto::SendChannelMessage,
4238    response: Response<proto::SendChannelMessage>,
4239    session: UserSession,
4240) -> Result<()> {
4241    // Validate the message body.
4242    let body = request.body.trim().to_string();
4243    if body.len() > MAX_MESSAGE_LEN {
4244        return Err(anyhow!("message is too long"))?;
4245    }
4246    if body.is_empty() {
4247        return Err(anyhow!("message can't be blank"))?;
4248    }
4249
4250    // TODO: adjust mentions if body is trimmed
4251
4252    let timestamp = OffsetDateTime::now_utc();
4253    let nonce = request
4254        .nonce
4255        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4256
4257    let channel_id = ChannelId::from_proto(request.channel_id);
4258    let CreatedChannelMessage {
4259        message_id,
4260        participant_connection_ids,
4261        notifications,
4262    } = session
4263        .db()
4264        .await
4265        .create_channel_message(
4266            channel_id,
4267            session.user_id(),
4268            &body,
4269            &request.mentions,
4270            timestamp,
4271            nonce.clone().into(),
4272            match request.reply_to_message_id {
4273                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4274                None => None,
4275            },
4276        )
4277        .await?;
4278
4279    let message = proto::ChannelMessage {
4280        sender_id: session.user_id().to_proto(),
4281        id: message_id.to_proto(),
4282        body,
4283        mentions: request.mentions,
4284        timestamp: timestamp.unix_timestamp() as u64,
4285        nonce: Some(nonce),
4286        reply_to_message_id: request.reply_to_message_id,
4287        edited_at: None,
4288    };
4289    broadcast(
4290        Some(session.connection_id),
4291        participant_connection_ids.clone(),
4292        |connection| {
4293            session.peer.send(
4294                connection,
4295                proto::ChannelMessageSent {
4296                    channel_id: channel_id.to_proto(),
4297                    message: Some(message.clone()),
4298                },
4299            )
4300        },
4301    );
4302    response.send(proto::SendChannelMessageResponse {
4303        message: Some(message),
4304    })?;
4305
4306    let pool = &*session.connection_pool().await;
4307    let non_participants =
4308        pool.channel_connection_ids(channel_id)
4309            .filter_map(|(connection_id, _)| {
4310                if participant_connection_ids.contains(&connection_id) {
4311                    None
4312                } else {
4313                    Some(connection_id)
4314                }
4315            });
4316    broadcast(None, non_participants, |peer_id| {
4317        session.peer.send(
4318            peer_id,
4319            proto::UpdateChannels {
4320                latest_channel_message_ids: vec![proto::ChannelMessageId {
4321                    channel_id: channel_id.to_proto(),
4322                    message_id: message_id.to_proto(),
4323                }],
4324                ..Default::default()
4325            },
4326        )
4327    });
4328    send_notifications(pool, &session.peer, notifications);
4329
4330    Ok(())
4331}
4332
4333/// Delete a channel message
4334async fn remove_channel_message(
4335    request: proto::RemoveChannelMessage,
4336    response: Response<proto::RemoveChannelMessage>,
4337    session: UserSession,
4338) -> Result<()> {
4339    let channel_id = ChannelId::from_proto(request.channel_id);
4340    let message_id = MessageId::from_proto(request.message_id);
4341    let (connection_ids, existing_notification_ids) = session
4342        .db()
4343        .await
4344        .remove_channel_message(channel_id, message_id, session.user_id())
4345        .await?;
4346
4347    broadcast(
4348        Some(session.connection_id),
4349        connection_ids,
4350        move |connection| {
4351            session.peer.send(connection, request.clone())?;
4352
4353            for notification_id in &existing_notification_ids {
4354                session.peer.send(
4355                    connection,
4356                    proto::DeleteNotification {
4357                        notification_id: (*notification_id).to_proto(),
4358                    },
4359                )?;
4360            }
4361
4362            Ok(())
4363        },
4364    );
4365    response.send(proto::Ack {})?;
4366    Ok(())
4367}
4368
4369async fn update_channel_message(
4370    request: proto::UpdateChannelMessage,
4371    response: Response<proto::UpdateChannelMessage>,
4372    session: UserSession,
4373) -> Result<()> {
4374    let channel_id = ChannelId::from_proto(request.channel_id);
4375    let message_id = MessageId::from_proto(request.message_id);
4376    let updated_at = OffsetDateTime::now_utc();
4377    let UpdatedChannelMessage {
4378        message_id,
4379        participant_connection_ids,
4380        notifications,
4381        reply_to_message_id,
4382        timestamp,
4383        deleted_mention_notification_ids,
4384        updated_mention_notifications,
4385    } = session
4386        .db()
4387        .await
4388        .update_channel_message(
4389            channel_id,
4390            message_id,
4391            session.user_id(),
4392            request.body.as_str(),
4393            &request.mentions,
4394            updated_at,
4395        )
4396        .await?;
4397
4398    let nonce = request
4399        .nonce
4400        .clone()
4401        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4402
4403    let message = proto::ChannelMessage {
4404        sender_id: session.user_id().to_proto(),
4405        id: message_id.to_proto(),
4406        body: request.body.clone(),
4407        mentions: request.mentions.clone(),
4408        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4409        nonce: Some(nonce),
4410        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4411        edited_at: Some(updated_at.unix_timestamp() as u64),
4412    };
4413
4414    response.send(proto::Ack {})?;
4415
4416    let pool = &*session.connection_pool().await;
4417    broadcast(
4418        Some(session.connection_id),
4419        participant_connection_ids,
4420        |connection| {
4421            session.peer.send(
4422                connection,
4423                proto::ChannelMessageUpdate {
4424                    channel_id: channel_id.to_proto(),
4425                    message: Some(message.clone()),
4426                },
4427            )?;
4428
4429            for notification_id in &deleted_mention_notification_ids {
4430                session.peer.send(
4431                    connection,
4432                    proto::DeleteNotification {
4433                        notification_id: (*notification_id).to_proto(),
4434                    },
4435                )?;
4436            }
4437
4438            for notification in &updated_mention_notifications {
4439                session.peer.send(
4440                    connection,
4441                    proto::UpdateNotification {
4442                        notification: Some(notification.clone()),
4443                    },
4444                )?;
4445            }
4446
4447            Ok(())
4448        },
4449    );
4450
4451    send_notifications(pool, &session.peer, notifications);
4452
4453    Ok(())
4454}
4455
4456/// Mark a channel message as read
4457async fn acknowledge_channel_message(
4458    request: proto::AckChannelMessage,
4459    session: UserSession,
4460) -> Result<()> {
4461    let channel_id = ChannelId::from_proto(request.channel_id);
4462    let message_id = MessageId::from_proto(request.message_id);
4463    let notifications = session
4464        .db()
4465        .await
4466        .observe_channel_message(channel_id, session.user_id(), message_id)
4467        .await?;
4468    send_notifications(
4469        &*session.connection_pool().await,
4470        &session.peer,
4471        notifications,
4472    );
4473    Ok(())
4474}
4475
4476/// Mark a buffer version as synced
4477async fn acknowledge_buffer_version(
4478    request: proto::AckBufferOperation,
4479    session: UserSession,
4480) -> Result<()> {
4481    let buffer_id = BufferId::from_proto(request.buffer_id);
4482    session
4483        .db()
4484        .await
4485        .observe_buffer_version(
4486            buffer_id,
4487            session.user_id(),
4488            request.epoch as i32,
4489            &request.version,
4490        )
4491        .await?;
4492    Ok(())
4493}
4494
4495async fn count_language_model_tokens(
4496    request: proto::CountLanguageModelTokens,
4497    response: Response<proto::CountLanguageModelTokens>,
4498    session: Session,
4499    config: &Config,
4500) -> Result<()> {
4501    let Some(session) = session.for_user() else {
4502        return Err(anyhow!("user not found"))?;
4503    };
4504    authorize_access_to_legacy_llm_endpoints(&session).await?;
4505
4506    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4507        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4508        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4509    };
4510
4511    session
4512        .app_state
4513        .rate_limiter
4514        .check(&*rate_limit, session.user_id())
4515        .await?;
4516
4517    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4518        Some(proto::LanguageModelProvider::Google) => {
4519            let api_key = config
4520                .google_ai_api_key
4521                .as_ref()
4522                .context("no Google AI API key configured on the server")?;
4523            google_ai::count_tokens(
4524                session.http_client.as_ref(),
4525                google_ai::API_URL,
4526                api_key,
4527                serde_json::from_str(&request.request)?,
4528            )
4529            .await?
4530        }
4531        _ => return Err(anyhow!("unsupported provider"))?,
4532    };
4533
4534    response.send(proto::CountLanguageModelTokensResponse {
4535        token_count: result.total_tokens as u32,
4536    })?;
4537
4538    Ok(())
4539}
4540
4541struct ZedProCountLanguageModelTokensRateLimit;
4542
4543impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4544    fn capacity(&self) -> usize {
4545        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4546            .ok()
4547            .and_then(|v| v.parse().ok())
4548            .unwrap_or(600) // Picked arbitrarily
4549    }
4550
4551    fn refill_duration(&self) -> chrono::Duration {
4552        chrono::Duration::hours(1)
4553    }
4554
4555    fn db_name(&self) -> &'static str {
4556        "zed-pro:count-language-model-tokens"
4557    }
4558}
4559
4560struct FreeCountLanguageModelTokensRateLimit;
4561
4562impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4563    fn capacity(&self) -> usize {
4564        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4565            .ok()
4566            .and_then(|v| v.parse().ok())
4567            .unwrap_or(600 / 10) // Picked arbitrarily
4568    }
4569
4570    fn refill_duration(&self) -> chrono::Duration {
4571        chrono::Duration::hours(1)
4572    }
4573
4574    fn db_name(&self) -> &'static str {
4575        "free:count-language-model-tokens"
4576    }
4577}
4578
4579struct ZedProComputeEmbeddingsRateLimit;
4580
4581impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4582    fn capacity(&self) -> usize {
4583        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4584            .ok()
4585            .and_then(|v| v.parse().ok())
4586            .unwrap_or(5000) // Picked arbitrarily
4587    }
4588
4589    fn refill_duration(&self) -> chrono::Duration {
4590        chrono::Duration::hours(1)
4591    }
4592
4593    fn db_name(&self) -> &'static str {
4594        "zed-pro:compute-embeddings"
4595    }
4596}
4597
4598struct FreeComputeEmbeddingsRateLimit;
4599
4600impl RateLimit for FreeComputeEmbeddingsRateLimit {
4601    fn capacity(&self) -> usize {
4602        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4603            .ok()
4604            .and_then(|v| v.parse().ok())
4605            .unwrap_or(5000 / 10) // Picked arbitrarily
4606    }
4607
4608    fn refill_duration(&self) -> chrono::Duration {
4609        chrono::Duration::hours(1)
4610    }
4611
4612    fn db_name(&self) -> &'static str {
4613        "free:compute-embeddings"
4614    }
4615}
4616
4617async fn compute_embeddings(
4618    request: proto::ComputeEmbeddings,
4619    response: Response<proto::ComputeEmbeddings>,
4620    session: UserSession,
4621    api_key: Option<Arc<str>>,
4622) -> Result<()> {
4623    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4624    authorize_access_to_legacy_llm_endpoints(&session).await?;
4625
4626    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4627        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4628        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4629    };
4630
4631    session
4632        .app_state
4633        .rate_limiter
4634        .check(&*rate_limit, session.user_id())
4635        .await?;
4636
4637    let embeddings = match request.model.as_str() {
4638        "openai/text-embedding-3-small" => {
4639            open_ai::embed(
4640                session.http_client.as_ref(),
4641                OPEN_AI_API_URL,
4642                &api_key,
4643                OpenAiEmbeddingModel::TextEmbedding3Small,
4644                request.texts.iter().map(|text| text.as_str()),
4645            )
4646            .await?
4647        }
4648        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4649    };
4650
4651    let embeddings = request
4652        .texts
4653        .iter()
4654        .map(|text| {
4655            let mut hasher = sha2::Sha256::new();
4656            hasher.update(text.as_bytes());
4657            let result = hasher.finalize();
4658            result.to_vec()
4659        })
4660        .zip(
4661            embeddings
4662                .data
4663                .into_iter()
4664                .map(|embedding| embedding.embedding),
4665        )
4666        .collect::<HashMap<_, _>>();
4667
4668    let db = session.db().await;
4669    db.save_embeddings(&request.model, &embeddings)
4670        .await
4671        .context("failed to save embeddings")
4672        .trace_err();
4673
4674    response.send(proto::ComputeEmbeddingsResponse {
4675        embeddings: embeddings
4676            .into_iter()
4677            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4678            .collect(),
4679    })?;
4680    Ok(())
4681}
4682
4683async fn get_cached_embeddings(
4684    request: proto::GetCachedEmbeddings,
4685    response: Response<proto::GetCachedEmbeddings>,
4686    session: UserSession,
4687) -> Result<()> {
4688    authorize_access_to_legacy_llm_endpoints(&session).await?;
4689
4690    let db = session.db().await;
4691    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4692
4693    response.send(proto::GetCachedEmbeddingsResponse {
4694        embeddings: embeddings
4695            .into_iter()
4696            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4697            .collect(),
4698    })?;
4699    Ok(())
4700}
4701
4702/// This is leftover from before the LLM service.
4703///
4704/// The endpoints protected by this check will be moved there eventually.
4705async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4706    if session.is_staff() {
4707        Ok(())
4708    } else {
4709        Err(anyhow!("permission denied"))?
4710    }
4711}
4712
4713/// Get a Supermaven API key for the user
4714async fn get_supermaven_api_key(
4715    _request: proto::GetSupermavenApiKey,
4716    response: Response<proto::GetSupermavenApiKey>,
4717    session: UserSession,
4718) -> Result<()> {
4719    let user_id: String = session.user_id().to_string();
4720    if !session.is_staff() {
4721        return Err(anyhow!("supermaven not enabled for this account"))?;
4722    }
4723
4724    let email = session
4725        .email()
4726        .ok_or_else(|| anyhow!("user must have an email"))?;
4727
4728    let supermaven_admin_api = session
4729        .supermaven_client
4730        .as_ref()
4731        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4732
4733    let result = supermaven_admin_api
4734        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4735        .await?;
4736
4737    response.send(proto::GetSupermavenApiKeyResponse {
4738        api_key: result.api_key,
4739    })?;
4740
4741    Ok(())
4742}
4743
4744/// Start receiving chat updates for a channel
4745async fn join_channel_chat(
4746    request: proto::JoinChannelChat,
4747    response: Response<proto::JoinChannelChat>,
4748    session: UserSession,
4749) -> Result<()> {
4750    let channel_id = ChannelId::from_proto(request.channel_id);
4751
4752    let db = session.db().await;
4753    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4754        .await?;
4755    let messages = db
4756        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4757        .await?;
4758    response.send(proto::JoinChannelChatResponse {
4759        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4760        messages,
4761    })?;
4762    Ok(())
4763}
4764
4765/// Stop receiving chat updates for a channel
4766async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4767    let channel_id = ChannelId::from_proto(request.channel_id);
4768    session
4769        .db()
4770        .await
4771        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4772        .await?;
4773    Ok(())
4774}
4775
4776/// Retrieve the chat history for a channel
4777async fn get_channel_messages(
4778    request: proto::GetChannelMessages,
4779    response: Response<proto::GetChannelMessages>,
4780    session: UserSession,
4781) -> Result<()> {
4782    let channel_id = ChannelId::from_proto(request.channel_id);
4783    let messages = session
4784        .db()
4785        .await
4786        .get_channel_messages(
4787            channel_id,
4788            session.user_id(),
4789            MESSAGE_COUNT_PER_PAGE,
4790            Some(MessageId::from_proto(request.before_message_id)),
4791        )
4792        .await?;
4793    response.send(proto::GetChannelMessagesResponse {
4794        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4795        messages,
4796    })?;
4797    Ok(())
4798}
4799
4800/// Retrieve specific chat messages
4801async fn get_channel_messages_by_id(
4802    request: proto::GetChannelMessagesById,
4803    response: Response<proto::GetChannelMessagesById>,
4804    session: UserSession,
4805) -> Result<()> {
4806    let message_ids = request
4807        .message_ids
4808        .iter()
4809        .map(|id| MessageId::from_proto(*id))
4810        .collect::<Vec<_>>();
4811    let messages = session
4812        .db()
4813        .await
4814        .get_channel_messages_by_id(session.user_id(), &message_ids)
4815        .await?;
4816    response.send(proto::GetChannelMessagesResponse {
4817        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4818        messages,
4819    })?;
4820    Ok(())
4821}
4822
4823/// Retrieve the current users notifications
4824async fn get_notifications(
4825    request: proto::GetNotifications,
4826    response: Response<proto::GetNotifications>,
4827    session: UserSession,
4828) -> Result<()> {
4829    let notifications = session
4830        .db()
4831        .await
4832        .get_notifications(
4833            session.user_id(),
4834            NOTIFICATION_COUNT_PER_PAGE,
4835            request
4836                .before_id
4837                .map(|id| db::NotificationId::from_proto(id)),
4838        )
4839        .await?;
4840    response.send(proto::GetNotificationsResponse {
4841        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4842        notifications,
4843    })?;
4844    Ok(())
4845}
4846
4847/// Mark notifications as read
4848async fn mark_notification_as_read(
4849    request: proto::MarkNotificationRead,
4850    response: Response<proto::MarkNotificationRead>,
4851    session: UserSession,
4852) -> Result<()> {
4853    let database = &session.db().await;
4854    let notifications = database
4855        .mark_notification_as_read_by_id(
4856            session.user_id(),
4857            NotificationId::from_proto(request.notification_id),
4858        )
4859        .await?;
4860    send_notifications(
4861        &*session.connection_pool().await,
4862        &session.peer,
4863        notifications,
4864    );
4865    response.send(proto::Ack {})?;
4866    Ok(())
4867}
4868
4869/// Get the current users information
4870async fn get_private_user_info(
4871    _request: proto::GetPrivateUserInfo,
4872    response: Response<proto::GetPrivateUserInfo>,
4873    session: UserSession,
4874) -> Result<()> {
4875    let db = session.db().await;
4876
4877    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4878    let user = db
4879        .get_user_by_id(session.user_id())
4880        .await?
4881        .ok_or_else(|| anyhow!("user not found"))?;
4882    let flags = db.get_user_flags(session.user_id()).await?;
4883
4884    response.send(proto::GetPrivateUserInfoResponse {
4885        metrics_id,
4886        staff: user.admin,
4887        flags,
4888        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4889    })?;
4890    Ok(())
4891}
4892
4893/// Accept the terms of service (tos) on behalf of the current user
4894async fn accept_terms_of_service(
4895    _request: proto::AcceptTermsOfService,
4896    response: Response<proto::AcceptTermsOfService>,
4897    session: UserSession,
4898) -> Result<()> {
4899    let db = session.db().await;
4900
4901    let accepted_tos_at = Utc::now();
4902    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4903        .await?;
4904
4905    response.send(proto::AcceptTermsOfServiceResponse {
4906        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4907    })?;
4908    Ok(())
4909}
4910
4911/// The minimum account age an account must have in order to use the LLM service.
4912const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4913
4914async fn get_llm_api_token(
4915    _request: proto::GetLlmToken,
4916    response: Response<proto::GetLlmToken>,
4917    session: UserSession,
4918) -> Result<()> {
4919    let db = session.db().await;
4920
4921    let flags = db.get_user_flags(session.user_id()).await?;
4922    if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
4923        Err(anyhow!("permission denied"))?
4924    }
4925
4926    let user_id = session.user_id();
4927    let user = db
4928        .get_user_by_id(user_id)
4929        .await?
4930        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4931
4932    if user.accepted_tos_at.is_none() {
4933        Err(anyhow!("terms of service not accepted"))?
4934    }
4935
4936    let mut account_created_at = user.created_at;
4937    if let Some(github_created_at) = user.github_user_created_at {
4938        account_created_at = account_created_at.min(github_created_at);
4939    }
4940    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4941        Err(anyhow!("account too young"))?
4942    }
4943
4944    let token = LlmTokenClaims::create(
4945        user.id,
4946        session.is_staff(),
4947        session.current_plan().await?,
4948        &session.app_state.config,
4949    )?;
4950    response.send(proto::GetLlmTokenResponse { token })?;
4951    Ok(())
4952}
4953
4954fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4955    let message = match message {
4956        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4957        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4958        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4959        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4960        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4961            code: frame.code.into(),
4962            reason: frame.reason,
4963        })),
4964        // We should never receive a frame while reading the message, according
4965        // to the `tungstenite` maintainers:
4966        //
4967        // > It cannot occur when you read messages from the WebSocket, but it
4968        // > can be used when you want to send the raw frames (e.g. you want to
4969        // > send the frames to the WebSocket without composing the full message first).
4970        // >
4971        // > — https://github.com/snapview/tungstenite-rs/issues/268
4972        TungsteniteMessage::Frame(_) => {
4973            bail!("received an unexpected frame while reading the message")
4974        }
4975    };
4976
4977    Ok(message)
4978}
4979
4980fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4981    match message {
4982        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4983        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4984        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4985        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4986        AxumMessage::Close(frame) => {
4987            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4988                code: frame.code.into(),
4989                reason: frame.reason,
4990            }))
4991        }
4992    }
4993}
4994
4995fn notify_membership_updated(
4996    connection_pool: &mut ConnectionPool,
4997    result: MembershipUpdated,
4998    user_id: UserId,
4999    peer: &Peer,
5000) {
5001    for membership in &result.new_channels.channel_memberships {
5002        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5003    }
5004    for channel_id in &result.removed_channels {
5005        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5006    }
5007
5008    let user_channels_update = proto::UpdateUserChannels {
5009        channel_memberships: result
5010            .new_channels
5011            .channel_memberships
5012            .iter()
5013            .map(|cm| proto::ChannelMembership {
5014                channel_id: cm.channel_id.to_proto(),
5015                role: cm.role.into(),
5016            })
5017            .collect(),
5018        ..Default::default()
5019    };
5020
5021    let mut update = build_channels_update(result.new_channels);
5022    update.delete_channels = result
5023        .removed_channels
5024        .into_iter()
5025        .map(|id| id.to_proto())
5026        .collect();
5027    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5028
5029    for connection_id in connection_pool.user_connection_ids(user_id) {
5030        peer.send(connection_id, user_channels_update.clone())
5031            .trace_err();
5032        peer.send(connection_id, update.clone()).trace_err();
5033    }
5034}
5035
5036fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5037    proto::UpdateUserChannels {
5038        channel_memberships: channels
5039            .channel_memberships
5040            .iter()
5041            .map(|m| proto::ChannelMembership {
5042                channel_id: m.channel_id.to_proto(),
5043                role: m.role.into(),
5044            })
5045            .collect(),
5046        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5047        observed_channel_message_id: channels.observed_channel_messages.clone(),
5048    }
5049}
5050
5051fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5052    let mut update = proto::UpdateChannels::default();
5053
5054    for channel in channels.channels {
5055        update.channels.push(channel.to_proto());
5056    }
5057
5058    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5059    update.latest_channel_message_ids = channels.latest_channel_messages;
5060
5061    for (channel_id, participants) in channels.channel_participants {
5062        update
5063            .channel_participants
5064            .push(proto::ChannelParticipants {
5065                channel_id: channel_id.to_proto(),
5066                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5067            });
5068    }
5069
5070    for channel in channels.invited_channels {
5071        update.channel_invitations.push(channel.to_proto());
5072    }
5073
5074    update.hosted_projects = channels.hosted_projects;
5075    update
5076}
5077
5078fn build_initial_contacts_update(
5079    contacts: Vec<db::Contact>,
5080    pool: &ConnectionPool,
5081) -> proto::UpdateContacts {
5082    let mut update = proto::UpdateContacts::default();
5083
5084    for contact in contacts {
5085        match contact {
5086            db::Contact::Accepted { user_id, busy } => {
5087                update.contacts.push(contact_for_user(user_id, busy, &pool));
5088            }
5089            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5090            db::Contact::Incoming { user_id } => {
5091                update
5092                    .incoming_requests
5093                    .push(proto::IncomingContactRequest {
5094                        requester_id: user_id.to_proto(),
5095                    })
5096            }
5097        }
5098    }
5099
5100    update
5101}
5102
5103fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5104    proto::Contact {
5105        user_id: user_id.to_proto(),
5106        online: pool.is_user_online(user_id),
5107        busy,
5108    }
5109}
5110
5111fn room_updated(room: &proto::Room, peer: &Peer) {
5112    broadcast(
5113        None,
5114        room.participants
5115            .iter()
5116            .filter_map(|participant| Some(participant.peer_id?.into())),
5117        |peer_id| {
5118            peer.send(
5119                peer_id,
5120                proto::RoomUpdated {
5121                    room: Some(room.clone()),
5122                },
5123            )
5124        },
5125    );
5126}
5127
5128fn channel_updated(
5129    channel: &db::channel::Model,
5130    room: &proto::Room,
5131    peer: &Peer,
5132    pool: &ConnectionPool,
5133) {
5134    let participants = room
5135        .participants
5136        .iter()
5137        .map(|p| p.user_id)
5138        .collect::<Vec<_>>();
5139
5140    broadcast(
5141        None,
5142        pool.channel_connection_ids(channel.root_id())
5143            .filter_map(|(channel_id, role)| {
5144                role.can_see_channel(channel.visibility).then(|| channel_id)
5145            }),
5146        |peer_id| {
5147            peer.send(
5148                peer_id,
5149                proto::UpdateChannels {
5150                    channel_participants: vec![proto::ChannelParticipants {
5151                        channel_id: channel.id.to_proto(),
5152                        participant_user_ids: participants.clone(),
5153                    }],
5154                    ..Default::default()
5155                },
5156            )
5157        },
5158    );
5159}
5160
5161async fn send_dev_server_projects_update(
5162    user_id: UserId,
5163    mut status: proto::DevServerProjectsUpdate,
5164    session: &Session,
5165) {
5166    let pool = session.connection_pool().await;
5167    for dev_server in &mut status.dev_servers {
5168        dev_server.status =
5169            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5170    }
5171    let connections = pool.user_connection_ids(user_id);
5172    for connection_id in connections {
5173        session.peer.send(connection_id, status.clone()).trace_err();
5174    }
5175}
5176
5177async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5178    let db = session.db().await;
5179
5180    let contacts = db.get_contacts(user_id).await?;
5181    let busy = db.is_user_busy(user_id).await?;
5182
5183    let pool = session.connection_pool().await;
5184    let updated_contact = contact_for_user(user_id, busy, &pool);
5185    for contact in contacts {
5186        if let db::Contact::Accepted {
5187            user_id: contact_user_id,
5188            ..
5189        } = contact
5190        {
5191            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5192                session
5193                    .peer
5194                    .send(
5195                        contact_conn_id,
5196                        proto::UpdateContacts {
5197                            contacts: vec![updated_contact.clone()],
5198                            remove_contacts: Default::default(),
5199                            incoming_requests: Default::default(),
5200                            remove_incoming_requests: Default::default(),
5201                            outgoing_requests: Default::default(),
5202                            remove_outgoing_requests: Default::default(),
5203                        },
5204                    )
5205                    .trace_err();
5206            }
5207        }
5208    }
5209    Ok(())
5210}
5211
5212async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5213    log::info!("lost dev server connection, unsharing projects");
5214    let project_ids = session
5215        .db()
5216        .await
5217        .get_stale_dev_server_projects(session.connection_id)
5218        .await?;
5219
5220    for project_id in project_ids {
5221        // not unshare re-checks the connection ids match, so we get away with no transaction
5222        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5223    }
5224
5225    let user_id = session.dev_server().user_id;
5226    let update = session
5227        .db()
5228        .await
5229        .dev_server_projects_update(user_id)
5230        .await?;
5231
5232    send_dev_server_projects_update(user_id, update, session).await;
5233
5234    Ok(())
5235}
5236
5237async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5238    let mut contacts_to_update = HashSet::default();
5239
5240    let room_id;
5241    let canceled_calls_to_user_ids;
5242    let live_kit_room;
5243    let delete_live_kit_room;
5244    let room;
5245    let channel;
5246
5247    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5248        contacts_to_update.insert(session.user_id());
5249
5250        for project in left_room.left_projects.values() {
5251            project_left(project, session);
5252        }
5253
5254        room_id = RoomId::from_proto(left_room.room.id);
5255        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5256        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5257        delete_live_kit_room = left_room.deleted;
5258        room = mem::take(&mut left_room.room);
5259        channel = mem::take(&mut left_room.channel);
5260
5261        room_updated(&room, &session.peer);
5262    } else {
5263        return Ok(());
5264    }
5265
5266    if let Some(channel) = channel {
5267        channel_updated(
5268            &channel,
5269            &room,
5270            &session.peer,
5271            &*session.connection_pool().await,
5272        );
5273    }
5274
5275    {
5276        let pool = session.connection_pool().await;
5277        for canceled_user_id in canceled_calls_to_user_ids {
5278            for connection_id in pool.user_connection_ids(canceled_user_id) {
5279                session
5280                    .peer
5281                    .send(
5282                        connection_id,
5283                        proto::CallCanceled {
5284                            room_id: room_id.to_proto(),
5285                        },
5286                    )
5287                    .trace_err();
5288            }
5289            contacts_to_update.insert(canceled_user_id);
5290        }
5291    }
5292
5293    for contact_user_id in contacts_to_update {
5294        update_user_contacts(contact_user_id, &session).await?;
5295    }
5296
5297    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5298        live_kit
5299            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5300            .await
5301            .trace_err();
5302
5303        if delete_live_kit_room {
5304            live_kit.delete_room(live_kit_room).await.trace_err();
5305        }
5306    }
5307
5308    Ok(())
5309}
5310
5311async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5312    let left_channel_buffers = session
5313        .db()
5314        .await
5315        .leave_channel_buffers(session.connection_id)
5316        .await?;
5317
5318    for left_buffer in left_channel_buffers {
5319        channel_buffer_updated(
5320            session.connection_id,
5321            left_buffer.connections,
5322            &proto::UpdateChannelBufferCollaborators {
5323                channel_id: left_buffer.channel_id.to_proto(),
5324                collaborators: left_buffer.collaborators,
5325            },
5326            &session.peer,
5327        );
5328    }
5329
5330    Ok(())
5331}
5332
5333fn project_left(project: &db::LeftProject, session: &UserSession) {
5334    for connection_id in &project.connection_ids {
5335        if project.should_unshare {
5336            session
5337                .peer
5338                .send(
5339                    *connection_id,
5340                    proto::UnshareProject {
5341                        project_id: project.id.to_proto(),
5342                    },
5343                )
5344                .trace_err();
5345        } else {
5346            session
5347                .peer
5348                .send(
5349                    *connection_id,
5350                    proto::RemoveProjectCollaborator {
5351                        project_id: project.id.to_proto(),
5352                        peer_id: Some(session.connection_id.into()),
5353                    },
5354                )
5355                .trace_err();
5356        }
5357    }
5358}
5359
5360pub trait ResultExt {
5361    type Ok;
5362
5363    fn trace_err(self) -> Option<Self::Ok>;
5364}
5365
5366impl<T, E> ResultExt for Result<T, E>
5367where
5368    E: std::fmt::Debug,
5369{
5370    type Ok = T;
5371
5372    #[track_caller]
5373    fn trace_err(self) -> Option<T> {
5374        match self {
5375            Ok(value) => Some(value),
5376            Err(error) => {
5377                tracing::error!("{:?}", error);
5378                None
5379            }
5380        }
5381    }
5382}