rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  38use sha2::Digest;
  39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  40
  41use futures::{
  42    channel::oneshot,
  43    future::{self, BoxFuture},
  44    stream::FuturesUnordered,
  45    FutureExt, SinkExt, StreamExt, TryStreamExt,
  46};
  47use http_client::IsahcHttpClient;
  48use prometheus::{register_int_gauge, IntGauge};
  49use rpc::{
  50    proto::{
  51        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  52        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  53    },
  54    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  55};
  56use semantic_version::SemanticVersion;
  57use serde::{Serialize, Serializer};
  58use std::{
  59    any::TypeId,
  60    future::Future,
  61    marker::PhantomData,
  62    mem,
  63    net::SocketAddr,
  64    ops::{Deref, DerefMut},
  65    rc::Rc,
  66    sync::{
  67        atomic::{AtomicBool, Ordering::SeqCst},
  68        Arc, OnceLock,
  69    },
  70    time::{Duration, Instant},
  71};
  72use time::OffsetDateTime;
  73use tokio::sync::{watch, Semaphore};
  74use tower::ServiceBuilder;
  75use tracing::{
  76    field::{self},
  77    info_span, instrument, Instrument,
  78};
  79
  80use self::connection_pool::VersionedMessage;
  81
  82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  83
  84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  86
  87const MESSAGE_COUNT_PER_PAGE: usize = 100;
  88const MAX_MESSAGE_LEN: usize = 1024;
  89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  90
  91type MessageHandler =
  92    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  93
  94struct Response<R> {
  95    peer: Arc<Peer>,
  96    receipt: Receipt<R>,
  97    responded: Arc<AtomicBool>,
  98}
  99
 100impl<R: RequestMessage> Response<R> {
 101    fn send(self, payload: R::Response) -> Result<()> {
 102        self.responded.store(true, SeqCst);
 103        self.peer.respond(self.receipt, payload)?;
 104        Ok(())
 105    }
 106}
 107
 108struct StreamingResponse<R: RequestMessage> {
 109    peer: Arc<Peer>,
 110    receipt: Receipt<R>,
 111}
 112
 113impl<R: RequestMessage> StreamingResponse<R> {
 114    fn send(&self, payload: R::Response) -> Result<()> {
 115        self.peer.respond(self.receipt, payload)?;
 116        Ok(())
 117    }
 118}
 119
 120#[derive(Clone, Debug)]
 121pub enum Principal {
 122    User(User),
 123    Impersonated { user: User, admin: User },
 124    DevServer(dev_server::Model),
 125}
 126
 127impl Principal {
 128    fn update_span(&self, span: &tracing::Span) {
 129        match &self {
 130            Principal::User(user) => {
 131                span.record("user_id", &user.id.0);
 132                span.record("login", &user.github_login);
 133            }
 134            Principal::Impersonated { user, admin } => {
 135                span.record("user_id", &user.id.0);
 136                span.record("login", &user.github_login);
 137                span.record("impersonator", &admin.github_login);
 138            }
 139            Principal::DevServer(dev_server) => {
 140                span.record("dev_server_id", &dev_server.id.0);
 141            }
 142        }
 143    }
 144}
 145
 146#[derive(Clone)]
 147struct Session {
 148    principal: Principal,
 149    connection_id: ConnectionId,
 150    db: Arc<tokio::sync::Mutex<DbHandle>>,
 151    peer: Arc<Peer>,
 152    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 153    app_state: Arc<AppState>,
 154    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 155    http_client: Arc<IsahcHttpClient>,
 156    /// The GeoIP country code for the user.
 157    #[allow(unused)]
 158    geoip_country_code: Option<String>,
 159    _executor: Executor,
 160}
 161
 162impl Session {
 163    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 164        #[cfg(test)]
 165        tokio::task::yield_now().await;
 166        let guard = self.db.lock().await;
 167        #[cfg(test)]
 168        tokio::task::yield_now().await;
 169        guard
 170    }
 171
 172    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 173        #[cfg(test)]
 174        tokio::task::yield_now().await;
 175        let guard = self.connection_pool.lock();
 176        ConnectionPoolGuard {
 177            guard,
 178            _not_send: PhantomData,
 179        }
 180    }
 181
 182    fn for_user(self) -> Option<UserSession> {
 183        UserSession::new(self)
 184    }
 185
 186    fn for_dev_server(self) -> Option<DevServerSession> {
 187        DevServerSession::new(self)
 188    }
 189
 190    fn user_id(&self) -> Option<UserId> {
 191        match &self.principal {
 192            Principal::User(user) => Some(user.id),
 193            Principal::Impersonated { user, .. } => Some(user.id),
 194            Principal::DevServer(_) => None,
 195        }
 196    }
 197
 198    fn is_staff(&self) -> bool {
 199        match &self.principal {
 200            Principal::User(user) => user.admin,
 201            Principal::Impersonated { .. } => true,
 202            Principal::DevServer(_) => false,
 203        }
 204    }
 205
 206    pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
 207        if self.is_staff() {
 208            return Ok(proto::Plan::ZedPro);
 209        }
 210
 211        let Some(user_id) = self.user_id() else {
 212            return Ok(proto::Plan::Free);
 213        };
 214
 215        let db = self.db().await;
 216        if db.has_active_billing_subscription(user_id).await? {
 217            Ok(proto::Plan::ZedPro)
 218        } else {
 219            Ok(proto::Plan::Free)
 220        }
 221    }
 222
 223    fn dev_server_id(&self) -> Option<DevServerId> {
 224        match &self.principal {
 225            Principal::User(_) | Principal::Impersonated { .. } => None,
 226            Principal::DevServer(dev_server) => Some(dev_server.id),
 227        }
 228    }
 229
 230    fn principal_id(&self) -> PrincipalId {
 231        match &self.principal {
 232            Principal::User(user) => PrincipalId::UserId(user.id),
 233            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 234            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 235        }
 236    }
 237}
 238
 239impl Debug for Session {
 240    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 241        let mut result = f.debug_struct("Session");
 242        match &self.principal {
 243            Principal::User(user) => {
 244                result.field("user", &user.github_login);
 245            }
 246            Principal::Impersonated { user, admin } => {
 247                result.field("user", &user.github_login);
 248                result.field("impersonator", &admin.github_login);
 249            }
 250            Principal::DevServer(dev_server) => {
 251                result.field("dev_server", &dev_server.id);
 252            }
 253        }
 254        result.field("connection_id", &self.connection_id).finish()
 255    }
 256}
 257
 258struct UserSession(Session);
 259
 260impl UserSession {
 261    pub fn new(s: Session) -> Option<Self> {
 262        s.user_id().map(|_| UserSession(s))
 263    }
 264    pub fn user_id(&self) -> UserId {
 265        self.0.user_id().unwrap()
 266    }
 267
 268    pub fn email(&self) -> Option<String> {
 269        match &self.0.principal {
 270            Principal::User(user) => user.email_address.clone(),
 271            Principal::Impersonated { user, .. } => user.email_address.clone(),
 272            Principal::DevServer(..) => None,
 273        }
 274    }
 275}
 276
 277impl Deref for UserSession {
 278    type Target = Session;
 279
 280    fn deref(&self) -> &Self::Target {
 281        &self.0
 282    }
 283}
 284impl DerefMut for UserSession {
 285    fn deref_mut(&mut self) -> &mut Self::Target {
 286        &mut self.0
 287    }
 288}
 289
 290struct DevServerSession(Session);
 291
 292impl DevServerSession {
 293    pub fn new(s: Session) -> Option<Self> {
 294        s.dev_server_id().map(|_| DevServerSession(s))
 295    }
 296    pub fn dev_server_id(&self) -> DevServerId {
 297        self.0.dev_server_id().unwrap()
 298    }
 299
 300    fn dev_server(&self) -> &dev_server::Model {
 301        match &self.0.principal {
 302            Principal::DevServer(dev_server) => dev_server,
 303            _ => unreachable!(),
 304        }
 305    }
 306}
 307
 308impl Deref for DevServerSession {
 309    type Target = Session;
 310
 311    fn deref(&self) -> &Self::Target {
 312        &self.0
 313    }
 314}
 315impl DerefMut for DevServerSession {
 316    fn deref_mut(&mut self) -> &mut Self::Target {
 317        &mut self.0
 318    }
 319}
 320
 321fn user_handler<M: RequestMessage, Fut>(
 322    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 323) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 324where
 325    Fut: Send + Future<Output = Result<()>>,
 326{
 327    let handler = Arc::new(handler);
 328    move |message, response, session| {
 329        let handler = handler.clone();
 330        Box::pin(async move {
 331            if let Some(user_session) = session.for_user() {
 332                Ok(handler(message, response, user_session).await?)
 333            } else {
 334                Err(Error::Internal(anyhow!(
 335                    "must be a user to call {}",
 336                    M::NAME
 337                )))
 338            }
 339        })
 340    }
 341}
 342
 343fn dev_server_handler<M: RequestMessage, Fut>(
 344    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 345) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 346where
 347    Fut: Send + Future<Output = Result<()>>,
 348{
 349    let handler = Arc::new(handler);
 350    move |message, response, session| {
 351        let handler = handler.clone();
 352        Box::pin(async move {
 353            if let Some(dev_server_session) = session.for_dev_server() {
 354                Ok(handler(message, response, dev_server_session).await?)
 355            } else {
 356                Err(Error::Internal(anyhow!(
 357                    "must be a dev server to call {}",
 358                    M::NAME
 359                )))
 360            }
 361        })
 362    }
 363}
 364
 365fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 366    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 367) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 368where
 369    InnertRetFut: Send + Future<Output = Result<()>>,
 370{
 371    let handler = Arc::new(handler);
 372    move |message, session| {
 373        let handler = handler.clone();
 374        Box::pin(async move {
 375            if let Some(user_session) = session.for_user() {
 376                Ok(handler(message, user_session).await?)
 377            } else {
 378                Err(Error::Internal(anyhow!(
 379                    "must be a user to call {}",
 380                    M::NAME
 381                )))
 382            }
 383        })
 384    }
 385}
 386
 387struct DbHandle(Arc<Database>);
 388
 389impl Deref for DbHandle {
 390    type Target = Database;
 391
 392    fn deref(&self) -> &Self::Target {
 393        self.0.as_ref()
 394    }
 395}
 396
 397pub struct Server {
 398    id: parking_lot::Mutex<ServerId>,
 399    peer: Arc<Peer>,
 400    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 401    app_state: Arc<AppState>,
 402    handlers: HashMap<TypeId, MessageHandler>,
 403    teardown: watch::Sender<bool>,
 404}
 405
 406pub(crate) struct ConnectionPoolGuard<'a> {
 407    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 408    _not_send: PhantomData<Rc<()>>,
 409}
 410
 411#[derive(Serialize)]
 412pub struct ServerSnapshot<'a> {
 413    peer: &'a Peer,
 414    #[serde(serialize_with = "serialize_deref")]
 415    connection_pool: ConnectionPoolGuard<'a>,
 416}
 417
 418pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 419where
 420    S: Serializer,
 421    T: Deref<Target = U>,
 422    U: Serialize,
 423{
 424    Serialize::serialize(value.deref(), serializer)
 425}
 426
 427impl Server {
 428    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 429        let mut server = Self {
 430            id: parking_lot::Mutex::new(id),
 431            peer: Peer::new(id.0 as u32),
 432            app_state: app_state.clone(),
 433            connection_pool: Default::default(),
 434            handlers: Default::default(),
 435            teardown: watch::channel(false).0,
 436        };
 437
 438        server
 439            .add_request_handler(ping)
 440            .add_request_handler(user_handler(create_room))
 441            .add_request_handler(user_handler(join_room))
 442            .add_request_handler(user_handler(rejoin_room))
 443            .add_request_handler(user_handler(leave_room))
 444            .add_request_handler(user_handler(set_room_participant_role))
 445            .add_request_handler(user_handler(call))
 446            .add_request_handler(user_handler(cancel_call))
 447            .add_message_handler(user_message_handler(decline_call))
 448            .add_request_handler(user_handler(update_participant_location))
 449            .add_request_handler(user_handler(share_project))
 450            .add_message_handler(unshare_project)
 451            .add_request_handler(user_handler(join_project))
 452            .add_request_handler(user_handler(join_hosted_project))
 453            .add_request_handler(user_handler(rejoin_dev_server_projects))
 454            .add_request_handler(user_handler(create_dev_server_project))
 455            .add_request_handler(user_handler(update_dev_server_project))
 456            .add_request_handler(user_handler(delete_dev_server_project))
 457            .add_request_handler(user_handler(create_dev_server))
 458            .add_request_handler(user_handler(regenerate_dev_server_token))
 459            .add_request_handler(user_handler(rename_dev_server))
 460            .add_request_handler(user_handler(delete_dev_server))
 461            .add_request_handler(user_handler(list_remote_directory))
 462            .add_request_handler(dev_server_handler(share_dev_server_project))
 463            .add_request_handler(dev_server_handler(shutdown_dev_server))
 464            .add_request_handler(dev_server_handler(reconnect_dev_server))
 465            .add_message_handler(user_message_handler(leave_project))
 466            .add_request_handler(update_project)
 467            .add_request_handler(update_worktree)
 468            .add_message_handler(start_language_server)
 469            .add_message_handler(update_language_server)
 470            .add_message_handler(update_diagnostic_summary)
 471            .add_message_handler(update_worktree_settings)
 472            .add_request_handler(user_handler(
 473                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_project_request_for_owner::<proto::TaskTemplates>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::GetHover>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetDefinition>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::GetTypeDefinition>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::GetReferences>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::SearchProject>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_read_only_project_request::<proto::GetProjectSymbols>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_read_only_project_request::<proto::OpenBufferById>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_read_only_project_request::<proto::InlayHints>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_read_only_project_request::<proto::OpenBufferByPath>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::GetCompletions>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::GetCodeActions>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::ApplyCodeAction>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::PrepareRename>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::PerformRename>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::ReloadBuffers>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_mutating_project_request::<proto::FormatBuffers>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::CreateProjectEntry>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::RenameProjectEntry>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::CopyProjectEntry>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 555            ))
 556            .add_request_handler(user_handler(
 557                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 558            ))
 559            .add_request_handler(user_handler(
 560                forward_mutating_project_request::<proto::OnTypeFormatting>,
 561            ))
 562            .add_request_handler(user_handler(
 563                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 564            ))
 565            .add_request_handler(user_handler(
 566                forward_mutating_project_request::<proto::BlameBuffer>,
 567            ))
 568            .add_request_handler(user_handler(
 569                forward_mutating_project_request::<proto::MultiLspQuery>,
 570            ))
 571            .add_request_handler(user_handler(
 572                forward_mutating_project_request::<proto::RestartLanguageServers>,
 573            ))
 574            .add_request_handler(user_handler(
 575                forward_mutating_project_request::<proto::LinkedEditingRange>,
 576            ))
 577            .add_message_handler(create_buffer_for_peer)
 578            .add_request_handler(update_buffer)
 579            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 580            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 581            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 582            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 583            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 584            .add_request_handler(get_users)
 585            .add_request_handler(user_handler(fuzzy_search_users))
 586            .add_request_handler(user_handler(request_contact))
 587            .add_request_handler(user_handler(remove_contact))
 588            .add_request_handler(user_handler(respond_to_contact_request))
 589            .add_message_handler(subscribe_to_channels)
 590            .add_request_handler(user_handler(create_channel))
 591            .add_request_handler(user_handler(delete_channel))
 592            .add_request_handler(user_handler(invite_channel_member))
 593            .add_request_handler(user_handler(remove_channel_member))
 594            .add_request_handler(user_handler(set_channel_member_role))
 595            .add_request_handler(user_handler(set_channel_visibility))
 596            .add_request_handler(user_handler(rename_channel))
 597            .add_request_handler(user_handler(join_channel_buffer))
 598            .add_request_handler(user_handler(leave_channel_buffer))
 599            .add_message_handler(user_message_handler(update_channel_buffer))
 600            .add_request_handler(user_handler(rejoin_channel_buffers))
 601            .add_request_handler(user_handler(get_channel_members))
 602            .add_request_handler(user_handler(respond_to_channel_invite))
 603            .add_request_handler(user_handler(join_channel))
 604            .add_request_handler(user_handler(join_channel_chat))
 605            .add_message_handler(user_message_handler(leave_channel_chat))
 606            .add_request_handler(user_handler(send_channel_message))
 607            .add_request_handler(user_handler(remove_channel_message))
 608            .add_request_handler(user_handler(update_channel_message))
 609            .add_request_handler(user_handler(get_channel_messages))
 610            .add_request_handler(user_handler(get_channel_messages_by_id))
 611            .add_request_handler(user_handler(get_notifications))
 612            .add_request_handler(user_handler(mark_notification_as_read))
 613            .add_request_handler(user_handler(move_channel))
 614            .add_request_handler(user_handler(follow))
 615            .add_message_handler(user_message_handler(unfollow))
 616            .add_message_handler(user_message_handler(update_followers))
 617            .add_request_handler(user_handler(get_private_user_info))
 618            .add_request_handler(user_handler(get_llm_api_token))
 619            .add_message_handler(user_message_handler(acknowledge_channel_message))
 620            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 621            .add_request_handler(user_handler(get_supermaven_api_key))
 622            .add_request_handler(user_handler(
 623                forward_mutating_project_request::<proto::OpenContext>,
 624            ))
 625            .add_request_handler(user_handler(
 626                forward_mutating_project_request::<proto::CreateContext>,
 627            ))
 628            .add_request_handler(user_handler(
 629                forward_mutating_project_request::<proto::SynchronizeContexts>,
 630            ))
 631            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 632            .add_message_handler(update_context)
 633            .add_request_handler({
 634                let app_state = app_state.clone();
 635                move |request, response, session| {
 636                    let app_state = app_state.clone();
 637                    async move {
 638                        complete_with_language_model(request, response, session, &app_state.config)
 639                            .await
 640                    }
 641                }
 642            })
 643            .add_streaming_request_handler({
 644                let app_state = app_state.clone();
 645                move |request, response, session| {
 646                    let app_state = app_state.clone();
 647                    async move {
 648                        stream_complete_with_language_model(
 649                            request,
 650                            response,
 651                            session,
 652                            &app_state.config,
 653                        )
 654                        .await
 655                    }
 656                }
 657            })
 658            .add_request_handler({
 659                let app_state = app_state.clone();
 660                move |request, response, session| {
 661                    let app_state = app_state.clone();
 662                    async move {
 663                        count_language_model_tokens(request, response, session, &app_state.config)
 664                            .await
 665                    }
 666                }
 667            })
 668            .add_request_handler({
 669                user_handler(move |request, response, session| {
 670                    get_cached_embeddings(request, response, session)
 671                })
 672            })
 673            .add_request_handler({
 674                let app_state = app_state.clone();
 675                user_handler(move |request, response, session| {
 676                    compute_embeddings(
 677                        request,
 678                        response,
 679                        session,
 680                        app_state.config.openai_api_key.clone(),
 681                    )
 682                })
 683            });
 684
 685        Arc::new(server)
 686    }
 687
 688    pub async fn start(&self) -> Result<()> {
 689        let server_id = *self.id.lock();
 690        let app_state = self.app_state.clone();
 691        let peer = self.peer.clone();
 692        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 693        let pool = self.connection_pool.clone();
 694        let live_kit_client = self.app_state.live_kit_client.clone();
 695
 696        let span = info_span!("start server");
 697        self.app_state.executor.spawn_detached(
 698            async move {
 699                tracing::info!("waiting for cleanup timeout");
 700                timeout.await;
 701                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 702                if let Some((room_ids, channel_ids)) = app_state
 703                    .db
 704                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 705                    .await
 706                    .trace_err()
 707                {
 708                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 709                    tracing::info!(
 710                        stale_channel_buffer_count = channel_ids.len(),
 711                        "retrieved stale channel buffers"
 712                    );
 713
 714                    for channel_id in channel_ids {
 715                        if let Some(refreshed_channel_buffer) = app_state
 716                            .db
 717                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 718                            .await
 719                            .trace_err()
 720                        {
 721                            for connection_id in refreshed_channel_buffer.connection_ids {
 722                                peer.send(
 723                                    connection_id,
 724                                    proto::UpdateChannelBufferCollaborators {
 725                                        channel_id: channel_id.to_proto(),
 726                                        collaborators: refreshed_channel_buffer
 727                                            .collaborators
 728                                            .clone(),
 729                                    },
 730                                )
 731                                .trace_err();
 732                            }
 733                        }
 734                    }
 735
 736                    for room_id in room_ids {
 737                        let mut contacts_to_update = HashSet::default();
 738                        let mut canceled_calls_to_user_ids = Vec::new();
 739                        let mut live_kit_room = String::new();
 740                        let mut delete_live_kit_room = false;
 741
 742                        if let Some(mut refreshed_room) = app_state
 743                            .db
 744                            .clear_stale_room_participants(room_id, server_id)
 745                            .await
 746                            .trace_err()
 747                        {
 748                            tracing::info!(
 749                                room_id = room_id.0,
 750                                new_participant_count = refreshed_room.room.participants.len(),
 751                                "refreshed room"
 752                            );
 753                            room_updated(&refreshed_room.room, &peer);
 754                            if let Some(channel) = refreshed_room.channel.as_ref() {
 755                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 756                            }
 757                            contacts_to_update
 758                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 759                            contacts_to_update
 760                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 761                            canceled_calls_to_user_ids =
 762                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 763                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 764                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 765                        }
 766
 767                        {
 768                            let pool = pool.lock();
 769                            for canceled_user_id in canceled_calls_to_user_ids {
 770                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 771                                    peer.send(
 772                                        connection_id,
 773                                        proto::CallCanceled {
 774                                            room_id: room_id.to_proto(),
 775                                        },
 776                                    )
 777                                    .trace_err();
 778                                }
 779                            }
 780                        }
 781
 782                        for user_id in contacts_to_update {
 783                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 784                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 785                            if let Some((busy, contacts)) = busy.zip(contacts) {
 786                                let pool = pool.lock();
 787                                let updated_contact = contact_for_user(user_id, busy, &pool);
 788                                for contact in contacts {
 789                                    if let db::Contact::Accepted {
 790                                        user_id: contact_user_id,
 791                                        ..
 792                                    } = contact
 793                                    {
 794                                        for contact_conn_id in
 795                                            pool.user_connection_ids(contact_user_id)
 796                                        {
 797                                            peer.send(
 798                                                contact_conn_id,
 799                                                proto::UpdateContacts {
 800                                                    contacts: vec![updated_contact.clone()],
 801                                                    remove_contacts: Default::default(),
 802                                                    incoming_requests: Default::default(),
 803                                                    remove_incoming_requests: Default::default(),
 804                                                    outgoing_requests: Default::default(),
 805                                                    remove_outgoing_requests: Default::default(),
 806                                                },
 807                                            )
 808                                            .trace_err();
 809                                        }
 810                                    }
 811                                }
 812                            }
 813                        }
 814
 815                        if let Some(live_kit) = live_kit_client.as_ref() {
 816                            if delete_live_kit_room {
 817                                live_kit.delete_room(live_kit_room).await.trace_err();
 818                            }
 819                        }
 820                    }
 821                }
 822
 823                app_state
 824                    .db
 825                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 826                    .await
 827                    .trace_err();
 828            }
 829            .instrument(span),
 830        );
 831        Ok(())
 832    }
 833
 834    pub fn teardown(&self) {
 835        self.peer.teardown();
 836        self.connection_pool.lock().reset();
 837        let _ = self.teardown.send(true);
 838    }
 839
 840    #[cfg(test)]
 841    pub fn reset(&self, id: ServerId) {
 842        self.teardown();
 843        *self.id.lock() = id;
 844        self.peer.reset(id.0 as u32);
 845        let _ = self.teardown.send(false);
 846    }
 847
 848    #[cfg(test)]
 849    pub fn id(&self) -> ServerId {
 850        *self.id.lock()
 851    }
 852
 853    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 854    where
 855        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 856        Fut: 'static + Send + Future<Output = Result<()>>,
 857        M: EnvelopedMessage,
 858    {
 859        let prev_handler = self.handlers.insert(
 860            TypeId::of::<M>(),
 861            Box::new(move |envelope, session| {
 862                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 863                let received_at = envelope.received_at;
 864                tracing::info!("message received");
 865                let start_time = Instant::now();
 866                let future = (handler)(*envelope, session);
 867                async move {
 868                    let result = future.await;
 869                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 870                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 871                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 872                    let payload_type = M::NAME;
 873
 874                    match result {
 875                        Err(error) => {
 876                            tracing::error!(
 877                                ?error,
 878                                total_duration_ms,
 879                                processing_duration_ms,
 880                                queue_duration_ms,
 881                                payload_type,
 882                                "error handling message"
 883                            )
 884                        }
 885                        Ok(()) => tracing::info!(
 886                            total_duration_ms,
 887                            processing_duration_ms,
 888                            queue_duration_ms,
 889                            "finished handling message"
 890                        ),
 891                    }
 892                }
 893                .boxed()
 894            }),
 895        );
 896        if prev_handler.is_some() {
 897            panic!("registered a handler for the same message twice");
 898        }
 899        self
 900    }
 901
 902    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 903    where
 904        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 905        Fut: 'static + Send + Future<Output = Result<()>>,
 906        M: EnvelopedMessage,
 907    {
 908        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 909        self
 910    }
 911
 912    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 913    where
 914        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 915        Fut: Send + Future<Output = Result<()>>,
 916        M: RequestMessage,
 917    {
 918        let handler = Arc::new(handler);
 919        self.add_handler(move |envelope, session| {
 920            let receipt = envelope.receipt();
 921            let handler = handler.clone();
 922            async move {
 923                let peer = session.peer.clone();
 924                let responded = Arc::new(AtomicBool::default());
 925                let response = Response {
 926                    peer: peer.clone(),
 927                    responded: responded.clone(),
 928                    receipt,
 929                };
 930                match (handler)(envelope.payload, response, session).await {
 931                    Ok(()) => {
 932                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 933                            Ok(())
 934                        } else {
 935                            Err(anyhow!("handler did not send a response"))?
 936                        }
 937                    }
 938                    Err(error) => {
 939                        let proto_err = match &error {
 940                            Error::Internal(err) => err.to_proto(),
 941                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 942                        };
 943                        peer.respond_with_error(receipt, proto_err)?;
 944                        Err(error)
 945                    }
 946                }
 947            }
 948        })
 949    }
 950
 951    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 952    where
 953        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 954        Fut: Send + Future<Output = Result<()>>,
 955        M: RequestMessage,
 956    {
 957        let handler = Arc::new(handler);
 958        self.add_handler(move |envelope, session| {
 959            let receipt = envelope.receipt();
 960            let handler = handler.clone();
 961            async move {
 962                let peer = session.peer.clone();
 963                let response = StreamingResponse {
 964                    peer: peer.clone(),
 965                    receipt,
 966                };
 967                match (handler)(envelope.payload, response, session).await {
 968                    Ok(()) => {
 969                        peer.end_stream(receipt)?;
 970                        Ok(())
 971                    }
 972                    Err(error) => {
 973                        let proto_err = match &error {
 974                            Error::Internal(err) => err.to_proto(),
 975                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 976                        };
 977                        peer.respond_with_error(receipt, proto_err)?;
 978                        Err(error)
 979                    }
 980                }
 981            }
 982        })
 983    }
 984
 985    #[allow(clippy::too_many_arguments)]
 986    pub fn handle_connection(
 987        self: &Arc<Self>,
 988        connection: Connection,
 989        address: String,
 990        principal: Principal,
 991        zed_version: ZedVersion,
 992        geoip_country_code: Option<String>,
 993        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 994        executor: Executor,
 995    ) -> impl Future<Output = ()> {
 996        let this = self.clone();
 997        let span = info_span!("handle connection", %address,
 998            connection_id=field::Empty,
 999            user_id=field::Empty,
1000            login=field::Empty,
1001            impersonator=field::Empty,
1002            dev_server_id=field::Empty,
1003            geoip_country_code=field::Empty
1004        );
1005        principal.update_span(&span);
1006        if let Some(country_code) = geoip_country_code.as_ref() {
1007            span.record("geoip_country_code", country_code);
1008        }
1009
1010        let mut teardown = self.teardown.subscribe();
1011        async move {
1012            if *teardown.borrow() {
1013                tracing::error!("server is tearing down");
1014                return
1015            }
1016            let (connection_id, handle_io, mut incoming_rx) = this
1017                .peer
1018                .add_connection(connection, {
1019                    let executor = executor.clone();
1020                    move |duration| executor.sleep(duration)
1021                });
1022            tracing::Span::current().record("connection_id", format!("{}", connection_id));
1023
1024            tracing::info!("connection opened");
1025
1026            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
1027            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1028                Ok(http_client) => Arc::new(http_client),
1029                Err(error) => {
1030                    tracing::error!(?error, "failed to create HTTP client");
1031                    return;
1032                }
1033            };
1034
1035            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1036                Some(Arc::new(SupermavenAdminApi::new(
1037                    supermaven_admin_api_key.to_string(),
1038                    http_client.clone(),
1039                )))
1040            } else {
1041                None
1042            };
1043
1044            let session = Session {
1045                principal: principal.clone(),
1046                connection_id,
1047                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1048                peer: this.peer.clone(),
1049                connection_pool: this.connection_pool.clone(),
1050                app_state: this.app_state.clone(),
1051                http_client,
1052                geoip_country_code,
1053                _executor: executor.clone(),
1054                supermaven_client,
1055            };
1056
1057            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1058                tracing::error!(?error, "failed to send initial client update");
1059                return;
1060            }
1061
1062            let handle_io = handle_io.fuse();
1063            futures::pin_mut!(handle_io);
1064
1065            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1066            // This prevents deadlocks when e.g., client A performs a request to client B and
1067            // client B performs a request to client A. If both clients stop processing further
1068            // messages until their respective request completes, they won't have a chance to
1069            // respond to the other client's request and cause a deadlock.
1070            //
1071            // This arrangement ensures we will attempt to process earlier messages first, but fall
1072            // back to processing messages arrived later in the spirit of making progress.
1073            let mut foreground_message_handlers = FuturesUnordered::new();
1074            let concurrent_handlers = Arc::new(Semaphore::new(256));
1075            loop {
1076                let next_message = async {
1077                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1078                    let message = incoming_rx.next().await;
1079                    (permit, message)
1080                }.fuse();
1081                futures::pin_mut!(next_message);
1082                futures::select_biased! {
1083                    _ = teardown.changed().fuse() => return,
1084                    result = handle_io => {
1085                        if let Err(error) = result {
1086                            tracing::error!(?error, "error handling I/O");
1087                        }
1088                        break;
1089                    }
1090                    _ = foreground_message_handlers.next() => {}
1091                    next_message = next_message => {
1092                        let (permit, message) = next_message;
1093                        if let Some(message) = message {
1094                            let type_name = message.payload_type_name();
1095                            // note: we copy all the fields from the parent span so we can query them in the logs.
1096                            // (https://github.com/tokio-rs/tracing/issues/2670).
1097                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1098                                user_id=field::Empty,
1099                                login=field::Empty,
1100                                impersonator=field::Empty,
1101                                dev_server_id=field::Empty
1102                            );
1103                            principal.update_span(&span);
1104                            let span_enter = span.enter();
1105                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1106                                let is_background = message.is_background();
1107                                let handle_message = (handler)(message, session.clone());
1108                                drop(span_enter);
1109
1110                                let handle_message = async move {
1111                                    handle_message.await;
1112                                    drop(permit);
1113                                }.instrument(span);
1114                                if is_background {
1115                                    executor.spawn_detached(handle_message);
1116                                } else {
1117                                    foreground_message_handlers.push(handle_message);
1118                                }
1119                            } else {
1120                                tracing::error!("no message handler");
1121                            }
1122                        } else {
1123                            tracing::info!("connection closed");
1124                            break;
1125                        }
1126                    }
1127                }
1128            }
1129
1130            drop(foreground_message_handlers);
1131            tracing::info!("signing out");
1132            if let Err(error) = connection_lost(session, teardown, executor).await {
1133                tracing::error!(?error, "error signing out");
1134            }
1135
1136        }.instrument(span)
1137    }
1138
1139    async fn send_initial_client_update(
1140        &self,
1141        connection_id: ConnectionId,
1142        principal: &Principal,
1143        zed_version: ZedVersion,
1144        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1145        session: &Session,
1146    ) -> Result<()> {
1147        self.peer.send(
1148            connection_id,
1149            proto::Hello {
1150                peer_id: Some(connection_id.into()),
1151            },
1152        )?;
1153        tracing::info!("sent hello message");
1154        if let Some(send_connection_id) = send_connection_id.take() {
1155            let _ = send_connection_id.send(connection_id);
1156        }
1157
1158        match principal {
1159            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1160                if !user.connected_once {
1161                    self.peer.send(connection_id, proto::ShowContacts {})?;
1162                    self.app_state
1163                        .db
1164                        .set_user_connected_once(user.id, true)
1165                        .await?;
1166                }
1167
1168                update_user_plan(user.id, session).await?;
1169
1170                let (contacts, dev_server_projects) = future::try_join(
1171                    self.app_state.db.get_contacts(user.id),
1172                    self.app_state.db.dev_server_projects_update(user.id),
1173                )
1174                .await?;
1175
1176                {
1177                    let mut pool = self.connection_pool.lock();
1178                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1179                    self.peer.send(
1180                        connection_id,
1181                        build_initial_contacts_update(contacts, &pool),
1182                    )?;
1183                }
1184
1185                if should_auto_subscribe_to_channels(zed_version) {
1186                    subscribe_user_to_channels(user.id, session).await?;
1187                }
1188
1189                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1190
1191                if let Some(incoming_call) =
1192                    self.app_state.db.incoming_call_for_user(user.id).await?
1193                {
1194                    self.peer.send(connection_id, incoming_call)?;
1195                }
1196
1197                update_user_contacts(user.id, &session).await?;
1198            }
1199            Principal::DevServer(dev_server) => {
1200                {
1201                    let mut pool = self.connection_pool.lock();
1202                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1203                    {
1204                        self.peer.send(
1205                            stale_connection_id,
1206                            proto::ShutdownDevServer {
1207                                reason: Some(
1208                                    "another dev server connected with the same token".to_string(),
1209                                ),
1210                            },
1211                        )?;
1212                        pool.remove_connection(stale_connection_id)?;
1213                    };
1214                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1215                }
1216
1217                let projects = self
1218                    .app_state
1219                    .db
1220                    .get_projects_for_dev_server(dev_server.id)
1221                    .await?;
1222                self.peer
1223                    .send(connection_id, proto::DevServerInstructions { projects })?;
1224
1225                let status = self
1226                    .app_state
1227                    .db
1228                    .dev_server_projects_update(dev_server.user_id)
1229                    .await?;
1230                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1231            }
1232        }
1233
1234        Ok(())
1235    }
1236
1237    pub async fn invite_code_redeemed(
1238        self: &Arc<Self>,
1239        inviter_id: UserId,
1240        invitee_id: UserId,
1241    ) -> Result<()> {
1242        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1243            if let Some(code) = &user.invite_code {
1244                let pool = self.connection_pool.lock();
1245                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1246                for connection_id in pool.user_connection_ids(inviter_id) {
1247                    self.peer.send(
1248                        connection_id,
1249                        proto::UpdateContacts {
1250                            contacts: vec![invitee_contact.clone()],
1251                            ..Default::default()
1252                        },
1253                    )?;
1254                    self.peer.send(
1255                        connection_id,
1256                        proto::UpdateInviteInfo {
1257                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1258                            count: user.invite_count as u32,
1259                        },
1260                    )?;
1261                }
1262            }
1263        }
1264        Ok(())
1265    }
1266
1267    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1268        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1269            if let Some(invite_code) = &user.invite_code {
1270                let pool = self.connection_pool.lock();
1271                for connection_id in pool.user_connection_ids(user_id) {
1272                    self.peer.send(
1273                        connection_id,
1274                        proto::UpdateInviteInfo {
1275                            url: format!(
1276                                "{}{}",
1277                                self.app_state.config.invite_link_prefix, invite_code
1278                            ),
1279                            count: user.invite_count as u32,
1280                        },
1281                    )?;
1282                }
1283            }
1284        }
1285        Ok(())
1286    }
1287
1288    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1289        ServerSnapshot {
1290            connection_pool: ConnectionPoolGuard {
1291                guard: self.connection_pool.lock(),
1292                _not_send: PhantomData,
1293            },
1294            peer: &self.peer,
1295        }
1296    }
1297}
1298
1299impl<'a> Deref for ConnectionPoolGuard<'a> {
1300    type Target = ConnectionPool;
1301
1302    fn deref(&self) -> &Self::Target {
1303        &self.guard
1304    }
1305}
1306
1307impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1308    fn deref_mut(&mut self) -> &mut Self::Target {
1309        &mut self.guard
1310    }
1311}
1312
1313impl<'a> Drop for ConnectionPoolGuard<'a> {
1314    fn drop(&mut self) {
1315        #[cfg(test)]
1316        self.check_invariants();
1317    }
1318}
1319
1320fn broadcast<F>(
1321    sender_id: Option<ConnectionId>,
1322    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1323    mut f: F,
1324) where
1325    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1326{
1327    for receiver_id in receiver_ids {
1328        if Some(receiver_id) != sender_id {
1329            if let Err(error) = f(receiver_id) {
1330                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1331            }
1332        }
1333    }
1334}
1335
1336pub struct ProtocolVersion(u32);
1337
1338impl Header for ProtocolVersion {
1339    fn name() -> &'static HeaderName {
1340        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1341        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1342    }
1343
1344    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1345    where
1346        Self: Sized,
1347        I: Iterator<Item = &'i axum::http::HeaderValue>,
1348    {
1349        let version = values
1350            .next()
1351            .ok_or_else(axum::headers::Error::invalid)?
1352            .to_str()
1353            .map_err(|_| axum::headers::Error::invalid())?
1354            .parse()
1355            .map_err(|_| axum::headers::Error::invalid())?;
1356        Ok(Self(version))
1357    }
1358
1359    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1360        values.extend([self.0.to_string().parse().unwrap()]);
1361    }
1362}
1363
1364pub struct AppVersionHeader(SemanticVersion);
1365impl Header for AppVersionHeader {
1366    fn name() -> &'static HeaderName {
1367        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1368        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1369    }
1370
1371    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1372    where
1373        Self: Sized,
1374        I: Iterator<Item = &'i axum::http::HeaderValue>,
1375    {
1376        let version = values
1377            .next()
1378            .ok_or_else(axum::headers::Error::invalid)?
1379            .to_str()
1380            .map_err(|_| axum::headers::Error::invalid())?
1381            .parse()
1382            .map_err(|_| axum::headers::Error::invalid())?;
1383        Ok(Self(version))
1384    }
1385
1386    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1387        values.extend([self.0.to_string().parse().unwrap()]);
1388    }
1389}
1390
1391pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1392    Router::new()
1393        .route("/rpc", get(handle_websocket_request))
1394        .layer(
1395            ServiceBuilder::new()
1396                .layer(Extension(server.app_state.clone()))
1397                .layer(middleware::from_fn(auth::validate_header)),
1398        )
1399        .route("/metrics", get(handle_metrics))
1400        .layer(Extension(server))
1401}
1402
1403pub async fn handle_websocket_request(
1404    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1405    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1406    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1407    Extension(server): Extension<Arc<Server>>,
1408    Extension(principal): Extension<Principal>,
1409    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1410    ws: WebSocketUpgrade,
1411) -> axum::response::Response {
1412    if protocol_version != rpc::PROTOCOL_VERSION {
1413        return (
1414            StatusCode::UPGRADE_REQUIRED,
1415            "client must be upgraded".to_string(),
1416        )
1417            .into_response();
1418    }
1419
1420    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1421        return (
1422            StatusCode::UPGRADE_REQUIRED,
1423            "no version header found".to_string(),
1424        )
1425            .into_response();
1426    };
1427
1428    if !version.can_collaborate() {
1429        return (
1430            StatusCode::UPGRADE_REQUIRED,
1431            "client must be upgraded".to_string(),
1432        )
1433            .into_response();
1434    }
1435
1436    let socket_address = socket_address.to_string();
1437    ws.on_upgrade(move |socket| {
1438        let socket = socket
1439            .map_ok(to_tungstenite_message)
1440            .err_into()
1441            .with(|message| async move { to_axum_message(message) });
1442        let connection = Connection::new(Box::pin(socket));
1443        async move {
1444            server
1445                .handle_connection(
1446                    connection,
1447                    socket_address,
1448                    principal,
1449                    version,
1450                    country_code_header.map(|header| header.to_string()),
1451                    None,
1452                    Executor::Production,
1453                )
1454                .await;
1455        }
1456    })
1457}
1458
1459pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1460    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1461    let connections_metric = CONNECTIONS_METRIC
1462        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1463
1464    let connections = server
1465        .connection_pool
1466        .lock()
1467        .connections()
1468        .filter(|connection| !connection.admin)
1469        .count();
1470    connections_metric.set(connections as _);
1471
1472    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1473    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1474        register_int_gauge!(
1475            "shared_projects",
1476            "number of open projects with one or more guests"
1477        )
1478        .unwrap()
1479    });
1480
1481    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1482    shared_projects_metric.set(shared_projects as _);
1483
1484    let encoder = prometheus::TextEncoder::new();
1485    let metric_families = prometheus::gather();
1486    let encoded_metrics = encoder
1487        .encode_to_string(&metric_families)
1488        .map_err(|err| anyhow!("{}", err))?;
1489    Ok(encoded_metrics)
1490}
1491
1492#[instrument(err, skip(executor))]
1493async fn connection_lost(
1494    session: Session,
1495    mut teardown: watch::Receiver<bool>,
1496    executor: Executor,
1497) -> Result<()> {
1498    session.peer.disconnect(session.connection_id);
1499    session
1500        .connection_pool()
1501        .await
1502        .remove_connection(session.connection_id)?;
1503
1504    session
1505        .db()
1506        .await
1507        .connection_lost(session.connection_id)
1508        .await
1509        .trace_err();
1510
1511    futures::select_biased! {
1512        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1513            match &session.principal {
1514                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1515                    let session = session.for_user().unwrap();
1516
1517                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1518                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1519                    leave_channel_buffers_for_session(&session)
1520                        .await
1521                        .trace_err();
1522
1523                    if !session
1524                        .connection_pool()
1525                        .await
1526                        .is_user_online(session.user_id())
1527                    {
1528                        let db = session.db().await;
1529                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1530                            room_updated(&room, &session.peer);
1531                        }
1532                    }
1533
1534                    update_user_contacts(session.user_id(), &session).await?;
1535                },
1536            Principal::DevServer(_) => {
1537                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1538            },
1539        }
1540        },
1541        _ = teardown.changed().fuse() => {}
1542    }
1543
1544    Ok(())
1545}
1546
1547/// Acknowledges a ping from a client, used to keep the connection alive.
1548async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1549    response.send(proto::Ack {})?;
1550    Ok(())
1551}
1552
1553/// Creates a new room for calling (outside of channels)
1554async fn create_room(
1555    _request: proto::CreateRoom,
1556    response: Response<proto::CreateRoom>,
1557    session: UserSession,
1558) -> Result<()> {
1559    let live_kit_room = nanoid::nanoid!(30);
1560
1561    let live_kit_connection_info = util::maybe!(async {
1562        let live_kit = session.app_state.live_kit_client.as_ref();
1563        let live_kit = live_kit?;
1564        let user_id = session.user_id().to_string();
1565
1566        let token = live_kit
1567            .room_token(&live_kit_room, &user_id.to_string())
1568            .trace_err()?;
1569
1570        Some(proto::LiveKitConnectionInfo {
1571            server_url: live_kit.url().into(),
1572            token,
1573            can_publish: true,
1574        })
1575    })
1576    .await;
1577
1578    let room = session
1579        .db()
1580        .await
1581        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1582        .await?;
1583
1584    response.send(proto::CreateRoomResponse {
1585        room: Some(room.clone()),
1586        live_kit_connection_info,
1587    })?;
1588
1589    update_user_contacts(session.user_id(), &session).await?;
1590    Ok(())
1591}
1592
1593/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1594async fn join_room(
1595    request: proto::JoinRoom,
1596    response: Response<proto::JoinRoom>,
1597    session: UserSession,
1598) -> Result<()> {
1599    let room_id = RoomId::from_proto(request.id);
1600
1601    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1602
1603    if let Some(channel_id) = channel_id {
1604        return join_channel_internal(channel_id, Box::new(response), session).await;
1605    }
1606
1607    let joined_room = {
1608        let room = session
1609            .db()
1610            .await
1611            .join_room(room_id, session.user_id(), session.connection_id)
1612            .await?;
1613        room_updated(&room.room, &session.peer);
1614        room.into_inner()
1615    };
1616
1617    for connection_id in session
1618        .connection_pool()
1619        .await
1620        .user_connection_ids(session.user_id())
1621    {
1622        session
1623            .peer
1624            .send(
1625                connection_id,
1626                proto::CallCanceled {
1627                    room_id: room_id.to_proto(),
1628                },
1629            )
1630            .trace_err();
1631    }
1632
1633    let live_kit_connection_info =
1634        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1635            if let Some(token) = live_kit
1636                .room_token(
1637                    &joined_room.room.live_kit_room,
1638                    &session.user_id().to_string(),
1639                )
1640                .trace_err()
1641            {
1642                Some(proto::LiveKitConnectionInfo {
1643                    server_url: live_kit.url().into(),
1644                    token,
1645                    can_publish: true,
1646                })
1647            } else {
1648                None
1649            }
1650        } else {
1651            None
1652        };
1653
1654    response.send(proto::JoinRoomResponse {
1655        room: Some(joined_room.room),
1656        channel_id: None,
1657        live_kit_connection_info,
1658    })?;
1659
1660    update_user_contacts(session.user_id(), &session).await?;
1661    Ok(())
1662}
1663
1664/// Rejoin room is used to reconnect to a room after connection errors.
1665async fn rejoin_room(
1666    request: proto::RejoinRoom,
1667    response: Response<proto::RejoinRoom>,
1668    session: UserSession,
1669) -> Result<()> {
1670    let room;
1671    let channel;
1672    {
1673        let mut rejoined_room = session
1674            .db()
1675            .await
1676            .rejoin_room(request, session.user_id(), session.connection_id)
1677            .await?;
1678
1679        response.send(proto::RejoinRoomResponse {
1680            room: Some(rejoined_room.room.clone()),
1681            reshared_projects: rejoined_room
1682                .reshared_projects
1683                .iter()
1684                .map(|project| proto::ResharedProject {
1685                    id: project.id.to_proto(),
1686                    collaborators: project
1687                        .collaborators
1688                        .iter()
1689                        .map(|collaborator| collaborator.to_proto())
1690                        .collect(),
1691                })
1692                .collect(),
1693            rejoined_projects: rejoined_room
1694                .rejoined_projects
1695                .iter()
1696                .map(|rejoined_project| rejoined_project.to_proto())
1697                .collect(),
1698        })?;
1699        room_updated(&rejoined_room.room, &session.peer);
1700
1701        for project in &rejoined_room.reshared_projects {
1702            for collaborator in &project.collaborators {
1703                session
1704                    .peer
1705                    .send(
1706                        collaborator.connection_id,
1707                        proto::UpdateProjectCollaborator {
1708                            project_id: project.id.to_proto(),
1709                            old_peer_id: Some(project.old_connection_id.into()),
1710                            new_peer_id: Some(session.connection_id.into()),
1711                        },
1712                    )
1713                    .trace_err();
1714            }
1715
1716            broadcast(
1717                Some(session.connection_id),
1718                project
1719                    .collaborators
1720                    .iter()
1721                    .map(|collaborator| collaborator.connection_id),
1722                |connection_id| {
1723                    session.peer.forward_send(
1724                        session.connection_id,
1725                        connection_id,
1726                        proto::UpdateProject {
1727                            project_id: project.id.to_proto(),
1728                            worktrees: project.worktrees.clone(),
1729                        },
1730                    )
1731                },
1732            );
1733        }
1734
1735        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1736
1737        let rejoined_room = rejoined_room.into_inner();
1738
1739        room = rejoined_room.room;
1740        channel = rejoined_room.channel;
1741    }
1742
1743    if let Some(channel) = channel {
1744        channel_updated(
1745            &channel,
1746            &room,
1747            &session.peer,
1748            &*session.connection_pool().await,
1749        );
1750    }
1751
1752    update_user_contacts(session.user_id(), &session).await?;
1753    Ok(())
1754}
1755
1756fn notify_rejoined_projects(
1757    rejoined_projects: &mut Vec<RejoinedProject>,
1758    session: &UserSession,
1759) -> Result<()> {
1760    for project in rejoined_projects.iter() {
1761        for collaborator in &project.collaborators {
1762            session
1763                .peer
1764                .send(
1765                    collaborator.connection_id,
1766                    proto::UpdateProjectCollaborator {
1767                        project_id: project.id.to_proto(),
1768                        old_peer_id: Some(project.old_connection_id.into()),
1769                        new_peer_id: Some(session.connection_id.into()),
1770                    },
1771                )
1772                .trace_err();
1773        }
1774    }
1775
1776    for project in rejoined_projects {
1777        for worktree in mem::take(&mut project.worktrees) {
1778            #[cfg(any(test, feature = "test-support"))]
1779            const MAX_CHUNK_SIZE: usize = 2;
1780            #[cfg(not(any(test, feature = "test-support")))]
1781            const MAX_CHUNK_SIZE: usize = 256;
1782
1783            // Stream this worktree's entries.
1784            let message = proto::UpdateWorktree {
1785                project_id: project.id.to_proto(),
1786                worktree_id: worktree.id,
1787                abs_path: worktree.abs_path.clone(),
1788                root_name: worktree.root_name,
1789                updated_entries: worktree.updated_entries,
1790                removed_entries: worktree.removed_entries,
1791                scan_id: worktree.scan_id,
1792                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1793                updated_repositories: worktree.updated_repositories,
1794                removed_repositories: worktree.removed_repositories,
1795            };
1796            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1797                session.peer.send(session.connection_id, update.clone())?;
1798            }
1799
1800            // Stream this worktree's diagnostics.
1801            for summary in worktree.diagnostic_summaries {
1802                session.peer.send(
1803                    session.connection_id,
1804                    proto::UpdateDiagnosticSummary {
1805                        project_id: project.id.to_proto(),
1806                        worktree_id: worktree.id,
1807                        summary: Some(summary),
1808                    },
1809                )?;
1810            }
1811
1812            for settings_file in worktree.settings_files {
1813                session.peer.send(
1814                    session.connection_id,
1815                    proto::UpdateWorktreeSettings {
1816                        project_id: project.id.to_proto(),
1817                        worktree_id: worktree.id,
1818                        path: settings_file.path,
1819                        content: Some(settings_file.content),
1820                    },
1821                )?;
1822            }
1823        }
1824
1825        for language_server in &project.language_servers {
1826            session.peer.send(
1827                session.connection_id,
1828                proto::UpdateLanguageServer {
1829                    project_id: project.id.to_proto(),
1830                    language_server_id: language_server.id,
1831                    variant: Some(
1832                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1833                            proto::LspDiskBasedDiagnosticsUpdated {},
1834                        ),
1835                    ),
1836                },
1837            )?;
1838        }
1839    }
1840    Ok(())
1841}
1842
1843/// leave room disconnects from the room.
1844async fn leave_room(
1845    _: proto::LeaveRoom,
1846    response: Response<proto::LeaveRoom>,
1847    session: UserSession,
1848) -> Result<()> {
1849    leave_room_for_session(&session, session.connection_id).await?;
1850    response.send(proto::Ack {})?;
1851    Ok(())
1852}
1853
1854/// Updates the permissions of someone else in the room.
1855async fn set_room_participant_role(
1856    request: proto::SetRoomParticipantRole,
1857    response: Response<proto::SetRoomParticipantRole>,
1858    session: UserSession,
1859) -> Result<()> {
1860    let user_id = UserId::from_proto(request.user_id);
1861    let role = ChannelRole::from(request.role());
1862
1863    let (live_kit_room, can_publish) = {
1864        let room = session
1865            .db()
1866            .await
1867            .set_room_participant_role(
1868                session.user_id(),
1869                RoomId::from_proto(request.room_id),
1870                user_id,
1871                role,
1872            )
1873            .await?;
1874
1875        let live_kit_room = room.live_kit_room.clone();
1876        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1877        room_updated(&room, &session.peer);
1878        (live_kit_room, can_publish)
1879    };
1880
1881    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1882        live_kit
1883            .update_participant(
1884                live_kit_room.clone(),
1885                request.user_id.to_string(),
1886                live_kit_server::proto::ParticipantPermission {
1887                    can_subscribe: true,
1888                    can_publish,
1889                    can_publish_data: can_publish,
1890                    hidden: false,
1891                    recorder: false,
1892                },
1893            )
1894            .await
1895            .trace_err();
1896    }
1897
1898    response.send(proto::Ack {})?;
1899    Ok(())
1900}
1901
1902/// Call someone else into the current room
1903async fn call(
1904    request: proto::Call,
1905    response: Response<proto::Call>,
1906    session: UserSession,
1907) -> Result<()> {
1908    let room_id = RoomId::from_proto(request.room_id);
1909    let calling_user_id = session.user_id();
1910    let calling_connection_id = session.connection_id;
1911    let called_user_id = UserId::from_proto(request.called_user_id);
1912    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1913    if !session
1914        .db()
1915        .await
1916        .has_contact(calling_user_id, called_user_id)
1917        .await?
1918    {
1919        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1920    }
1921
1922    let incoming_call = {
1923        let (room, incoming_call) = &mut *session
1924            .db()
1925            .await
1926            .call(
1927                room_id,
1928                calling_user_id,
1929                calling_connection_id,
1930                called_user_id,
1931                initial_project_id,
1932            )
1933            .await?;
1934        room_updated(&room, &session.peer);
1935        mem::take(incoming_call)
1936    };
1937    update_user_contacts(called_user_id, &session).await?;
1938
1939    let mut calls = session
1940        .connection_pool()
1941        .await
1942        .user_connection_ids(called_user_id)
1943        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1944        .collect::<FuturesUnordered<_>>();
1945
1946    while let Some(call_response) = calls.next().await {
1947        match call_response.as_ref() {
1948            Ok(_) => {
1949                response.send(proto::Ack {})?;
1950                return Ok(());
1951            }
1952            Err(_) => {
1953                call_response.trace_err();
1954            }
1955        }
1956    }
1957
1958    {
1959        let room = session
1960            .db()
1961            .await
1962            .call_failed(room_id, called_user_id)
1963            .await?;
1964        room_updated(&room, &session.peer);
1965    }
1966    update_user_contacts(called_user_id, &session).await?;
1967
1968    Err(anyhow!("failed to ring user"))?
1969}
1970
1971/// Cancel an outgoing call.
1972async fn cancel_call(
1973    request: proto::CancelCall,
1974    response: Response<proto::CancelCall>,
1975    session: UserSession,
1976) -> Result<()> {
1977    let called_user_id = UserId::from_proto(request.called_user_id);
1978    let room_id = RoomId::from_proto(request.room_id);
1979    {
1980        let room = session
1981            .db()
1982            .await
1983            .cancel_call(room_id, session.connection_id, called_user_id)
1984            .await?;
1985        room_updated(&room, &session.peer);
1986    }
1987
1988    for connection_id in session
1989        .connection_pool()
1990        .await
1991        .user_connection_ids(called_user_id)
1992    {
1993        session
1994            .peer
1995            .send(
1996                connection_id,
1997                proto::CallCanceled {
1998                    room_id: room_id.to_proto(),
1999                },
2000            )
2001            .trace_err();
2002    }
2003    response.send(proto::Ack {})?;
2004
2005    update_user_contacts(called_user_id, &session).await?;
2006    Ok(())
2007}
2008
2009/// Decline an incoming call.
2010async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
2011    let room_id = RoomId::from_proto(message.room_id);
2012    {
2013        let room = session
2014            .db()
2015            .await
2016            .decline_call(Some(room_id), session.user_id())
2017            .await?
2018            .ok_or_else(|| anyhow!("failed to decline call"))?;
2019        room_updated(&room, &session.peer);
2020    }
2021
2022    for connection_id in session
2023        .connection_pool()
2024        .await
2025        .user_connection_ids(session.user_id())
2026    {
2027        session
2028            .peer
2029            .send(
2030                connection_id,
2031                proto::CallCanceled {
2032                    room_id: room_id.to_proto(),
2033                },
2034            )
2035            .trace_err();
2036    }
2037    update_user_contacts(session.user_id(), &session).await?;
2038    Ok(())
2039}
2040
2041/// Updates other participants in the room with your current location.
2042async fn update_participant_location(
2043    request: proto::UpdateParticipantLocation,
2044    response: Response<proto::UpdateParticipantLocation>,
2045    session: UserSession,
2046) -> Result<()> {
2047    let room_id = RoomId::from_proto(request.room_id);
2048    let location = request
2049        .location
2050        .ok_or_else(|| anyhow!("invalid location"))?;
2051
2052    let db = session.db().await;
2053    let room = db
2054        .update_room_participant_location(room_id, session.connection_id, location)
2055        .await?;
2056
2057    room_updated(&room, &session.peer);
2058    response.send(proto::Ack {})?;
2059    Ok(())
2060}
2061
2062/// Share a project into the room.
2063async fn share_project(
2064    request: proto::ShareProject,
2065    response: Response<proto::ShareProject>,
2066    session: UserSession,
2067) -> Result<()> {
2068    let (project_id, room) = &*session
2069        .db()
2070        .await
2071        .share_project(
2072            RoomId::from_proto(request.room_id),
2073            session.connection_id,
2074            &request.worktrees,
2075            request
2076                .dev_server_project_id
2077                .map(|id| DevServerProjectId::from_proto(id)),
2078        )
2079        .await?;
2080    response.send(proto::ShareProjectResponse {
2081        project_id: project_id.to_proto(),
2082    })?;
2083    room_updated(&room, &session.peer);
2084
2085    Ok(())
2086}
2087
2088/// Unshare a project from the room.
2089async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2090    let project_id = ProjectId::from_proto(message.project_id);
2091    unshare_project_internal(
2092        project_id,
2093        session.connection_id,
2094        session.user_id(),
2095        &session,
2096    )
2097    .await
2098}
2099
2100async fn unshare_project_internal(
2101    project_id: ProjectId,
2102    connection_id: ConnectionId,
2103    user_id: Option<UserId>,
2104    session: &Session,
2105) -> Result<()> {
2106    let delete = {
2107        let room_guard = session
2108            .db()
2109            .await
2110            .unshare_project(project_id, connection_id, user_id)
2111            .await?;
2112
2113        let (delete, room, guest_connection_ids) = &*room_guard;
2114
2115        let message = proto::UnshareProject {
2116            project_id: project_id.to_proto(),
2117        };
2118
2119        broadcast(
2120            Some(connection_id),
2121            guest_connection_ids.iter().copied(),
2122            |conn_id| session.peer.send(conn_id, message.clone()),
2123        );
2124        if let Some(room) = room {
2125            room_updated(room, &session.peer);
2126        }
2127
2128        *delete
2129    };
2130
2131    if delete {
2132        let db = session.db().await;
2133        db.delete_project(project_id).await?;
2134    }
2135
2136    Ok(())
2137}
2138
2139/// DevServer makes a project available online
2140async fn share_dev_server_project(
2141    request: proto::ShareDevServerProject,
2142    response: Response<proto::ShareDevServerProject>,
2143    session: DevServerSession,
2144) -> Result<()> {
2145    let (dev_server_project, user_id, status) = session
2146        .db()
2147        .await
2148        .share_dev_server_project(
2149            DevServerProjectId::from_proto(request.dev_server_project_id),
2150            session.dev_server_id(),
2151            session.connection_id,
2152            &request.worktrees,
2153        )
2154        .await?;
2155    let Some(project_id) = dev_server_project.project_id else {
2156        return Err(anyhow!("failed to share remote project"))?;
2157    };
2158
2159    send_dev_server_projects_update(user_id, status, &session).await;
2160
2161    response.send(proto::ShareProjectResponse { project_id })?;
2162
2163    Ok(())
2164}
2165
2166/// Join someone elses shared project.
2167async fn join_project(
2168    request: proto::JoinProject,
2169    response: Response<proto::JoinProject>,
2170    session: UserSession,
2171) -> Result<()> {
2172    let project_id = ProjectId::from_proto(request.project_id);
2173
2174    tracing::info!(%project_id, "join project");
2175
2176    let db = session.db().await;
2177    let (project, replica_id) = &mut *db
2178        .join_project(project_id, session.connection_id, session.user_id())
2179        .await?;
2180    drop(db);
2181    tracing::info!(%project_id, "join remote project");
2182    join_project_internal(response, session, project, replica_id)
2183}
2184
2185trait JoinProjectInternalResponse {
2186    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2187}
2188impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2189    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2190        Response::<proto::JoinProject>::send(self, result)
2191    }
2192}
2193impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2194    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2195        Response::<proto::JoinHostedProject>::send(self, result)
2196    }
2197}
2198
2199fn join_project_internal(
2200    response: impl JoinProjectInternalResponse,
2201    session: UserSession,
2202    project: &mut Project,
2203    replica_id: &ReplicaId,
2204) -> Result<()> {
2205    let collaborators = project
2206        .collaborators
2207        .iter()
2208        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2209        .map(|collaborator| collaborator.to_proto())
2210        .collect::<Vec<_>>();
2211    let project_id = project.id;
2212    let guest_user_id = session.user_id();
2213
2214    let worktrees = project
2215        .worktrees
2216        .iter()
2217        .map(|(id, worktree)| proto::WorktreeMetadata {
2218            id: *id,
2219            root_name: worktree.root_name.clone(),
2220            visible: worktree.visible,
2221            abs_path: worktree.abs_path.clone(),
2222        })
2223        .collect::<Vec<_>>();
2224
2225    let add_project_collaborator = proto::AddProjectCollaborator {
2226        project_id: project_id.to_proto(),
2227        collaborator: Some(proto::Collaborator {
2228            peer_id: Some(session.connection_id.into()),
2229            replica_id: replica_id.0 as u32,
2230            user_id: guest_user_id.to_proto(),
2231        }),
2232    };
2233
2234    for collaborator in &collaborators {
2235        session
2236            .peer
2237            .send(
2238                collaborator.peer_id.unwrap().into(),
2239                add_project_collaborator.clone(),
2240            )
2241            .trace_err();
2242    }
2243
2244    // First, we send the metadata associated with each worktree.
2245    response.send(proto::JoinProjectResponse {
2246        project_id: project.id.0 as u64,
2247        worktrees: worktrees.clone(),
2248        replica_id: replica_id.0 as u32,
2249        collaborators: collaborators.clone(),
2250        language_servers: project.language_servers.clone(),
2251        role: project.role.into(),
2252        dev_server_project_id: project
2253            .dev_server_project_id
2254            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2255    })?;
2256
2257    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2258        #[cfg(any(test, feature = "test-support"))]
2259        const MAX_CHUNK_SIZE: usize = 2;
2260        #[cfg(not(any(test, feature = "test-support")))]
2261        const MAX_CHUNK_SIZE: usize = 256;
2262
2263        // Stream this worktree's entries.
2264        let message = proto::UpdateWorktree {
2265            project_id: project_id.to_proto(),
2266            worktree_id,
2267            abs_path: worktree.abs_path.clone(),
2268            root_name: worktree.root_name,
2269            updated_entries: worktree.entries,
2270            removed_entries: Default::default(),
2271            scan_id: worktree.scan_id,
2272            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2273            updated_repositories: worktree.repository_entries.into_values().collect(),
2274            removed_repositories: Default::default(),
2275        };
2276        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2277            session.peer.send(session.connection_id, update.clone())?;
2278        }
2279
2280        // Stream this worktree's diagnostics.
2281        for summary in worktree.diagnostic_summaries {
2282            session.peer.send(
2283                session.connection_id,
2284                proto::UpdateDiagnosticSummary {
2285                    project_id: project_id.to_proto(),
2286                    worktree_id: worktree.id,
2287                    summary: Some(summary),
2288                },
2289            )?;
2290        }
2291
2292        for settings_file in worktree.settings_files {
2293            session.peer.send(
2294                session.connection_id,
2295                proto::UpdateWorktreeSettings {
2296                    project_id: project_id.to_proto(),
2297                    worktree_id: worktree.id,
2298                    path: settings_file.path,
2299                    content: Some(settings_file.content),
2300                },
2301            )?;
2302        }
2303    }
2304
2305    for language_server in &project.language_servers {
2306        session.peer.send(
2307            session.connection_id,
2308            proto::UpdateLanguageServer {
2309                project_id: project_id.to_proto(),
2310                language_server_id: language_server.id,
2311                variant: Some(
2312                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2313                        proto::LspDiskBasedDiagnosticsUpdated {},
2314                    ),
2315                ),
2316            },
2317        )?;
2318    }
2319
2320    Ok(())
2321}
2322
2323/// Leave someone elses shared project.
2324async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2325    let sender_id = session.connection_id;
2326    let project_id = ProjectId::from_proto(request.project_id);
2327    let db = session.db().await;
2328    if db.is_hosted_project(project_id).await? {
2329        let project = db.leave_hosted_project(project_id, sender_id).await?;
2330        project_left(&project, &session);
2331        return Ok(());
2332    }
2333
2334    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2335    tracing::info!(
2336        %project_id,
2337        "leave project"
2338    );
2339
2340    project_left(&project, &session);
2341    if let Some(room) = room {
2342        room_updated(&room, &session.peer);
2343    }
2344
2345    Ok(())
2346}
2347
2348async fn join_hosted_project(
2349    request: proto::JoinHostedProject,
2350    response: Response<proto::JoinHostedProject>,
2351    session: UserSession,
2352) -> Result<()> {
2353    let (mut project, replica_id) = session
2354        .db()
2355        .await
2356        .join_hosted_project(
2357            ProjectId(request.project_id as i32),
2358            session.user_id(),
2359            session.connection_id,
2360        )
2361        .await?;
2362
2363    join_project_internal(response, session, &mut project, &replica_id)
2364}
2365
2366async fn list_remote_directory(
2367    request: proto::ListRemoteDirectory,
2368    response: Response<proto::ListRemoteDirectory>,
2369    session: UserSession,
2370) -> Result<()> {
2371    let dev_server_id = DevServerId(request.dev_server_id as i32);
2372    let dev_server_connection_id = session
2373        .connection_pool()
2374        .await
2375        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2376
2377    session
2378        .db()
2379        .await
2380        .get_dev_server_for_user(dev_server_id, session.user_id())
2381        .await?;
2382
2383    response.send(
2384        session
2385            .peer
2386            .forward_request(session.connection_id, dev_server_connection_id, request)
2387            .await?,
2388    )?;
2389    Ok(())
2390}
2391
2392async fn update_dev_server_project(
2393    request: proto::UpdateDevServerProject,
2394    response: Response<proto::UpdateDevServerProject>,
2395    session: UserSession,
2396) -> Result<()> {
2397    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2398
2399    let (dev_server_project, update) = session
2400        .db()
2401        .await
2402        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2403        .await?;
2404
2405    let projects = session
2406        .db()
2407        .await
2408        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2409        .await?;
2410
2411    let dev_server_connection_id = session
2412        .connection_pool()
2413        .await
2414        .dev_server_connection_id_supporting(
2415            dev_server_project.dev_server_id,
2416            ZedVersion::with_list_directory(),
2417        )?;
2418
2419    session.peer.send(
2420        dev_server_connection_id,
2421        proto::DevServerInstructions { projects },
2422    )?;
2423
2424    send_dev_server_projects_update(session.user_id(), update, &session).await;
2425
2426    response.send(proto::Ack {})
2427}
2428
2429async fn create_dev_server_project(
2430    request: proto::CreateDevServerProject,
2431    response: Response<proto::CreateDevServerProject>,
2432    session: UserSession,
2433) -> Result<()> {
2434    let dev_server_id = DevServerId(request.dev_server_id as i32);
2435    let dev_server_connection_id = session
2436        .connection_pool()
2437        .await
2438        .dev_server_connection_id(dev_server_id);
2439    let Some(dev_server_connection_id) = dev_server_connection_id else {
2440        Err(ErrorCode::DevServerOffline
2441            .message("Cannot create a remote project when the dev server is offline".to_string())
2442            .anyhow())?
2443    };
2444
2445    let path = request.path.clone();
2446    //Check that the path exists on the dev server
2447    session
2448        .peer
2449        .forward_request(
2450            session.connection_id,
2451            dev_server_connection_id,
2452            proto::ValidateDevServerProjectRequest { path: path.clone() },
2453        )
2454        .await?;
2455
2456    let (dev_server_project, update) = session
2457        .db()
2458        .await
2459        .create_dev_server_project(
2460            DevServerId(request.dev_server_id as i32),
2461            &request.path,
2462            session.user_id(),
2463        )
2464        .await?;
2465
2466    let projects = session
2467        .db()
2468        .await
2469        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2470        .await?;
2471
2472    session.peer.send(
2473        dev_server_connection_id,
2474        proto::DevServerInstructions { projects },
2475    )?;
2476
2477    send_dev_server_projects_update(session.user_id(), update, &session).await;
2478
2479    response.send(proto::CreateDevServerProjectResponse {
2480        dev_server_project: Some(dev_server_project.to_proto(None)),
2481    })?;
2482    Ok(())
2483}
2484
2485async fn create_dev_server(
2486    request: proto::CreateDevServer,
2487    response: Response<proto::CreateDevServer>,
2488    session: UserSession,
2489) -> Result<()> {
2490    let access_token = auth::random_token();
2491    let hashed_access_token = auth::hash_access_token(&access_token);
2492
2493    if request.name.is_empty() {
2494        return Err(proto::ErrorCode::Forbidden
2495            .message("Dev server name cannot be empty".to_string())
2496            .anyhow())?;
2497    }
2498
2499    let (dev_server, status) = session
2500        .db()
2501        .await
2502        .create_dev_server(
2503            &request.name,
2504            request.ssh_connection_string.as_deref(),
2505            &hashed_access_token,
2506            session.user_id(),
2507        )
2508        .await?;
2509
2510    send_dev_server_projects_update(session.user_id(), status, &session).await;
2511
2512    response.send(proto::CreateDevServerResponse {
2513        dev_server_id: dev_server.id.0 as u64,
2514        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2515        name: request.name,
2516    })?;
2517    Ok(())
2518}
2519
2520async fn regenerate_dev_server_token(
2521    request: proto::RegenerateDevServerToken,
2522    response: Response<proto::RegenerateDevServerToken>,
2523    session: UserSession,
2524) -> Result<()> {
2525    let dev_server_id = DevServerId(request.dev_server_id as i32);
2526    let access_token = auth::random_token();
2527    let hashed_access_token = auth::hash_access_token(&access_token);
2528
2529    let connection_id = session
2530        .connection_pool()
2531        .await
2532        .dev_server_connection_id(dev_server_id);
2533    if let Some(connection_id) = connection_id {
2534        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2535        session.peer.send(
2536            connection_id,
2537            proto::ShutdownDevServer {
2538                reason: Some("dev server token was regenerated".to_string()),
2539            },
2540        )?;
2541        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2542    }
2543
2544    let status = session
2545        .db()
2546        .await
2547        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2548        .await?;
2549
2550    send_dev_server_projects_update(session.user_id(), status, &session).await;
2551
2552    response.send(proto::RegenerateDevServerTokenResponse {
2553        dev_server_id: dev_server_id.to_proto(),
2554        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2555    })?;
2556    Ok(())
2557}
2558
2559async fn rename_dev_server(
2560    request: proto::RenameDevServer,
2561    response: Response<proto::RenameDevServer>,
2562    session: UserSession,
2563) -> Result<()> {
2564    if request.name.trim().is_empty() {
2565        return Err(proto::ErrorCode::Forbidden
2566            .message("Dev server name cannot be empty".to_string())
2567            .anyhow())?;
2568    }
2569
2570    let dev_server_id = DevServerId(request.dev_server_id as i32);
2571    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2572    if dev_server.user_id != session.user_id() {
2573        return Err(anyhow!(ErrorCode::Forbidden))?;
2574    }
2575
2576    let status = session
2577        .db()
2578        .await
2579        .rename_dev_server(
2580            dev_server_id,
2581            &request.name,
2582            request.ssh_connection_string.as_deref(),
2583            session.user_id(),
2584        )
2585        .await?;
2586
2587    send_dev_server_projects_update(session.user_id(), status, &session).await;
2588
2589    response.send(proto::Ack {})?;
2590    Ok(())
2591}
2592
2593async fn delete_dev_server(
2594    request: proto::DeleteDevServer,
2595    response: Response<proto::DeleteDevServer>,
2596    session: UserSession,
2597) -> Result<()> {
2598    let dev_server_id = DevServerId(request.dev_server_id as i32);
2599    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2600    if dev_server.user_id != session.user_id() {
2601        return Err(anyhow!(ErrorCode::Forbidden))?;
2602    }
2603
2604    let connection_id = session
2605        .connection_pool()
2606        .await
2607        .dev_server_connection_id(dev_server_id);
2608    if let Some(connection_id) = connection_id {
2609        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2610        session.peer.send(
2611            connection_id,
2612            proto::ShutdownDevServer {
2613                reason: Some("dev server was deleted".to_string()),
2614            },
2615        )?;
2616        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2617    }
2618
2619    let status = session
2620        .db()
2621        .await
2622        .delete_dev_server(dev_server_id, session.user_id())
2623        .await?;
2624
2625    send_dev_server_projects_update(session.user_id(), status, &session).await;
2626
2627    response.send(proto::Ack {})?;
2628    Ok(())
2629}
2630
2631async fn delete_dev_server_project(
2632    request: proto::DeleteDevServerProject,
2633    response: Response<proto::DeleteDevServerProject>,
2634    session: UserSession,
2635) -> Result<()> {
2636    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2637    let dev_server_project = session
2638        .db()
2639        .await
2640        .get_dev_server_project(dev_server_project_id)
2641        .await?;
2642
2643    let dev_server = session
2644        .db()
2645        .await
2646        .get_dev_server(dev_server_project.dev_server_id)
2647        .await?;
2648    if dev_server.user_id != session.user_id() {
2649        return Err(anyhow!(ErrorCode::Forbidden))?;
2650    }
2651
2652    let dev_server_connection_id = session
2653        .connection_pool()
2654        .await
2655        .dev_server_connection_id(dev_server.id);
2656
2657    if let Some(dev_server_connection_id) = dev_server_connection_id {
2658        let project = session
2659            .db()
2660            .await
2661            .find_dev_server_project(dev_server_project_id)
2662            .await;
2663        if let Ok(project) = project {
2664            unshare_project_internal(
2665                project.id,
2666                dev_server_connection_id,
2667                Some(session.user_id()),
2668                &session,
2669            )
2670            .await?;
2671        }
2672    }
2673
2674    let (projects, status) = session
2675        .db()
2676        .await
2677        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2678        .await?;
2679
2680    if let Some(dev_server_connection_id) = dev_server_connection_id {
2681        session.peer.send(
2682            dev_server_connection_id,
2683            proto::DevServerInstructions { projects },
2684        )?;
2685    }
2686
2687    send_dev_server_projects_update(session.user_id(), status, &session).await;
2688
2689    response.send(proto::Ack {})?;
2690    Ok(())
2691}
2692
2693async fn rejoin_dev_server_projects(
2694    request: proto::RejoinRemoteProjects,
2695    response: Response<proto::RejoinRemoteProjects>,
2696    session: UserSession,
2697) -> Result<()> {
2698    let mut rejoined_projects = {
2699        let db = session.db().await;
2700        db.rejoin_dev_server_projects(
2701            &request.rejoined_projects,
2702            session.user_id(),
2703            session.0.connection_id,
2704        )
2705        .await?
2706    };
2707    response.send(proto::RejoinRemoteProjectsResponse {
2708        rejoined_projects: rejoined_projects
2709            .iter()
2710            .map(|project| project.to_proto())
2711            .collect(),
2712    })?;
2713    notify_rejoined_projects(&mut rejoined_projects, &session)
2714}
2715
2716async fn reconnect_dev_server(
2717    request: proto::ReconnectDevServer,
2718    response: Response<proto::ReconnectDevServer>,
2719    session: DevServerSession,
2720) -> Result<()> {
2721    let reshared_projects = {
2722        let db = session.db().await;
2723        db.reshare_dev_server_projects(
2724            &request.reshared_projects,
2725            session.dev_server_id(),
2726            session.0.connection_id,
2727        )
2728        .await?
2729    };
2730
2731    for project in &reshared_projects {
2732        for collaborator in &project.collaborators {
2733            session
2734                .peer
2735                .send(
2736                    collaborator.connection_id,
2737                    proto::UpdateProjectCollaborator {
2738                        project_id: project.id.to_proto(),
2739                        old_peer_id: Some(project.old_connection_id.into()),
2740                        new_peer_id: Some(session.connection_id.into()),
2741                    },
2742                )
2743                .trace_err();
2744        }
2745
2746        broadcast(
2747            Some(session.connection_id),
2748            project
2749                .collaborators
2750                .iter()
2751                .map(|collaborator| collaborator.connection_id),
2752            |connection_id| {
2753                session.peer.forward_send(
2754                    session.connection_id,
2755                    connection_id,
2756                    proto::UpdateProject {
2757                        project_id: project.id.to_proto(),
2758                        worktrees: project.worktrees.clone(),
2759                    },
2760                )
2761            },
2762        );
2763    }
2764
2765    response.send(proto::ReconnectDevServerResponse {
2766        reshared_projects: reshared_projects
2767            .iter()
2768            .map(|project| proto::ResharedProject {
2769                id: project.id.to_proto(),
2770                collaborators: project
2771                    .collaborators
2772                    .iter()
2773                    .map(|collaborator| collaborator.to_proto())
2774                    .collect(),
2775            })
2776            .collect(),
2777    })?;
2778
2779    Ok(())
2780}
2781
2782async fn shutdown_dev_server(
2783    _: proto::ShutdownDevServer,
2784    response: Response<proto::ShutdownDevServer>,
2785    session: DevServerSession,
2786) -> Result<()> {
2787    response.send(proto::Ack {})?;
2788    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2789    remove_dev_server_connection(session.dev_server_id(), &session).await
2790}
2791
2792async fn shutdown_dev_server_internal(
2793    dev_server_id: DevServerId,
2794    connection_id: ConnectionId,
2795    session: &Session,
2796) -> Result<()> {
2797    let (dev_server_projects, dev_server) = {
2798        let db = session.db().await;
2799        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2800        let dev_server = db.get_dev_server(dev_server_id).await?;
2801        (dev_server_projects, dev_server)
2802    };
2803
2804    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2805        unshare_project_internal(
2806            ProjectId::from_proto(project_id),
2807            connection_id,
2808            None,
2809            session,
2810        )
2811        .await?;
2812    }
2813
2814    session
2815        .connection_pool()
2816        .await
2817        .set_dev_server_offline(dev_server_id);
2818
2819    let status = session
2820        .db()
2821        .await
2822        .dev_server_projects_update(dev_server.user_id)
2823        .await?;
2824    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2825
2826    Ok(())
2827}
2828
2829async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2830    let dev_server_connection = session
2831        .connection_pool()
2832        .await
2833        .dev_server_connection_id(dev_server_id);
2834
2835    if let Some(dev_server_connection) = dev_server_connection {
2836        session
2837            .connection_pool()
2838            .await
2839            .remove_connection(dev_server_connection)?;
2840    }
2841    Ok(())
2842}
2843
2844/// Updates other participants with changes to the project
2845async fn update_project(
2846    request: proto::UpdateProject,
2847    response: Response<proto::UpdateProject>,
2848    session: Session,
2849) -> Result<()> {
2850    let project_id = ProjectId::from_proto(request.project_id);
2851    let (room, guest_connection_ids) = &*session
2852        .db()
2853        .await
2854        .update_project(project_id, session.connection_id, &request.worktrees)
2855        .await?;
2856    broadcast(
2857        Some(session.connection_id),
2858        guest_connection_ids.iter().copied(),
2859        |connection_id| {
2860            session
2861                .peer
2862                .forward_send(session.connection_id, connection_id, request.clone())
2863        },
2864    );
2865    if let Some(room) = room {
2866        room_updated(&room, &session.peer);
2867    }
2868    response.send(proto::Ack {})?;
2869
2870    Ok(())
2871}
2872
2873/// Updates other participants with changes to the worktree
2874async fn update_worktree(
2875    request: proto::UpdateWorktree,
2876    response: Response<proto::UpdateWorktree>,
2877    session: Session,
2878) -> Result<()> {
2879    let guest_connection_ids = session
2880        .db()
2881        .await
2882        .update_worktree(&request, session.connection_id)
2883        .await?;
2884
2885    broadcast(
2886        Some(session.connection_id),
2887        guest_connection_ids.iter().copied(),
2888        |connection_id| {
2889            session
2890                .peer
2891                .forward_send(session.connection_id, connection_id, request.clone())
2892        },
2893    );
2894    response.send(proto::Ack {})?;
2895    Ok(())
2896}
2897
2898/// Updates other participants with changes to the diagnostics
2899async fn update_diagnostic_summary(
2900    message: proto::UpdateDiagnosticSummary,
2901    session: Session,
2902) -> Result<()> {
2903    let guest_connection_ids = session
2904        .db()
2905        .await
2906        .update_diagnostic_summary(&message, session.connection_id)
2907        .await?;
2908
2909    broadcast(
2910        Some(session.connection_id),
2911        guest_connection_ids.iter().copied(),
2912        |connection_id| {
2913            session
2914                .peer
2915                .forward_send(session.connection_id, connection_id, message.clone())
2916        },
2917    );
2918
2919    Ok(())
2920}
2921
2922/// Updates other participants with changes to the worktree settings
2923async fn update_worktree_settings(
2924    message: proto::UpdateWorktreeSettings,
2925    session: Session,
2926) -> Result<()> {
2927    let guest_connection_ids = session
2928        .db()
2929        .await
2930        .update_worktree_settings(&message, session.connection_id)
2931        .await?;
2932
2933    broadcast(
2934        Some(session.connection_id),
2935        guest_connection_ids.iter().copied(),
2936        |connection_id| {
2937            session
2938                .peer
2939                .forward_send(session.connection_id, connection_id, message.clone())
2940        },
2941    );
2942
2943    Ok(())
2944}
2945
2946/// Notify other participants that a  language server has started.
2947async fn start_language_server(
2948    request: proto::StartLanguageServer,
2949    session: Session,
2950) -> Result<()> {
2951    let guest_connection_ids = session
2952        .db()
2953        .await
2954        .start_language_server(&request, session.connection_id)
2955        .await?;
2956
2957    broadcast(
2958        Some(session.connection_id),
2959        guest_connection_ids.iter().copied(),
2960        |connection_id| {
2961            session
2962                .peer
2963                .forward_send(session.connection_id, connection_id, request.clone())
2964        },
2965    );
2966    Ok(())
2967}
2968
2969/// Notify other participants that a language server has changed.
2970async fn update_language_server(
2971    request: proto::UpdateLanguageServer,
2972    session: Session,
2973) -> Result<()> {
2974    let project_id = ProjectId::from_proto(request.project_id);
2975    let project_connection_ids = session
2976        .db()
2977        .await
2978        .project_connection_ids(project_id, session.connection_id, true)
2979        .await?;
2980    broadcast(
2981        Some(session.connection_id),
2982        project_connection_ids.iter().copied(),
2983        |connection_id| {
2984            session
2985                .peer
2986                .forward_send(session.connection_id, connection_id, request.clone())
2987        },
2988    );
2989    Ok(())
2990}
2991
2992/// forward a project request to the host. These requests should be read only
2993/// as guests are allowed to send them.
2994async fn forward_read_only_project_request<T>(
2995    request: T,
2996    response: Response<T>,
2997    session: UserSession,
2998) -> Result<()>
2999where
3000    T: EntityMessage + RequestMessage,
3001{
3002    let project_id = ProjectId::from_proto(request.remote_entity_id());
3003    let host_connection_id = session
3004        .db()
3005        .await
3006        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
3007        .await?;
3008    let payload = session
3009        .peer
3010        .forward_request(session.connection_id, host_connection_id, request)
3011        .await?;
3012    response.send(payload)?;
3013    Ok(())
3014}
3015
3016/// forward a project request to the dev server. Only allowed
3017/// if it's your dev server.
3018async fn forward_project_request_for_owner<T>(
3019    request: T,
3020    response: Response<T>,
3021    session: UserSession,
3022) -> Result<()>
3023where
3024    T: EntityMessage + RequestMessage,
3025{
3026    let project_id = ProjectId::from_proto(request.remote_entity_id());
3027
3028    let host_connection_id = session
3029        .db()
3030        .await
3031        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3032        .await?;
3033    let payload = session
3034        .peer
3035        .forward_request(session.connection_id, host_connection_id, request)
3036        .await?;
3037    response.send(payload)?;
3038    Ok(())
3039}
3040
3041/// forward a project request to the host. These requests are disallowed
3042/// for guests.
3043async fn forward_mutating_project_request<T>(
3044    request: T,
3045    response: Response<T>,
3046    session: UserSession,
3047) -> Result<()>
3048where
3049    T: EntityMessage + RequestMessage,
3050{
3051    let project_id = ProjectId::from_proto(request.remote_entity_id());
3052
3053    let host_connection_id = session
3054        .db()
3055        .await
3056        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3057        .await?;
3058    let payload = session
3059        .peer
3060        .forward_request(session.connection_id, host_connection_id, request)
3061        .await?;
3062    response.send(payload)?;
3063    Ok(())
3064}
3065
3066/// forward a project request to the host. These requests are disallowed
3067/// for guests.
3068async fn forward_versioned_mutating_project_request<T>(
3069    request: T,
3070    response: Response<T>,
3071    session: UserSession,
3072) -> Result<()>
3073where
3074    T: EntityMessage + RequestMessage + VersionedMessage,
3075{
3076    let project_id = ProjectId::from_proto(request.remote_entity_id());
3077
3078    let host_connection_id = session
3079        .db()
3080        .await
3081        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3082        .await?;
3083    if let Some(host_version) = session
3084        .connection_pool()
3085        .await
3086        .connection(host_connection_id)
3087        .map(|c| c.zed_version)
3088    {
3089        if let Some(min_required_version) = request.required_host_version() {
3090            if min_required_version > host_version {
3091                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3092                    .with_tag("required", &min_required_version.to_string())))?;
3093            }
3094        }
3095    }
3096
3097    let payload = session
3098        .peer
3099        .forward_request(session.connection_id, host_connection_id, request)
3100        .await?;
3101    response.send(payload)?;
3102    Ok(())
3103}
3104
3105/// Notify other participants that a new buffer has been created
3106async fn create_buffer_for_peer(
3107    request: proto::CreateBufferForPeer,
3108    session: Session,
3109) -> Result<()> {
3110    session
3111        .db()
3112        .await
3113        .check_user_is_project_host(
3114            ProjectId::from_proto(request.project_id),
3115            session.connection_id,
3116        )
3117        .await?;
3118    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3119    session
3120        .peer
3121        .forward_send(session.connection_id, peer_id.into(), request)?;
3122    Ok(())
3123}
3124
3125/// Notify other participants that a buffer has been updated. This is
3126/// allowed for guests as long as the update is limited to selections.
3127async fn update_buffer(
3128    request: proto::UpdateBuffer,
3129    response: Response<proto::UpdateBuffer>,
3130    session: Session,
3131) -> Result<()> {
3132    let project_id = ProjectId::from_proto(request.project_id);
3133    let mut capability = Capability::ReadOnly;
3134
3135    for op in request.operations.iter() {
3136        match op.variant {
3137            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3138            Some(_) => capability = Capability::ReadWrite,
3139        }
3140    }
3141
3142    let host = {
3143        let guard = session
3144            .db()
3145            .await
3146            .connections_for_buffer_update(
3147                project_id,
3148                session.principal_id(),
3149                session.connection_id,
3150                capability,
3151            )
3152            .await?;
3153
3154        let (host, guests) = &*guard;
3155
3156        broadcast(
3157            Some(session.connection_id),
3158            guests.clone(),
3159            |connection_id| {
3160                session
3161                    .peer
3162                    .forward_send(session.connection_id, connection_id, request.clone())
3163            },
3164        );
3165
3166        *host
3167    };
3168
3169    if host != session.connection_id {
3170        session
3171            .peer
3172            .forward_request(session.connection_id, host, request.clone())
3173            .await?;
3174    }
3175
3176    response.send(proto::Ack {})?;
3177    Ok(())
3178}
3179
3180async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3181    let project_id = ProjectId::from_proto(message.project_id);
3182
3183    let operation = message.operation.as_ref().context("invalid operation")?;
3184    let capability = match operation.variant.as_ref() {
3185        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3186            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3187                match buffer_op.variant {
3188                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3189                        Capability::ReadOnly
3190                    }
3191                    _ => Capability::ReadWrite,
3192                }
3193            } else {
3194                Capability::ReadWrite
3195            }
3196        }
3197        Some(_) => Capability::ReadWrite,
3198        None => Capability::ReadOnly,
3199    };
3200
3201    let guard = session
3202        .db()
3203        .await
3204        .connections_for_buffer_update(
3205            project_id,
3206            session.principal_id(),
3207            session.connection_id,
3208            capability,
3209        )
3210        .await?;
3211
3212    let (host, guests) = &*guard;
3213
3214    broadcast(
3215        Some(session.connection_id),
3216        guests.iter().chain([host]).copied(),
3217        |connection_id| {
3218            session
3219                .peer
3220                .forward_send(session.connection_id, connection_id, message.clone())
3221        },
3222    );
3223
3224    Ok(())
3225}
3226
3227/// Notify other participants that a project has been updated.
3228async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3229    request: T,
3230    session: Session,
3231) -> Result<()> {
3232    let project_id = ProjectId::from_proto(request.remote_entity_id());
3233    let project_connection_ids = session
3234        .db()
3235        .await
3236        .project_connection_ids(project_id, session.connection_id, false)
3237        .await?;
3238
3239    broadcast(
3240        Some(session.connection_id),
3241        project_connection_ids.iter().copied(),
3242        |connection_id| {
3243            session
3244                .peer
3245                .forward_send(session.connection_id, connection_id, request.clone())
3246        },
3247    );
3248    Ok(())
3249}
3250
3251/// Start following another user in a call.
3252async fn follow(
3253    request: proto::Follow,
3254    response: Response<proto::Follow>,
3255    session: UserSession,
3256) -> Result<()> {
3257    let room_id = RoomId::from_proto(request.room_id);
3258    let project_id = request.project_id.map(ProjectId::from_proto);
3259    let leader_id = request
3260        .leader_id
3261        .ok_or_else(|| anyhow!("invalid leader id"))?
3262        .into();
3263    let follower_id = session.connection_id;
3264
3265    session
3266        .db()
3267        .await
3268        .check_room_participants(room_id, leader_id, session.connection_id)
3269        .await?;
3270
3271    let response_payload = session
3272        .peer
3273        .forward_request(session.connection_id, leader_id, request)
3274        .await?;
3275    response.send(response_payload)?;
3276
3277    if let Some(project_id) = project_id {
3278        let room = session
3279            .db()
3280            .await
3281            .follow(room_id, project_id, leader_id, follower_id)
3282            .await?;
3283        room_updated(&room, &session.peer);
3284    }
3285
3286    Ok(())
3287}
3288
3289/// Stop following another user in a call.
3290async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3291    let room_id = RoomId::from_proto(request.room_id);
3292    let project_id = request.project_id.map(ProjectId::from_proto);
3293    let leader_id = request
3294        .leader_id
3295        .ok_or_else(|| anyhow!("invalid leader id"))?
3296        .into();
3297    let follower_id = session.connection_id;
3298
3299    session
3300        .db()
3301        .await
3302        .check_room_participants(room_id, leader_id, session.connection_id)
3303        .await?;
3304
3305    session
3306        .peer
3307        .forward_send(session.connection_id, leader_id, request)?;
3308
3309    if let Some(project_id) = project_id {
3310        let room = session
3311            .db()
3312            .await
3313            .unfollow(room_id, project_id, leader_id, follower_id)
3314            .await?;
3315        room_updated(&room, &session.peer);
3316    }
3317
3318    Ok(())
3319}
3320
3321/// Notify everyone following you of your current location.
3322async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3323    let room_id = RoomId::from_proto(request.room_id);
3324    let database = session.db.lock().await;
3325
3326    let connection_ids = if let Some(project_id) = request.project_id {
3327        let project_id = ProjectId::from_proto(project_id);
3328        database
3329            .project_connection_ids(project_id, session.connection_id, true)
3330            .await?
3331    } else {
3332        database
3333            .room_connection_ids(room_id, session.connection_id)
3334            .await?
3335    };
3336
3337    // For now, don't send view update messages back to that view's current leader.
3338    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3339        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3340        _ => None,
3341    });
3342
3343    for connection_id in connection_ids.iter().cloned() {
3344        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3345            session
3346                .peer
3347                .forward_send(session.connection_id, connection_id, request.clone())?;
3348        }
3349    }
3350    Ok(())
3351}
3352
3353/// Get public data about users.
3354async fn get_users(
3355    request: proto::GetUsers,
3356    response: Response<proto::GetUsers>,
3357    session: Session,
3358) -> Result<()> {
3359    let user_ids = request
3360        .user_ids
3361        .into_iter()
3362        .map(UserId::from_proto)
3363        .collect();
3364    let users = session
3365        .db()
3366        .await
3367        .get_users_by_ids(user_ids)
3368        .await?
3369        .into_iter()
3370        .map(|user| proto::User {
3371            id: user.id.to_proto(),
3372            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3373            github_login: user.github_login,
3374        })
3375        .collect();
3376    response.send(proto::UsersResponse { users })?;
3377    Ok(())
3378}
3379
3380/// Search for users (to invite) buy Github login
3381async fn fuzzy_search_users(
3382    request: proto::FuzzySearchUsers,
3383    response: Response<proto::FuzzySearchUsers>,
3384    session: UserSession,
3385) -> Result<()> {
3386    let query = request.query;
3387    let users = match query.len() {
3388        0 => vec![],
3389        1 | 2 => session
3390            .db()
3391            .await
3392            .get_user_by_github_login(&query)
3393            .await?
3394            .into_iter()
3395            .collect(),
3396        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3397    };
3398    let users = users
3399        .into_iter()
3400        .filter(|user| user.id != session.user_id())
3401        .map(|user| proto::User {
3402            id: user.id.to_proto(),
3403            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3404            github_login: user.github_login,
3405        })
3406        .collect();
3407    response.send(proto::UsersResponse { users })?;
3408    Ok(())
3409}
3410
3411/// Send a contact request to another user.
3412async fn request_contact(
3413    request: proto::RequestContact,
3414    response: Response<proto::RequestContact>,
3415    session: UserSession,
3416) -> Result<()> {
3417    let requester_id = session.user_id();
3418    let responder_id = UserId::from_proto(request.responder_id);
3419    if requester_id == responder_id {
3420        return Err(anyhow!("cannot add yourself as a contact"))?;
3421    }
3422
3423    let notifications = session
3424        .db()
3425        .await
3426        .send_contact_request(requester_id, responder_id)
3427        .await?;
3428
3429    // Update outgoing contact requests of requester
3430    let mut update = proto::UpdateContacts::default();
3431    update.outgoing_requests.push(responder_id.to_proto());
3432    for connection_id in session
3433        .connection_pool()
3434        .await
3435        .user_connection_ids(requester_id)
3436    {
3437        session.peer.send(connection_id, update.clone())?;
3438    }
3439
3440    // Update incoming contact requests of responder
3441    let mut update = proto::UpdateContacts::default();
3442    update
3443        .incoming_requests
3444        .push(proto::IncomingContactRequest {
3445            requester_id: requester_id.to_proto(),
3446        });
3447    let connection_pool = session.connection_pool().await;
3448    for connection_id in connection_pool.user_connection_ids(responder_id) {
3449        session.peer.send(connection_id, update.clone())?;
3450    }
3451
3452    send_notifications(&connection_pool, &session.peer, notifications);
3453
3454    response.send(proto::Ack {})?;
3455    Ok(())
3456}
3457
3458/// Accept or decline a contact request
3459async fn respond_to_contact_request(
3460    request: proto::RespondToContactRequest,
3461    response: Response<proto::RespondToContactRequest>,
3462    session: UserSession,
3463) -> Result<()> {
3464    let responder_id = session.user_id();
3465    let requester_id = UserId::from_proto(request.requester_id);
3466    let db = session.db().await;
3467    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3468        db.dismiss_contact_notification(responder_id, requester_id)
3469            .await?;
3470    } else {
3471        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3472
3473        let notifications = db
3474            .respond_to_contact_request(responder_id, requester_id, accept)
3475            .await?;
3476        let requester_busy = db.is_user_busy(requester_id).await?;
3477        let responder_busy = db.is_user_busy(responder_id).await?;
3478
3479        let pool = session.connection_pool().await;
3480        // Update responder with new contact
3481        let mut update = proto::UpdateContacts::default();
3482        if accept {
3483            update
3484                .contacts
3485                .push(contact_for_user(requester_id, requester_busy, &pool));
3486        }
3487        update
3488            .remove_incoming_requests
3489            .push(requester_id.to_proto());
3490        for connection_id in pool.user_connection_ids(responder_id) {
3491            session.peer.send(connection_id, update.clone())?;
3492        }
3493
3494        // Update requester with new contact
3495        let mut update = proto::UpdateContacts::default();
3496        if accept {
3497            update
3498                .contacts
3499                .push(contact_for_user(responder_id, responder_busy, &pool));
3500        }
3501        update
3502            .remove_outgoing_requests
3503            .push(responder_id.to_proto());
3504
3505        for connection_id in pool.user_connection_ids(requester_id) {
3506            session.peer.send(connection_id, update.clone())?;
3507        }
3508
3509        send_notifications(&pool, &session.peer, notifications);
3510    }
3511
3512    response.send(proto::Ack {})?;
3513    Ok(())
3514}
3515
3516/// Remove a contact.
3517async fn remove_contact(
3518    request: proto::RemoveContact,
3519    response: Response<proto::RemoveContact>,
3520    session: UserSession,
3521) -> Result<()> {
3522    let requester_id = session.user_id();
3523    let responder_id = UserId::from_proto(request.user_id);
3524    let db = session.db().await;
3525    let (contact_accepted, deleted_notification_id) =
3526        db.remove_contact(requester_id, responder_id).await?;
3527
3528    let pool = session.connection_pool().await;
3529    // Update outgoing contact requests of requester
3530    let mut update = proto::UpdateContacts::default();
3531    if contact_accepted {
3532        update.remove_contacts.push(responder_id.to_proto());
3533    } else {
3534        update
3535            .remove_outgoing_requests
3536            .push(responder_id.to_proto());
3537    }
3538    for connection_id in pool.user_connection_ids(requester_id) {
3539        session.peer.send(connection_id, update.clone())?;
3540    }
3541
3542    // Update incoming contact requests of responder
3543    let mut update = proto::UpdateContacts::default();
3544    if contact_accepted {
3545        update.remove_contacts.push(requester_id.to_proto());
3546    } else {
3547        update
3548            .remove_incoming_requests
3549            .push(requester_id.to_proto());
3550    }
3551    for connection_id in pool.user_connection_ids(responder_id) {
3552        session.peer.send(connection_id, update.clone())?;
3553        if let Some(notification_id) = deleted_notification_id {
3554            session.peer.send(
3555                connection_id,
3556                proto::DeleteNotification {
3557                    notification_id: notification_id.to_proto(),
3558                },
3559            )?;
3560        }
3561    }
3562
3563    response.send(proto::Ack {})?;
3564    Ok(())
3565}
3566
3567fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3568    version.0.minor() < 139
3569}
3570
3571async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3572    let plan = session.current_plan().await?;
3573
3574    session
3575        .peer
3576        .send(
3577            session.connection_id,
3578            proto::UpdateUserPlan { plan: plan.into() },
3579        )
3580        .trace_err();
3581
3582    Ok(())
3583}
3584
3585async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3586    subscribe_user_to_channels(
3587        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3588        &session,
3589    )
3590    .await?;
3591    Ok(())
3592}
3593
3594async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3595    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3596    let mut pool = session.connection_pool().await;
3597    for membership in &channels_for_user.channel_memberships {
3598        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3599    }
3600    session.peer.send(
3601        session.connection_id,
3602        build_update_user_channels(&channels_for_user),
3603    )?;
3604    session.peer.send(
3605        session.connection_id,
3606        build_channels_update(channels_for_user),
3607    )?;
3608    Ok(())
3609}
3610
3611/// Creates a new channel.
3612async fn create_channel(
3613    request: proto::CreateChannel,
3614    response: Response<proto::CreateChannel>,
3615    session: UserSession,
3616) -> Result<()> {
3617    let db = session.db().await;
3618
3619    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3620    let (channel, membership) = db
3621        .create_channel(&request.name, parent_id, session.user_id())
3622        .await?;
3623
3624    let root_id = channel.root_id();
3625    let channel = Channel::from_model(channel);
3626
3627    response.send(proto::CreateChannelResponse {
3628        channel: Some(channel.to_proto()),
3629        parent_id: request.parent_id,
3630    })?;
3631
3632    let mut connection_pool = session.connection_pool().await;
3633    if let Some(membership) = membership {
3634        connection_pool.subscribe_to_channel(
3635            membership.user_id,
3636            membership.channel_id,
3637            membership.role,
3638        );
3639        let update = proto::UpdateUserChannels {
3640            channel_memberships: vec![proto::ChannelMembership {
3641                channel_id: membership.channel_id.to_proto(),
3642                role: membership.role.into(),
3643            }],
3644            ..Default::default()
3645        };
3646        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3647            session.peer.send(connection_id, update.clone())?;
3648        }
3649    }
3650
3651    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3652        if !role.can_see_channel(channel.visibility) {
3653            continue;
3654        }
3655
3656        let update = proto::UpdateChannels {
3657            channels: vec![channel.to_proto()],
3658            ..Default::default()
3659        };
3660        session.peer.send(connection_id, update.clone())?;
3661    }
3662
3663    Ok(())
3664}
3665
3666/// Delete a channel
3667async fn delete_channel(
3668    request: proto::DeleteChannel,
3669    response: Response<proto::DeleteChannel>,
3670    session: UserSession,
3671) -> Result<()> {
3672    let db = session.db().await;
3673
3674    let channel_id = request.channel_id;
3675    let (root_channel, removed_channels) = db
3676        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3677        .await?;
3678    response.send(proto::Ack {})?;
3679
3680    // Notify members of removed channels
3681    let mut update = proto::UpdateChannels::default();
3682    update
3683        .delete_channels
3684        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3685
3686    let connection_pool = session.connection_pool().await;
3687    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3688        session.peer.send(connection_id, update.clone())?;
3689    }
3690
3691    Ok(())
3692}
3693
3694/// Invite someone to join a channel.
3695async fn invite_channel_member(
3696    request: proto::InviteChannelMember,
3697    response: Response<proto::InviteChannelMember>,
3698    session: UserSession,
3699) -> Result<()> {
3700    let db = session.db().await;
3701    let channel_id = ChannelId::from_proto(request.channel_id);
3702    let invitee_id = UserId::from_proto(request.user_id);
3703    let InviteMemberResult {
3704        channel,
3705        notifications,
3706    } = db
3707        .invite_channel_member(
3708            channel_id,
3709            invitee_id,
3710            session.user_id(),
3711            request.role().into(),
3712        )
3713        .await?;
3714
3715    let update = proto::UpdateChannels {
3716        channel_invitations: vec![channel.to_proto()],
3717        ..Default::default()
3718    };
3719
3720    let connection_pool = session.connection_pool().await;
3721    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3722        session.peer.send(connection_id, update.clone())?;
3723    }
3724
3725    send_notifications(&connection_pool, &session.peer, notifications);
3726
3727    response.send(proto::Ack {})?;
3728    Ok(())
3729}
3730
3731/// remove someone from a channel
3732async fn remove_channel_member(
3733    request: proto::RemoveChannelMember,
3734    response: Response<proto::RemoveChannelMember>,
3735    session: UserSession,
3736) -> Result<()> {
3737    let db = session.db().await;
3738    let channel_id = ChannelId::from_proto(request.channel_id);
3739    let member_id = UserId::from_proto(request.user_id);
3740
3741    let RemoveChannelMemberResult {
3742        membership_update,
3743        notification_id,
3744    } = db
3745        .remove_channel_member(channel_id, member_id, session.user_id())
3746        .await?;
3747
3748    let mut connection_pool = session.connection_pool().await;
3749    notify_membership_updated(
3750        &mut connection_pool,
3751        membership_update,
3752        member_id,
3753        &session.peer,
3754    );
3755    for connection_id in connection_pool.user_connection_ids(member_id) {
3756        if let Some(notification_id) = notification_id {
3757            session
3758                .peer
3759                .send(
3760                    connection_id,
3761                    proto::DeleteNotification {
3762                        notification_id: notification_id.to_proto(),
3763                    },
3764                )
3765                .trace_err();
3766        }
3767    }
3768
3769    response.send(proto::Ack {})?;
3770    Ok(())
3771}
3772
3773/// Toggle the channel between public and private.
3774/// Care is taken to maintain the invariant that public channels only descend from public channels,
3775/// (though members-only channels can appear at any point in the hierarchy).
3776async fn set_channel_visibility(
3777    request: proto::SetChannelVisibility,
3778    response: Response<proto::SetChannelVisibility>,
3779    session: UserSession,
3780) -> Result<()> {
3781    let db = session.db().await;
3782    let channel_id = ChannelId::from_proto(request.channel_id);
3783    let visibility = request.visibility().into();
3784
3785    let channel_model = db
3786        .set_channel_visibility(channel_id, visibility, session.user_id())
3787        .await?;
3788    let root_id = channel_model.root_id();
3789    let channel = Channel::from_model(channel_model);
3790
3791    let mut connection_pool = session.connection_pool().await;
3792    for (user_id, role) in connection_pool
3793        .channel_user_ids(root_id)
3794        .collect::<Vec<_>>()
3795        .into_iter()
3796    {
3797        let update = if role.can_see_channel(channel.visibility) {
3798            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3799            proto::UpdateChannels {
3800                channels: vec![channel.to_proto()],
3801                ..Default::default()
3802            }
3803        } else {
3804            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3805            proto::UpdateChannels {
3806                delete_channels: vec![channel.id.to_proto()],
3807                ..Default::default()
3808            }
3809        };
3810
3811        for connection_id in connection_pool.user_connection_ids(user_id) {
3812            session.peer.send(connection_id, update.clone())?;
3813        }
3814    }
3815
3816    response.send(proto::Ack {})?;
3817    Ok(())
3818}
3819
3820/// Alter the role for a user in the channel.
3821async fn set_channel_member_role(
3822    request: proto::SetChannelMemberRole,
3823    response: Response<proto::SetChannelMemberRole>,
3824    session: UserSession,
3825) -> Result<()> {
3826    let db = session.db().await;
3827    let channel_id = ChannelId::from_proto(request.channel_id);
3828    let member_id = UserId::from_proto(request.user_id);
3829    let result = db
3830        .set_channel_member_role(
3831            channel_id,
3832            session.user_id(),
3833            member_id,
3834            request.role().into(),
3835        )
3836        .await?;
3837
3838    match result {
3839        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3840            let mut connection_pool = session.connection_pool().await;
3841            notify_membership_updated(
3842                &mut connection_pool,
3843                membership_update,
3844                member_id,
3845                &session.peer,
3846            )
3847        }
3848        db::SetMemberRoleResult::InviteUpdated(channel) => {
3849            let update = proto::UpdateChannels {
3850                channel_invitations: vec![channel.to_proto()],
3851                ..Default::default()
3852            };
3853
3854            for connection_id in session
3855                .connection_pool()
3856                .await
3857                .user_connection_ids(member_id)
3858            {
3859                session.peer.send(connection_id, update.clone())?;
3860            }
3861        }
3862    }
3863
3864    response.send(proto::Ack {})?;
3865    Ok(())
3866}
3867
3868/// Change the name of a channel
3869async fn rename_channel(
3870    request: proto::RenameChannel,
3871    response: Response<proto::RenameChannel>,
3872    session: UserSession,
3873) -> Result<()> {
3874    let db = session.db().await;
3875    let channel_id = ChannelId::from_proto(request.channel_id);
3876    let channel_model = db
3877        .rename_channel(channel_id, session.user_id(), &request.name)
3878        .await?;
3879    let root_id = channel_model.root_id();
3880    let channel = Channel::from_model(channel_model);
3881
3882    response.send(proto::RenameChannelResponse {
3883        channel: Some(channel.to_proto()),
3884    })?;
3885
3886    let connection_pool = session.connection_pool().await;
3887    let update = proto::UpdateChannels {
3888        channels: vec![channel.to_proto()],
3889        ..Default::default()
3890    };
3891    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3892        if role.can_see_channel(channel.visibility) {
3893            session.peer.send(connection_id, update.clone())?;
3894        }
3895    }
3896
3897    Ok(())
3898}
3899
3900/// Move a channel to a new parent.
3901async fn move_channel(
3902    request: proto::MoveChannel,
3903    response: Response<proto::MoveChannel>,
3904    session: UserSession,
3905) -> Result<()> {
3906    let channel_id = ChannelId::from_proto(request.channel_id);
3907    let to = ChannelId::from_proto(request.to);
3908
3909    let (root_id, channels) = session
3910        .db()
3911        .await
3912        .move_channel(channel_id, to, session.user_id())
3913        .await?;
3914
3915    let connection_pool = session.connection_pool().await;
3916    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3917        let channels = channels
3918            .iter()
3919            .filter_map(|channel| {
3920                if role.can_see_channel(channel.visibility) {
3921                    Some(channel.to_proto())
3922                } else {
3923                    None
3924                }
3925            })
3926            .collect::<Vec<_>>();
3927        if channels.is_empty() {
3928            continue;
3929        }
3930
3931        let update = proto::UpdateChannels {
3932            channels,
3933            ..Default::default()
3934        };
3935
3936        session.peer.send(connection_id, update.clone())?;
3937    }
3938
3939    response.send(Ack {})?;
3940    Ok(())
3941}
3942
3943/// Get the list of channel members
3944async fn get_channel_members(
3945    request: proto::GetChannelMembers,
3946    response: Response<proto::GetChannelMembers>,
3947    session: UserSession,
3948) -> Result<()> {
3949    let db = session.db().await;
3950    let channel_id = ChannelId::from_proto(request.channel_id);
3951    let limit = if request.limit == 0 {
3952        u16::MAX as u64
3953    } else {
3954        request.limit
3955    };
3956    let (members, users) = db
3957        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3958        .await?;
3959    response.send(proto::GetChannelMembersResponse { members, users })?;
3960    Ok(())
3961}
3962
3963/// Accept or decline a channel invitation.
3964async fn respond_to_channel_invite(
3965    request: proto::RespondToChannelInvite,
3966    response: Response<proto::RespondToChannelInvite>,
3967    session: UserSession,
3968) -> Result<()> {
3969    let db = session.db().await;
3970    let channel_id = ChannelId::from_proto(request.channel_id);
3971    let RespondToChannelInvite {
3972        membership_update,
3973        notifications,
3974    } = db
3975        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3976        .await?;
3977
3978    let mut connection_pool = session.connection_pool().await;
3979    if let Some(membership_update) = membership_update {
3980        notify_membership_updated(
3981            &mut connection_pool,
3982            membership_update,
3983            session.user_id(),
3984            &session.peer,
3985        );
3986    } else {
3987        let update = proto::UpdateChannels {
3988            remove_channel_invitations: vec![channel_id.to_proto()],
3989            ..Default::default()
3990        };
3991
3992        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3993            session.peer.send(connection_id, update.clone())?;
3994        }
3995    };
3996
3997    send_notifications(&connection_pool, &session.peer, notifications);
3998
3999    response.send(proto::Ack {})?;
4000
4001    Ok(())
4002}
4003
4004/// Join the channels' room
4005async fn join_channel(
4006    request: proto::JoinChannel,
4007    response: Response<proto::JoinChannel>,
4008    session: UserSession,
4009) -> Result<()> {
4010    let channel_id = ChannelId::from_proto(request.channel_id);
4011    join_channel_internal(channel_id, Box::new(response), session).await
4012}
4013
4014trait JoinChannelInternalResponse {
4015    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
4016}
4017impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
4018    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4019        Response::<proto::JoinChannel>::send(self, result)
4020    }
4021}
4022impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
4023    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4024        Response::<proto::JoinRoom>::send(self, result)
4025    }
4026}
4027
4028async fn join_channel_internal(
4029    channel_id: ChannelId,
4030    response: Box<impl JoinChannelInternalResponse>,
4031    session: UserSession,
4032) -> Result<()> {
4033    let joined_room = {
4034        let mut db = session.db().await;
4035        // If zed quits without leaving the room, and the user re-opens zed before the
4036        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4037        // room they were in.
4038        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4039            tracing::info!(
4040                stale_connection_id = %connection,
4041                "cleaning up stale connection",
4042            );
4043            drop(db);
4044            leave_room_for_session(&session, connection).await?;
4045            db = session.db().await;
4046        }
4047
4048        let (joined_room, membership_updated, role) = db
4049            .join_channel(channel_id, session.user_id(), session.connection_id)
4050            .await?;
4051
4052        let live_kit_connection_info =
4053            session
4054                .app_state
4055                .live_kit_client
4056                .as_ref()
4057                .and_then(|live_kit| {
4058                    let (can_publish, token) = if role == ChannelRole::Guest {
4059                        (
4060                            false,
4061                            live_kit
4062                                .guest_token(
4063                                    &joined_room.room.live_kit_room,
4064                                    &session.user_id().to_string(),
4065                                )
4066                                .trace_err()?,
4067                        )
4068                    } else {
4069                        (
4070                            true,
4071                            live_kit
4072                                .room_token(
4073                                    &joined_room.room.live_kit_room,
4074                                    &session.user_id().to_string(),
4075                                )
4076                                .trace_err()?,
4077                        )
4078                    };
4079
4080                    Some(LiveKitConnectionInfo {
4081                        server_url: live_kit.url().into(),
4082                        token,
4083                        can_publish,
4084                    })
4085                });
4086
4087        response.send(proto::JoinRoomResponse {
4088            room: Some(joined_room.room.clone()),
4089            channel_id: joined_room
4090                .channel
4091                .as_ref()
4092                .map(|channel| channel.id.to_proto()),
4093            live_kit_connection_info,
4094        })?;
4095
4096        let mut connection_pool = session.connection_pool().await;
4097        if let Some(membership_updated) = membership_updated {
4098            notify_membership_updated(
4099                &mut connection_pool,
4100                membership_updated,
4101                session.user_id(),
4102                &session.peer,
4103            );
4104        }
4105
4106        room_updated(&joined_room.room, &session.peer);
4107
4108        joined_room
4109    };
4110
4111    channel_updated(
4112        &joined_room
4113            .channel
4114            .ok_or_else(|| anyhow!("channel not returned"))?,
4115        &joined_room.room,
4116        &session.peer,
4117        &*session.connection_pool().await,
4118    );
4119
4120    update_user_contacts(session.user_id(), &session).await?;
4121    Ok(())
4122}
4123
4124/// Start editing the channel notes
4125async fn join_channel_buffer(
4126    request: proto::JoinChannelBuffer,
4127    response: Response<proto::JoinChannelBuffer>,
4128    session: UserSession,
4129) -> Result<()> {
4130    let db = session.db().await;
4131    let channel_id = ChannelId::from_proto(request.channel_id);
4132
4133    let open_response = db
4134        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4135        .await?;
4136
4137    let collaborators = open_response.collaborators.clone();
4138    response.send(open_response)?;
4139
4140    let update = UpdateChannelBufferCollaborators {
4141        channel_id: channel_id.to_proto(),
4142        collaborators: collaborators.clone(),
4143    };
4144    channel_buffer_updated(
4145        session.connection_id,
4146        collaborators
4147            .iter()
4148            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4149        &update,
4150        &session.peer,
4151    );
4152
4153    Ok(())
4154}
4155
4156/// Edit the channel notes
4157async fn update_channel_buffer(
4158    request: proto::UpdateChannelBuffer,
4159    session: UserSession,
4160) -> Result<()> {
4161    let db = session.db().await;
4162    let channel_id = ChannelId::from_proto(request.channel_id);
4163
4164    let (collaborators, epoch, version) = db
4165        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4166        .await?;
4167
4168    channel_buffer_updated(
4169        session.connection_id,
4170        collaborators.clone(),
4171        &proto::UpdateChannelBuffer {
4172            channel_id: channel_id.to_proto(),
4173            operations: request.operations,
4174        },
4175        &session.peer,
4176    );
4177
4178    let pool = &*session.connection_pool().await;
4179
4180    let non_collaborators =
4181        pool.channel_connection_ids(channel_id)
4182            .filter_map(|(connection_id, _)| {
4183                if collaborators.contains(&connection_id) {
4184                    None
4185                } else {
4186                    Some(connection_id)
4187                }
4188            });
4189
4190    broadcast(None, non_collaborators, |peer_id| {
4191        session.peer.send(
4192            peer_id,
4193            proto::UpdateChannels {
4194                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4195                    channel_id: channel_id.to_proto(),
4196                    epoch: epoch as u64,
4197                    version: version.clone(),
4198                }],
4199                ..Default::default()
4200            },
4201        )
4202    });
4203
4204    Ok(())
4205}
4206
4207/// Rejoin the channel notes after a connection blip
4208async fn rejoin_channel_buffers(
4209    request: proto::RejoinChannelBuffers,
4210    response: Response<proto::RejoinChannelBuffers>,
4211    session: UserSession,
4212) -> Result<()> {
4213    let db = session.db().await;
4214    let buffers = db
4215        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4216        .await?;
4217
4218    for rejoined_buffer in &buffers {
4219        let collaborators_to_notify = rejoined_buffer
4220            .buffer
4221            .collaborators
4222            .iter()
4223            .filter_map(|c| Some(c.peer_id?.into()));
4224        channel_buffer_updated(
4225            session.connection_id,
4226            collaborators_to_notify,
4227            &proto::UpdateChannelBufferCollaborators {
4228                channel_id: rejoined_buffer.buffer.channel_id,
4229                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4230            },
4231            &session.peer,
4232        );
4233    }
4234
4235    response.send(proto::RejoinChannelBuffersResponse {
4236        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4237    })?;
4238
4239    Ok(())
4240}
4241
4242/// Stop editing the channel notes
4243async fn leave_channel_buffer(
4244    request: proto::LeaveChannelBuffer,
4245    response: Response<proto::LeaveChannelBuffer>,
4246    session: UserSession,
4247) -> Result<()> {
4248    let db = session.db().await;
4249    let channel_id = ChannelId::from_proto(request.channel_id);
4250
4251    let left_buffer = db
4252        .leave_channel_buffer(channel_id, session.connection_id)
4253        .await?;
4254
4255    response.send(Ack {})?;
4256
4257    channel_buffer_updated(
4258        session.connection_id,
4259        left_buffer.connections,
4260        &proto::UpdateChannelBufferCollaborators {
4261            channel_id: channel_id.to_proto(),
4262            collaborators: left_buffer.collaborators,
4263        },
4264        &session.peer,
4265    );
4266
4267    Ok(())
4268}
4269
4270fn channel_buffer_updated<T: EnvelopedMessage>(
4271    sender_id: ConnectionId,
4272    collaborators: impl IntoIterator<Item = ConnectionId>,
4273    message: &T,
4274    peer: &Peer,
4275) {
4276    broadcast(Some(sender_id), collaborators, |peer_id| {
4277        peer.send(peer_id, message.clone())
4278    });
4279}
4280
4281fn send_notifications(
4282    connection_pool: &ConnectionPool,
4283    peer: &Peer,
4284    notifications: db::NotificationBatch,
4285) {
4286    for (user_id, notification) in notifications {
4287        for connection_id in connection_pool.user_connection_ids(user_id) {
4288            if let Err(error) = peer.send(
4289                connection_id,
4290                proto::AddNotification {
4291                    notification: Some(notification.clone()),
4292                },
4293            ) {
4294                tracing::error!(
4295                    "failed to send notification to {:?} {}",
4296                    connection_id,
4297                    error
4298                );
4299            }
4300        }
4301    }
4302}
4303
4304/// Send a message to the channel
4305async fn send_channel_message(
4306    request: proto::SendChannelMessage,
4307    response: Response<proto::SendChannelMessage>,
4308    session: UserSession,
4309) -> Result<()> {
4310    // Validate the message body.
4311    let body = request.body.trim().to_string();
4312    if body.len() > MAX_MESSAGE_LEN {
4313        return Err(anyhow!("message is too long"))?;
4314    }
4315    if body.is_empty() {
4316        return Err(anyhow!("message can't be blank"))?;
4317    }
4318
4319    // TODO: adjust mentions if body is trimmed
4320
4321    let timestamp = OffsetDateTime::now_utc();
4322    let nonce = request
4323        .nonce
4324        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4325
4326    let channel_id = ChannelId::from_proto(request.channel_id);
4327    let CreatedChannelMessage {
4328        message_id,
4329        participant_connection_ids,
4330        notifications,
4331    } = session
4332        .db()
4333        .await
4334        .create_channel_message(
4335            channel_id,
4336            session.user_id(),
4337            &body,
4338            &request.mentions,
4339            timestamp,
4340            nonce.clone().into(),
4341            match request.reply_to_message_id {
4342                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4343                None => None,
4344            },
4345        )
4346        .await?;
4347
4348    let message = proto::ChannelMessage {
4349        sender_id: session.user_id().to_proto(),
4350        id: message_id.to_proto(),
4351        body,
4352        mentions: request.mentions,
4353        timestamp: timestamp.unix_timestamp() as u64,
4354        nonce: Some(nonce),
4355        reply_to_message_id: request.reply_to_message_id,
4356        edited_at: None,
4357    };
4358    broadcast(
4359        Some(session.connection_id),
4360        participant_connection_ids.clone(),
4361        |connection| {
4362            session.peer.send(
4363                connection,
4364                proto::ChannelMessageSent {
4365                    channel_id: channel_id.to_proto(),
4366                    message: Some(message.clone()),
4367                },
4368            )
4369        },
4370    );
4371    response.send(proto::SendChannelMessageResponse {
4372        message: Some(message),
4373    })?;
4374
4375    let pool = &*session.connection_pool().await;
4376    let non_participants =
4377        pool.channel_connection_ids(channel_id)
4378            .filter_map(|(connection_id, _)| {
4379                if participant_connection_ids.contains(&connection_id) {
4380                    None
4381                } else {
4382                    Some(connection_id)
4383                }
4384            });
4385    broadcast(None, non_participants, |peer_id| {
4386        session.peer.send(
4387            peer_id,
4388            proto::UpdateChannels {
4389                latest_channel_message_ids: vec![proto::ChannelMessageId {
4390                    channel_id: channel_id.to_proto(),
4391                    message_id: message_id.to_proto(),
4392                }],
4393                ..Default::default()
4394            },
4395        )
4396    });
4397    send_notifications(pool, &session.peer, notifications);
4398
4399    Ok(())
4400}
4401
4402/// Delete a channel message
4403async fn remove_channel_message(
4404    request: proto::RemoveChannelMessage,
4405    response: Response<proto::RemoveChannelMessage>,
4406    session: UserSession,
4407) -> Result<()> {
4408    let channel_id = ChannelId::from_proto(request.channel_id);
4409    let message_id = MessageId::from_proto(request.message_id);
4410    let (connection_ids, existing_notification_ids) = session
4411        .db()
4412        .await
4413        .remove_channel_message(channel_id, message_id, session.user_id())
4414        .await?;
4415
4416    broadcast(
4417        Some(session.connection_id),
4418        connection_ids,
4419        move |connection| {
4420            session.peer.send(connection, request.clone())?;
4421
4422            for notification_id in &existing_notification_ids {
4423                session.peer.send(
4424                    connection,
4425                    proto::DeleteNotification {
4426                        notification_id: (*notification_id).to_proto(),
4427                    },
4428                )?;
4429            }
4430
4431            Ok(())
4432        },
4433    );
4434    response.send(proto::Ack {})?;
4435    Ok(())
4436}
4437
4438async fn update_channel_message(
4439    request: proto::UpdateChannelMessage,
4440    response: Response<proto::UpdateChannelMessage>,
4441    session: UserSession,
4442) -> Result<()> {
4443    let channel_id = ChannelId::from_proto(request.channel_id);
4444    let message_id = MessageId::from_proto(request.message_id);
4445    let updated_at = OffsetDateTime::now_utc();
4446    let UpdatedChannelMessage {
4447        message_id,
4448        participant_connection_ids,
4449        notifications,
4450        reply_to_message_id,
4451        timestamp,
4452        deleted_mention_notification_ids,
4453        updated_mention_notifications,
4454    } = session
4455        .db()
4456        .await
4457        .update_channel_message(
4458            channel_id,
4459            message_id,
4460            session.user_id(),
4461            request.body.as_str(),
4462            &request.mentions,
4463            updated_at,
4464        )
4465        .await?;
4466
4467    let nonce = request
4468        .nonce
4469        .clone()
4470        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4471
4472    let message = proto::ChannelMessage {
4473        sender_id: session.user_id().to_proto(),
4474        id: message_id.to_proto(),
4475        body: request.body.clone(),
4476        mentions: request.mentions.clone(),
4477        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4478        nonce: Some(nonce),
4479        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4480        edited_at: Some(updated_at.unix_timestamp() as u64),
4481    };
4482
4483    response.send(proto::Ack {})?;
4484
4485    let pool = &*session.connection_pool().await;
4486    broadcast(
4487        Some(session.connection_id),
4488        participant_connection_ids,
4489        |connection| {
4490            session.peer.send(
4491                connection,
4492                proto::ChannelMessageUpdate {
4493                    channel_id: channel_id.to_proto(),
4494                    message: Some(message.clone()),
4495                },
4496            )?;
4497
4498            for notification_id in &deleted_mention_notification_ids {
4499                session.peer.send(
4500                    connection,
4501                    proto::DeleteNotification {
4502                        notification_id: (*notification_id).to_proto(),
4503                    },
4504                )?;
4505            }
4506
4507            for notification in &updated_mention_notifications {
4508                session.peer.send(
4509                    connection,
4510                    proto::UpdateNotification {
4511                        notification: Some(notification.clone()),
4512                    },
4513                )?;
4514            }
4515
4516            Ok(())
4517        },
4518    );
4519
4520    send_notifications(pool, &session.peer, notifications);
4521
4522    Ok(())
4523}
4524
4525/// Mark a channel message as read
4526async fn acknowledge_channel_message(
4527    request: proto::AckChannelMessage,
4528    session: UserSession,
4529) -> Result<()> {
4530    let channel_id = ChannelId::from_proto(request.channel_id);
4531    let message_id = MessageId::from_proto(request.message_id);
4532    let notifications = session
4533        .db()
4534        .await
4535        .observe_channel_message(channel_id, session.user_id(), message_id)
4536        .await?;
4537    send_notifications(
4538        &*session.connection_pool().await,
4539        &session.peer,
4540        notifications,
4541    );
4542    Ok(())
4543}
4544
4545/// Mark a buffer version as synced
4546async fn acknowledge_buffer_version(
4547    request: proto::AckBufferOperation,
4548    session: UserSession,
4549) -> Result<()> {
4550    let buffer_id = BufferId::from_proto(request.buffer_id);
4551    session
4552        .db()
4553        .await
4554        .observe_buffer_version(
4555            buffer_id,
4556            session.user_id(),
4557            request.epoch as i32,
4558            &request.version,
4559        )
4560        .await?;
4561    Ok(())
4562}
4563
4564struct ZedProCompleteWithLanguageModelRateLimit;
4565
4566impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
4567    fn capacity(&self) -> usize {
4568        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4569            .ok()
4570            .and_then(|v| v.parse().ok())
4571            .unwrap_or(120) // Picked arbitrarily
4572    }
4573
4574    fn refill_duration(&self) -> chrono::Duration {
4575        chrono::Duration::hours(1)
4576    }
4577
4578    fn db_name(&self) -> &'static str {
4579        "zed-pro:complete-with-language-model"
4580    }
4581}
4582
4583struct FreeCompleteWithLanguageModelRateLimit;
4584
4585impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
4586    fn capacity(&self) -> usize {
4587        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
4588            .ok()
4589            .and_then(|v| v.parse().ok())
4590            .unwrap_or(120 / 10) // Picked arbitrarily
4591    }
4592
4593    fn refill_duration(&self) -> chrono::Duration {
4594        chrono::Duration::hours(1)
4595    }
4596
4597    fn db_name(&self) -> &'static str {
4598        "free:complete-with-language-model"
4599    }
4600}
4601
4602async fn complete_with_language_model(
4603    request: proto::CompleteWithLanguageModel,
4604    response: Response<proto::CompleteWithLanguageModel>,
4605    session: Session,
4606    config: &Config,
4607) -> Result<()> {
4608    let Some(session) = session.for_user() else {
4609        return Err(anyhow!("user not found"))?;
4610    };
4611    authorize_access_to_language_models(&session).await?;
4612
4613    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4614        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4615        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4616    };
4617
4618    session
4619        .app_state
4620        .rate_limiter
4621        .check(&*rate_limit, session.user_id())
4622        .await?;
4623
4624    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4625        Some(proto::LanguageModelProvider::Anthropic) => {
4626            let api_key = config
4627                .anthropic_api_key
4628                .as_ref()
4629                .context("no Anthropic AI API key configured on the server")?;
4630            anthropic::complete(
4631                session.http_client.as_ref(),
4632                anthropic::ANTHROPIC_API_URL,
4633                api_key,
4634                serde_json::from_str(&request.request)?,
4635            )
4636            .await?
4637        }
4638        _ => return Err(anyhow!("unsupported provider"))?,
4639    };
4640
4641    response.send(proto::CompleteWithLanguageModelResponse {
4642        completion: serde_json::to_string(&result)?,
4643    })?;
4644
4645    Ok(())
4646}
4647
4648async fn stream_complete_with_language_model(
4649    request: proto::StreamCompleteWithLanguageModel,
4650    response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4651    session: Session,
4652    config: &Config,
4653) -> Result<()> {
4654    let Some(session) = session.for_user() else {
4655        return Err(anyhow!("user not found"))?;
4656    };
4657    authorize_access_to_language_models(&session).await?;
4658
4659    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4660        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4661        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4662    };
4663
4664    session
4665        .app_state
4666        .rate_limiter
4667        .check(&*rate_limit, session.user_id())
4668        .await?;
4669
4670    match proto::LanguageModelProvider::from_i32(request.provider) {
4671        Some(proto::LanguageModelProvider::Anthropic) => {
4672            let api_key = config
4673                .anthropic_api_key
4674                .as_ref()
4675                .context("no Anthropic AI API key configured on the server")?;
4676            let mut chunks = anthropic::stream_completion(
4677                session.http_client.as_ref(),
4678                anthropic::ANTHROPIC_API_URL,
4679                api_key,
4680                serde_json::from_str(&request.request)?,
4681                None,
4682            )
4683            .await?;
4684            while let Some(event) = chunks.next().await {
4685                let chunk = event?;
4686                response.send(proto::StreamCompleteWithLanguageModelResponse {
4687                    event: serde_json::to_string(&chunk)?,
4688                })?;
4689            }
4690        }
4691        Some(proto::LanguageModelProvider::OpenAi) => {
4692            let api_key = config
4693                .openai_api_key
4694                .as_ref()
4695                .context("no OpenAI API key configured on the server")?;
4696            let mut events = open_ai::stream_completion(
4697                session.http_client.as_ref(),
4698                open_ai::OPEN_AI_API_URL,
4699                api_key,
4700                serde_json::from_str(&request.request)?,
4701                None,
4702            )
4703            .await?;
4704            while let Some(event) = events.next().await {
4705                let event = event?;
4706                response.send(proto::StreamCompleteWithLanguageModelResponse {
4707                    event: serde_json::to_string(&event)?,
4708                })?;
4709            }
4710        }
4711        Some(proto::LanguageModelProvider::Google) => {
4712            let api_key = config
4713                .google_ai_api_key
4714                .as_ref()
4715                .context("no Google AI API key configured on the server")?;
4716            let mut events = google_ai::stream_generate_content(
4717                session.http_client.as_ref(),
4718                google_ai::API_URL,
4719                api_key,
4720                serde_json::from_str(&request.request)?,
4721            )
4722            .await?;
4723            while let Some(event) = events.next().await {
4724                let event = event?;
4725                response.send(proto::StreamCompleteWithLanguageModelResponse {
4726                    event: serde_json::to_string(&event)?,
4727                })?;
4728            }
4729        }
4730        Some(proto::LanguageModelProvider::Zed) => {
4731            let api_key = config
4732                .qwen2_7b_api_key
4733                .as_ref()
4734                .context("no Qwen2-7B API key configured on the server")?;
4735            let api_url = config
4736                .qwen2_7b_api_url
4737                .as_ref()
4738                .context("no Qwen2-7B URL configured on the server")?;
4739            let mut events = open_ai::stream_completion(
4740                session.http_client.as_ref(),
4741                &api_url,
4742                api_key,
4743                serde_json::from_str(&request.request)?,
4744                None,
4745            )
4746            .await?;
4747            while let Some(event) = events.next().await {
4748                let event = event?;
4749                response.send(proto::StreamCompleteWithLanguageModelResponse {
4750                    event: serde_json::to_string(&event)?,
4751                })?;
4752            }
4753        }
4754        None => return Err(anyhow!("unknown provider"))?,
4755    }
4756
4757    Ok(())
4758}
4759
4760async fn count_language_model_tokens(
4761    request: proto::CountLanguageModelTokens,
4762    response: Response<proto::CountLanguageModelTokens>,
4763    session: Session,
4764    config: &Config,
4765) -> Result<()> {
4766    let Some(session) = session.for_user() else {
4767        return Err(anyhow!("user not found"))?;
4768    };
4769    authorize_access_to_language_models(&session).await?;
4770
4771    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4772        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4773        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4774    };
4775
4776    session
4777        .app_state
4778        .rate_limiter
4779        .check(&*rate_limit, session.user_id())
4780        .await?;
4781
4782    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4783        Some(proto::LanguageModelProvider::Google) => {
4784            let api_key = config
4785                .google_ai_api_key
4786                .as_ref()
4787                .context("no Google AI API key configured on the server")?;
4788            google_ai::count_tokens(
4789                session.http_client.as_ref(),
4790                google_ai::API_URL,
4791                api_key,
4792                serde_json::from_str(&request.request)?,
4793            )
4794            .await?
4795        }
4796        _ => return Err(anyhow!("unsupported provider"))?,
4797    };
4798
4799    response.send(proto::CountLanguageModelTokensResponse {
4800        token_count: result.total_tokens as u32,
4801    })?;
4802
4803    Ok(())
4804}
4805
4806struct ZedProCountLanguageModelTokensRateLimit;
4807
4808impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4809    fn capacity(&self) -> usize {
4810        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4811            .ok()
4812            .and_then(|v| v.parse().ok())
4813            .unwrap_or(600) // Picked arbitrarily
4814    }
4815
4816    fn refill_duration(&self) -> chrono::Duration {
4817        chrono::Duration::hours(1)
4818    }
4819
4820    fn db_name(&self) -> &'static str {
4821        "zed-pro:count-language-model-tokens"
4822    }
4823}
4824
4825struct FreeCountLanguageModelTokensRateLimit;
4826
4827impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4828    fn capacity(&self) -> usize {
4829        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4830            .ok()
4831            .and_then(|v| v.parse().ok())
4832            .unwrap_or(600 / 10) // Picked arbitrarily
4833    }
4834
4835    fn refill_duration(&self) -> chrono::Duration {
4836        chrono::Duration::hours(1)
4837    }
4838
4839    fn db_name(&self) -> &'static str {
4840        "free:count-language-model-tokens"
4841    }
4842}
4843
4844struct ZedProComputeEmbeddingsRateLimit;
4845
4846impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4847    fn capacity(&self) -> usize {
4848        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4849            .ok()
4850            .and_then(|v| v.parse().ok())
4851            .unwrap_or(5000) // Picked arbitrarily
4852    }
4853
4854    fn refill_duration(&self) -> chrono::Duration {
4855        chrono::Duration::hours(1)
4856    }
4857
4858    fn db_name(&self) -> &'static str {
4859        "zed-pro:compute-embeddings"
4860    }
4861}
4862
4863struct FreeComputeEmbeddingsRateLimit;
4864
4865impl RateLimit for FreeComputeEmbeddingsRateLimit {
4866    fn capacity(&self) -> usize {
4867        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4868            .ok()
4869            .and_then(|v| v.parse().ok())
4870            .unwrap_or(5000 / 10) // Picked arbitrarily
4871    }
4872
4873    fn refill_duration(&self) -> chrono::Duration {
4874        chrono::Duration::hours(1)
4875    }
4876
4877    fn db_name(&self) -> &'static str {
4878        "free:compute-embeddings"
4879    }
4880}
4881
4882async fn compute_embeddings(
4883    request: proto::ComputeEmbeddings,
4884    response: Response<proto::ComputeEmbeddings>,
4885    session: UserSession,
4886    api_key: Option<Arc<str>>,
4887) -> Result<()> {
4888    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4889    authorize_access_to_language_models(&session).await?;
4890
4891    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4892        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4893        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4894    };
4895
4896    session
4897        .app_state
4898        .rate_limiter
4899        .check(&*rate_limit, session.user_id())
4900        .await?;
4901
4902    let embeddings = match request.model.as_str() {
4903        "openai/text-embedding-3-small" => {
4904            open_ai::embed(
4905                session.http_client.as_ref(),
4906                OPEN_AI_API_URL,
4907                &api_key,
4908                OpenAiEmbeddingModel::TextEmbedding3Small,
4909                request.texts.iter().map(|text| text.as_str()),
4910            )
4911            .await?
4912        }
4913        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4914    };
4915
4916    let embeddings = request
4917        .texts
4918        .iter()
4919        .map(|text| {
4920            let mut hasher = sha2::Sha256::new();
4921            hasher.update(text.as_bytes());
4922            let result = hasher.finalize();
4923            result.to_vec()
4924        })
4925        .zip(
4926            embeddings
4927                .data
4928                .into_iter()
4929                .map(|embedding| embedding.embedding),
4930        )
4931        .collect::<HashMap<_, _>>();
4932
4933    let db = session.db().await;
4934    db.save_embeddings(&request.model, &embeddings)
4935        .await
4936        .context("failed to save embeddings")
4937        .trace_err();
4938
4939    response.send(proto::ComputeEmbeddingsResponse {
4940        embeddings: embeddings
4941            .into_iter()
4942            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4943            .collect(),
4944    })?;
4945    Ok(())
4946}
4947
4948async fn get_cached_embeddings(
4949    request: proto::GetCachedEmbeddings,
4950    response: Response<proto::GetCachedEmbeddings>,
4951    session: UserSession,
4952) -> Result<()> {
4953    authorize_access_to_language_models(&session).await?;
4954
4955    let db = session.db().await;
4956    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4957
4958    response.send(proto::GetCachedEmbeddingsResponse {
4959        embeddings: embeddings
4960            .into_iter()
4961            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4962            .collect(),
4963    })?;
4964    Ok(())
4965}
4966
4967async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4968    let db = session.db().await;
4969    let flags = db.get_user_flags(session.user_id()).await?;
4970    if flags.iter().any(|flag| flag == "language-models") {
4971        return Ok(());
4972    }
4973
4974    Err(anyhow!("permission denied"))?
4975}
4976
4977/// Get a Supermaven API key for the user
4978async fn get_supermaven_api_key(
4979    _request: proto::GetSupermavenApiKey,
4980    response: Response<proto::GetSupermavenApiKey>,
4981    session: UserSession,
4982) -> Result<()> {
4983    let user_id: String = session.user_id().to_string();
4984    if !session.is_staff() {
4985        return Err(anyhow!("supermaven not enabled for this account"))?;
4986    }
4987
4988    let email = session
4989        .email()
4990        .ok_or_else(|| anyhow!("user must have an email"))?;
4991
4992    let supermaven_admin_api = session
4993        .supermaven_client
4994        .as_ref()
4995        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4996
4997    let result = supermaven_admin_api
4998        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4999        .await?;
5000
5001    response.send(proto::GetSupermavenApiKeyResponse {
5002        api_key: result.api_key,
5003    })?;
5004
5005    Ok(())
5006}
5007
5008/// Start receiving chat updates for a channel
5009async fn join_channel_chat(
5010    request: proto::JoinChannelChat,
5011    response: Response<proto::JoinChannelChat>,
5012    session: UserSession,
5013) -> Result<()> {
5014    let channel_id = ChannelId::from_proto(request.channel_id);
5015
5016    let db = session.db().await;
5017    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
5018        .await?;
5019    let messages = db
5020        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
5021        .await?;
5022    response.send(proto::JoinChannelChatResponse {
5023        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5024        messages,
5025    })?;
5026    Ok(())
5027}
5028
5029/// Stop receiving chat updates for a channel
5030async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5031    let channel_id = ChannelId::from_proto(request.channel_id);
5032    session
5033        .db()
5034        .await
5035        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5036        .await?;
5037    Ok(())
5038}
5039
5040/// Retrieve the chat history for a channel
5041async fn get_channel_messages(
5042    request: proto::GetChannelMessages,
5043    response: Response<proto::GetChannelMessages>,
5044    session: UserSession,
5045) -> Result<()> {
5046    let channel_id = ChannelId::from_proto(request.channel_id);
5047    let messages = session
5048        .db()
5049        .await
5050        .get_channel_messages(
5051            channel_id,
5052            session.user_id(),
5053            MESSAGE_COUNT_PER_PAGE,
5054            Some(MessageId::from_proto(request.before_message_id)),
5055        )
5056        .await?;
5057    response.send(proto::GetChannelMessagesResponse {
5058        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5059        messages,
5060    })?;
5061    Ok(())
5062}
5063
5064/// Retrieve specific chat messages
5065async fn get_channel_messages_by_id(
5066    request: proto::GetChannelMessagesById,
5067    response: Response<proto::GetChannelMessagesById>,
5068    session: UserSession,
5069) -> Result<()> {
5070    let message_ids = request
5071        .message_ids
5072        .iter()
5073        .map(|id| MessageId::from_proto(*id))
5074        .collect::<Vec<_>>();
5075    let messages = session
5076        .db()
5077        .await
5078        .get_channel_messages_by_id(session.user_id(), &message_ids)
5079        .await?;
5080    response.send(proto::GetChannelMessagesResponse {
5081        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5082        messages,
5083    })?;
5084    Ok(())
5085}
5086
5087/// Retrieve the current users notifications
5088async fn get_notifications(
5089    request: proto::GetNotifications,
5090    response: Response<proto::GetNotifications>,
5091    session: UserSession,
5092) -> Result<()> {
5093    let notifications = session
5094        .db()
5095        .await
5096        .get_notifications(
5097            session.user_id(),
5098            NOTIFICATION_COUNT_PER_PAGE,
5099            request
5100                .before_id
5101                .map(|id| db::NotificationId::from_proto(id)),
5102        )
5103        .await?;
5104    response.send(proto::GetNotificationsResponse {
5105        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5106        notifications,
5107    })?;
5108    Ok(())
5109}
5110
5111/// Mark notifications as read
5112async fn mark_notification_as_read(
5113    request: proto::MarkNotificationRead,
5114    response: Response<proto::MarkNotificationRead>,
5115    session: UserSession,
5116) -> Result<()> {
5117    let database = &session.db().await;
5118    let notifications = database
5119        .mark_notification_as_read_by_id(
5120            session.user_id(),
5121            NotificationId::from_proto(request.notification_id),
5122        )
5123        .await?;
5124    send_notifications(
5125        &*session.connection_pool().await,
5126        &session.peer,
5127        notifications,
5128    );
5129    response.send(proto::Ack {})?;
5130    Ok(())
5131}
5132
5133/// Get the current users information
5134async fn get_private_user_info(
5135    _request: proto::GetPrivateUserInfo,
5136    response: Response<proto::GetPrivateUserInfo>,
5137    session: UserSession,
5138) -> Result<()> {
5139    let db = session.db().await;
5140
5141    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5142    let user = db
5143        .get_user_by_id(session.user_id())
5144        .await?
5145        .ok_or_else(|| anyhow!("user not found"))?;
5146    let flags = db.get_user_flags(session.user_id()).await?;
5147
5148    response.send(proto::GetPrivateUserInfoResponse {
5149        metrics_id,
5150        staff: user.admin,
5151        flags,
5152    })?;
5153    Ok(())
5154}
5155
5156async fn get_llm_api_token(
5157    _request: proto::GetLlmToken,
5158    response: Response<proto::GetLlmToken>,
5159    session: UserSession,
5160) -> Result<()> {
5161    if !session.is_staff() {
5162        Err(anyhow!("permission denied"))?
5163    }
5164
5165    let token = LlmTokenClaims::create(
5166        session.user_id(),
5167        session.is_staff(),
5168        session.current_plan().await?,
5169        &session.app_state.config,
5170    )?;
5171    response.send(proto::GetLlmTokenResponse { token })?;
5172    Ok(())
5173}
5174
5175fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5176    let message = match message {
5177        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5178        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5179        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5180        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5181        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5182            code: frame.code.into(),
5183            reason: frame.reason,
5184        })),
5185        // We should never receive a frame while reading the message, according
5186        // to the `tungstenite` maintainers:
5187        //
5188        // > It cannot occur when you read messages from the WebSocket, but it
5189        // > can be used when you want to send the raw frames (e.g. you want to
5190        // > send the frames to the WebSocket without composing the full message first).
5191        // >
5192        // > — https://github.com/snapview/tungstenite-rs/issues/268
5193        TungsteniteMessage::Frame(_) => {
5194            bail!("received an unexpected frame while reading the message")
5195        }
5196    };
5197
5198    Ok(message)
5199}
5200
5201fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5202    match message {
5203        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5204        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5205        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5206        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5207        AxumMessage::Close(frame) => {
5208            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5209                code: frame.code.into(),
5210                reason: frame.reason,
5211            }))
5212        }
5213    }
5214}
5215
5216fn notify_membership_updated(
5217    connection_pool: &mut ConnectionPool,
5218    result: MembershipUpdated,
5219    user_id: UserId,
5220    peer: &Peer,
5221) {
5222    for membership in &result.new_channels.channel_memberships {
5223        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5224    }
5225    for channel_id in &result.removed_channels {
5226        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5227    }
5228
5229    let user_channels_update = proto::UpdateUserChannels {
5230        channel_memberships: result
5231            .new_channels
5232            .channel_memberships
5233            .iter()
5234            .map(|cm| proto::ChannelMembership {
5235                channel_id: cm.channel_id.to_proto(),
5236                role: cm.role.into(),
5237            })
5238            .collect(),
5239        ..Default::default()
5240    };
5241
5242    let mut update = build_channels_update(result.new_channels);
5243    update.delete_channels = result
5244        .removed_channels
5245        .into_iter()
5246        .map(|id| id.to_proto())
5247        .collect();
5248    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5249
5250    for connection_id in connection_pool.user_connection_ids(user_id) {
5251        peer.send(connection_id, user_channels_update.clone())
5252            .trace_err();
5253        peer.send(connection_id, update.clone()).trace_err();
5254    }
5255}
5256
5257fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5258    proto::UpdateUserChannels {
5259        channel_memberships: channels
5260            .channel_memberships
5261            .iter()
5262            .map(|m| proto::ChannelMembership {
5263                channel_id: m.channel_id.to_proto(),
5264                role: m.role.into(),
5265            })
5266            .collect(),
5267        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5268        observed_channel_message_id: channels.observed_channel_messages.clone(),
5269    }
5270}
5271
5272fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5273    let mut update = proto::UpdateChannels::default();
5274
5275    for channel in channels.channels {
5276        update.channels.push(channel.to_proto());
5277    }
5278
5279    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5280    update.latest_channel_message_ids = channels.latest_channel_messages;
5281
5282    for (channel_id, participants) in channels.channel_participants {
5283        update
5284            .channel_participants
5285            .push(proto::ChannelParticipants {
5286                channel_id: channel_id.to_proto(),
5287                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5288            });
5289    }
5290
5291    for channel in channels.invited_channels {
5292        update.channel_invitations.push(channel.to_proto());
5293    }
5294
5295    update.hosted_projects = channels.hosted_projects;
5296    update
5297}
5298
5299fn build_initial_contacts_update(
5300    contacts: Vec<db::Contact>,
5301    pool: &ConnectionPool,
5302) -> proto::UpdateContacts {
5303    let mut update = proto::UpdateContacts::default();
5304
5305    for contact in contacts {
5306        match contact {
5307            db::Contact::Accepted { user_id, busy } => {
5308                update.contacts.push(contact_for_user(user_id, busy, &pool));
5309            }
5310            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5311            db::Contact::Incoming { user_id } => {
5312                update
5313                    .incoming_requests
5314                    .push(proto::IncomingContactRequest {
5315                        requester_id: user_id.to_proto(),
5316                    })
5317            }
5318        }
5319    }
5320
5321    update
5322}
5323
5324fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5325    proto::Contact {
5326        user_id: user_id.to_proto(),
5327        online: pool.is_user_online(user_id),
5328        busy,
5329    }
5330}
5331
5332fn room_updated(room: &proto::Room, peer: &Peer) {
5333    broadcast(
5334        None,
5335        room.participants
5336            .iter()
5337            .filter_map(|participant| Some(participant.peer_id?.into())),
5338        |peer_id| {
5339            peer.send(
5340                peer_id,
5341                proto::RoomUpdated {
5342                    room: Some(room.clone()),
5343                },
5344            )
5345        },
5346    );
5347}
5348
5349fn channel_updated(
5350    channel: &db::channel::Model,
5351    room: &proto::Room,
5352    peer: &Peer,
5353    pool: &ConnectionPool,
5354) {
5355    let participants = room
5356        .participants
5357        .iter()
5358        .map(|p| p.user_id)
5359        .collect::<Vec<_>>();
5360
5361    broadcast(
5362        None,
5363        pool.channel_connection_ids(channel.root_id())
5364            .filter_map(|(channel_id, role)| {
5365                role.can_see_channel(channel.visibility).then(|| channel_id)
5366            }),
5367        |peer_id| {
5368            peer.send(
5369                peer_id,
5370                proto::UpdateChannels {
5371                    channel_participants: vec![proto::ChannelParticipants {
5372                        channel_id: channel.id.to_proto(),
5373                        participant_user_ids: participants.clone(),
5374                    }],
5375                    ..Default::default()
5376                },
5377            )
5378        },
5379    );
5380}
5381
5382async fn send_dev_server_projects_update(
5383    user_id: UserId,
5384    mut status: proto::DevServerProjectsUpdate,
5385    session: &Session,
5386) {
5387    let pool = session.connection_pool().await;
5388    for dev_server in &mut status.dev_servers {
5389        dev_server.status =
5390            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5391    }
5392    let connections = pool.user_connection_ids(user_id);
5393    for connection_id in connections {
5394        session.peer.send(connection_id, status.clone()).trace_err();
5395    }
5396}
5397
5398async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5399    let db = session.db().await;
5400
5401    let contacts = db.get_contacts(user_id).await?;
5402    let busy = db.is_user_busy(user_id).await?;
5403
5404    let pool = session.connection_pool().await;
5405    let updated_contact = contact_for_user(user_id, busy, &pool);
5406    for contact in contacts {
5407        if let db::Contact::Accepted {
5408            user_id: contact_user_id,
5409            ..
5410        } = contact
5411        {
5412            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5413                session
5414                    .peer
5415                    .send(
5416                        contact_conn_id,
5417                        proto::UpdateContacts {
5418                            contacts: vec![updated_contact.clone()],
5419                            remove_contacts: Default::default(),
5420                            incoming_requests: Default::default(),
5421                            remove_incoming_requests: Default::default(),
5422                            outgoing_requests: Default::default(),
5423                            remove_outgoing_requests: Default::default(),
5424                        },
5425                    )
5426                    .trace_err();
5427            }
5428        }
5429    }
5430    Ok(())
5431}
5432
5433async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5434    log::info!("lost dev server connection, unsharing projects");
5435    let project_ids = session
5436        .db()
5437        .await
5438        .get_stale_dev_server_projects(session.connection_id)
5439        .await?;
5440
5441    for project_id in project_ids {
5442        // not unshare re-checks the connection ids match, so we get away with no transaction
5443        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5444    }
5445
5446    let user_id = session.dev_server().user_id;
5447    let update = session
5448        .db()
5449        .await
5450        .dev_server_projects_update(user_id)
5451        .await?;
5452
5453    send_dev_server_projects_update(user_id, update, session).await;
5454
5455    Ok(())
5456}
5457
5458async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5459    let mut contacts_to_update = HashSet::default();
5460
5461    let room_id;
5462    let canceled_calls_to_user_ids;
5463    let live_kit_room;
5464    let delete_live_kit_room;
5465    let room;
5466    let channel;
5467
5468    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5469        contacts_to_update.insert(session.user_id());
5470
5471        for project in left_room.left_projects.values() {
5472            project_left(project, session);
5473        }
5474
5475        room_id = RoomId::from_proto(left_room.room.id);
5476        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5477        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5478        delete_live_kit_room = left_room.deleted;
5479        room = mem::take(&mut left_room.room);
5480        channel = mem::take(&mut left_room.channel);
5481
5482        room_updated(&room, &session.peer);
5483    } else {
5484        return Ok(());
5485    }
5486
5487    if let Some(channel) = channel {
5488        channel_updated(
5489            &channel,
5490            &room,
5491            &session.peer,
5492            &*session.connection_pool().await,
5493        );
5494    }
5495
5496    {
5497        let pool = session.connection_pool().await;
5498        for canceled_user_id in canceled_calls_to_user_ids {
5499            for connection_id in pool.user_connection_ids(canceled_user_id) {
5500                session
5501                    .peer
5502                    .send(
5503                        connection_id,
5504                        proto::CallCanceled {
5505                            room_id: room_id.to_proto(),
5506                        },
5507                    )
5508                    .trace_err();
5509            }
5510            contacts_to_update.insert(canceled_user_id);
5511        }
5512    }
5513
5514    for contact_user_id in contacts_to_update {
5515        update_user_contacts(contact_user_id, &session).await?;
5516    }
5517
5518    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5519        live_kit
5520            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5521            .await
5522            .trace_err();
5523
5524        if delete_live_kit_room {
5525            live_kit.delete_room(live_kit_room).await.trace_err();
5526        }
5527    }
5528
5529    Ok(())
5530}
5531
5532async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5533    let left_channel_buffers = session
5534        .db()
5535        .await
5536        .leave_channel_buffers(session.connection_id)
5537        .await?;
5538
5539    for left_buffer in left_channel_buffers {
5540        channel_buffer_updated(
5541            session.connection_id,
5542            left_buffer.connections,
5543            &proto::UpdateChannelBufferCollaborators {
5544                channel_id: left_buffer.channel_id.to_proto(),
5545                collaborators: left_buffer.collaborators,
5546            },
5547            &session.peer,
5548        );
5549    }
5550
5551    Ok(())
5552}
5553
5554fn project_left(project: &db::LeftProject, session: &UserSession) {
5555    for connection_id in &project.connection_ids {
5556        if project.should_unshare {
5557            session
5558                .peer
5559                .send(
5560                    *connection_id,
5561                    proto::UnshareProject {
5562                        project_id: project.id.to_proto(),
5563                    },
5564                )
5565                .trace_err();
5566        } else {
5567            session
5568                .peer
5569                .send(
5570                    *connection_id,
5571                    proto::RemoveProjectCollaborator {
5572                        project_id: project.id.to_proto(),
5573                        peer_id: Some(session.connection_id.into()),
5574                    },
5575                )
5576                .trace_err();
5577        }
5578    }
5579}
5580
5581pub trait ResultExt {
5582    type Ok;
5583
5584    fn trace_err(self) -> Option<Self::Ok>;
5585}
5586
5587impl<T, E> ResultExt for Result<T, E>
5588where
5589    E: std::fmt::Debug,
5590{
5591    type Ok = T;
5592
5593    #[track_caller]
5594    fn trace_err(self) -> Option<T> {
5595        match self {
5596            Ok(value) => Some(value),
5597            Err(error) => {
5598                tracing::error!("{:?}", error);
5599                None
5600            }
5601        }
5602    }
5603}