rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use http_client::HttpClient;
  39use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  40use reqwest_client::ReqwestClient;
  41use sha2::Digest;
  42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  43
  44use futures::{
  45    channel::oneshot,
  46    future::{self, BoxFuture},
  47    stream::FuturesUnordered,
  48    FutureExt, SinkExt, StreamExt, TryStreamExt,
  49};
  50use prometheus::{register_int_gauge, IntGauge};
  51use rpc::{
  52    proto::{
  53        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  54        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  55    },
  56    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  57};
  58use semantic_version::SemanticVersion;
  59use serde::{Serialize, Serializer};
  60use std::{
  61    any::TypeId,
  62    future::Future,
  63    marker::PhantomData,
  64    mem,
  65    net::SocketAddr,
  66    ops::{Deref, DerefMut},
  67    rc::Rc,
  68    sync::{
  69        atomic::{AtomicBool, Ordering::SeqCst},
  70        Arc, OnceLock,
  71    },
  72    time::{Duration, Instant},
  73};
  74use time::OffsetDateTime;
  75use tokio::sync::{watch, MutexGuard, Semaphore};
  76use tower::ServiceBuilder;
  77use tracing::{
  78    field::{self},
  79    info_span, instrument, Instrument,
  80};
  81
  82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  83
  84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  86
  87const MESSAGE_COUNT_PER_PAGE: usize = 100;
  88const MAX_MESSAGE_LEN: usize = 1024;
  89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  93
  94struct Response<R> {
  95    peer: Arc<Peer>,
  96    receipt: Receipt<R>,
  97    responded: Arc<AtomicBool>,
  98}
  99
 100impl<R: RequestMessage> Response<R> {
 101    fn send(self, payload: R::Response) -> Result<()> {
 102        self.responded.store(true, SeqCst);
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108#[derive(Clone, Debug)]
 109pub enum Principal {
 110    User(User),
 111    Impersonated { user: User, admin: User },
 112    DevServer(dev_server::Model),
 113}
 114
 115impl Principal {
 116    fn update_span(&self, span: &tracing::Span) {
 117        match &self {
 118            Principal::User(user) => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121            }
 122            Principal::Impersonated { user, admin } => {
 123                span.record("user_id", user.id.0);
 124                span.record("login", &user.github_login);
 125                span.record("impersonator", &admin.github_login);
 126            }
 127            Principal::DevServer(dev_server) => {
 128                span.record("dev_server_id", dev_server.id.0);
 129            }
 130        }
 131    }
 132}
 133
 134#[derive(Clone)]
 135struct Session {
 136    principal: Principal,
 137    connection_id: ConnectionId,
 138    db: Arc<tokio::sync::Mutex<DbHandle>>,
 139    peer: Arc<Peer>,
 140    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 141    app_state: Arc<AppState>,
 142    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 143    http_client: Arc<dyn HttpClient>,
 144    /// The GeoIP country code for the user.
 145    #[allow(unused)]
 146    geoip_country_code: Option<String>,
 147    _executor: Executor,
 148}
 149
 150impl Session {
 151    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.db.lock().await;
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        guard
 158    }
 159
 160    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        let guard = self.connection_pool.lock();
 164        ConnectionPoolGuard {
 165            guard,
 166            _not_send: PhantomData,
 167        }
 168    }
 169
 170    fn for_user(self) -> Option<UserSession> {
 171        UserSession::new(self)
 172    }
 173
 174    fn for_dev_server(self) -> Option<DevServerSession> {
 175        DevServerSession::new(self)
 176    }
 177
 178    fn user_id(&self) -> Option<UserId> {
 179        match &self.principal {
 180            Principal::User(user) => Some(user.id),
 181            Principal::Impersonated { user, .. } => Some(user.id),
 182            Principal::DevServer(_) => None,
 183        }
 184    }
 185
 186    fn is_staff(&self) -> bool {
 187        match &self.principal {
 188            Principal::User(user) => user.admin,
 189            Principal::Impersonated { .. } => true,
 190            Principal::DevServer(_) => false,
 191        }
 192    }
 193
 194    pub async fn has_llm_subscription(
 195        &self,
 196        db: &MutexGuard<'_, DbHandle>,
 197    ) -> anyhow::Result<bool> {
 198        if self.is_staff() {
 199            return Ok(true);
 200        }
 201
 202        let Some(user_id) = self.user_id() else {
 203            return Ok(false);
 204        };
 205
 206        Ok(db.has_active_billing_subscription(user_id).await?)
 207    }
 208
 209    pub async fn current_plan(
 210        &self,
 211        _db: &MutexGuard<'_, DbHandle>,
 212    ) -> anyhow::Result<proto::Plan> {
 213        if self.is_staff() {
 214            Ok(proto::Plan::ZedPro)
 215        } else {
 216            Ok(proto::Plan::Free)
 217        }
 218    }
 219
 220    fn dev_server_id(&self) -> Option<DevServerId> {
 221        match &self.principal {
 222            Principal::User(_) | Principal::Impersonated { .. } => None,
 223            Principal::DevServer(dev_server) => Some(dev_server.id),
 224        }
 225    }
 226
 227    fn principal_id(&self) -> PrincipalId {
 228        match &self.principal {
 229            Principal::User(user) => PrincipalId::UserId(user.id),
 230            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 231            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 232        }
 233    }
 234}
 235
 236impl Debug for Session {
 237    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 238        let mut result = f.debug_struct("Session");
 239        match &self.principal {
 240            Principal::User(user) => {
 241                result.field("user", &user.github_login);
 242            }
 243            Principal::Impersonated { user, admin } => {
 244                result.field("user", &user.github_login);
 245                result.field("impersonator", &admin.github_login);
 246            }
 247            Principal::DevServer(dev_server) => {
 248                result.field("dev_server", &dev_server.id);
 249            }
 250        }
 251        result.field("connection_id", &self.connection_id).finish()
 252    }
 253}
 254
 255struct UserSession(Session);
 256
 257impl UserSession {
 258    pub fn new(s: Session) -> Option<Self> {
 259        s.user_id().map(|_| UserSession(s))
 260    }
 261    pub fn user_id(&self) -> UserId {
 262        self.0.user_id().unwrap()
 263    }
 264
 265    pub fn email(&self) -> Option<String> {
 266        match &self.0.principal {
 267            Principal::User(user) => user.email_address.clone(),
 268            Principal::Impersonated { user, .. } => user.email_address.clone(),
 269            Principal::DevServer(..) => None,
 270        }
 271    }
 272}
 273
 274impl Deref for UserSession {
 275    type Target = Session;
 276
 277    fn deref(&self) -> &Self::Target {
 278        &self.0
 279    }
 280}
 281impl DerefMut for UserSession {
 282    fn deref_mut(&mut self) -> &mut Self::Target {
 283        &mut self.0
 284    }
 285}
 286
 287struct DevServerSession(Session);
 288
 289impl DevServerSession {
 290    pub fn new(s: Session) -> Option<Self> {
 291        s.dev_server_id().map(|_| DevServerSession(s))
 292    }
 293    pub fn dev_server_id(&self) -> DevServerId {
 294        self.0.dev_server_id().unwrap()
 295    }
 296
 297    fn dev_server(&self) -> &dev_server::Model {
 298        match &self.0.principal {
 299            Principal::DevServer(dev_server) => dev_server,
 300            _ => unreachable!(),
 301        }
 302    }
 303}
 304
 305impl Deref for DevServerSession {
 306    type Target = Session;
 307
 308    fn deref(&self) -> &Self::Target {
 309        &self.0
 310    }
 311}
 312impl DerefMut for DevServerSession {
 313    fn deref_mut(&mut self) -> &mut Self::Target {
 314        &mut self.0
 315    }
 316}
 317
 318fn user_handler<M: RequestMessage, Fut>(
 319    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 320) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 321where
 322    Fut: Send + Future<Output = Result<()>>,
 323{
 324    let handler = Arc::new(handler);
 325    move |message, response, session| {
 326        let handler = handler.clone();
 327        Box::pin(async move {
 328            if let Some(user_session) = session.for_user() {
 329                Ok(handler(message, response, user_session).await?)
 330            } else {
 331                Err(Error::Internal(anyhow!(
 332                    "must be a user to call {}",
 333                    M::NAME
 334                )))
 335            }
 336        })
 337    }
 338}
 339
 340fn dev_server_handler<M: RequestMessage, Fut>(
 341    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 342) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 343where
 344    Fut: Send + Future<Output = Result<()>>,
 345{
 346    let handler = Arc::new(handler);
 347    move |message, response, session| {
 348        let handler = handler.clone();
 349        Box::pin(async move {
 350            if let Some(dev_server_session) = session.for_dev_server() {
 351                Ok(handler(message, response, dev_server_session).await?)
 352            } else {
 353                Err(Error::Internal(anyhow!(
 354                    "must be a dev server to call {}",
 355                    M::NAME
 356                )))
 357            }
 358        })
 359    }
 360}
 361
 362fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 363    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 364) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 365where
 366    InnertRetFut: Send + Future<Output = Result<()>>,
 367{
 368    let handler = Arc::new(handler);
 369    move |message, session| {
 370        let handler = handler.clone();
 371        Box::pin(async move {
 372            if let Some(user_session) = session.for_user() {
 373                Ok(handler(message, user_session).await?)
 374            } else {
 375                Err(Error::Internal(anyhow!(
 376                    "must be a user to call {}",
 377                    M::NAME
 378                )))
 379            }
 380        })
 381    }
 382}
 383
 384struct DbHandle(Arc<Database>);
 385
 386impl Deref for DbHandle {
 387    type Target = Database;
 388
 389    fn deref(&self) -> &Self::Target {
 390        self.0.as_ref()
 391    }
 392}
 393
 394pub struct Server {
 395    id: parking_lot::Mutex<ServerId>,
 396    peer: Arc<Peer>,
 397    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 398    app_state: Arc<AppState>,
 399    handlers: HashMap<TypeId, MessageHandler>,
 400    teardown: watch::Sender<bool>,
 401}
 402
 403pub(crate) struct ConnectionPoolGuard<'a> {
 404    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 405    _not_send: PhantomData<Rc<()>>,
 406}
 407
 408#[derive(Serialize)]
 409pub struct ServerSnapshot<'a> {
 410    peer: &'a Peer,
 411    #[serde(serialize_with = "serialize_deref")]
 412    connection_pool: ConnectionPoolGuard<'a>,
 413}
 414
 415pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 416where
 417    S: Serializer,
 418    T: Deref<Target = U>,
 419    U: Serialize,
 420{
 421    Serialize::serialize(value.deref(), serializer)
 422}
 423
 424impl Server {
 425    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 426        let mut server = Self {
 427            id: parking_lot::Mutex::new(id),
 428            peer: Peer::new(id.0 as u32),
 429            app_state: app_state.clone(),
 430            connection_pool: Default::default(),
 431            handlers: Default::default(),
 432            teardown: watch::channel(false).0,
 433        };
 434
 435        server
 436            .add_request_handler(ping)
 437            .add_request_handler(user_handler(create_room))
 438            .add_request_handler(user_handler(join_room))
 439            .add_request_handler(user_handler(rejoin_room))
 440            .add_request_handler(user_handler(leave_room))
 441            .add_request_handler(user_handler(set_room_participant_role))
 442            .add_request_handler(user_handler(call))
 443            .add_request_handler(user_handler(cancel_call))
 444            .add_message_handler(user_message_handler(decline_call))
 445            .add_request_handler(user_handler(update_participant_location))
 446            .add_request_handler(user_handler(share_project))
 447            .add_message_handler(unshare_project)
 448            .add_request_handler(user_handler(join_project))
 449            .add_request_handler(user_handler(join_hosted_project))
 450            .add_request_handler(user_handler(rejoin_dev_server_projects))
 451            .add_request_handler(user_handler(create_dev_server_project))
 452            .add_request_handler(user_handler(update_dev_server_project))
 453            .add_request_handler(user_handler(delete_dev_server_project))
 454            .add_request_handler(user_handler(create_dev_server))
 455            .add_request_handler(user_handler(regenerate_dev_server_token))
 456            .add_request_handler(user_handler(rename_dev_server))
 457            .add_request_handler(user_handler(delete_dev_server))
 458            .add_request_handler(user_handler(list_remote_directory))
 459            .add_request_handler(dev_server_handler(share_dev_server_project))
 460            .add_request_handler(dev_server_handler(shutdown_dev_server))
 461            .add_request_handler(dev_server_handler(reconnect_dev_server))
 462            .add_message_handler(user_message_handler(leave_project))
 463            .add_request_handler(update_project)
 464            .add_request_handler(update_worktree)
 465            .add_message_handler(start_language_server)
 466            .add_message_handler(update_language_server)
 467            .add_message_handler(update_diagnostic_summary)
 468            .add_message_handler(update_worktree_settings)
 469            .add_request_handler(user_handler(
 470                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetHover>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetDefinition>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::GetTypeDefinition>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetReferences>,
 483            ))
 484            .add_request_handler(user_handler(forward_find_search_candidates_request))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::GetProjectSymbols>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_read_only_project_request::<proto::OpenBufferById>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_read_only_project_request::<proto::InlayHints>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_read_only_project_request::<proto::ResolveInlayHint>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_read_only_project_request::<proto::OpenBufferByPath>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::GetCompletions>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::OpenNewBuffer>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::GetCodeActions>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::ApplyCodeAction>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::PrepareRename>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::PerformRename>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::ReloadBuffers>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::FormatBuffers>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_mutating_project_request::<proto::CreateProjectEntry>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::RenameProjectEntry>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::CopyProjectEntry>,
 547            ))
 548            .add_request_handler(user_handler(
 549                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 550            ))
 551            .add_request_handler(user_handler(
 552                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 553            ))
 554            .add_request_handler(user_handler(
 555                forward_mutating_project_request::<proto::OnTypeFormatting>,
 556            ))
 557            .add_request_handler(user_handler(
 558                forward_mutating_project_request::<proto::SaveBuffer>,
 559            ))
 560            .add_request_handler(user_handler(
 561                forward_mutating_project_request::<proto::BlameBuffer>,
 562            ))
 563            .add_request_handler(user_handler(
 564                forward_mutating_project_request::<proto::MultiLspQuery>,
 565            ))
 566            .add_request_handler(user_handler(
 567                forward_mutating_project_request::<proto::RestartLanguageServers>,
 568            ))
 569            .add_request_handler(user_handler(
 570                forward_mutating_project_request::<proto::LinkedEditingRange>,
 571            ))
 572            .add_message_handler(create_buffer_for_peer)
 573            .add_request_handler(update_buffer)
 574            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 575            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 576            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 577            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 578            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 579            .add_request_handler(get_users)
 580            .add_request_handler(user_handler(fuzzy_search_users))
 581            .add_request_handler(user_handler(request_contact))
 582            .add_request_handler(user_handler(remove_contact))
 583            .add_request_handler(user_handler(respond_to_contact_request))
 584            .add_message_handler(subscribe_to_channels)
 585            .add_request_handler(user_handler(create_channel))
 586            .add_request_handler(user_handler(delete_channel))
 587            .add_request_handler(user_handler(invite_channel_member))
 588            .add_request_handler(user_handler(remove_channel_member))
 589            .add_request_handler(user_handler(set_channel_member_role))
 590            .add_request_handler(user_handler(set_channel_visibility))
 591            .add_request_handler(user_handler(rename_channel))
 592            .add_request_handler(user_handler(join_channel_buffer))
 593            .add_request_handler(user_handler(leave_channel_buffer))
 594            .add_message_handler(user_message_handler(update_channel_buffer))
 595            .add_request_handler(user_handler(rejoin_channel_buffers))
 596            .add_request_handler(user_handler(get_channel_members))
 597            .add_request_handler(user_handler(respond_to_channel_invite))
 598            .add_request_handler(user_handler(join_channel))
 599            .add_request_handler(user_handler(join_channel_chat))
 600            .add_message_handler(user_message_handler(leave_channel_chat))
 601            .add_request_handler(user_handler(send_channel_message))
 602            .add_request_handler(user_handler(remove_channel_message))
 603            .add_request_handler(user_handler(update_channel_message))
 604            .add_request_handler(user_handler(get_channel_messages))
 605            .add_request_handler(user_handler(get_channel_messages_by_id))
 606            .add_request_handler(user_handler(get_notifications))
 607            .add_request_handler(user_handler(mark_notification_as_read))
 608            .add_request_handler(user_handler(move_channel))
 609            .add_request_handler(user_handler(follow))
 610            .add_message_handler(user_message_handler(unfollow))
 611            .add_message_handler(user_message_handler(update_followers))
 612            .add_request_handler(user_handler(get_private_user_info))
 613            .add_request_handler(user_handler(get_llm_api_token))
 614            .add_request_handler(user_handler(accept_terms_of_service))
 615            .add_message_handler(user_message_handler(acknowledge_channel_message))
 616            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 617            .add_request_handler(user_handler(get_supermaven_api_key))
 618            .add_request_handler(user_handler(
 619                forward_mutating_project_request::<proto::OpenContext>,
 620            ))
 621            .add_request_handler(user_handler(
 622                forward_mutating_project_request::<proto::CreateContext>,
 623            ))
 624            .add_request_handler(user_handler(
 625                forward_mutating_project_request::<proto::SynchronizeContexts>,
 626            ))
 627            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 628            .add_message_handler(update_context)
 629            .add_request_handler({
 630                let app_state = app_state.clone();
 631                move |request, response, session| {
 632                    let app_state = app_state.clone();
 633                    async move {
 634                        count_language_model_tokens(request, response, session, &app_state.config)
 635                            .await
 636                    }
 637                }
 638            })
 639            .add_request_handler({
 640                user_handler(move |request, response, session| {
 641                    get_cached_embeddings(request, response, session)
 642                })
 643            })
 644            .add_request_handler({
 645                let app_state = app_state.clone();
 646                user_handler(move |request, response, session| {
 647                    compute_embeddings(
 648                        request,
 649                        response,
 650                        session,
 651                        app_state.config.openai_api_key.clone(),
 652                    )
 653                })
 654            });
 655
 656        Arc::new(server)
 657    }
 658
 659    pub async fn start(&self) -> Result<()> {
 660        let server_id = *self.id.lock();
 661        let app_state = self.app_state.clone();
 662        let peer = self.peer.clone();
 663        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 664        let pool = self.connection_pool.clone();
 665        let live_kit_client = self.app_state.live_kit_client.clone();
 666
 667        let span = info_span!("start server");
 668        self.app_state.executor.spawn_detached(
 669            async move {
 670                tracing::info!("waiting for cleanup timeout");
 671                timeout.await;
 672                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 673                if let Some((room_ids, channel_ids)) = app_state
 674                    .db
 675                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 676                    .await
 677                    .trace_err()
 678                {
 679                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 680                    tracing::info!(
 681                        stale_channel_buffer_count = channel_ids.len(),
 682                        "retrieved stale channel buffers"
 683                    );
 684
 685                    for channel_id in channel_ids {
 686                        if let Some(refreshed_channel_buffer) = app_state
 687                            .db
 688                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 689                            .await
 690                            .trace_err()
 691                        {
 692                            for connection_id in refreshed_channel_buffer.connection_ids {
 693                                peer.send(
 694                                    connection_id,
 695                                    proto::UpdateChannelBufferCollaborators {
 696                                        channel_id: channel_id.to_proto(),
 697                                        collaborators: refreshed_channel_buffer
 698                                            .collaborators
 699                                            .clone(),
 700                                    },
 701                                )
 702                                .trace_err();
 703                            }
 704                        }
 705                    }
 706
 707                    for room_id in room_ids {
 708                        let mut contacts_to_update = HashSet::default();
 709                        let mut canceled_calls_to_user_ids = Vec::new();
 710                        let mut live_kit_room = String::new();
 711                        let mut delete_live_kit_room = false;
 712
 713                        if let Some(mut refreshed_room) = app_state
 714                            .db
 715                            .clear_stale_room_participants(room_id, server_id)
 716                            .await
 717                            .trace_err()
 718                        {
 719                            tracing::info!(
 720                                room_id = room_id.0,
 721                                new_participant_count = refreshed_room.room.participants.len(),
 722                                "refreshed room"
 723                            );
 724                            room_updated(&refreshed_room.room, &peer);
 725                            if let Some(channel) = refreshed_room.channel.as_ref() {
 726                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 727                            }
 728                            contacts_to_update
 729                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 730                            contacts_to_update
 731                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 732                            canceled_calls_to_user_ids =
 733                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 734                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 735                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 736                        }
 737
 738                        {
 739                            let pool = pool.lock();
 740                            for canceled_user_id in canceled_calls_to_user_ids {
 741                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 742                                    peer.send(
 743                                        connection_id,
 744                                        proto::CallCanceled {
 745                                            room_id: room_id.to_proto(),
 746                                        },
 747                                    )
 748                                    .trace_err();
 749                                }
 750                            }
 751                        }
 752
 753                        for user_id in contacts_to_update {
 754                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 755                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 756                            if let Some((busy, contacts)) = busy.zip(contacts) {
 757                                let pool = pool.lock();
 758                                let updated_contact = contact_for_user(user_id, busy, &pool);
 759                                for contact in contacts {
 760                                    if let db::Contact::Accepted {
 761                                        user_id: contact_user_id,
 762                                        ..
 763                                    } = contact
 764                                    {
 765                                        for contact_conn_id in
 766                                            pool.user_connection_ids(contact_user_id)
 767                                        {
 768                                            peer.send(
 769                                                contact_conn_id,
 770                                                proto::UpdateContacts {
 771                                                    contacts: vec![updated_contact.clone()],
 772                                                    remove_contacts: Default::default(),
 773                                                    incoming_requests: Default::default(),
 774                                                    remove_incoming_requests: Default::default(),
 775                                                    outgoing_requests: Default::default(),
 776                                                    remove_outgoing_requests: Default::default(),
 777                                                },
 778                                            )
 779                                            .trace_err();
 780                                        }
 781                                    }
 782                                }
 783                            }
 784                        }
 785
 786                        if let Some(live_kit) = live_kit_client.as_ref() {
 787                            if delete_live_kit_room {
 788                                live_kit.delete_room(live_kit_room).await.trace_err();
 789                            }
 790                        }
 791                    }
 792                }
 793
 794                app_state
 795                    .db
 796                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 797                    .await
 798                    .trace_err();
 799            }
 800            .instrument(span),
 801        );
 802        Ok(())
 803    }
 804
 805    pub fn teardown(&self) {
 806        self.peer.teardown();
 807        self.connection_pool.lock().reset();
 808        let _ = self.teardown.send(true);
 809    }
 810
 811    #[cfg(test)]
 812    pub fn reset(&self, id: ServerId) {
 813        self.teardown();
 814        *self.id.lock() = id;
 815        self.peer.reset(id.0 as u32);
 816        let _ = self.teardown.send(false);
 817    }
 818
 819    #[cfg(test)]
 820    pub fn id(&self) -> ServerId {
 821        *self.id.lock()
 822    }
 823
 824    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 825    where
 826        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 827        Fut: 'static + Send + Future<Output = Result<()>>,
 828        M: EnvelopedMessage,
 829    {
 830        let prev_handler = self.handlers.insert(
 831            TypeId::of::<M>(),
 832            Box::new(move |envelope, session| {
 833                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 834                let received_at = envelope.received_at;
 835                tracing::info!("message received");
 836                let start_time = Instant::now();
 837                let future = (handler)(*envelope, session);
 838                async move {
 839                    let result = future.await;
 840                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 841                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 842                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 843                    let payload_type = M::NAME;
 844
 845                    match result {
 846                        Err(error) => {
 847                            tracing::error!(
 848                                ?error,
 849                                total_duration_ms,
 850                                processing_duration_ms,
 851                                queue_duration_ms,
 852                                payload_type,
 853                                "error handling message"
 854                            )
 855                        }
 856                        Ok(()) => tracing::info!(
 857                            total_duration_ms,
 858                            processing_duration_ms,
 859                            queue_duration_ms,
 860                            "finished handling message"
 861                        ),
 862                    }
 863                }
 864                .boxed()
 865            }),
 866        );
 867        if prev_handler.is_some() {
 868            panic!("registered a handler for the same message twice");
 869        }
 870        self
 871    }
 872
 873    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 874    where
 875        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 876        Fut: 'static + Send + Future<Output = Result<()>>,
 877        M: EnvelopedMessage,
 878    {
 879        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 880        self
 881    }
 882
 883    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 884    where
 885        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 886        Fut: Send + Future<Output = Result<()>>,
 887        M: RequestMessage,
 888    {
 889        let handler = Arc::new(handler);
 890        self.add_handler(move |envelope, session| {
 891            let receipt = envelope.receipt();
 892            let handler = handler.clone();
 893            async move {
 894                let peer = session.peer.clone();
 895                let responded = Arc::new(AtomicBool::default());
 896                let response = Response {
 897                    peer: peer.clone(),
 898                    responded: responded.clone(),
 899                    receipt,
 900                };
 901                match (handler)(envelope.payload, response, session).await {
 902                    Ok(()) => {
 903                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 904                            Ok(())
 905                        } else {
 906                            Err(anyhow!("handler did not send a response"))?
 907                        }
 908                    }
 909                    Err(error) => {
 910                        let proto_err = match &error {
 911                            Error::Internal(err) => err.to_proto(),
 912                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 913                        };
 914                        peer.respond_with_error(receipt, proto_err)?;
 915                        Err(error)
 916                    }
 917                }
 918            }
 919        })
 920    }
 921
 922    #[allow(clippy::too_many_arguments)]
 923    pub fn handle_connection(
 924        self: &Arc<Self>,
 925        connection: Connection,
 926        address: String,
 927        principal: Principal,
 928        zed_version: ZedVersion,
 929        geoip_country_code: Option<String>,
 930        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 931        executor: Executor,
 932    ) -> impl Future<Output = ()> {
 933        let this = self.clone();
 934        let span = info_span!("handle connection", %address,
 935            connection_id=field::Empty,
 936            user_id=field::Empty,
 937            login=field::Empty,
 938            impersonator=field::Empty,
 939            dev_server_id=field::Empty,
 940            geoip_country_code=field::Empty
 941        );
 942        principal.update_span(&span);
 943        if let Some(country_code) = geoip_country_code.as_ref() {
 944            span.record("geoip_country_code", country_code);
 945        }
 946
 947        let mut teardown = self.teardown.subscribe();
 948        async move {
 949            if *teardown.borrow() {
 950                tracing::error!("server is tearing down");
 951                return
 952            }
 953            let (connection_id, handle_io, mut incoming_rx) = this
 954                .peer
 955                .add_connection(connection, {
 956                    let executor = executor.clone();
 957                    move |duration| executor.sleep(duration)
 958                });
 959            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 960
 961            tracing::info!("connection opened");
 962
 963            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 964            let http_client = match ReqwestClient::user_agent(&user_agent) {
 965                Ok(http_client) => Arc::new(http_client),
 966                Err(error) => {
 967                    tracing::error!(?error, "failed to create HTTP client");
 968                    return;
 969                }
 970            };
 971
 972            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 973                    supermaven_admin_api_key.to_string(),
 974                    http_client.clone(),
 975                )));
 976
 977            let session = Session {
 978                principal: principal.clone(),
 979                connection_id,
 980                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 981                peer: this.peer.clone(),
 982                connection_pool: this.connection_pool.clone(),
 983                app_state: this.app_state.clone(),
 984                http_client,
 985                geoip_country_code,
 986                _executor: executor.clone(),
 987                supermaven_client,
 988            };
 989
 990            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 991                tracing::error!(?error, "failed to send initial client update");
 992                return;
 993            }
 994
 995            let handle_io = handle_io.fuse();
 996            futures::pin_mut!(handle_io);
 997
 998            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 999            // This prevents deadlocks when e.g., client A performs a request to client B and
1000            // client B performs a request to client A. If both clients stop processing further
1001            // messages until their respective request completes, they won't have a chance to
1002            // respond to the other client's request and cause a deadlock.
1003            //
1004            // This arrangement ensures we will attempt to process earlier messages first, but fall
1005            // back to processing messages arrived later in the spirit of making progress.
1006            let mut foreground_message_handlers = FuturesUnordered::new();
1007            let concurrent_handlers = Arc::new(Semaphore::new(256));
1008            loop {
1009                let next_message = async {
1010                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1011                    let message = incoming_rx.next().await;
1012                    (permit, message)
1013                }.fuse();
1014                futures::pin_mut!(next_message);
1015                futures::select_biased! {
1016                    _ = teardown.changed().fuse() => return,
1017                    result = handle_io => {
1018                        if let Err(error) = result {
1019                            tracing::error!(?error, "error handling I/O");
1020                        }
1021                        break;
1022                    }
1023                    _ = foreground_message_handlers.next() => {}
1024                    next_message = next_message => {
1025                        let (permit, message) = next_message;
1026                        if let Some(message) = message {
1027                            let type_name = message.payload_type_name();
1028                            // note: we copy all the fields from the parent span so we can query them in the logs.
1029                            // (https://github.com/tokio-rs/tracing/issues/2670).
1030                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1031                                user_id=field::Empty,
1032                                login=field::Empty,
1033                                impersonator=field::Empty,
1034                                dev_server_id=field::Empty
1035                            );
1036                            principal.update_span(&span);
1037                            let span_enter = span.enter();
1038                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1039                                let is_background = message.is_background();
1040                                let handle_message = (handler)(message, session.clone());
1041                                drop(span_enter);
1042
1043                                let handle_message = async move {
1044                                    handle_message.await;
1045                                    drop(permit);
1046                                }.instrument(span);
1047                                if is_background {
1048                                    executor.spawn_detached(handle_message);
1049                                } else {
1050                                    foreground_message_handlers.push(handle_message);
1051                                }
1052                            } else {
1053                                tracing::error!("no message handler");
1054                            }
1055                        } else {
1056                            tracing::info!("connection closed");
1057                            break;
1058                        }
1059                    }
1060                }
1061            }
1062
1063            drop(foreground_message_handlers);
1064            tracing::info!("signing out");
1065            if let Err(error) = connection_lost(session, teardown, executor).await {
1066                tracing::error!(?error, "error signing out");
1067            }
1068
1069        }.instrument(span)
1070    }
1071
1072    async fn send_initial_client_update(
1073        &self,
1074        connection_id: ConnectionId,
1075        principal: &Principal,
1076        zed_version: ZedVersion,
1077        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1078        session: &Session,
1079    ) -> Result<()> {
1080        self.peer.send(
1081            connection_id,
1082            proto::Hello {
1083                peer_id: Some(connection_id.into()),
1084            },
1085        )?;
1086        tracing::info!("sent hello message");
1087        if let Some(send_connection_id) = send_connection_id.take() {
1088            let _ = send_connection_id.send(connection_id);
1089        }
1090
1091        match principal {
1092            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1093                if !user.connected_once {
1094                    self.peer.send(connection_id, proto::ShowContacts {})?;
1095                    self.app_state
1096                        .db
1097                        .set_user_connected_once(user.id, true)
1098                        .await?;
1099                }
1100
1101                update_user_plan(user.id, session).await?;
1102
1103                let (contacts, dev_server_projects) = future::try_join(
1104                    self.app_state.db.get_contacts(user.id),
1105                    self.app_state.db.dev_server_projects_update(user.id),
1106                )
1107                .await?;
1108
1109                {
1110                    let mut pool = self.connection_pool.lock();
1111                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1112                    self.peer.send(
1113                        connection_id,
1114                        build_initial_contacts_update(contacts, &pool),
1115                    )?;
1116                }
1117
1118                if should_auto_subscribe_to_channels(zed_version) {
1119                    subscribe_user_to_channels(user.id, session).await?;
1120                }
1121
1122                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1123
1124                if let Some(incoming_call) =
1125                    self.app_state.db.incoming_call_for_user(user.id).await?
1126                {
1127                    self.peer.send(connection_id, incoming_call)?;
1128                }
1129
1130                update_user_contacts(user.id, session).await?;
1131            }
1132            Principal::DevServer(dev_server) => {
1133                {
1134                    let mut pool = self.connection_pool.lock();
1135                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1136                    {
1137                        self.peer.send(
1138                            stale_connection_id,
1139                            proto::ShutdownDevServer {
1140                                reason: Some(
1141                                    "another dev server connected with the same token".to_string(),
1142                                ),
1143                            },
1144                        )?;
1145                        pool.remove_connection(stale_connection_id)?;
1146                    };
1147                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1148                }
1149
1150                let projects = self
1151                    .app_state
1152                    .db
1153                    .get_projects_for_dev_server(dev_server.id)
1154                    .await?;
1155                self.peer
1156                    .send(connection_id, proto::DevServerInstructions { projects })?;
1157
1158                let status = self
1159                    .app_state
1160                    .db
1161                    .dev_server_projects_update(dev_server.user_id)
1162                    .await?;
1163                send_dev_server_projects_update(dev_server.user_id, status, session).await;
1164            }
1165        }
1166
1167        Ok(())
1168    }
1169
1170    pub async fn invite_code_redeemed(
1171        self: &Arc<Self>,
1172        inviter_id: UserId,
1173        invitee_id: UserId,
1174    ) -> Result<()> {
1175        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1176            if let Some(code) = &user.invite_code {
1177                let pool = self.connection_pool.lock();
1178                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1179                for connection_id in pool.user_connection_ids(inviter_id) {
1180                    self.peer.send(
1181                        connection_id,
1182                        proto::UpdateContacts {
1183                            contacts: vec![invitee_contact.clone()],
1184                            ..Default::default()
1185                        },
1186                    )?;
1187                    self.peer.send(
1188                        connection_id,
1189                        proto::UpdateInviteInfo {
1190                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1191                            count: user.invite_count as u32,
1192                        },
1193                    )?;
1194                }
1195            }
1196        }
1197        Ok(())
1198    }
1199
1200    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1201        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1202            if let Some(invite_code) = &user.invite_code {
1203                let pool = self.connection_pool.lock();
1204                for connection_id in pool.user_connection_ids(user_id) {
1205                    self.peer.send(
1206                        connection_id,
1207                        proto::UpdateInviteInfo {
1208                            url: format!(
1209                                "{}{}",
1210                                self.app_state.config.invite_link_prefix, invite_code
1211                            ),
1212                            count: user.invite_count as u32,
1213                        },
1214                    )?;
1215                }
1216            }
1217        }
1218        Ok(())
1219    }
1220
1221    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1222        let pool = self.connection_pool.lock();
1223        for connection_id in pool.user_connection_ids(user_id) {
1224            self.peer
1225                .send(connection_id, proto::RefreshLlmToken {})
1226                .trace_err();
1227        }
1228    }
1229
1230    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1231        ServerSnapshot {
1232            connection_pool: ConnectionPoolGuard {
1233                guard: self.connection_pool.lock(),
1234                _not_send: PhantomData,
1235            },
1236            peer: &self.peer,
1237        }
1238    }
1239}
1240
1241impl<'a> Deref for ConnectionPoolGuard<'a> {
1242    type Target = ConnectionPool;
1243
1244    fn deref(&self) -> &Self::Target {
1245        &self.guard
1246    }
1247}
1248
1249impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1250    fn deref_mut(&mut self) -> &mut Self::Target {
1251        &mut self.guard
1252    }
1253}
1254
1255impl<'a> Drop for ConnectionPoolGuard<'a> {
1256    fn drop(&mut self) {
1257        #[cfg(test)]
1258        self.check_invariants();
1259    }
1260}
1261
1262fn broadcast<F>(
1263    sender_id: Option<ConnectionId>,
1264    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1265    mut f: F,
1266) where
1267    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1268{
1269    for receiver_id in receiver_ids {
1270        if Some(receiver_id) != sender_id {
1271            if let Err(error) = f(receiver_id) {
1272                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1273            }
1274        }
1275    }
1276}
1277
1278pub struct ProtocolVersion(u32);
1279
1280impl Header for ProtocolVersion {
1281    fn name() -> &'static HeaderName {
1282        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1283        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1284    }
1285
1286    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1287    where
1288        Self: Sized,
1289        I: Iterator<Item = &'i axum::http::HeaderValue>,
1290    {
1291        let version = values
1292            .next()
1293            .ok_or_else(axum::headers::Error::invalid)?
1294            .to_str()
1295            .map_err(|_| axum::headers::Error::invalid())?
1296            .parse()
1297            .map_err(|_| axum::headers::Error::invalid())?;
1298        Ok(Self(version))
1299    }
1300
1301    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1302        values.extend([self.0.to_string().parse().unwrap()]);
1303    }
1304}
1305
1306pub struct AppVersionHeader(SemanticVersion);
1307impl Header for AppVersionHeader {
1308    fn name() -> &'static HeaderName {
1309        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1310        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1311    }
1312
1313    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1314    where
1315        Self: Sized,
1316        I: Iterator<Item = &'i axum::http::HeaderValue>,
1317    {
1318        let version = values
1319            .next()
1320            .ok_or_else(axum::headers::Error::invalid)?
1321            .to_str()
1322            .map_err(|_| axum::headers::Error::invalid())?
1323            .parse()
1324            .map_err(|_| axum::headers::Error::invalid())?;
1325        Ok(Self(version))
1326    }
1327
1328    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1329        values.extend([self.0.to_string().parse().unwrap()]);
1330    }
1331}
1332
1333pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1334    Router::new()
1335        .route("/rpc", get(handle_websocket_request))
1336        .layer(
1337            ServiceBuilder::new()
1338                .layer(Extension(server.app_state.clone()))
1339                .layer(middleware::from_fn(auth::validate_header)),
1340        )
1341        .route("/metrics", get(handle_metrics))
1342        .layer(Extension(server))
1343}
1344
1345pub async fn handle_websocket_request(
1346    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1347    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1348    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1349    Extension(server): Extension<Arc<Server>>,
1350    Extension(principal): Extension<Principal>,
1351    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1352    ws: WebSocketUpgrade,
1353) -> axum::response::Response {
1354    if protocol_version != rpc::PROTOCOL_VERSION {
1355        return (
1356            StatusCode::UPGRADE_REQUIRED,
1357            "client must be upgraded".to_string(),
1358        )
1359            .into_response();
1360    }
1361
1362    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1363        return (
1364            StatusCode::UPGRADE_REQUIRED,
1365            "no version header found".to_string(),
1366        )
1367            .into_response();
1368    };
1369
1370    if !version.can_collaborate() {
1371        return (
1372            StatusCode::UPGRADE_REQUIRED,
1373            "client must be upgraded".to_string(),
1374        )
1375            .into_response();
1376    }
1377
1378    let socket_address = socket_address.to_string();
1379    ws.on_upgrade(move |socket| {
1380        let socket = socket
1381            .map_ok(to_tungstenite_message)
1382            .err_into()
1383            .with(|message| async move { to_axum_message(message) });
1384        let connection = Connection::new(Box::pin(socket));
1385        async move {
1386            server
1387                .handle_connection(
1388                    connection,
1389                    socket_address,
1390                    principal,
1391                    version,
1392                    country_code_header.map(|header| header.to_string()),
1393                    None,
1394                    Executor::Production,
1395                )
1396                .await;
1397        }
1398    })
1399}
1400
1401pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1402    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403    let connections_metric = CONNECTIONS_METRIC
1404        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1405
1406    let connections = server
1407        .connection_pool
1408        .lock()
1409        .connections()
1410        .filter(|connection| !connection.admin)
1411        .count();
1412    connections_metric.set(connections as _);
1413
1414    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1415    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1416        register_int_gauge!(
1417            "shared_projects",
1418            "number of open projects with one or more guests"
1419        )
1420        .unwrap()
1421    });
1422
1423    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1424    shared_projects_metric.set(shared_projects as _);
1425
1426    let encoder = prometheus::TextEncoder::new();
1427    let metric_families = prometheus::gather();
1428    let encoded_metrics = encoder
1429        .encode_to_string(&metric_families)
1430        .map_err(|err| anyhow!("{}", err))?;
1431    Ok(encoded_metrics)
1432}
1433
1434#[instrument(err, skip(executor))]
1435async fn connection_lost(
1436    session: Session,
1437    mut teardown: watch::Receiver<bool>,
1438    executor: Executor,
1439) -> Result<()> {
1440    session.peer.disconnect(session.connection_id);
1441    session
1442        .connection_pool()
1443        .await
1444        .remove_connection(session.connection_id)?;
1445
1446    session
1447        .db()
1448        .await
1449        .connection_lost(session.connection_id)
1450        .await
1451        .trace_err();
1452
1453    futures::select_biased! {
1454        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1455            match &session.principal {
1456                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1457                    let session = session.for_user().unwrap();
1458
1459                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1460                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1461                    leave_channel_buffers_for_session(&session)
1462                        .await
1463                        .trace_err();
1464
1465                    if !session
1466                        .connection_pool()
1467                        .await
1468                        .is_user_online(session.user_id())
1469                    {
1470                        let db = session.db().await;
1471                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1472                            room_updated(&room, &session.peer);
1473                        }
1474                    }
1475
1476                    update_user_contacts(session.user_id(), &session).await?;
1477                },
1478            Principal::DevServer(_) => {
1479                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1480            },
1481        }
1482        },
1483        _ = teardown.changed().fuse() => {}
1484    }
1485
1486    Ok(())
1487}
1488
1489/// Acknowledges a ping from a client, used to keep the connection alive.
1490async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1491    response.send(proto::Ack {})?;
1492    Ok(())
1493}
1494
1495/// Creates a new room for calling (outside of channels)
1496async fn create_room(
1497    _request: proto::CreateRoom,
1498    response: Response<proto::CreateRoom>,
1499    session: UserSession,
1500) -> Result<()> {
1501    let live_kit_room = nanoid::nanoid!(30);
1502
1503    let live_kit_connection_info = util::maybe!(async {
1504        let live_kit = session.app_state.live_kit_client.as_ref();
1505        let live_kit = live_kit?;
1506        let user_id = session.user_id().to_string();
1507
1508        let token = live_kit
1509            .room_token(&live_kit_room, &user_id.to_string())
1510            .trace_err()?;
1511
1512        Some(proto::LiveKitConnectionInfo {
1513            server_url: live_kit.url().into(),
1514            token,
1515            can_publish: true,
1516        })
1517    })
1518    .await;
1519
1520    let room = session
1521        .db()
1522        .await
1523        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1524        .await?;
1525
1526    response.send(proto::CreateRoomResponse {
1527        room: Some(room.clone()),
1528        live_kit_connection_info,
1529    })?;
1530
1531    update_user_contacts(session.user_id(), &session).await?;
1532    Ok(())
1533}
1534
1535/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1536async fn join_room(
1537    request: proto::JoinRoom,
1538    response: Response<proto::JoinRoom>,
1539    session: UserSession,
1540) -> Result<()> {
1541    let room_id = RoomId::from_proto(request.id);
1542
1543    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1544
1545    if let Some(channel_id) = channel_id {
1546        return join_channel_internal(channel_id, Box::new(response), session).await;
1547    }
1548
1549    let joined_room = {
1550        let room = session
1551            .db()
1552            .await
1553            .join_room(room_id, session.user_id(), session.connection_id)
1554            .await?;
1555        room_updated(&room.room, &session.peer);
1556        room.into_inner()
1557    };
1558
1559    for connection_id in session
1560        .connection_pool()
1561        .await
1562        .user_connection_ids(session.user_id())
1563    {
1564        session
1565            .peer
1566            .send(
1567                connection_id,
1568                proto::CallCanceled {
1569                    room_id: room_id.to_proto(),
1570                },
1571            )
1572            .trace_err();
1573    }
1574
1575    let live_kit_connection_info =
1576        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1577            live_kit
1578                .room_token(
1579                    &joined_room.room.live_kit_room,
1580                    &session.user_id().to_string(),
1581                )
1582                .trace_err()
1583                .map(|token| proto::LiveKitConnectionInfo {
1584                    server_url: live_kit.url().into(),
1585                    token,
1586                    can_publish: true,
1587                })
1588        } else {
1589            None
1590        };
1591
1592    response.send(proto::JoinRoomResponse {
1593        room: Some(joined_room.room),
1594        channel_id: None,
1595        live_kit_connection_info,
1596    })?;
1597
1598    update_user_contacts(session.user_id(), &session).await?;
1599    Ok(())
1600}
1601
1602/// Rejoin room is used to reconnect to a room after connection errors.
1603async fn rejoin_room(
1604    request: proto::RejoinRoom,
1605    response: Response<proto::RejoinRoom>,
1606    session: UserSession,
1607) -> Result<()> {
1608    let room;
1609    let channel;
1610    {
1611        let mut rejoined_room = session
1612            .db()
1613            .await
1614            .rejoin_room(request, session.user_id(), session.connection_id)
1615            .await?;
1616
1617        response.send(proto::RejoinRoomResponse {
1618            room: Some(rejoined_room.room.clone()),
1619            reshared_projects: rejoined_room
1620                .reshared_projects
1621                .iter()
1622                .map(|project| proto::ResharedProject {
1623                    id: project.id.to_proto(),
1624                    collaborators: project
1625                        .collaborators
1626                        .iter()
1627                        .map(|collaborator| collaborator.to_proto())
1628                        .collect(),
1629                })
1630                .collect(),
1631            rejoined_projects: rejoined_room
1632                .rejoined_projects
1633                .iter()
1634                .map(|rejoined_project| rejoined_project.to_proto())
1635                .collect(),
1636        })?;
1637        room_updated(&rejoined_room.room, &session.peer);
1638
1639        for project in &rejoined_room.reshared_projects {
1640            for collaborator in &project.collaborators {
1641                session
1642                    .peer
1643                    .send(
1644                        collaborator.connection_id,
1645                        proto::UpdateProjectCollaborator {
1646                            project_id: project.id.to_proto(),
1647                            old_peer_id: Some(project.old_connection_id.into()),
1648                            new_peer_id: Some(session.connection_id.into()),
1649                        },
1650                    )
1651                    .trace_err();
1652            }
1653
1654            broadcast(
1655                Some(session.connection_id),
1656                project
1657                    .collaborators
1658                    .iter()
1659                    .map(|collaborator| collaborator.connection_id),
1660                |connection_id| {
1661                    session.peer.forward_send(
1662                        session.connection_id,
1663                        connection_id,
1664                        proto::UpdateProject {
1665                            project_id: project.id.to_proto(),
1666                            worktrees: project.worktrees.clone(),
1667                        },
1668                    )
1669                },
1670            );
1671        }
1672
1673        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1674
1675        let rejoined_room = rejoined_room.into_inner();
1676
1677        room = rejoined_room.room;
1678        channel = rejoined_room.channel;
1679    }
1680
1681    if let Some(channel) = channel {
1682        channel_updated(
1683            &channel,
1684            &room,
1685            &session.peer,
1686            &*session.connection_pool().await,
1687        );
1688    }
1689
1690    update_user_contacts(session.user_id(), &session).await?;
1691    Ok(())
1692}
1693
1694fn notify_rejoined_projects(
1695    rejoined_projects: &mut Vec<RejoinedProject>,
1696    session: &UserSession,
1697) -> Result<()> {
1698    for project in rejoined_projects.iter() {
1699        for collaborator in &project.collaborators {
1700            session
1701                .peer
1702                .send(
1703                    collaborator.connection_id,
1704                    proto::UpdateProjectCollaborator {
1705                        project_id: project.id.to_proto(),
1706                        old_peer_id: Some(project.old_connection_id.into()),
1707                        new_peer_id: Some(session.connection_id.into()),
1708                    },
1709                )
1710                .trace_err();
1711        }
1712    }
1713
1714    for project in rejoined_projects {
1715        for worktree in mem::take(&mut project.worktrees) {
1716            // Stream this worktree's entries.
1717            let message = proto::UpdateWorktree {
1718                project_id: project.id.to_proto(),
1719                worktree_id: worktree.id,
1720                abs_path: worktree.abs_path.clone(),
1721                root_name: worktree.root_name,
1722                updated_entries: worktree.updated_entries,
1723                removed_entries: worktree.removed_entries,
1724                scan_id: worktree.scan_id,
1725                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1726                updated_repositories: worktree.updated_repositories,
1727                removed_repositories: worktree.removed_repositories,
1728            };
1729            for update in proto::split_worktree_update(message) {
1730                session.peer.send(session.connection_id, update.clone())?;
1731            }
1732
1733            // Stream this worktree's diagnostics.
1734            for summary in worktree.diagnostic_summaries {
1735                session.peer.send(
1736                    session.connection_id,
1737                    proto::UpdateDiagnosticSummary {
1738                        project_id: project.id.to_proto(),
1739                        worktree_id: worktree.id,
1740                        summary: Some(summary),
1741                    },
1742                )?;
1743            }
1744
1745            for settings_file in worktree.settings_files {
1746                session.peer.send(
1747                    session.connection_id,
1748                    proto::UpdateWorktreeSettings {
1749                        project_id: project.id.to_proto(),
1750                        worktree_id: worktree.id,
1751                        path: settings_file.path,
1752                        content: Some(settings_file.content),
1753                        kind: Some(settings_file.kind.to_proto().into()),
1754                    },
1755                )?;
1756            }
1757        }
1758
1759        for language_server in &project.language_servers {
1760            session.peer.send(
1761                session.connection_id,
1762                proto::UpdateLanguageServer {
1763                    project_id: project.id.to_proto(),
1764                    language_server_id: language_server.id,
1765                    variant: Some(
1766                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1767                            proto::LspDiskBasedDiagnosticsUpdated {},
1768                        ),
1769                    ),
1770                },
1771            )?;
1772        }
1773    }
1774    Ok(())
1775}
1776
1777/// leave room disconnects from the room.
1778async fn leave_room(
1779    _: proto::LeaveRoom,
1780    response: Response<proto::LeaveRoom>,
1781    session: UserSession,
1782) -> Result<()> {
1783    leave_room_for_session(&session, session.connection_id).await?;
1784    response.send(proto::Ack {})?;
1785    Ok(())
1786}
1787
1788/// Updates the permissions of someone else in the room.
1789async fn set_room_participant_role(
1790    request: proto::SetRoomParticipantRole,
1791    response: Response<proto::SetRoomParticipantRole>,
1792    session: UserSession,
1793) -> Result<()> {
1794    let user_id = UserId::from_proto(request.user_id);
1795    let role = ChannelRole::from(request.role());
1796
1797    let (live_kit_room, can_publish) = {
1798        let room = session
1799            .db()
1800            .await
1801            .set_room_participant_role(
1802                session.user_id(),
1803                RoomId::from_proto(request.room_id),
1804                user_id,
1805                role,
1806            )
1807            .await?;
1808
1809        let live_kit_room = room.live_kit_room.clone();
1810        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1811        room_updated(&room, &session.peer);
1812        (live_kit_room, can_publish)
1813    };
1814
1815    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1816        live_kit
1817            .update_participant(
1818                live_kit_room.clone(),
1819                request.user_id.to_string(),
1820                live_kit_server::proto::ParticipantPermission {
1821                    can_subscribe: true,
1822                    can_publish,
1823                    can_publish_data: can_publish,
1824                    hidden: false,
1825                    recorder: false,
1826                },
1827            )
1828            .await
1829            .trace_err();
1830    }
1831
1832    response.send(proto::Ack {})?;
1833    Ok(())
1834}
1835
1836/// Call someone else into the current room
1837async fn call(
1838    request: proto::Call,
1839    response: Response<proto::Call>,
1840    session: UserSession,
1841) -> Result<()> {
1842    let room_id = RoomId::from_proto(request.room_id);
1843    let calling_user_id = session.user_id();
1844    let calling_connection_id = session.connection_id;
1845    let called_user_id = UserId::from_proto(request.called_user_id);
1846    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1847    if !session
1848        .db()
1849        .await
1850        .has_contact(calling_user_id, called_user_id)
1851        .await?
1852    {
1853        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1854    }
1855
1856    let incoming_call = {
1857        let (room, incoming_call) = &mut *session
1858            .db()
1859            .await
1860            .call(
1861                room_id,
1862                calling_user_id,
1863                calling_connection_id,
1864                called_user_id,
1865                initial_project_id,
1866            )
1867            .await?;
1868        room_updated(room, &session.peer);
1869        mem::take(incoming_call)
1870    };
1871    update_user_contacts(called_user_id, &session).await?;
1872
1873    let mut calls = session
1874        .connection_pool()
1875        .await
1876        .user_connection_ids(called_user_id)
1877        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1878        .collect::<FuturesUnordered<_>>();
1879
1880    while let Some(call_response) = calls.next().await {
1881        match call_response.as_ref() {
1882            Ok(_) => {
1883                response.send(proto::Ack {})?;
1884                return Ok(());
1885            }
1886            Err(_) => {
1887                call_response.trace_err();
1888            }
1889        }
1890    }
1891
1892    {
1893        let room = session
1894            .db()
1895            .await
1896            .call_failed(room_id, called_user_id)
1897            .await?;
1898        room_updated(&room, &session.peer);
1899    }
1900    update_user_contacts(called_user_id, &session).await?;
1901
1902    Err(anyhow!("failed to ring user"))?
1903}
1904
1905/// Cancel an outgoing call.
1906async fn cancel_call(
1907    request: proto::CancelCall,
1908    response: Response<proto::CancelCall>,
1909    session: UserSession,
1910) -> Result<()> {
1911    let called_user_id = UserId::from_proto(request.called_user_id);
1912    let room_id = RoomId::from_proto(request.room_id);
1913    {
1914        let room = session
1915            .db()
1916            .await
1917            .cancel_call(room_id, session.connection_id, called_user_id)
1918            .await?;
1919        room_updated(&room, &session.peer);
1920    }
1921
1922    for connection_id in session
1923        .connection_pool()
1924        .await
1925        .user_connection_ids(called_user_id)
1926    {
1927        session
1928            .peer
1929            .send(
1930                connection_id,
1931                proto::CallCanceled {
1932                    room_id: room_id.to_proto(),
1933                },
1934            )
1935            .trace_err();
1936    }
1937    response.send(proto::Ack {})?;
1938
1939    update_user_contacts(called_user_id, &session).await?;
1940    Ok(())
1941}
1942
1943/// Decline an incoming call.
1944async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1945    let room_id = RoomId::from_proto(message.room_id);
1946    {
1947        let room = session
1948            .db()
1949            .await
1950            .decline_call(Some(room_id), session.user_id())
1951            .await?
1952            .ok_or_else(|| anyhow!("failed to decline call"))?;
1953        room_updated(&room, &session.peer);
1954    }
1955
1956    for connection_id in session
1957        .connection_pool()
1958        .await
1959        .user_connection_ids(session.user_id())
1960    {
1961        session
1962            .peer
1963            .send(
1964                connection_id,
1965                proto::CallCanceled {
1966                    room_id: room_id.to_proto(),
1967                },
1968            )
1969            .trace_err();
1970    }
1971    update_user_contacts(session.user_id(), &session).await?;
1972    Ok(())
1973}
1974
1975/// Updates other participants in the room with your current location.
1976async fn update_participant_location(
1977    request: proto::UpdateParticipantLocation,
1978    response: Response<proto::UpdateParticipantLocation>,
1979    session: UserSession,
1980) -> Result<()> {
1981    let room_id = RoomId::from_proto(request.room_id);
1982    let location = request
1983        .location
1984        .ok_or_else(|| anyhow!("invalid location"))?;
1985
1986    let db = session.db().await;
1987    let room = db
1988        .update_room_participant_location(room_id, session.connection_id, location)
1989        .await?;
1990
1991    room_updated(&room, &session.peer);
1992    response.send(proto::Ack {})?;
1993    Ok(())
1994}
1995
1996/// Share a project into the room.
1997async fn share_project(
1998    request: proto::ShareProject,
1999    response: Response<proto::ShareProject>,
2000    session: UserSession,
2001) -> Result<()> {
2002    let (project_id, room) = &*session
2003        .db()
2004        .await
2005        .share_project(
2006            RoomId::from_proto(request.room_id),
2007            session.connection_id,
2008            &request.worktrees,
2009            request.is_ssh_project,
2010            request
2011                .dev_server_project_id
2012                .map(DevServerProjectId::from_proto),
2013        )
2014        .await?;
2015    response.send(proto::ShareProjectResponse {
2016        project_id: project_id.to_proto(),
2017    })?;
2018    room_updated(room, &session.peer);
2019
2020    Ok(())
2021}
2022
2023/// Unshare a project from the room.
2024async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2025    let project_id = ProjectId::from_proto(message.project_id);
2026    unshare_project_internal(
2027        project_id,
2028        session.connection_id,
2029        session.user_id(),
2030        &session,
2031    )
2032    .await
2033}
2034
2035async fn unshare_project_internal(
2036    project_id: ProjectId,
2037    connection_id: ConnectionId,
2038    user_id: Option<UserId>,
2039    session: &Session,
2040) -> Result<()> {
2041    let delete = {
2042        let room_guard = session
2043            .db()
2044            .await
2045            .unshare_project(project_id, connection_id, user_id)
2046            .await?;
2047
2048        let (delete, room, guest_connection_ids) = &*room_guard;
2049
2050        let message = proto::UnshareProject {
2051            project_id: project_id.to_proto(),
2052        };
2053
2054        broadcast(
2055            Some(connection_id),
2056            guest_connection_ids.iter().copied(),
2057            |conn_id| session.peer.send(conn_id, message.clone()),
2058        );
2059        if let Some(room) = room {
2060            room_updated(room, &session.peer);
2061        }
2062
2063        *delete
2064    };
2065
2066    if delete {
2067        let db = session.db().await;
2068        db.delete_project(project_id).await?;
2069    }
2070
2071    Ok(())
2072}
2073
2074/// DevServer makes a project available online
2075async fn share_dev_server_project(
2076    request: proto::ShareDevServerProject,
2077    response: Response<proto::ShareDevServerProject>,
2078    session: DevServerSession,
2079) -> Result<()> {
2080    let (dev_server_project, user_id, status) = session
2081        .db()
2082        .await
2083        .share_dev_server_project(
2084            DevServerProjectId::from_proto(request.dev_server_project_id),
2085            session.dev_server_id(),
2086            session.connection_id,
2087            &request.worktrees,
2088        )
2089        .await?;
2090    let Some(project_id) = dev_server_project.project_id else {
2091        return Err(anyhow!("failed to share remote project"))?;
2092    };
2093
2094    send_dev_server_projects_update(user_id, status, &session).await;
2095
2096    response.send(proto::ShareProjectResponse { project_id })?;
2097
2098    Ok(())
2099}
2100
2101/// Join someone elses shared project.
2102async fn join_project(
2103    request: proto::JoinProject,
2104    response: Response<proto::JoinProject>,
2105    session: UserSession,
2106) -> Result<()> {
2107    let project_id = ProjectId::from_proto(request.project_id);
2108
2109    tracing::info!(%project_id, "join project");
2110
2111    let db = session.db().await;
2112    let (project, replica_id) = &mut *db
2113        .join_project(project_id, session.connection_id, session.user_id())
2114        .await?;
2115    drop(db);
2116    tracing::info!(%project_id, "join remote project");
2117    join_project_internal(response, session, project, replica_id)
2118}
2119
2120trait JoinProjectInternalResponse {
2121    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2122}
2123impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2124    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2125        Response::<proto::JoinProject>::send(self, result)
2126    }
2127}
2128impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2129    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2130        Response::<proto::JoinHostedProject>::send(self, result)
2131    }
2132}
2133
2134fn join_project_internal(
2135    response: impl JoinProjectInternalResponse,
2136    session: UserSession,
2137    project: &mut Project,
2138    replica_id: &ReplicaId,
2139) -> Result<()> {
2140    let collaborators = project
2141        .collaborators
2142        .iter()
2143        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2144        .map(|collaborator| collaborator.to_proto())
2145        .collect::<Vec<_>>();
2146    let project_id = project.id;
2147    let guest_user_id = session.user_id();
2148
2149    let worktrees = project
2150        .worktrees
2151        .iter()
2152        .map(|(id, worktree)| proto::WorktreeMetadata {
2153            id: *id,
2154            root_name: worktree.root_name.clone(),
2155            visible: worktree.visible,
2156            abs_path: worktree.abs_path.clone(),
2157        })
2158        .collect::<Vec<_>>();
2159
2160    let add_project_collaborator = proto::AddProjectCollaborator {
2161        project_id: project_id.to_proto(),
2162        collaborator: Some(proto::Collaborator {
2163            peer_id: Some(session.connection_id.into()),
2164            replica_id: replica_id.0 as u32,
2165            user_id: guest_user_id.to_proto(),
2166        }),
2167    };
2168
2169    for collaborator in &collaborators {
2170        session
2171            .peer
2172            .send(
2173                collaborator.peer_id.unwrap().into(),
2174                add_project_collaborator.clone(),
2175            )
2176            .trace_err();
2177    }
2178
2179    // First, we send the metadata associated with each worktree.
2180    response.send(proto::JoinProjectResponse {
2181        project_id: project.id.0 as u64,
2182        worktrees: worktrees.clone(),
2183        replica_id: replica_id.0 as u32,
2184        collaborators: collaborators.clone(),
2185        language_servers: project.language_servers.clone(),
2186        role: project.role.into(),
2187        dev_server_project_id: project
2188            .dev_server_project_id
2189            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2190    })?;
2191
2192    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2193        // Stream this worktree's entries.
2194        let message = proto::UpdateWorktree {
2195            project_id: project_id.to_proto(),
2196            worktree_id,
2197            abs_path: worktree.abs_path.clone(),
2198            root_name: worktree.root_name,
2199            updated_entries: worktree.entries,
2200            removed_entries: Default::default(),
2201            scan_id: worktree.scan_id,
2202            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2203            updated_repositories: worktree.repository_entries.into_values().collect(),
2204            removed_repositories: Default::default(),
2205        };
2206        for update in proto::split_worktree_update(message) {
2207            session.peer.send(session.connection_id, update.clone())?;
2208        }
2209
2210        // Stream this worktree's diagnostics.
2211        for summary in worktree.diagnostic_summaries {
2212            session.peer.send(
2213                session.connection_id,
2214                proto::UpdateDiagnosticSummary {
2215                    project_id: project_id.to_proto(),
2216                    worktree_id: worktree.id,
2217                    summary: Some(summary),
2218                },
2219            )?;
2220        }
2221
2222        for settings_file in worktree.settings_files {
2223            session.peer.send(
2224                session.connection_id,
2225                proto::UpdateWorktreeSettings {
2226                    project_id: project_id.to_proto(),
2227                    worktree_id: worktree.id,
2228                    path: settings_file.path,
2229                    content: Some(settings_file.content),
2230                    kind: Some(settings_file.kind.to_proto() as i32),
2231                },
2232            )?;
2233        }
2234    }
2235
2236    for language_server in &project.language_servers {
2237        session.peer.send(
2238            session.connection_id,
2239            proto::UpdateLanguageServer {
2240                project_id: project_id.to_proto(),
2241                language_server_id: language_server.id,
2242                variant: Some(
2243                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2244                        proto::LspDiskBasedDiagnosticsUpdated {},
2245                    ),
2246                ),
2247            },
2248        )?;
2249    }
2250
2251    Ok(())
2252}
2253
2254/// Leave someone elses shared project.
2255async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2256    let sender_id = session.connection_id;
2257    let project_id = ProjectId::from_proto(request.project_id);
2258    let db = session.db().await;
2259    if db.is_hosted_project(project_id).await? {
2260        let project = db.leave_hosted_project(project_id, sender_id).await?;
2261        project_left(&project, &session);
2262        return Ok(());
2263    }
2264
2265    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2266    tracing::info!(
2267        %project_id,
2268        "leave project"
2269    );
2270
2271    project_left(project, &session);
2272    if let Some(room) = room {
2273        room_updated(room, &session.peer);
2274    }
2275
2276    Ok(())
2277}
2278
2279async fn join_hosted_project(
2280    request: proto::JoinHostedProject,
2281    response: Response<proto::JoinHostedProject>,
2282    session: UserSession,
2283) -> Result<()> {
2284    let (mut project, replica_id) = session
2285        .db()
2286        .await
2287        .join_hosted_project(
2288            ProjectId(request.project_id as i32),
2289            session.user_id(),
2290            session.connection_id,
2291        )
2292        .await?;
2293
2294    join_project_internal(response, session, &mut project, &replica_id)
2295}
2296
2297async fn list_remote_directory(
2298    request: proto::ListRemoteDirectory,
2299    response: Response<proto::ListRemoteDirectory>,
2300    session: UserSession,
2301) -> Result<()> {
2302    let dev_server_id = DevServerId(request.dev_server_id as i32);
2303    let dev_server_connection_id = session
2304        .connection_pool()
2305        .await
2306        .online_dev_server_connection_id(dev_server_id)?;
2307
2308    session
2309        .db()
2310        .await
2311        .get_dev_server_for_user(dev_server_id, session.user_id())
2312        .await?;
2313
2314    response.send(
2315        session
2316            .peer
2317            .forward_request(session.connection_id, dev_server_connection_id, request)
2318            .await?,
2319    )?;
2320    Ok(())
2321}
2322
2323async fn update_dev_server_project(
2324    request: proto::UpdateDevServerProject,
2325    response: Response<proto::UpdateDevServerProject>,
2326    session: UserSession,
2327) -> Result<()> {
2328    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2329
2330    let (dev_server_project, update) = session
2331        .db()
2332        .await
2333        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2334        .await?;
2335
2336    let projects = session
2337        .db()
2338        .await
2339        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2340        .await?;
2341
2342    let dev_server_connection_id = session
2343        .connection_pool()
2344        .await
2345        .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2346
2347    session.peer.send(
2348        dev_server_connection_id,
2349        proto::DevServerInstructions { projects },
2350    )?;
2351
2352    send_dev_server_projects_update(session.user_id(), update, &session).await;
2353
2354    response.send(proto::Ack {})
2355}
2356
2357async fn create_dev_server_project(
2358    request: proto::CreateDevServerProject,
2359    response: Response<proto::CreateDevServerProject>,
2360    session: UserSession,
2361) -> Result<()> {
2362    let dev_server_id = DevServerId(request.dev_server_id as i32);
2363    let dev_server_connection_id = session
2364        .connection_pool()
2365        .await
2366        .dev_server_connection_id(dev_server_id);
2367    let Some(dev_server_connection_id) = dev_server_connection_id else {
2368        Err(ErrorCode::DevServerOffline
2369            .message("Cannot create a remote project when the dev server is offline".to_string())
2370            .anyhow())?
2371    };
2372
2373    let path = request.path.clone();
2374    //Check that the path exists on the dev server
2375    session
2376        .peer
2377        .forward_request(
2378            session.connection_id,
2379            dev_server_connection_id,
2380            proto::ValidateDevServerProjectRequest { path: path.clone() },
2381        )
2382        .await?;
2383
2384    let (dev_server_project, update) = session
2385        .db()
2386        .await
2387        .create_dev_server_project(
2388            DevServerId(request.dev_server_id as i32),
2389            &request.path,
2390            session.user_id(),
2391        )
2392        .await?;
2393
2394    let projects = session
2395        .db()
2396        .await
2397        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2398        .await?;
2399
2400    session.peer.send(
2401        dev_server_connection_id,
2402        proto::DevServerInstructions { projects },
2403    )?;
2404
2405    send_dev_server_projects_update(session.user_id(), update, &session).await;
2406
2407    response.send(proto::CreateDevServerProjectResponse {
2408        dev_server_project: Some(dev_server_project.to_proto(None)),
2409    })?;
2410    Ok(())
2411}
2412
2413async fn create_dev_server(
2414    request: proto::CreateDevServer,
2415    response: Response<proto::CreateDevServer>,
2416    session: UserSession,
2417) -> Result<()> {
2418    let access_token = auth::random_token();
2419    let hashed_access_token = auth::hash_access_token(&access_token);
2420
2421    if request.name.is_empty() {
2422        return Err(proto::ErrorCode::Forbidden
2423            .message("Dev server name cannot be empty".to_string())
2424            .anyhow())?;
2425    }
2426
2427    let (dev_server, status) = session
2428        .db()
2429        .await
2430        .create_dev_server(
2431            &request.name,
2432            request.ssh_connection_string.as_deref(),
2433            &hashed_access_token,
2434            session.user_id(),
2435        )
2436        .await?;
2437
2438    send_dev_server_projects_update(session.user_id(), status, &session).await;
2439
2440    response.send(proto::CreateDevServerResponse {
2441        dev_server_id: dev_server.id.0 as u64,
2442        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2443        name: request.name,
2444    })?;
2445    Ok(())
2446}
2447
2448async fn regenerate_dev_server_token(
2449    request: proto::RegenerateDevServerToken,
2450    response: Response<proto::RegenerateDevServerToken>,
2451    session: UserSession,
2452) -> Result<()> {
2453    let dev_server_id = DevServerId(request.dev_server_id as i32);
2454    let access_token = auth::random_token();
2455    let hashed_access_token = auth::hash_access_token(&access_token);
2456
2457    let connection_id = session
2458        .connection_pool()
2459        .await
2460        .dev_server_connection_id(dev_server_id);
2461    if let Some(connection_id) = connection_id {
2462        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2463        session.peer.send(
2464            connection_id,
2465            proto::ShutdownDevServer {
2466                reason: Some("dev server token was regenerated".to_string()),
2467            },
2468        )?;
2469        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2470    }
2471
2472    let status = session
2473        .db()
2474        .await
2475        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2476        .await?;
2477
2478    send_dev_server_projects_update(session.user_id(), status, &session).await;
2479
2480    response.send(proto::RegenerateDevServerTokenResponse {
2481        dev_server_id: dev_server_id.to_proto(),
2482        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2483    })?;
2484    Ok(())
2485}
2486
2487async fn rename_dev_server(
2488    request: proto::RenameDevServer,
2489    response: Response<proto::RenameDevServer>,
2490    session: UserSession,
2491) -> Result<()> {
2492    if request.name.trim().is_empty() {
2493        return Err(proto::ErrorCode::Forbidden
2494            .message("Dev server name cannot be empty".to_string())
2495            .anyhow())?;
2496    }
2497
2498    let dev_server_id = DevServerId(request.dev_server_id as i32);
2499    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2500    if dev_server.user_id != session.user_id() {
2501        return Err(anyhow!(ErrorCode::Forbidden))?;
2502    }
2503
2504    let status = session
2505        .db()
2506        .await
2507        .rename_dev_server(
2508            dev_server_id,
2509            &request.name,
2510            request.ssh_connection_string.as_deref(),
2511            session.user_id(),
2512        )
2513        .await?;
2514
2515    send_dev_server_projects_update(session.user_id(), status, &session).await;
2516
2517    response.send(proto::Ack {})?;
2518    Ok(())
2519}
2520
2521async fn delete_dev_server(
2522    request: proto::DeleteDevServer,
2523    response: Response<proto::DeleteDevServer>,
2524    session: UserSession,
2525) -> Result<()> {
2526    let dev_server_id = DevServerId(request.dev_server_id as i32);
2527    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2528    if dev_server.user_id != session.user_id() {
2529        return Err(anyhow!(ErrorCode::Forbidden))?;
2530    }
2531
2532    let connection_id = session
2533        .connection_pool()
2534        .await
2535        .dev_server_connection_id(dev_server_id);
2536    if let Some(connection_id) = connection_id {
2537        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2538        session.peer.send(
2539            connection_id,
2540            proto::ShutdownDevServer {
2541                reason: Some("dev server was deleted".to_string()),
2542            },
2543        )?;
2544        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2545    }
2546
2547    let status = session
2548        .db()
2549        .await
2550        .delete_dev_server(dev_server_id, session.user_id())
2551        .await?;
2552
2553    send_dev_server_projects_update(session.user_id(), status, &session).await;
2554
2555    response.send(proto::Ack {})?;
2556    Ok(())
2557}
2558
2559async fn delete_dev_server_project(
2560    request: proto::DeleteDevServerProject,
2561    response: Response<proto::DeleteDevServerProject>,
2562    session: UserSession,
2563) -> Result<()> {
2564    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2565    let dev_server_project = session
2566        .db()
2567        .await
2568        .get_dev_server_project(dev_server_project_id)
2569        .await?;
2570
2571    let dev_server = session
2572        .db()
2573        .await
2574        .get_dev_server(dev_server_project.dev_server_id)
2575        .await?;
2576    if dev_server.user_id != session.user_id() {
2577        return Err(anyhow!(ErrorCode::Forbidden))?;
2578    }
2579
2580    let dev_server_connection_id = session
2581        .connection_pool()
2582        .await
2583        .dev_server_connection_id(dev_server.id);
2584
2585    if let Some(dev_server_connection_id) = dev_server_connection_id {
2586        let project = session
2587            .db()
2588            .await
2589            .find_dev_server_project(dev_server_project_id)
2590            .await;
2591        if let Ok(project) = project {
2592            unshare_project_internal(
2593                project.id,
2594                dev_server_connection_id,
2595                Some(session.user_id()),
2596                &session,
2597            )
2598            .await?;
2599        }
2600    }
2601
2602    let (projects, status) = session
2603        .db()
2604        .await
2605        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2606        .await?;
2607
2608    if let Some(dev_server_connection_id) = dev_server_connection_id {
2609        session.peer.send(
2610            dev_server_connection_id,
2611            proto::DevServerInstructions { projects },
2612        )?;
2613    }
2614
2615    send_dev_server_projects_update(session.user_id(), status, &session).await;
2616
2617    response.send(proto::Ack {})?;
2618    Ok(())
2619}
2620
2621async fn rejoin_dev_server_projects(
2622    request: proto::RejoinRemoteProjects,
2623    response: Response<proto::RejoinRemoteProjects>,
2624    session: UserSession,
2625) -> Result<()> {
2626    let mut rejoined_projects = {
2627        let db = session.db().await;
2628        db.rejoin_dev_server_projects(
2629            &request.rejoined_projects,
2630            session.user_id(),
2631            session.0.connection_id,
2632        )
2633        .await?
2634    };
2635    response.send(proto::RejoinRemoteProjectsResponse {
2636        rejoined_projects: rejoined_projects
2637            .iter()
2638            .map(|project| project.to_proto())
2639            .collect(),
2640    })?;
2641    notify_rejoined_projects(&mut rejoined_projects, &session)
2642}
2643
2644async fn reconnect_dev_server(
2645    request: proto::ReconnectDevServer,
2646    response: Response<proto::ReconnectDevServer>,
2647    session: DevServerSession,
2648) -> Result<()> {
2649    let reshared_projects = {
2650        let db = session.db().await;
2651        db.reshare_dev_server_projects(
2652            &request.reshared_projects,
2653            session.dev_server_id(),
2654            session.0.connection_id,
2655        )
2656        .await?
2657    };
2658
2659    for project in &reshared_projects {
2660        for collaborator in &project.collaborators {
2661            session
2662                .peer
2663                .send(
2664                    collaborator.connection_id,
2665                    proto::UpdateProjectCollaborator {
2666                        project_id: project.id.to_proto(),
2667                        old_peer_id: Some(project.old_connection_id.into()),
2668                        new_peer_id: Some(session.connection_id.into()),
2669                    },
2670                )
2671                .trace_err();
2672        }
2673
2674        broadcast(
2675            Some(session.connection_id),
2676            project
2677                .collaborators
2678                .iter()
2679                .map(|collaborator| collaborator.connection_id),
2680            |connection_id| {
2681                session.peer.forward_send(
2682                    session.connection_id,
2683                    connection_id,
2684                    proto::UpdateProject {
2685                        project_id: project.id.to_proto(),
2686                        worktrees: project.worktrees.clone(),
2687                    },
2688                )
2689            },
2690        );
2691    }
2692
2693    response.send(proto::ReconnectDevServerResponse {
2694        reshared_projects: reshared_projects
2695            .iter()
2696            .map(|project| proto::ResharedProject {
2697                id: project.id.to_proto(),
2698                collaborators: project
2699                    .collaborators
2700                    .iter()
2701                    .map(|collaborator| collaborator.to_proto())
2702                    .collect(),
2703            })
2704            .collect(),
2705    })?;
2706
2707    Ok(())
2708}
2709
2710async fn shutdown_dev_server(
2711    _: proto::ShutdownDevServer,
2712    response: Response<proto::ShutdownDevServer>,
2713    session: DevServerSession,
2714) -> Result<()> {
2715    response.send(proto::Ack {})?;
2716    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2717    remove_dev_server_connection(session.dev_server_id(), &session).await
2718}
2719
2720async fn shutdown_dev_server_internal(
2721    dev_server_id: DevServerId,
2722    connection_id: ConnectionId,
2723    session: &Session,
2724) -> Result<()> {
2725    let (dev_server_projects, dev_server) = {
2726        let db = session.db().await;
2727        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2728        let dev_server = db.get_dev_server(dev_server_id).await?;
2729        (dev_server_projects, dev_server)
2730    };
2731
2732    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2733        unshare_project_internal(
2734            ProjectId::from_proto(project_id),
2735            connection_id,
2736            None,
2737            session,
2738        )
2739        .await?;
2740    }
2741
2742    session
2743        .connection_pool()
2744        .await
2745        .set_dev_server_offline(dev_server_id);
2746
2747    let status = session
2748        .db()
2749        .await
2750        .dev_server_projects_update(dev_server.user_id)
2751        .await?;
2752    send_dev_server_projects_update(dev_server.user_id, status, session).await;
2753
2754    Ok(())
2755}
2756
2757async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2758    let dev_server_connection = session
2759        .connection_pool()
2760        .await
2761        .dev_server_connection_id(dev_server_id);
2762
2763    if let Some(dev_server_connection) = dev_server_connection {
2764        session
2765            .connection_pool()
2766            .await
2767            .remove_connection(dev_server_connection)?;
2768    }
2769    Ok(())
2770}
2771
2772/// Updates other participants with changes to the project
2773async fn update_project(
2774    request: proto::UpdateProject,
2775    response: Response<proto::UpdateProject>,
2776    session: Session,
2777) -> Result<()> {
2778    let project_id = ProjectId::from_proto(request.project_id);
2779    let (room, guest_connection_ids) = &*session
2780        .db()
2781        .await
2782        .update_project(project_id, session.connection_id, &request.worktrees)
2783        .await?;
2784    broadcast(
2785        Some(session.connection_id),
2786        guest_connection_ids.iter().copied(),
2787        |connection_id| {
2788            session
2789                .peer
2790                .forward_send(session.connection_id, connection_id, request.clone())
2791        },
2792    );
2793    if let Some(room) = room {
2794        room_updated(room, &session.peer);
2795    }
2796    response.send(proto::Ack {})?;
2797
2798    Ok(())
2799}
2800
2801/// Updates other participants with changes to the worktree
2802async fn update_worktree(
2803    request: proto::UpdateWorktree,
2804    response: Response<proto::UpdateWorktree>,
2805    session: Session,
2806) -> Result<()> {
2807    let guest_connection_ids = session
2808        .db()
2809        .await
2810        .update_worktree(&request, session.connection_id)
2811        .await?;
2812
2813    broadcast(
2814        Some(session.connection_id),
2815        guest_connection_ids.iter().copied(),
2816        |connection_id| {
2817            session
2818                .peer
2819                .forward_send(session.connection_id, connection_id, request.clone())
2820        },
2821    );
2822    response.send(proto::Ack {})?;
2823    Ok(())
2824}
2825
2826/// Updates other participants with changes to the diagnostics
2827async fn update_diagnostic_summary(
2828    message: proto::UpdateDiagnosticSummary,
2829    session: Session,
2830) -> Result<()> {
2831    let guest_connection_ids = session
2832        .db()
2833        .await
2834        .update_diagnostic_summary(&message, session.connection_id)
2835        .await?;
2836
2837    broadcast(
2838        Some(session.connection_id),
2839        guest_connection_ids.iter().copied(),
2840        |connection_id| {
2841            session
2842                .peer
2843                .forward_send(session.connection_id, connection_id, message.clone())
2844        },
2845    );
2846
2847    Ok(())
2848}
2849
2850/// Updates other participants with changes to the worktree settings
2851async fn update_worktree_settings(
2852    message: proto::UpdateWorktreeSettings,
2853    session: Session,
2854) -> Result<()> {
2855    let guest_connection_ids = session
2856        .db()
2857        .await
2858        .update_worktree_settings(&message, session.connection_id)
2859        .await?;
2860
2861    broadcast(
2862        Some(session.connection_id),
2863        guest_connection_ids.iter().copied(),
2864        |connection_id| {
2865            session
2866                .peer
2867                .forward_send(session.connection_id, connection_id, message.clone())
2868        },
2869    );
2870
2871    Ok(())
2872}
2873
2874/// Notify other participants that a  language server has started.
2875async fn start_language_server(
2876    request: proto::StartLanguageServer,
2877    session: Session,
2878) -> Result<()> {
2879    let guest_connection_ids = session
2880        .db()
2881        .await
2882        .start_language_server(&request, session.connection_id)
2883        .await?;
2884
2885    broadcast(
2886        Some(session.connection_id),
2887        guest_connection_ids.iter().copied(),
2888        |connection_id| {
2889            session
2890                .peer
2891                .forward_send(session.connection_id, connection_id, request.clone())
2892        },
2893    );
2894    Ok(())
2895}
2896
2897/// Notify other participants that a language server has changed.
2898async fn update_language_server(
2899    request: proto::UpdateLanguageServer,
2900    session: Session,
2901) -> Result<()> {
2902    let project_id = ProjectId::from_proto(request.project_id);
2903    let project_connection_ids = session
2904        .db()
2905        .await
2906        .project_connection_ids(project_id, session.connection_id, true)
2907        .await?;
2908    broadcast(
2909        Some(session.connection_id),
2910        project_connection_ids.iter().copied(),
2911        |connection_id| {
2912            session
2913                .peer
2914                .forward_send(session.connection_id, connection_id, request.clone())
2915        },
2916    );
2917    Ok(())
2918}
2919
2920/// forward a project request to the host. These requests should be read only
2921/// as guests are allowed to send them.
2922async fn forward_read_only_project_request<T>(
2923    request: T,
2924    response: Response<T>,
2925    session: UserSession,
2926) -> Result<()>
2927where
2928    T: EntityMessage + RequestMessage,
2929{
2930    let project_id = ProjectId::from_proto(request.remote_entity_id());
2931    let host_connection_id = session
2932        .db()
2933        .await
2934        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2935        .await?;
2936    let payload = session
2937        .peer
2938        .forward_request(session.connection_id, host_connection_id, request)
2939        .await?;
2940    response.send(payload)?;
2941    Ok(())
2942}
2943
2944async fn forward_find_search_candidates_request(
2945    request: proto::FindSearchCandidates,
2946    response: Response<proto::FindSearchCandidates>,
2947    session: UserSession,
2948) -> Result<()> {
2949    let project_id = ProjectId::from_proto(request.remote_entity_id());
2950    let host_connection_id = session
2951        .db()
2952        .await
2953        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2954        .await?;
2955    let payload = session
2956        .peer
2957        .forward_request(session.connection_id, host_connection_id, request)
2958        .await?;
2959    response.send(payload)?;
2960    Ok(())
2961}
2962
2963/// forward a project request to the dev server. Only allowed
2964/// if it's your dev server.
2965async fn forward_project_request_for_owner<T>(
2966    request: T,
2967    response: Response<T>,
2968    session: UserSession,
2969) -> Result<()>
2970where
2971    T: EntityMessage + RequestMessage,
2972{
2973    let project_id = ProjectId::from_proto(request.remote_entity_id());
2974
2975    let host_connection_id = session
2976        .db()
2977        .await
2978        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2979        .await?;
2980    let payload = session
2981        .peer
2982        .forward_request(session.connection_id, host_connection_id, request)
2983        .await?;
2984    response.send(payload)?;
2985    Ok(())
2986}
2987
2988/// forward a project request to the host. These requests are disallowed
2989/// for guests.
2990async fn forward_mutating_project_request<T>(
2991    request: T,
2992    response: Response<T>,
2993    session: UserSession,
2994) -> Result<()>
2995where
2996    T: EntityMessage + RequestMessage,
2997{
2998    let project_id = ProjectId::from_proto(request.remote_entity_id());
2999
3000    let host_connection_id = session
3001        .db()
3002        .await
3003        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3004        .await?;
3005    let payload = session
3006        .peer
3007        .forward_request(session.connection_id, host_connection_id, request)
3008        .await?;
3009    response.send(payload)?;
3010    Ok(())
3011}
3012
3013/// Notify other participants that a new buffer has been created
3014async fn create_buffer_for_peer(
3015    request: proto::CreateBufferForPeer,
3016    session: Session,
3017) -> Result<()> {
3018    session
3019        .db()
3020        .await
3021        .check_user_is_project_host(
3022            ProjectId::from_proto(request.project_id),
3023            session.connection_id,
3024        )
3025        .await?;
3026    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3027    session
3028        .peer
3029        .forward_send(session.connection_id, peer_id.into(), request)?;
3030    Ok(())
3031}
3032
3033/// Notify other participants that a buffer has been updated. This is
3034/// allowed for guests as long as the update is limited to selections.
3035async fn update_buffer(
3036    request: proto::UpdateBuffer,
3037    response: Response<proto::UpdateBuffer>,
3038    session: Session,
3039) -> Result<()> {
3040    let project_id = ProjectId::from_proto(request.project_id);
3041    let mut capability = Capability::ReadOnly;
3042
3043    for op in request.operations.iter() {
3044        match op.variant {
3045            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3046            Some(_) => capability = Capability::ReadWrite,
3047        }
3048    }
3049
3050    let host = {
3051        let guard = session
3052            .db()
3053            .await
3054            .connections_for_buffer_update(
3055                project_id,
3056                session.principal_id(),
3057                session.connection_id,
3058                capability,
3059            )
3060            .await?;
3061
3062        let (host, guests) = &*guard;
3063
3064        broadcast(
3065            Some(session.connection_id),
3066            guests.clone(),
3067            |connection_id| {
3068                session
3069                    .peer
3070                    .forward_send(session.connection_id, connection_id, request.clone())
3071            },
3072        );
3073
3074        *host
3075    };
3076
3077    if host != session.connection_id {
3078        session
3079            .peer
3080            .forward_request(session.connection_id, host, request.clone())
3081            .await?;
3082    }
3083
3084    response.send(proto::Ack {})?;
3085    Ok(())
3086}
3087
3088async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3089    let project_id = ProjectId::from_proto(message.project_id);
3090
3091    let operation = message.operation.as_ref().context("invalid operation")?;
3092    let capability = match operation.variant.as_ref() {
3093        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3094            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3095                match buffer_op.variant {
3096                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3097                        Capability::ReadOnly
3098                    }
3099                    _ => Capability::ReadWrite,
3100                }
3101            } else {
3102                Capability::ReadWrite
3103            }
3104        }
3105        Some(_) => Capability::ReadWrite,
3106        None => Capability::ReadOnly,
3107    };
3108
3109    let guard = session
3110        .db()
3111        .await
3112        .connections_for_buffer_update(
3113            project_id,
3114            session.principal_id(),
3115            session.connection_id,
3116            capability,
3117        )
3118        .await?;
3119
3120    let (host, guests) = &*guard;
3121
3122    broadcast(
3123        Some(session.connection_id),
3124        guests.iter().chain([host]).copied(),
3125        |connection_id| {
3126            session
3127                .peer
3128                .forward_send(session.connection_id, connection_id, message.clone())
3129        },
3130    );
3131
3132    Ok(())
3133}
3134
3135/// Notify other participants that a project has been updated.
3136async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3137    request: T,
3138    session: Session,
3139) -> Result<()> {
3140    let project_id = ProjectId::from_proto(request.remote_entity_id());
3141    let project_connection_ids = session
3142        .db()
3143        .await
3144        .project_connection_ids(project_id, session.connection_id, false)
3145        .await?;
3146
3147    broadcast(
3148        Some(session.connection_id),
3149        project_connection_ids.iter().copied(),
3150        |connection_id| {
3151            session
3152                .peer
3153                .forward_send(session.connection_id, connection_id, request.clone())
3154        },
3155    );
3156    Ok(())
3157}
3158
3159/// Start following another user in a call.
3160async fn follow(
3161    request: proto::Follow,
3162    response: Response<proto::Follow>,
3163    session: UserSession,
3164) -> Result<()> {
3165    let room_id = RoomId::from_proto(request.room_id);
3166    let project_id = request.project_id.map(ProjectId::from_proto);
3167    let leader_id = request
3168        .leader_id
3169        .ok_or_else(|| anyhow!("invalid leader id"))?
3170        .into();
3171    let follower_id = session.connection_id;
3172
3173    session
3174        .db()
3175        .await
3176        .check_room_participants(room_id, leader_id, session.connection_id)
3177        .await?;
3178
3179    let response_payload = session
3180        .peer
3181        .forward_request(session.connection_id, leader_id, request)
3182        .await?;
3183    response.send(response_payload)?;
3184
3185    if let Some(project_id) = project_id {
3186        let room = session
3187            .db()
3188            .await
3189            .follow(room_id, project_id, leader_id, follower_id)
3190            .await?;
3191        room_updated(&room, &session.peer);
3192    }
3193
3194    Ok(())
3195}
3196
3197/// Stop following another user in a call.
3198async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3199    let room_id = RoomId::from_proto(request.room_id);
3200    let project_id = request.project_id.map(ProjectId::from_proto);
3201    let leader_id = request
3202        .leader_id
3203        .ok_or_else(|| anyhow!("invalid leader id"))?
3204        .into();
3205    let follower_id = session.connection_id;
3206
3207    session
3208        .db()
3209        .await
3210        .check_room_participants(room_id, leader_id, session.connection_id)
3211        .await?;
3212
3213    session
3214        .peer
3215        .forward_send(session.connection_id, leader_id, request)?;
3216
3217    if let Some(project_id) = project_id {
3218        let room = session
3219            .db()
3220            .await
3221            .unfollow(room_id, project_id, leader_id, follower_id)
3222            .await?;
3223        room_updated(&room, &session.peer);
3224    }
3225
3226    Ok(())
3227}
3228
3229/// Notify everyone following you of your current location.
3230async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3231    let room_id = RoomId::from_proto(request.room_id);
3232    let database = session.db.lock().await;
3233
3234    let connection_ids = if let Some(project_id) = request.project_id {
3235        let project_id = ProjectId::from_proto(project_id);
3236        database
3237            .project_connection_ids(project_id, session.connection_id, true)
3238            .await?
3239    } else {
3240        database
3241            .room_connection_ids(room_id, session.connection_id)
3242            .await?
3243    };
3244
3245    // For now, don't send view update messages back to that view's current leader.
3246    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3247        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3248        _ => None,
3249    });
3250
3251    for connection_id in connection_ids.iter().cloned() {
3252        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3253            session
3254                .peer
3255                .forward_send(session.connection_id, connection_id, request.clone())?;
3256        }
3257    }
3258    Ok(())
3259}
3260
3261/// Get public data about users.
3262async fn get_users(
3263    request: proto::GetUsers,
3264    response: Response<proto::GetUsers>,
3265    session: Session,
3266) -> Result<()> {
3267    let user_ids = request
3268        .user_ids
3269        .into_iter()
3270        .map(UserId::from_proto)
3271        .collect();
3272    let users = session
3273        .db()
3274        .await
3275        .get_users_by_ids(user_ids)
3276        .await?
3277        .into_iter()
3278        .map(|user| proto::User {
3279            id: user.id.to_proto(),
3280            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3281            github_login: user.github_login,
3282        })
3283        .collect();
3284    response.send(proto::UsersResponse { users })?;
3285    Ok(())
3286}
3287
3288/// Search for users (to invite) buy Github login
3289async fn fuzzy_search_users(
3290    request: proto::FuzzySearchUsers,
3291    response: Response<proto::FuzzySearchUsers>,
3292    session: UserSession,
3293) -> Result<()> {
3294    let query = request.query;
3295    let users = match query.len() {
3296        0 => vec![],
3297        1 | 2 => session
3298            .db()
3299            .await
3300            .get_user_by_github_login(&query)
3301            .await?
3302            .into_iter()
3303            .collect(),
3304        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3305    };
3306    let users = users
3307        .into_iter()
3308        .filter(|user| user.id != session.user_id())
3309        .map(|user| proto::User {
3310            id: user.id.to_proto(),
3311            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3312            github_login: user.github_login,
3313        })
3314        .collect();
3315    response.send(proto::UsersResponse { users })?;
3316    Ok(())
3317}
3318
3319/// Send a contact request to another user.
3320async fn request_contact(
3321    request: proto::RequestContact,
3322    response: Response<proto::RequestContact>,
3323    session: UserSession,
3324) -> Result<()> {
3325    let requester_id = session.user_id();
3326    let responder_id = UserId::from_proto(request.responder_id);
3327    if requester_id == responder_id {
3328        return Err(anyhow!("cannot add yourself as a contact"))?;
3329    }
3330
3331    let notifications = session
3332        .db()
3333        .await
3334        .send_contact_request(requester_id, responder_id)
3335        .await?;
3336
3337    // Update outgoing contact requests of requester
3338    let mut update = proto::UpdateContacts::default();
3339    update.outgoing_requests.push(responder_id.to_proto());
3340    for connection_id in session
3341        .connection_pool()
3342        .await
3343        .user_connection_ids(requester_id)
3344    {
3345        session.peer.send(connection_id, update.clone())?;
3346    }
3347
3348    // Update incoming contact requests of responder
3349    let mut update = proto::UpdateContacts::default();
3350    update
3351        .incoming_requests
3352        .push(proto::IncomingContactRequest {
3353            requester_id: requester_id.to_proto(),
3354        });
3355    let connection_pool = session.connection_pool().await;
3356    for connection_id in connection_pool.user_connection_ids(responder_id) {
3357        session.peer.send(connection_id, update.clone())?;
3358    }
3359
3360    send_notifications(&connection_pool, &session.peer, notifications);
3361
3362    response.send(proto::Ack {})?;
3363    Ok(())
3364}
3365
3366/// Accept or decline a contact request
3367async fn respond_to_contact_request(
3368    request: proto::RespondToContactRequest,
3369    response: Response<proto::RespondToContactRequest>,
3370    session: UserSession,
3371) -> Result<()> {
3372    let responder_id = session.user_id();
3373    let requester_id = UserId::from_proto(request.requester_id);
3374    let db = session.db().await;
3375    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3376        db.dismiss_contact_notification(responder_id, requester_id)
3377            .await?;
3378    } else {
3379        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3380
3381        let notifications = db
3382            .respond_to_contact_request(responder_id, requester_id, accept)
3383            .await?;
3384        let requester_busy = db.is_user_busy(requester_id).await?;
3385        let responder_busy = db.is_user_busy(responder_id).await?;
3386
3387        let pool = session.connection_pool().await;
3388        // Update responder with new contact
3389        let mut update = proto::UpdateContacts::default();
3390        if accept {
3391            update
3392                .contacts
3393                .push(contact_for_user(requester_id, requester_busy, &pool));
3394        }
3395        update
3396            .remove_incoming_requests
3397            .push(requester_id.to_proto());
3398        for connection_id in pool.user_connection_ids(responder_id) {
3399            session.peer.send(connection_id, update.clone())?;
3400        }
3401
3402        // Update requester with new contact
3403        let mut update = proto::UpdateContacts::default();
3404        if accept {
3405            update
3406                .contacts
3407                .push(contact_for_user(responder_id, responder_busy, &pool));
3408        }
3409        update
3410            .remove_outgoing_requests
3411            .push(responder_id.to_proto());
3412
3413        for connection_id in pool.user_connection_ids(requester_id) {
3414            session.peer.send(connection_id, update.clone())?;
3415        }
3416
3417        send_notifications(&pool, &session.peer, notifications);
3418    }
3419
3420    response.send(proto::Ack {})?;
3421    Ok(())
3422}
3423
3424/// Remove a contact.
3425async fn remove_contact(
3426    request: proto::RemoveContact,
3427    response: Response<proto::RemoveContact>,
3428    session: UserSession,
3429) -> Result<()> {
3430    let requester_id = session.user_id();
3431    let responder_id = UserId::from_proto(request.user_id);
3432    let db = session.db().await;
3433    let (contact_accepted, deleted_notification_id) =
3434        db.remove_contact(requester_id, responder_id).await?;
3435
3436    let pool = session.connection_pool().await;
3437    // Update outgoing contact requests of requester
3438    let mut update = proto::UpdateContacts::default();
3439    if contact_accepted {
3440        update.remove_contacts.push(responder_id.to_proto());
3441    } else {
3442        update
3443            .remove_outgoing_requests
3444            .push(responder_id.to_proto());
3445    }
3446    for connection_id in pool.user_connection_ids(requester_id) {
3447        session.peer.send(connection_id, update.clone())?;
3448    }
3449
3450    // Update incoming contact requests of responder
3451    let mut update = proto::UpdateContacts::default();
3452    if contact_accepted {
3453        update.remove_contacts.push(requester_id.to_proto());
3454    } else {
3455        update
3456            .remove_incoming_requests
3457            .push(requester_id.to_proto());
3458    }
3459    for connection_id in pool.user_connection_ids(responder_id) {
3460        session.peer.send(connection_id, update.clone())?;
3461        if let Some(notification_id) = deleted_notification_id {
3462            session.peer.send(
3463                connection_id,
3464                proto::DeleteNotification {
3465                    notification_id: notification_id.to_proto(),
3466                },
3467            )?;
3468        }
3469    }
3470
3471    response.send(proto::Ack {})?;
3472    Ok(())
3473}
3474
3475fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3476    version.0.minor() < 139
3477}
3478
3479async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3480    let plan = session.current_plan(&session.db().await).await?;
3481
3482    session
3483        .peer
3484        .send(
3485            session.connection_id,
3486            proto::UpdateUserPlan { plan: plan.into() },
3487        )
3488        .trace_err();
3489
3490    Ok(())
3491}
3492
3493async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3494    subscribe_user_to_channels(
3495        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3496        &session,
3497    )
3498    .await?;
3499    Ok(())
3500}
3501
3502async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3503    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3504    let mut pool = session.connection_pool().await;
3505    for membership in &channels_for_user.channel_memberships {
3506        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3507    }
3508    session.peer.send(
3509        session.connection_id,
3510        build_update_user_channels(&channels_for_user),
3511    )?;
3512    session.peer.send(
3513        session.connection_id,
3514        build_channels_update(channels_for_user),
3515    )?;
3516    Ok(())
3517}
3518
3519/// Creates a new channel.
3520async fn create_channel(
3521    request: proto::CreateChannel,
3522    response: Response<proto::CreateChannel>,
3523    session: UserSession,
3524) -> Result<()> {
3525    let db = session.db().await;
3526
3527    let parent_id = request.parent_id.map(ChannelId::from_proto);
3528    let (channel, membership) = db
3529        .create_channel(&request.name, parent_id, session.user_id())
3530        .await?;
3531
3532    let root_id = channel.root_id();
3533    let channel = Channel::from_model(channel);
3534
3535    response.send(proto::CreateChannelResponse {
3536        channel: Some(channel.to_proto()),
3537        parent_id: request.parent_id,
3538    })?;
3539
3540    let mut connection_pool = session.connection_pool().await;
3541    if let Some(membership) = membership {
3542        connection_pool.subscribe_to_channel(
3543            membership.user_id,
3544            membership.channel_id,
3545            membership.role,
3546        );
3547        let update = proto::UpdateUserChannels {
3548            channel_memberships: vec![proto::ChannelMembership {
3549                channel_id: membership.channel_id.to_proto(),
3550                role: membership.role.into(),
3551            }],
3552            ..Default::default()
3553        };
3554        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3555            session.peer.send(connection_id, update.clone())?;
3556        }
3557    }
3558
3559    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3560        if !role.can_see_channel(channel.visibility) {
3561            continue;
3562        }
3563
3564        let update = proto::UpdateChannels {
3565            channels: vec![channel.to_proto()],
3566            ..Default::default()
3567        };
3568        session.peer.send(connection_id, update.clone())?;
3569    }
3570
3571    Ok(())
3572}
3573
3574/// Delete a channel
3575async fn delete_channel(
3576    request: proto::DeleteChannel,
3577    response: Response<proto::DeleteChannel>,
3578    session: UserSession,
3579) -> Result<()> {
3580    let db = session.db().await;
3581
3582    let channel_id = request.channel_id;
3583    let (root_channel, removed_channels) = db
3584        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3585        .await?;
3586    response.send(proto::Ack {})?;
3587
3588    // Notify members of removed channels
3589    let mut update = proto::UpdateChannels::default();
3590    update
3591        .delete_channels
3592        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3593
3594    let connection_pool = session.connection_pool().await;
3595    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3596        session.peer.send(connection_id, update.clone())?;
3597    }
3598
3599    Ok(())
3600}
3601
3602/// Invite someone to join a channel.
3603async fn invite_channel_member(
3604    request: proto::InviteChannelMember,
3605    response: Response<proto::InviteChannelMember>,
3606    session: UserSession,
3607) -> Result<()> {
3608    let db = session.db().await;
3609    let channel_id = ChannelId::from_proto(request.channel_id);
3610    let invitee_id = UserId::from_proto(request.user_id);
3611    let InviteMemberResult {
3612        channel,
3613        notifications,
3614    } = db
3615        .invite_channel_member(
3616            channel_id,
3617            invitee_id,
3618            session.user_id(),
3619            request.role().into(),
3620        )
3621        .await?;
3622
3623    let update = proto::UpdateChannels {
3624        channel_invitations: vec![channel.to_proto()],
3625        ..Default::default()
3626    };
3627
3628    let connection_pool = session.connection_pool().await;
3629    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3630        session.peer.send(connection_id, update.clone())?;
3631    }
3632
3633    send_notifications(&connection_pool, &session.peer, notifications);
3634
3635    response.send(proto::Ack {})?;
3636    Ok(())
3637}
3638
3639/// remove someone from a channel
3640async fn remove_channel_member(
3641    request: proto::RemoveChannelMember,
3642    response: Response<proto::RemoveChannelMember>,
3643    session: UserSession,
3644) -> Result<()> {
3645    let db = session.db().await;
3646    let channel_id = ChannelId::from_proto(request.channel_id);
3647    let member_id = UserId::from_proto(request.user_id);
3648
3649    let RemoveChannelMemberResult {
3650        membership_update,
3651        notification_id,
3652    } = db
3653        .remove_channel_member(channel_id, member_id, session.user_id())
3654        .await?;
3655
3656    let mut connection_pool = session.connection_pool().await;
3657    notify_membership_updated(
3658        &mut connection_pool,
3659        membership_update,
3660        member_id,
3661        &session.peer,
3662    );
3663    for connection_id in connection_pool.user_connection_ids(member_id) {
3664        if let Some(notification_id) = notification_id {
3665            session
3666                .peer
3667                .send(
3668                    connection_id,
3669                    proto::DeleteNotification {
3670                        notification_id: notification_id.to_proto(),
3671                    },
3672                )
3673                .trace_err();
3674        }
3675    }
3676
3677    response.send(proto::Ack {})?;
3678    Ok(())
3679}
3680
3681/// Toggle the channel between public and private.
3682/// Care is taken to maintain the invariant that public channels only descend from public channels,
3683/// (though members-only channels can appear at any point in the hierarchy).
3684async fn set_channel_visibility(
3685    request: proto::SetChannelVisibility,
3686    response: Response<proto::SetChannelVisibility>,
3687    session: UserSession,
3688) -> Result<()> {
3689    let db = session.db().await;
3690    let channel_id = ChannelId::from_proto(request.channel_id);
3691    let visibility = request.visibility().into();
3692
3693    let channel_model = db
3694        .set_channel_visibility(channel_id, visibility, session.user_id())
3695        .await?;
3696    let root_id = channel_model.root_id();
3697    let channel = Channel::from_model(channel_model);
3698
3699    let mut connection_pool = session.connection_pool().await;
3700    for (user_id, role) in connection_pool
3701        .channel_user_ids(root_id)
3702        .collect::<Vec<_>>()
3703        .into_iter()
3704    {
3705        let update = if role.can_see_channel(channel.visibility) {
3706            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3707            proto::UpdateChannels {
3708                channels: vec![channel.to_proto()],
3709                ..Default::default()
3710            }
3711        } else {
3712            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3713            proto::UpdateChannels {
3714                delete_channels: vec![channel.id.to_proto()],
3715                ..Default::default()
3716            }
3717        };
3718
3719        for connection_id in connection_pool.user_connection_ids(user_id) {
3720            session.peer.send(connection_id, update.clone())?;
3721        }
3722    }
3723
3724    response.send(proto::Ack {})?;
3725    Ok(())
3726}
3727
3728/// Alter the role for a user in the channel.
3729async fn set_channel_member_role(
3730    request: proto::SetChannelMemberRole,
3731    response: Response<proto::SetChannelMemberRole>,
3732    session: UserSession,
3733) -> Result<()> {
3734    let db = session.db().await;
3735    let channel_id = ChannelId::from_proto(request.channel_id);
3736    let member_id = UserId::from_proto(request.user_id);
3737    let result = db
3738        .set_channel_member_role(
3739            channel_id,
3740            session.user_id(),
3741            member_id,
3742            request.role().into(),
3743        )
3744        .await?;
3745
3746    match result {
3747        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3748            let mut connection_pool = session.connection_pool().await;
3749            notify_membership_updated(
3750                &mut connection_pool,
3751                membership_update,
3752                member_id,
3753                &session.peer,
3754            )
3755        }
3756        db::SetMemberRoleResult::InviteUpdated(channel) => {
3757            let update = proto::UpdateChannels {
3758                channel_invitations: vec![channel.to_proto()],
3759                ..Default::default()
3760            };
3761
3762            for connection_id in session
3763                .connection_pool()
3764                .await
3765                .user_connection_ids(member_id)
3766            {
3767                session.peer.send(connection_id, update.clone())?;
3768            }
3769        }
3770    }
3771
3772    response.send(proto::Ack {})?;
3773    Ok(())
3774}
3775
3776/// Change the name of a channel
3777async fn rename_channel(
3778    request: proto::RenameChannel,
3779    response: Response<proto::RenameChannel>,
3780    session: UserSession,
3781) -> Result<()> {
3782    let db = session.db().await;
3783    let channel_id = ChannelId::from_proto(request.channel_id);
3784    let channel_model = db
3785        .rename_channel(channel_id, session.user_id(), &request.name)
3786        .await?;
3787    let root_id = channel_model.root_id();
3788    let channel = Channel::from_model(channel_model);
3789
3790    response.send(proto::RenameChannelResponse {
3791        channel: Some(channel.to_proto()),
3792    })?;
3793
3794    let connection_pool = session.connection_pool().await;
3795    let update = proto::UpdateChannels {
3796        channels: vec![channel.to_proto()],
3797        ..Default::default()
3798    };
3799    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3800        if role.can_see_channel(channel.visibility) {
3801            session.peer.send(connection_id, update.clone())?;
3802        }
3803    }
3804
3805    Ok(())
3806}
3807
3808/// Move a channel to a new parent.
3809async fn move_channel(
3810    request: proto::MoveChannel,
3811    response: Response<proto::MoveChannel>,
3812    session: UserSession,
3813) -> Result<()> {
3814    let channel_id = ChannelId::from_proto(request.channel_id);
3815    let to = ChannelId::from_proto(request.to);
3816
3817    let (root_id, channels) = session
3818        .db()
3819        .await
3820        .move_channel(channel_id, to, session.user_id())
3821        .await?;
3822
3823    let connection_pool = session.connection_pool().await;
3824    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3825        let channels = channels
3826            .iter()
3827            .filter_map(|channel| {
3828                if role.can_see_channel(channel.visibility) {
3829                    Some(channel.to_proto())
3830                } else {
3831                    None
3832                }
3833            })
3834            .collect::<Vec<_>>();
3835        if channels.is_empty() {
3836            continue;
3837        }
3838
3839        let update = proto::UpdateChannels {
3840            channels,
3841            ..Default::default()
3842        };
3843
3844        session.peer.send(connection_id, update.clone())?;
3845    }
3846
3847    response.send(Ack {})?;
3848    Ok(())
3849}
3850
3851/// Get the list of channel members
3852async fn get_channel_members(
3853    request: proto::GetChannelMembers,
3854    response: Response<proto::GetChannelMembers>,
3855    session: UserSession,
3856) -> Result<()> {
3857    let db = session.db().await;
3858    let channel_id = ChannelId::from_proto(request.channel_id);
3859    let limit = if request.limit == 0 {
3860        u16::MAX as u64
3861    } else {
3862        request.limit
3863    };
3864    let (members, users) = db
3865        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3866        .await?;
3867    response.send(proto::GetChannelMembersResponse { members, users })?;
3868    Ok(())
3869}
3870
3871/// Accept or decline a channel invitation.
3872async fn respond_to_channel_invite(
3873    request: proto::RespondToChannelInvite,
3874    response: Response<proto::RespondToChannelInvite>,
3875    session: UserSession,
3876) -> Result<()> {
3877    let db = session.db().await;
3878    let channel_id = ChannelId::from_proto(request.channel_id);
3879    let RespondToChannelInvite {
3880        membership_update,
3881        notifications,
3882    } = db
3883        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3884        .await?;
3885
3886    let mut connection_pool = session.connection_pool().await;
3887    if let Some(membership_update) = membership_update {
3888        notify_membership_updated(
3889            &mut connection_pool,
3890            membership_update,
3891            session.user_id(),
3892            &session.peer,
3893        );
3894    } else {
3895        let update = proto::UpdateChannels {
3896            remove_channel_invitations: vec![channel_id.to_proto()],
3897            ..Default::default()
3898        };
3899
3900        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3901            session.peer.send(connection_id, update.clone())?;
3902        }
3903    };
3904
3905    send_notifications(&connection_pool, &session.peer, notifications);
3906
3907    response.send(proto::Ack {})?;
3908
3909    Ok(())
3910}
3911
3912/// Join the channels' room
3913async fn join_channel(
3914    request: proto::JoinChannel,
3915    response: Response<proto::JoinChannel>,
3916    session: UserSession,
3917) -> Result<()> {
3918    let channel_id = ChannelId::from_proto(request.channel_id);
3919    join_channel_internal(channel_id, Box::new(response), session).await
3920}
3921
3922trait JoinChannelInternalResponse {
3923    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3924}
3925impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3926    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3927        Response::<proto::JoinChannel>::send(self, result)
3928    }
3929}
3930impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3931    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3932        Response::<proto::JoinRoom>::send(self, result)
3933    }
3934}
3935
3936async fn join_channel_internal(
3937    channel_id: ChannelId,
3938    response: Box<impl JoinChannelInternalResponse>,
3939    session: UserSession,
3940) -> Result<()> {
3941    let joined_room = {
3942        let mut db = session.db().await;
3943        // If zed quits without leaving the room, and the user re-opens zed before the
3944        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3945        // room they were in.
3946        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3947            tracing::info!(
3948                stale_connection_id = %connection,
3949                "cleaning up stale connection",
3950            );
3951            drop(db);
3952            leave_room_for_session(&session, connection).await?;
3953            db = session.db().await;
3954        }
3955
3956        let (joined_room, membership_updated, role) = db
3957            .join_channel(channel_id, session.user_id(), session.connection_id)
3958            .await?;
3959
3960        let live_kit_connection_info =
3961            session
3962                .app_state
3963                .live_kit_client
3964                .as_ref()
3965                .and_then(|live_kit| {
3966                    let (can_publish, token) = if role == ChannelRole::Guest {
3967                        (
3968                            false,
3969                            live_kit
3970                                .guest_token(
3971                                    &joined_room.room.live_kit_room,
3972                                    &session.user_id().to_string(),
3973                                )
3974                                .trace_err()?,
3975                        )
3976                    } else {
3977                        (
3978                            true,
3979                            live_kit
3980                                .room_token(
3981                                    &joined_room.room.live_kit_room,
3982                                    &session.user_id().to_string(),
3983                                )
3984                                .trace_err()?,
3985                        )
3986                    };
3987
3988                    Some(LiveKitConnectionInfo {
3989                        server_url: live_kit.url().into(),
3990                        token,
3991                        can_publish,
3992                    })
3993                });
3994
3995        response.send(proto::JoinRoomResponse {
3996            room: Some(joined_room.room.clone()),
3997            channel_id: joined_room
3998                .channel
3999                .as_ref()
4000                .map(|channel| channel.id.to_proto()),
4001            live_kit_connection_info,
4002        })?;
4003
4004        let mut connection_pool = session.connection_pool().await;
4005        if let Some(membership_updated) = membership_updated {
4006            notify_membership_updated(
4007                &mut connection_pool,
4008                membership_updated,
4009                session.user_id(),
4010                &session.peer,
4011            );
4012        }
4013
4014        room_updated(&joined_room.room, &session.peer);
4015
4016        joined_room
4017    };
4018
4019    channel_updated(
4020        &joined_room
4021            .channel
4022            .ok_or_else(|| anyhow!("channel not returned"))?,
4023        &joined_room.room,
4024        &session.peer,
4025        &*session.connection_pool().await,
4026    );
4027
4028    update_user_contacts(session.user_id(), &session).await?;
4029    Ok(())
4030}
4031
4032/// Start editing the channel notes
4033async fn join_channel_buffer(
4034    request: proto::JoinChannelBuffer,
4035    response: Response<proto::JoinChannelBuffer>,
4036    session: UserSession,
4037) -> Result<()> {
4038    let db = session.db().await;
4039    let channel_id = ChannelId::from_proto(request.channel_id);
4040
4041    let open_response = db
4042        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4043        .await?;
4044
4045    let collaborators = open_response.collaborators.clone();
4046    response.send(open_response)?;
4047
4048    let update = UpdateChannelBufferCollaborators {
4049        channel_id: channel_id.to_proto(),
4050        collaborators: collaborators.clone(),
4051    };
4052    channel_buffer_updated(
4053        session.connection_id,
4054        collaborators
4055            .iter()
4056            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4057        &update,
4058        &session.peer,
4059    );
4060
4061    Ok(())
4062}
4063
4064/// Edit the channel notes
4065async fn update_channel_buffer(
4066    request: proto::UpdateChannelBuffer,
4067    session: UserSession,
4068) -> Result<()> {
4069    let db = session.db().await;
4070    let channel_id = ChannelId::from_proto(request.channel_id);
4071
4072    let (collaborators, epoch, version) = db
4073        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4074        .await?;
4075
4076    channel_buffer_updated(
4077        session.connection_id,
4078        collaborators.clone(),
4079        &proto::UpdateChannelBuffer {
4080            channel_id: channel_id.to_proto(),
4081            operations: request.operations,
4082        },
4083        &session.peer,
4084    );
4085
4086    let pool = &*session.connection_pool().await;
4087
4088    let non_collaborators =
4089        pool.channel_connection_ids(channel_id)
4090            .filter_map(|(connection_id, _)| {
4091                if collaborators.contains(&connection_id) {
4092                    None
4093                } else {
4094                    Some(connection_id)
4095                }
4096            });
4097
4098    broadcast(None, non_collaborators, |peer_id| {
4099        session.peer.send(
4100            peer_id,
4101            proto::UpdateChannels {
4102                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4103                    channel_id: channel_id.to_proto(),
4104                    epoch: epoch as u64,
4105                    version: version.clone(),
4106                }],
4107                ..Default::default()
4108            },
4109        )
4110    });
4111
4112    Ok(())
4113}
4114
4115/// Rejoin the channel notes after a connection blip
4116async fn rejoin_channel_buffers(
4117    request: proto::RejoinChannelBuffers,
4118    response: Response<proto::RejoinChannelBuffers>,
4119    session: UserSession,
4120) -> Result<()> {
4121    let db = session.db().await;
4122    let buffers = db
4123        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4124        .await?;
4125
4126    for rejoined_buffer in &buffers {
4127        let collaborators_to_notify = rejoined_buffer
4128            .buffer
4129            .collaborators
4130            .iter()
4131            .filter_map(|c| Some(c.peer_id?.into()));
4132        channel_buffer_updated(
4133            session.connection_id,
4134            collaborators_to_notify,
4135            &proto::UpdateChannelBufferCollaborators {
4136                channel_id: rejoined_buffer.buffer.channel_id,
4137                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4138            },
4139            &session.peer,
4140        );
4141    }
4142
4143    response.send(proto::RejoinChannelBuffersResponse {
4144        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4145    })?;
4146
4147    Ok(())
4148}
4149
4150/// Stop editing the channel notes
4151async fn leave_channel_buffer(
4152    request: proto::LeaveChannelBuffer,
4153    response: Response<proto::LeaveChannelBuffer>,
4154    session: UserSession,
4155) -> Result<()> {
4156    let db = session.db().await;
4157    let channel_id = ChannelId::from_proto(request.channel_id);
4158
4159    let left_buffer = db
4160        .leave_channel_buffer(channel_id, session.connection_id)
4161        .await?;
4162
4163    response.send(Ack {})?;
4164
4165    channel_buffer_updated(
4166        session.connection_id,
4167        left_buffer.connections,
4168        &proto::UpdateChannelBufferCollaborators {
4169            channel_id: channel_id.to_proto(),
4170            collaborators: left_buffer.collaborators,
4171        },
4172        &session.peer,
4173    );
4174
4175    Ok(())
4176}
4177
4178fn channel_buffer_updated<T: EnvelopedMessage>(
4179    sender_id: ConnectionId,
4180    collaborators: impl IntoIterator<Item = ConnectionId>,
4181    message: &T,
4182    peer: &Peer,
4183) {
4184    broadcast(Some(sender_id), collaborators, |peer_id| {
4185        peer.send(peer_id, message.clone())
4186    });
4187}
4188
4189fn send_notifications(
4190    connection_pool: &ConnectionPool,
4191    peer: &Peer,
4192    notifications: db::NotificationBatch,
4193) {
4194    for (user_id, notification) in notifications {
4195        for connection_id in connection_pool.user_connection_ids(user_id) {
4196            if let Err(error) = peer.send(
4197                connection_id,
4198                proto::AddNotification {
4199                    notification: Some(notification.clone()),
4200                },
4201            ) {
4202                tracing::error!(
4203                    "failed to send notification to {:?} {}",
4204                    connection_id,
4205                    error
4206                );
4207            }
4208        }
4209    }
4210}
4211
4212/// Send a message to the channel
4213async fn send_channel_message(
4214    request: proto::SendChannelMessage,
4215    response: Response<proto::SendChannelMessage>,
4216    session: UserSession,
4217) -> Result<()> {
4218    // Validate the message body.
4219    let body = request.body.trim().to_string();
4220    if body.len() > MAX_MESSAGE_LEN {
4221        return Err(anyhow!("message is too long"))?;
4222    }
4223    if body.is_empty() {
4224        return Err(anyhow!("message can't be blank"))?;
4225    }
4226
4227    // TODO: adjust mentions if body is trimmed
4228
4229    let timestamp = OffsetDateTime::now_utc();
4230    let nonce = request
4231        .nonce
4232        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4233
4234    let channel_id = ChannelId::from_proto(request.channel_id);
4235    let CreatedChannelMessage {
4236        message_id,
4237        participant_connection_ids,
4238        notifications,
4239    } = session
4240        .db()
4241        .await
4242        .create_channel_message(
4243            channel_id,
4244            session.user_id(),
4245            &body,
4246            &request.mentions,
4247            timestamp,
4248            nonce.clone().into(),
4249            request.reply_to_message_id.map(MessageId::from_proto),
4250        )
4251        .await?;
4252
4253    let message = proto::ChannelMessage {
4254        sender_id: session.user_id().to_proto(),
4255        id: message_id.to_proto(),
4256        body,
4257        mentions: request.mentions,
4258        timestamp: timestamp.unix_timestamp() as u64,
4259        nonce: Some(nonce),
4260        reply_to_message_id: request.reply_to_message_id,
4261        edited_at: None,
4262    };
4263    broadcast(
4264        Some(session.connection_id),
4265        participant_connection_ids.clone(),
4266        |connection| {
4267            session.peer.send(
4268                connection,
4269                proto::ChannelMessageSent {
4270                    channel_id: channel_id.to_proto(),
4271                    message: Some(message.clone()),
4272                },
4273            )
4274        },
4275    );
4276    response.send(proto::SendChannelMessageResponse {
4277        message: Some(message),
4278    })?;
4279
4280    let pool = &*session.connection_pool().await;
4281    let non_participants =
4282        pool.channel_connection_ids(channel_id)
4283            .filter_map(|(connection_id, _)| {
4284                if participant_connection_ids.contains(&connection_id) {
4285                    None
4286                } else {
4287                    Some(connection_id)
4288                }
4289            });
4290    broadcast(None, non_participants, |peer_id| {
4291        session.peer.send(
4292            peer_id,
4293            proto::UpdateChannels {
4294                latest_channel_message_ids: vec![proto::ChannelMessageId {
4295                    channel_id: channel_id.to_proto(),
4296                    message_id: message_id.to_proto(),
4297                }],
4298                ..Default::default()
4299            },
4300        )
4301    });
4302    send_notifications(pool, &session.peer, notifications);
4303
4304    Ok(())
4305}
4306
4307/// Delete a channel message
4308async fn remove_channel_message(
4309    request: proto::RemoveChannelMessage,
4310    response: Response<proto::RemoveChannelMessage>,
4311    session: UserSession,
4312) -> Result<()> {
4313    let channel_id = ChannelId::from_proto(request.channel_id);
4314    let message_id = MessageId::from_proto(request.message_id);
4315    let (connection_ids, existing_notification_ids) = session
4316        .db()
4317        .await
4318        .remove_channel_message(channel_id, message_id, session.user_id())
4319        .await?;
4320
4321    broadcast(
4322        Some(session.connection_id),
4323        connection_ids,
4324        move |connection| {
4325            session.peer.send(connection, request.clone())?;
4326
4327            for notification_id in &existing_notification_ids {
4328                session.peer.send(
4329                    connection,
4330                    proto::DeleteNotification {
4331                        notification_id: (*notification_id).to_proto(),
4332                    },
4333                )?;
4334            }
4335
4336            Ok(())
4337        },
4338    );
4339    response.send(proto::Ack {})?;
4340    Ok(())
4341}
4342
4343async fn update_channel_message(
4344    request: proto::UpdateChannelMessage,
4345    response: Response<proto::UpdateChannelMessage>,
4346    session: UserSession,
4347) -> Result<()> {
4348    let channel_id = ChannelId::from_proto(request.channel_id);
4349    let message_id = MessageId::from_proto(request.message_id);
4350    let updated_at = OffsetDateTime::now_utc();
4351    let UpdatedChannelMessage {
4352        message_id,
4353        participant_connection_ids,
4354        notifications,
4355        reply_to_message_id,
4356        timestamp,
4357        deleted_mention_notification_ids,
4358        updated_mention_notifications,
4359    } = session
4360        .db()
4361        .await
4362        .update_channel_message(
4363            channel_id,
4364            message_id,
4365            session.user_id(),
4366            request.body.as_str(),
4367            &request.mentions,
4368            updated_at,
4369        )
4370        .await?;
4371
4372    let nonce = request
4373        .nonce
4374        .clone()
4375        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4376
4377    let message = proto::ChannelMessage {
4378        sender_id: session.user_id().to_proto(),
4379        id: message_id.to_proto(),
4380        body: request.body.clone(),
4381        mentions: request.mentions.clone(),
4382        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4383        nonce: Some(nonce),
4384        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4385        edited_at: Some(updated_at.unix_timestamp() as u64),
4386    };
4387
4388    response.send(proto::Ack {})?;
4389
4390    let pool = &*session.connection_pool().await;
4391    broadcast(
4392        Some(session.connection_id),
4393        participant_connection_ids,
4394        |connection| {
4395            session.peer.send(
4396                connection,
4397                proto::ChannelMessageUpdate {
4398                    channel_id: channel_id.to_proto(),
4399                    message: Some(message.clone()),
4400                },
4401            )?;
4402
4403            for notification_id in &deleted_mention_notification_ids {
4404                session.peer.send(
4405                    connection,
4406                    proto::DeleteNotification {
4407                        notification_id: (*notification_id).to_proto(),
4408                    },
4409                )?;
4410            }
4411
4412            for notification in &updated_mention_notifications {
4413                session.peer.send(
4414                    connection,
4415                    proto::UpdateNotification {
4416                        notification: Some(notification.clone()),
4417                    },
4418                )?;
4419            }
4420
4421            Ok(())
4422        },
4423    );
4424
4425    send_notifications(pool, &session.peer, notifications);
4426
4427    Ok(())
4428}
4429
4430/// Mark a channel message as read
4431async fn acknowledge_channel_message(
4432    request: proto::AckChannelMessage,
4433    session: UserSession,
4434) -> Result<()> {
4435    let channel_id = ChannelId::from_proto(request.channel_id);
4436    let message_id = MessageId::from_proto(request.message_id);
4437    let notifications = session
4438        .db()
4439        .await
4440        .observe_channel_message(channel_id, session.user_id(), message_id)
4441        .await?;
4442    send_notifications(
4443        &*session.connection_pool().await,
4444        &session.peer,
4445        notifications,
4446    );
4447    Ok(())
4448}
4449
4450/// Mark a buffer version as synced
4451async fn acknowledge_buffer_version(
4452    request: proto::AckBufferOperation,
4453    session: UserSession,
4454) -> Result<()> {
4455    let buffer_id = BufferId::from_proto(request.buffer_id);
4456    session
4457        .db()
4458        .await
4459        .observe_buffer_version(
4460            buffer_id,
4461            session.user_id(),
4462            request.epoch as i32,
4463            &request.version,
4464        )
4465        .await?;
4466    Ok(())
4467}
4468
4469async fn count_language_model_tokens(
4470    request: proto::CountLanguageModelTokens,
4471    response: Response<proto::CountLanguageModelTokens>,
4472    session: Session,
4473    config: &Config,
4474) -> Result<()> {
4475    let Some(session) = session.for_user() else {
4476        return Err(anyhow!("user not found"))?;
4477    };
4478    authorize_access_to_legacy_llm_endpoints(&session).await?;
4479
4480    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4481        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4482        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4483    };
4484
4485    session
4486        .app_state
4487        .rate_limiter
4488        .check(&*rate_limit, session.user_id())
4489        .await?;
4490
4491    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4492        Some(proto::LanguageModelProvider::Google) => {
4493            let api_key = config
4494                .google_ai_api_key
4495                .as_ref()
4496                .context("no Google AI API key configured on the server")?;
4497            google_ai::count_tokens(
4498                session.http_client.as_ref(),
4499                google_ai::API_URL,
4500                api_key,
4501                serde_json::from_str(&request.request)?,
4502                None,
4503            )
4504            .await?
4505        }
4506        _ => return Err(anyhow!("unsupported provider"))?,
4507    };
4508
4509    response.send(proto::CountLanguageModelTokensResponse {
4510        token_count: result.total_tokens as u32,
4511    })?;
4512
4513    Ok(())
4514}
4515
4516struct ZedProCountLanguageModelTokensRateLimit;
4517
4518impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4519    fn capacity(&self) -> usize {
4520        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4521            .ok()
4522            .and_then(|v| v.parse().ok())
4523            .unwrap_or(600) // Picked arbitrarily
4524    }
4525
4526    fn refill_duration(&self) -> chrono::Duration {
4527        chrono::Duration::hours(1)
4528    }
4529
4530    fn db_name(&self) -> &'static str {
4531        "zed-pro:count-language-model-tokens"
4532    }
4533}
4534
4535struct FreeCountLanguageModelTokensRateLimit;
4536
4537impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4538    fn capacity(&self) -> usize {
4539        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4540            .ok()
4541            .and_then(|v| v.parse().ok())
4542            .unwrap_or(600 / 10) // Picked arbitrarily
4543    }
4544
4545    fn refill_duration(&self) -> chrono::Duration {
4546        chrono::Duration::hours(1)
4547    }
4548
4549    fn db_name(&self) -> &'static str {
4550        "free:count-language-model-tokens"
4551    }
4552}
4553
4554struct ZedProComputeEmbeddingsRateLimit;
4555
4556impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4557    fn capacity(&self) -> usize {
4558        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4559            .ok()
4560            .and_then(|v| v.parse().ok())
4561            .unwrap_or(5000) // Picked arbitrarily
4562    }
4563
4564    fn refill_duration(&self) -> chrono::Duration {
4565        chrono::Duration::hours(1)
4566    }
4567
4568    fn db_name(&self) -> &'static str {
4569        "zed-pro:compute-embeddings"
4570    }
4571}
4572
4573struct FreeComputeEmbeddingsRateLimit;
4574
4575impl RateLimit for FreeComputeEmbeddingsRateLimit {
4576    fn capacity(&self) -> usize {
4577        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4578            .ok()
4579            .and_then(|v| v.parse().ok())
4580            .unwrap_or(5000 / 10) // Picked arbitrarily
4581    }
4582
4583    fn refill_duration(&self) -> chrono::Duration {
4584        chrono::Duration::hours(1)
4585    }
4586
4587    fn db_name(&self) -> &'static str {
4588        "free:compute-embeddings"
4589    }
4590}
4591
4592async fn compute_embeddings(
4593    request: proto::ComputeEmbeddings,
4594    response: Response<proto::ComputeEmbeddings>,
4595    session: UserSession,
4596    api_key: Option<Arc<str>>,
4597) -> Result<()> {
4598    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4599    authorize_access_to_legacy_llm_endpoints(&session).await?;
4600
4601    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
4602        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4603        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4604    };
4605
4606    session
4607        .app_state
4608        .rate_limiter
4609        .check(&*rate_limit, session.user_id())
4610        .await?;
4611
4612    let embeddings = match request.model.as_str() {
4613        "openai/text-embedding-3-small" => {
4614            open_ai::embed(
4615                session.http_client.as_ref(),
4616                OPEN_AI_API_URL,
4617                &api_key,
4618                OpenAiEmbeddingModel::TextEmbedding3Small,
4619                request.texts.iter().map(|text| text.as_str()),
4620            )
4621            .await?
4622        }
4623        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4624    };
4625
4626    let embeddings = request
4627        .texts
4628        .iter()
4629        .map(|text| {
4630            let mut hasher = sha2::Sha256::new();
4631            hasher.update(text.as_bytes());
4632            let result = hasher.finalize();
4633            result.to_vec()
4634        })
4635        .zip(
4636            embeddings
4637                .data
4638                .into_iter()
4639                .map(|embedding| embedding.embedding),
4640        )
4641        .collect::<HashMap<_, _>>();
4642
4643    let db = session.db().await;
4644    db.save_embeddings(&request.model, &embeddings)
4645        .await
4646        .context("failed to save embeddings")
4647        .trace_err();
4648
4649    response.send(proto::ComputeEmbeddingsResponse {
4650        embeddings: embeddings
4651            .into_iter()
4652            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4653            .collect(),
4654    })?;
4655    Ok(())
4656}
4657
4658async fn get_cached_embeddings(
4659    request: proto::GetCachedEmbeddings,
4660    response: Response<proto::GetCachedEmbeddings>,
4661    session: UserSession,
4662) -> Result<()> {
4663    authorize_access_to_legacy_llm_endpoints(&session).await?;
4664
4665    let db = session.db().await;
4666    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4667
4668    response.send(proto::GetCachedEmbeddingsResponse {
4669        embeddings: embeddings
4670            .into_iter()
4671            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4672            .collect(),
4673    })?;
4674    Ok(())
4675}
4676
4677/// This is leftover from before the LLM service.
4678///
4679/// The endpoints protected by this check will be moved there eventually.
4680async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4681    if session.is_staff() {
4682        Ok(())
4683    } else {
4684        Err(anyhow!("permission denied"))?
4685    }
4686}
4687
4688/// Get a Supermaven API key for the user
4689async fn get_supermaven_api_key(
4690    _request: proto::GetSupermavenApiKey,
4691    response: Response<proto::GetSupermavenApiKey>,
4692    session: UserSession,
4693) -> Result<()> {
4694    let user_id: String = session.user_id().to_string();
4695    if !session.is_staff() {
4696        return Err(anyhow!("supermaven not enabled for this account"))?;
4697    }
4698
4699    let email = session
4700        .email()
4701        .ok_or_else(|| anyhow!("user must have an email"))?;
4702
4703    let supermaven_admin_api = session
4704        .supermaven_client
4705        .as_ref()
4706        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4707
4708    let result = supermaven_admin_api
4709        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4710        .await?;
4711
4712    response.send(proto::GetSupermavenApiKeyResponse {
4713        api_key: result.api_key,
4714    })?;
4715
4716    Ok(())
4717}
4718
4719/// Start receiving chat updates for a channel
4720async fn join_channel_chat(
4721    request: proto::JoinChannelChat,
4722    response: Response<proto::JoinChannelChat>,
4723    session: UserSession,
4724) -> Result<()> {
4725    let channel_id = ChannelId::from_proto(request.channel_id);
4726
4727    let db = session.db().await;
4728    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4729        .await?;
4730    let messages = db
4731        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4732        .await?;
4733    response.send(proto::JoinChannelChatResponse {
4734        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4735        messages,
4736    })?;
4737    Ok(())
4738}
4739
4740/// Stop receiving chat updates for a channel
4741async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4742    let channel_id = ChannelId::from_proto(request.channel_id);
4743    session
4744        .db()
4745        .await
4746        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4747        .await?;
4748    Ok(())
4749}
4750
4751/// Retrieve the chat history for a channel
4752async fn get_channel_messages(
4753    request: proto::GetChannelMessages,
4754    response: Response<proto::GetChannelMessages>,
4755    session: UserSession,
4756) -> Result<()> {
4757    let channel_id = ChannelId::from_proto(request.channel_id);
4758    let messages = session
4759        .db()
4760        .await
4761        .get_channel_messages(
4762            channel_id,
4763            session.user_id(),
4764            MESSAGE_COUNT_PER_PAGE,
4765            Some(MessageId::from_proto(request.before_message_id)),
4766        )
4767        .await?;
4768    response.send(proto::GetChannelMessagesResponse {
4769        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4770        messages,
4771    })?;
4772    Ok(())
4773}
4774
4775/// Retrieve specific chat messages
4776async fn get_channel_messages_by_id(
4777    request: proto::GetChannelMessagesById,
4778    response: Response<proto::GetChannelMessagesById>,
4779    session: UserSession,
4780) -> Result<()> {
4781    let message_ids = request
4782        .message_ids
4783        .iter()
4784        .map(|id| MessageId::from_proto(*id))
4785        .collect::<Vec<_>>();
4786    let messages = session
4787        .db()
4788        .await
4789        .get_channel_messages_by_id(session.user_id(), &message_ids)
4790        .await?;
4791    response.send(proto::GetChannelMessagesResponse {
4792        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4793        messages,
4794    })?;
4795    Ok(())
4796}
4797
4798/// Retrieve the current users notifications
4799async fn get_notifications(
4800    request: proto::GetNotifications,
4801    response: Response<proto::GetNotifications>,
4802    session: UserSession,
4803) -> Result<()> {
4804    let notifications = session
4805        .db()
4806        .await
4807        .get_notifications(
4808            session.user_id(),
4809            NOTIFICATION_COUNT_PER_PAGE,
4810            request.before_id.map(db::NotificationId::from_proto),
4811        )
4812        .await?;
4813    response.send(proto::GetNotificationsResponse {
4814        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4815        notifications,
4816    })?;
4817    Ok(())
4818}
4819
4820/// Mark notifications as read
4821async fn mark_notification_as_read(
4822    request: proto::MarkNotificationRead,
4823    response: Response<proto::MarkNotificationRead>,
4824    session: UserSession,
4825) -> Result<()> {
4826    let database = &session.db().await;
4827    let notifications = database
4828        .mark_notification_as_read_by_id(
4829            session.user_id(),
4830            NotificationId::from_proto(request.notification_id),
4831        )
4832        .await?;
4833    send_notifications(
4834        &*session.connection_pool().await,
4835        &session.peer,
4836        notifications,
4837    );
4838    response.send(proto::Ack {})?;
4839    Ok(())
4840}
4841
4842/// Get the current users information
4843async fn get_private_user_info(
4844    _request: proto::GetPrivateUserInfo,
4845    response: Response<proto::GetPrivateUserInfo>,
4846    session: UserSession,
4847) -> Result<()> {
4848    let db = session.db().await;
4849
4850    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4851    let user = db
4852        .get_user_by_id(session.user_id())
4853        .await?
4854        .ok_or_else(|| anyhow!("user not found"))?;
4855    let flags = db.get_user_flags(session.user_id()).await?;
4856
4857    response.send(proto::GetPrivateUserInfoResponse {
4858        metrics_id,
4859        staff: user.admin,
4860        flags,
4861        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4862    })?;
4863    Ok(())
4864}
4865
4866/// Accept the terms of service (tos) on behalf of the current user
4867async fn accept_terms_of_service(
4868    _request: proto::AcceptTermsOfService,
4869    response: Response<proto::AcceptTermsOfService>,
4870    session: UserSession,
4871) -> Result<()> {
4872    let db = session.db().await;
4873
4874    let accepted_tos_at = Utc::now();
4875    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4876        .await?;
4877
4878    response.send(proto::AcceptTermsOfServiceResponse {
4879        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4880    })?;
4881    Ok(())
4882}
4883
4884/// The minimum account age an account must have in order to use the LLM service.
4885const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4886
4887async fn get_llm_api_token(
4888    _request: proto::GetLlmToken,
4889    response: Response<proto::GetLlmToken>,
4890    session: UserSession,
4891) -> Result<()> {
4892    let db = session.db().await;
4893
4894    let flags = db.get_user_flags(session.user_id()).await?;
4895    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4896    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4897
4898    if !session.is_staff() && !has_language_models_feature_flag {
4899        Err(anyhow!("permission denied"))?
4900    }
4901
4902    let user_id = session.user_id();
4903    let user = db
4904        .get_user_by_id(user_id)
4905        .await?
4906        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4907
4908    if user.accepted_tos_at.is_none() {
4909        Err(anyhow!("terms of service not accepted"))?
4910    }
4911
4912    let mut account_created_at = user.created_at;
4913    if let Some(github_created_at) = user.github_user_created_at {
4914        account_created_at = account_created_at.min(github_created_at);
4915    }
4916    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4917        Err(anyhow!("account too young"))?
4918    }
4919
4920    let billing_preferences = db.get_billing_preferences(user.id).await?;
4921
4922    let token = LlmTokenClaims::create(
4923        &user,
4924        session.is_staff(),
4925        billing_preferences,
4926        has_llm_closed_beta_feature_flag,
4927        session.has_llm_subscription(&db).await?,
4928        session.current_plan(&db).await?,
4929        &session.app_state.config,
4930    )?;
4931    response.send(proto::GetLlmTokenResponse { token })?;
4932    Ok(())
4933}
4934
4935fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4936    let message = match message {
4937        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4938        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4939        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4940        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4941        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4942            code: frame.code.into(),
4943            reason: frame.reason,
4944        })),
4945        // We should never receive a frame while reading the message, according
4946        // to the `tungstenite` maintainers:
4947        //
4948        // > It cannot occur when you read messages from the WebSocket, but it
4949        // > can be used when you want to send the raw frames (e.g. you want to
4950        // > send the frames to the WebSocket without composing the full message first).
4951        // >
4952        // > — https://github.com/snapview/tungstenite-rs/issues/268
4953        TungsteniteMessage::Frame(_) => {
4954            bail!("received an unexpected frame while reading the message")
4955        }
4956    };
4957
4958    Ok(message)
4959}
4960
4961fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4962    match message {
4963        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4964        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4965        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4966        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4967        AxumMessage::Close(frame) => {
4968            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4969                code: frame.code.into(),
4970                reason: frame.reason,
4971            }))
4972        }
4973    }
4974}
4975
4976fn notify_membership_updated(
4977    connection_pool: &mut ConnectionPool,
4978    result: MembershipUpdated,
4979    user_id: UserId,
4980    peer: &Peer,
4981) {
4982    for membership in &result.new_channels.channel_memberships {
4983        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4984    }
4985    for channel_id in &result.removed_channels {
4986        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4987    }
4988
4989    let user_channels_update = proto::UpdateUserChannels {
4990        channel_memberships: result
4991            .new_channels
4992            .channel_memberships
4993            .iter()
4994            .map(|cm| proto::ChannelMembership {
4995                channel_id: cm.channel_id.to_proto(),
4996                role: cm.role.into(),
4997            })
4998            .collect(),
4999        ..Default::default()
5000    };
5001
5002    let mut update = build_channels_update(result.new_channels);
5003    update.delete_channels = result
5004        .removed_channels
5005        .into_iter()
5006        .map(|id| id.to_proto())
5007        .collect();
5008    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5009
5010    for connection_id in connection_pool.user_connection_ids(user_id) {
5011        peer.send(connection_id, user_channels_update.clone())
5012            .trace_err();
5013        peer.send(connection_id, update.clone()).trace_err();
5014    }
5015}
5016
5017fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5018    proto::UpdateUserChannels {
5019        channel_memberships: channels
5020            .channel_memberships
5021            .iter()
5022            .map(|m| proto::ChannelMembership {
5023                channel_id: m.channel_id.to_proto(),
5024                role: m.role.into(),
5025            })
5026            .collect(),
5027        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5028        observed_channel_message_id: channels.observed_channel_messages.clone(),
5029    }
5030}
5031
5032fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5033    let mut update = proto::UpdateChannels::default();
5034
5035    for channel in channels.channels {
5036        update.channels.push(channel.to_proto());
5037    }
5038
5039    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5040    update.latest_channel_message_ids = channels.latest_channel_messages;
5041
5042    for (channel_id, participants) in channels.channel_participants {
5043        update
5044            .channel_participants
5045            .push(proto::ChannelParticipants {
5046                channel_id: channel_id.to_proto(),
5047                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5048            });
5049    }
5050
5051    for channel in channels.invited_channels {
5052        update.channel_invitations.push(channel.to_proto());
5053    }
5054
5055    update.hosted_projects = channels.hosted_projects;
5056    update
5057}
5058
5059fn build_initial_contacts_update(
5060    contacts: Vec<db::Contact>,
5061    pool: &ConnectionPool,
5062) -> proto::UpdateContacts {
5063    let mut update = proto::UpdateContacts::default();
5064
5065    for contact in contacts {
5066        match contact {
5067            db::Contact::Accepted { user_id, busy } => {
5068                update.contacts.push(contact_for_user(user_id, busy, pool));
5069            }
5070            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5071            db::Contact::Incoming { user_id } => {
5072                update
5073                    .incoming_requests
5074                    .push(proto::IncomingContactRequest {
5075                        requester_id: user_id.to_proto(),
5076                    })
5077            }
5078        }
5079    }
5080
5081    update
5082}
5083
5084fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5085    proto::Contact {
5086        user_id: user_id.to_proto(),
5087        online: pool.is_user_online(user_id),
5088        busy,
5089    }
5090}
5091
5092fn room_updated(room: &proto::Room, peer: &Peer) {
5093    broadcast(
5094        None,
5095        room.participants
5096            .iter()
5097            .filter_map(|participant| Some(participant.peer_id?.into())),
5098        |peer_id| {
5099            peer.send(
5100                peer_id,
5101                proto::RoomUpdated {
5102                    room: Some(room.clone()),
5103                },
5104            )
5105        },
5106    );
5107}
5108
5109fn channel_updated(
5110    channel: &db::channel::Model,
5111    room: &proto::Room,
5112    peer: &Peer,
5113    pool: &ConnectionPool,
5114) {
5115    let participants = room
5116        .participants
5117        .iter()
5118        .map(|p| p.user_id)
5119        .collect::<Vec<_>>();
5120
5121    broadcast(
5122        None,
5123        pool.channel_connection_ids(channel.root_id())
5124            .filter_map(|(channel_id, role)| {
5125                role.can_see_channel(channel.visibility)
5126                    .then_some(channel_id)
5127            }),
5128        |peer_id| {
5129            peer.send(
5130                peer_id,
5131                proto::UpdateChannels {
5132                    channel_participants: vec![proto::ChannelParticipants {
5133                        channel_id: channel.id.to_proto(),
5134                        participant_user_ids: participants.clone(),
5135                    }],
5136                    ..Default::default()
5137                },
5138            )
5139        },
5140    );
5141}
5142
5143async fn send_dev_server_projects_update(
5144    user_id: UserId,
5145    mut status: proto::DevServerProjectsUpdate,
5146    session: &Session,
5147) {
5148    let pool = session.connection_pool().await;
5149    for dev_server in &mut status.dev_servers {
5150        dev_server.status =
5151            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5152    }
5153    let connections = pool.user_connection_ids(user_id);
5154    for connection_id in connections {
5155        session.peer.send(connection_id, status.clone()).trace_err();
5156    }
5157}
5158
5159async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5160    let db = session.db().await;
5161
5162    let contacts = db.get_contacts(user_id).await?;
5163    let busy = db.is_user_busy(user_id).await?;
5164
5165    let pool = session.connection_pool().await;
5166    let updated_contact = contact_for_user(user_id, busy, &pool);
5167    for contact in contacts {
5168        if let db::Contact::Accepted {
5169            user_id: contact_user_id,
5170            ..
5171        } = contact
5172        {
5173            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5174                session
5175                    .peer
5176                    .send(
5177                        contact_conn_id,
5178                        proto::UpdateContacts {
5179                            contacts: vec![updated_contact.clone()],
5180                            remove_contacts: Default::default(),
5181                            incoming_requests: Default::default(),
5182                            remove_incoming_requests: Default::default(),
5183                            outgoing_requests: Default::default(),
5184                            remove_outgoing_requests: Default::default(),
5185                        },
5186                    )
5187                    .trace_err();
5188            }
5189        }
5190    }
5191    Ok(())
5192}
5193
5194async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5195    log::info!("lost dev server connection, unsharing projects");
5196    let project_ids = session
5197        .db()
5198        .await
5199        .get_stale_dev_server_projects(session.connection_id)
5200        .await?;
5201
5202    for project_id in project_ids {
5203        // not unshare re-checks the connection ids match, so we get away with no transaction
5204        unshare_project_internal(project_id, session.connection_id, None, session).await?;
5205    }
5206
5207    let user_id = session.dev_server().user_id;
5208    let update = session
5209        .db()
5210        .await
5211        .dev_server_projects_update(user_id)
5212        .await?;
5213
5214    send_dev_server_projects_update(user_id, update, session).await;
5215
5216    Ok(())
5217}
5218
5219async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5220    let mut contacts_to_update = HashSet::default();
5221
5222    let room_id;
5223    let canceled_calls_to_user_ids;
5224    let live_kit_room;
5225    let delete_live_kit_room;
5226    let room;
5227    let channel;
5228
5229    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5230        contacts_to_update.insert(session.user_id());
5231
5232        for project in left_room.left_projects.values() {
5233            project_left(project, session);
5234        }
5235
5236        room_id = RoomId::from_proto(left_room.room.id);
5237        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5238        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5239        delete_live_kit_room = left_room.deleted;
5240        room = mem::take(&mut left_room.room);
5241        channel = mem::take(&mut left_room.channel);
5242
5243        room_updated(&room, &session.peer);
5244    } else {
5245        return Ok(());
5246    }
5247
5248    if let Some(channel) = channel {
5249        channel_updated(
5250            &channel,
5251            &room,
5252            &session.peer,
5253            &*session.connection_pool().await,
5254        );
5255    }
5256
5257    {
5258        let pool = session.connection_pool().await;
5259        for canceled_user_id in canceled_calls_to_user_ids {
5260            for connection_id in pool.user_connection_ids(canceled_user_id) {
5261                session
5262                    .peer
5263                    .send(
5264                        connection_id,
5265                        proto::CallCanceled {
5266                            room_id: room_id.to_proto(),
5267                        },
5268                    )
5269                    .trace_err();
5270            }
5271            contacts_to_update.insert(canceled_user_id);
5272        }
5273    }
5274
5275    for contact_user_id in contacts_to_update {
5276        update_user_contacts(contact_user_id, session).await?;
5277    }
5278
5279    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5280        live_kit
5281            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5282            .await
5283            .trace_err();
5284
5285        if delete_live_kit_room {
5286            live_kit.delete_room(live_kit_room).await.trace_err();
5287        }
5288    }
5289
5290    Ok(())
5291}
5292
5293async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5294    let left_channel_buffers = session
5295        .db()
5296        .await
5297        .leave_channel_buffers(session.connection_id)
5298        .await?;
5299
5300    for left_buffer in left_channel_buffers {
5301        channel_buffer_updated(
5302            session.connection_id,
5303            left_buffer.connections,
5304            &proto::UpdateChannelBufferCollaborators {
5305                channel_id: left_buffer.channel_id.to_proto(),
5306                collaborators: left_buffer.collaborators,
5307            },
5308            &session.peer,
5309        );
5310    }
5311
5312    Ok(())
5313}
5314
5315fn project_left(project: &db::LeftProject, session: &UserSession) {
5316    for connection_id in &project.connection_ids {
5317        if project.should_unshare {
5318            session
5319                .peer
5320                .send(
5321                    *connection_id,
5322                    proto::UnshareProject {
5323                        project_id: project.id.to_proto(),
5324                    },
5325                )
5326                .trace_err();
5327        } else {
5328            session
5329                .peer
5330                .send(
5331                    *connection_id,
5332                    proto::RemoveProjectCollaborator {
5333                        project_id: project.id.to_proto(),
5334                        peer_id: Some(session.connection_id.into()),
5335                    },
5336                )
5337                .trace_err();
5338        }
5339    }
5340}
5341
5342pub trait ResultExt {
5343    type Ok;
5344
5345    fn trace_err(self) -> Option<Self::Ok>;
5346}
5347
5348impl<T, E> ResultExt for Result<T, E>
5349where
5350    E: std::fmt::Debug,
5351{
5352    type Ok = T;
5353
5354    #[track_caller]
5355    fn trace_err(self) -> Option<T> {
5356        match self {
5357            Ok(value) => Some(value),
5358            Err(error) => {
5359                tracing::error!("{:?}", error);
5360                None
5361            }
5362        }
5363    }
5364}