rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::{
   5    auth,
   6    db::{
   7        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   8        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   9        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  10        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  11        ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, RateLimiter, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use collections::{HashMap, HashSet};
  34pub use connection_pool::{ConnectionPool, ZedVersion};
  35use core::fmt::{self, Debug, Formatter};
  36use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  37use sha2::Digest;
  38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  39
  40use futures::{
  41    channel::oneshot,
  42    future::{self, BoxFuture},
  43    stream::FuturesUnordered,
  44    FutureExt, SinkExt, StreamExt, TryStreamExt,
  45};
  46use http_client::IsahcHttpClient;
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79use self::connection_pool::VersionedMessage;
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107struct StreamingResponse<R: RequestMessage> {
 108    peer: Arc<Peer>,
 109    receipt: Receipt<R>,
 110}
 111
 112impl<R: RequestMessage> StreamingResponse<R> {
 113    fn send(&self, payload: R::Response) -> Result<()> {
 114        self.peer.respond(self.receipt, payload)?;
 115        Ok(())
 116    }
 117}
 118
 119#[derive(Clone, Debug)]
 120pub enum Principal {
 121    User(User),
 122    Impersonated { user: User, admin: User },
 123    DevServer(dev_server::Model),
 124}
 125
 126impl Principal {
 127    fn update_span(&self, span: &tracing::Span) {
 128        match &self {
 129            Principal::User(user) => {
 130                span.record("user_id", &user.id.0);
 131                span.record("login", &user.github_login);
 132            }
 133            Principal::Impersonated { user, admin } => {
 134                span.record("user_id", &user.id.0);
 135                span.record("login", &user.github_login);
 136                span.record("impersonator", &admin.github_login);
 137            }
 138            Principal::DevServer(dev_server) => {
 139                span.record("dev_server_id", &dev_server.id.0);
 140            }
 141        }
 142    }
 143}
 144
 145#[derive(Clone)]
 146struct Session {
 147    principal: Principal,
 148    connection_id: ConnectionId,
 149    db: Arc<tokio::sync::Mutex<DbHandle>>,
 150    peer: Arc<Peer>,
 151    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 152    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 153    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 154    http_client: Arc<IsahcHttpClient>,
 155    rate_limiter: Arc<RateLimiter>,
 156    /// The GeoIP country code for the user.
 157    #[allow(unused)]
 158    geoip_country_code: Option<String>,
 159    _executor: Executor,
 160}
 161
 162impl Session {
 163    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 164        #[cfg(test)]
 165        tokio::task::yield_now().await;
 166        let guard = self.db.lock().await;
 167        #[cfg(test)]
 168        tokio::task::yield_now().await;
 169        guard
 170    }
 171
 172    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 173        #[cfg(test)]
 174        tokio::task::yield_now().await;
 175        let guard = self.connection_pool.lock();
 176        ConnectionPoolGuard {
 177            guard,
 178            _not_send: PhantomData,
 179        }
 180    }
 181
 182    fn for_user(self) -> Option<UserSession> {
 183        UserSession::new(self)
 184    }
 185
 186    fn for_dev_server(self) -> Option<DevServerSession> {
 187        DevServerSession::new(self)
 188    }
 189
 190    fn user_id(&self) -> Option<UserId> {
 191        match &self.principal {
 192            Principal::User(user) => Some(user.id),
 193            Principal::Impersonated { user, .. } => Some(user.id),
 194            Principal::DevServer(_) => None,
 195        }
 196    }
 197
 198    fn is_staff(&self) -> bool {
 199        match &self.principal {
 200            Principal::User(user) => user.admin,
 201            Principal::Impersonated { .. } => true,
 202            Principal::DevServer(_) => false,
 203        }
 204    }
 205
 206    pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
 207        if self.is_staff() {
 208            return Ok(proto::Plan::ZedPro);
 209        }
 210
 211        let Some(user_id) = self.user_id() else {
 212            return Ok(proto::Plan::Free);
 213        };
 214
 215        let db = self.db().await;
 216        if db.has_active_billing_subscription(user_id).await? {
 217            Ok(proto::Plan::ZedPro)
 218        } else {
 219            Ok(proto::Plan::Free)
 220        }
 221    }
 222
 223    fn dev_server_id(&self) -> Option<DevServerId> {
 224        match &self.principal {
 225            Principal::User(_) | Principal::Impersonated { .. } => None,
 226            Principal::DevServer(dev_server) => Some(dev_server.id),
 227        }
 228    }
 229
 230    fn principal_id(&self) -> PrincipalId {
 231        match &self.principal {
 232            Principal::User(user) => PrincipalId::UserId(user.id),
 233            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 234            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 235        }
 236    }
 237}
 238
 239impl Debug for Session {
 240    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 241        let mut result = f.debug_struct("Session");
 242        match &self.principal {
 243            Principal::User(user) => {
 244                result.field("user", &user.github_login);
 245            }
 246            Principal::Impersonated { user, admin } => {
 247                result.field("user", &user.github_login);
 248                result.field("impersonator", &admin.github_login);
 249            }
 250            Principal::DevServer(dev_server) => {
 251                result.field("dev_server", &dev_server.id);
 252            }
 253        }
 254        result.field("connection_id", &self.connection_id).finish()
 255    }
 256}
 257
 258struct UserSession(Session);
 259
 260impl UserSession {
 261    pub fn new(s: Session) -> Option<Self> {
 262        s.user_id().map(|_| UserSession(s))
 263    }
 264    pub fn user_id(&self) -> UserId {
 265        self.0.user_id().unwrap()
 266    }
 267
 268    pub fn email(&self) -> Option<String> {
 269        match &self.0.principal {
 270            Principal::User(user) => user.email_address.clone(),
 271            Principal::Impersonated { user, .. } => user.email_address.clone(),
 272            Principal::DevServer(..) => None,
 273        }
 274    }
 275}
 276
 277impl Deref for UserSession {
 278    type Target = Session;
 279
 280    fn deref(&self) -> &Self::Target {
 281        &self.0
 282    }
 283}
 284impl DerefMut for UserSession {
 285    fn deref_mut(&mut self) -> &mut Self::Target {
 286        &mut self.0
 287    }
 288}
 289
 290struct DevServerSession(Session);
 291
 292impl DevServerSession {
 293    pub fn new(s: Session) -> Option<Self> {
 294        s.dev_server_id().map(|_| DevServerSession(s))
 295    }
 296    pub fn dev_server_id(&self) -> DevServerId {
 297        self.0.dev_server_id().unwrap()
 298    }
 299
 300    fn dev_server(&self) -> &dev_server::Model {
 301        match &self.0.principal {
 302            Principal::DevServer(dev_server) => dev_server,
 303            _ => unreachable!(),
 304        }
 305    }
 306}
 307
 308impl Deref for DevServerSession {
 309    type Target = Session;
 310
 311    fn deref(&self) -> &Self::Target {
 312        &self.0
 313    }
 314}
 315impl DerefMut for DevServerSession {
 316    fn deref_mut(&mut self) -> &mut Self::Target {
 317        &mut self.0
 318    }
 319}
 320
 321fn user_handler<M: RequestMessage, Fut>(
 322    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 323) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 324where
 325    Fut: Send + Future<Output = Result<()>>,
 326{
 327    let handler = Arc::new(handler);
 328    move |message, response, session| {
 329        let handler = handler.clone();
 330        Box::pin(async move {
 331            if let Some(user_session) = session.for_user() {
 332                Ok(handler(message, response, user_session).await?)
 333            } else {
 334                Err(Error::Internal(anyhow!(
 335                    "must be a user to call {}",
 336                    M::NAME
 337                )))
 338            }
 339        })
 340    }
 341}
 342
 343fn dev_server_handler<M: RequestMessage, Fut>(
 344    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 345) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 346where
 347    Fut: Send + Future<Output = Result<()>>,
 348{
 349    let handler = Arc::new(handler);
 350    move |message, response, session| {
 351        let handler = handler.clone();
 352        Box::pin(async move {
 353            if let Some(dev_server_session) = session.for_dev_server() {
 354                Ok(handler(message, response, dev_server_session).await?)
 355            } else {
 356                Err(Error::Internal(anyhow!(
 357                    "must be a dev server to call {}",
 358                    M::NAME
 359                )))
 360            }
 361        })
 362    }
 363}
 364
 365fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 366    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 367) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 368where
 369    InnertRetFut: Send + Future<Output = Result<()>>,
 370{
 371    let handler = Arc::new(handler);
 372    move |message, session| {
 373        let handler = handler.clone();
 374        Box::pin(async move {
 375            if let Some(user_session) = session.for_user() {
 376                Ok(handler(message, user_session).await?)
 377            } else {
 378                Err(Error::Internal(anyhow!(
 379                    "must be a user to call {}",
 380                    M::NAME
 381                )))
 382            }
 383        })
 384    }
 385}
 386
 387struct DbHandle(Arc<Database>);
 388
 389impl Deref for DbHandle {
 390    type Target = Database;
 391
 392    fn deref(&self) -> &Self::Target {
 393        self.0.as_ref()
 394    }
 395}
 396
 397pub struct Server {
 398    id: parking_lot::Mutex<ServerId>,
 399    peer: Arc<Peer>,
 400    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 401    app_state: Arc<AppState>,
 402    handlers: HashMap<TypeId, MessageHandler>,
 403    teardown: watch::Sender<bool>,
 404}
 405
 406pub(crate) struct ConnectionPoolGuard<'a> {
 407    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 408    _not_send: PhantomData<Rc<()>>,
 409}
 410
 411#[derive(Serialize)]
 412pub struct ServerSnapshot<'a> {
 413    peer: &'a Peer,
 414    #[serde(serialize_with = "serialize_deref")]
 415    connection_pool: ConnectionPoolGuard<'a>,
 416}
 417
 418pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 419where
 420    S: Serializer,
 421    T: Deref<Target = U>,
 422    U: Serialize,
 423{
 424    Serialize::serialize(value.deref(), serializer)
 425}
 426
 427impl Server {
 428    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 429        let mut server = Self {
 430            id: parking_lot::Mutex::new(id),
 431            peer: Peer::new(id.0 as u32),
 432            app_state: app_state.clone(),
 433            connection_pool: Default::default(),
 434            handlers: Default::default(),
 435            teardown: watch::channel(false).0,
 436        };
 437
 438        server
 439            .add_request_handler(ping)
 440            .add_request_handler(user_handler(create_room))
 441            .add_request_handler(user_handler(join_room))
 442            .add_request_handler(user_handler(rejoin_room))
 443            .add_request_handler(user_handler(leave_room))
 444            .add_request_handler(user_handler(set_room_participant_role))
 445            .add_request_handler(user_handler(call))
 446            .add_request_handler(user_handler(cancel_call))
 447            .add_message_handler(user_message_handler(decline_call))
 448            .add_request_handler(user_handler(update_participant_location))
 449            .add_request_handler(user_handler(share_project))
 450            .add_message_handler(unshare_project)
 451            .add_request_handler(user_handler(join_project))
 452            .add_request_handler(user_handler(join_hosted_project))
 453            .add_request_handler(user_handler(rejoin_dev_server_projects))
 454            .add_request_handler(user_handler(create_dev_server_project))
 455            .add_request_handler(user_handler(update_dev_server_project))
 456            .add_request_handler(user_handler(delete_dev_server_project))
 457            .add_request_handler(user_handler(create_dev_server))
 458            .add_request_handler(user_handler(regenerate_dev_server_token))
 459            .add_request_handler(user_handler(rename_dev_server))
 460            .add_request_handler(user_handler(delete_dev_server))
 461            .add_request_handler(user_handler(list_remote_directory))
 462            .add_request_handler(dev_server_handler(share_dev_server_project))
 463            .add_request_handler(dev_server_handler(shutdown_dev_server))
 464            .add_request_handler(dev_server_handler(reconnect_dev_server))
 465            .add_message_handler(user_message_handler(leave_project))
 466            .add_request_handler(update_project)
 467            .add_request_handler(update_worktree)
 468            .add_message_handler(start_language_server)
 469            .add_message_handler(update_language_server)
 470            .add_message_handler(update_diagnostic_summary)
 471            .add_message_handler(update_worktree_settings)
 472            .add_request_handler(user_handler(
 473                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_project_request_for_owner::<proto::TaskTemplates>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::GetHover>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetDefinition>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::GetTypeDefinition>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::GetReferences>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::SearchProject>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_read_only_project_request::<proto::GetProjectSymbols>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_read_only_project_request::<proto::OpenBufferById>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_read_only_project_request::<proto::InlayHints>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_read_only_project_request::<proto::OpenBufferByPath>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::GetCompletions>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::GetCodeActions>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::ApplyCodeAction>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::PrepareRename>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::PerformRename>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::ReloadBuffers>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_mutating_project_request::<proto::FormatBuffers>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::CreateProjectEntry>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::RenameProjectEntry>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::CopyProjectEntry>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 555            ))
 556            .add_request_handler(user_handler(
 557                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 558            ))
 559            .add_request_handler(user_handler(
 560                forward_mutating_project_request::<proto::OnTypeFormatting>,
 561            ))
 562            .add_request_handler(user_handler(
 563                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 564            ))
 565            .add_request_handler(user_handler(
 566                forward_mutating_project_request::<proto::BlameBuffer>,
 567            ))
 568            .add_request_handler(user_handler(
 569                forward_mutating_project_request::<proto::MultiLspQuery>,
 570            ))
 571            .add_request_handler(user_handler(
 572                forward_mutating_project_request::<proto::RestartLanguageServers>,
 573            ))
 574            .add_request_handler(user_handler(
 575                forward_mutating_project_request::<proto::LinkedEditingRange>,
 576            ))
 577            .add_message_handler(create_buffer_for_peer)
 578            .add_request_handler(update_buffer)
 579            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 580            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 581            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 582            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 583            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 584            .add_request_handler(get_users)
 585            .add_request_handler(user_handler(fuzzy_search_users))
 586            .add_request_handler(user_handler(request_contact))
 587            .add_request_handler(user_handler(remove_contact))
 588            .add_request_handler(user_handler(respond_to_contact_request))
 589            .add_message_handler(subscribe_to_channels)
 590            .add_request_handler(user_handler(create_channel))
 591            .add_request_handler(user_handler(delete_channel))
 592            .add_request_handler(user_handler(invite_channel_member))
 593            .add_request_handler(user_handler(remove_channel_member))
 594            .add_request_handler(user_handler(set_channel_member_role))
 595            .add_request_handler(user_handler(set_channel_visibility))
 596            .add_request_handler(user_handler(rename_channel))
 597            .add_request_handler(user_handler(join_channel_buffer))
 598            .add_request_handler(user_handler(leave_channel_buffer))
 599            .add_message_handler(user_message_handler(update_channel_buffer))
 600            .add_request_handler(user_handler(rejoin_channel_buffers))
 601            .add_request_handler(user_handler(get_channel_members))
 602            .add_request_handler(user_handler(respond_to_channel_invite))
 603            .add_request_handler(user_handler(join_channel))
 604            .add_request_handler(user_handler(join_channel_chat))
 605            .add_message_handler(user_message_handler(leave_channel_chat))
 606            .add_request_handler(user_handler(send_channel_message))
 607            .add_request_handler(user_handler(remove_channel_message))
 608            .add_request_handler(user_handler(update_channel_message))
 609            .add_request_handler(user_handler(get_channel_messages))
 610            .add_request_handler(user_handler(get_channel_messages_by_id))
 611            .add_request_handler(user_handler(get_notifications))
 612            .add_request_handler(user_handler(mark_notification_as_read))
 613            .add_request_handler(user_handler(move_channel))
 614            .add_request_handler(user_handler(follow))
 615            .add_message_handler(user_message_handler(unfollow))
 616            .add_message_handler(user_message_handler(update_followers))
 617            .add_request_handler(user_handler(get_private_user_info))
 618            .add_message_handler(user_message_handler(acknowledge_channel_message))
 619            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 620            .add_request_handler(user_handler(get_supermaven_api_key))
 621            .add_request_handler(user_handler(
 622                forward_mutating_project_request::<proto::OpenContext>,
 623            ))
 624            .add_request_handler(user_handler(
 625                forward_mutating_project_request::<proto::CreateContext>,
 626            ))
 627            .add_request_handler(user_handler(
 628                forward_mutating_project_request::<proto::SynchronizeContexts>,
 629            ))
 630            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 631            .add_message_handler(update_context)
 632            .add_request_handler({
 633                let app_state = app_state.clone();
 634                move |request, response, session| {
 635                    let app_state = app_state.clone();
 636                    async move {
 637                        complete_with_language_model(request, response, session, &app_state.config)
 638                            .await
 639                    }
 640                }
 641            })
 642            .add_streaming_request_handler({
 643                let app_state = app_state.clone();
 644                move |request, response, session| {
 645                    let app_state = app_state.clone();
 646                    async move {
 647                        stream_complete_with_language_model(
 648                            request,
 649                            response,
 650                            session,
 651                            &app_state.config,
 652                        )
 653                        .await
 654                    }
 655                }
 656            })
 657            .add_request_handler({
 658                let app_state = app_state.clone();
 659                move |request, response, session| {
 660                    let app_state = app_state.clone();
 661                    async move {
 662                        count_language_model_tokens(request, response, session, &app_state.config)
 663                            .await
 664                    }
 665                }
 666            })
 667            .add_request_handler({
 668                user_handler(move |request, response, session| {
 669                    get_cached_embeddings(request, response, session)
 670                })
 671            })
 672            .add_request_handler({
 673                let app_state = app_state.clone();
 674                user_handler(move |request, response, session| {
 675                    compute_embeddings(
 676                        request,
 677                        response,
 678                        session,
 679                        app_state.config.openai_api_key.clone(),
 680                    )
 681                })
 682            });
 683
 684        Arc::new(server)
 685    }
 686
 687    pub async fn start(&self) -> Result<()> {
 688        let server_id = *self.id.lock();
 689        let app_state = self.app_state.clone();
 690        let peer = self.peer.clone();
 691        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 692        let pool = self.connection_pool.clone();
 693        let live_kit_client = self.app_state.live_kit_client.clone();
 694
 695        let span = info_span!("start server");
 696        self.app_state.executor.spawn_detached(
 697            async move {
 698                tracing::info!("waiting for cleanup timeout");
 699                timeout.await;
 700                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 701                if let Some((room_ids, channel_ids)) = app_state
 702                    .db
 703                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 704                    .await
 705                    .trace_err()
 706                {
 707                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 708                    tracing::info!(
 709                        stale_channel_buffer_count = channel_ids.len(),
 710                        "retrieved stale channel buffers"
 711                    );
 712
 713                    for channel_id in channel_ids {
 714                        if let Some(refreshed_channel_buffer) = app_state
 715                            .db
 716                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 717                            .await
 718                            .trace_err()
 719                        {
 720                            for connection_id in refreshed_channel_buffer.connection_ids {
 721                                peer.send(
 722                                    connection_id,
 723                                    proto::UpdateChannelBufferCollaborators {
 724                                        channel_id: channel_id.to_proto(),
 725                                        collaborators: refreshed_channel_buffer
 726                                            .collaborators
 727                                            .clone(),
 728                                    },
 729                                )
 730                                .trace_err();
 731                            }
 732                        }
 733                    }
 734
 735                    for room_id in room_ids {
 736                        let mut contacts_to_update = HashSet::default();
 737                        let mut canceled_calls_to_user_ids = Vec::new();
 738                        let mut live_kit_room = String::new();
 739                        let mut delete_live_kit_room = false;
 740
 741                        if let Some(mut refreshed_room) = app_state
 742                            .db
 743                            .clear_stale_room_participants(room_id, server_id)
 744                            .await
 745                            .trace_err()
 746                        {
 747                            tracing::info!(
 748                                room_id = room_id.0,
 749                                new_participant_count = refreshed_room.room.participants.len(),
 750                                "refreshed room"
 751                            );
 752                            room_updated(&refreshed_room.room, &peer);
 753                            if let Some(channel) = refreshed_room.channel.as_ref() {
 754                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 755                            }
 756                            contacts_to_update
 757                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 758                            contacts_to_update
 759                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 760                            canceled_calls_to_user_ids =
 761                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 762                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 763                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 764                        }
 765
 766                        {
 767                            let pool = pool.lock();
 768                            for canceled_user_id in canceled_calls_to_user_ids {
 769                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 770                                    peer.send(
 771                                        connection_id,
 772                                        proto::CallCanceled {
 773                                            room_id: room_id.to_proto(),
 774                                        },
 775                                    )
 776                                    .trace_err();
 777                                }
 778                            }
 779                        }
 780
 781                        for user_id in contacts_to_update {
 782                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 783                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 784                            if let Some((busy, contacts)) = busy.zip(contacts) {
 785                                let pool = pool.lock();
 786                                let updated_contact = contact_for_user(user_id, busy, &pool);
 787                                for contact in contacts {
 788                                    if let db::Contact::Accepted {
 789                                        user_id: contact_user_id,
 790                                        ..
 791                                    } = contact
 792                                    {
 793                                        for contact_conn_id in
 794                                            pool.user_connection_ids(contact_user_id)
 795                                        {
 796                                            peer.send(
 797                                                contact_conn_id,
 798                                                proto::UpdateContacts {
 799                                                    contacts: vec![updated_contact.clone()],
 800                                                    remove_contacts: Default::default(),
 801                                                    incoming_requests: Default::default(),
 802                                                    remove_incoming_requests: Default::default(),
 803                                                    outgoing_requests: Default::default(),
 804                                                    remove_outgoing_requests: Default::default(),
 805                                                },
 806                                            )
 807                                            .trace_err();
 808                                        }
 809                                    }
 810                                }
 811                            }
 812                        }
 813
 814                        if let Some(live_kit) = live_kit_client.as_ref() {
 815                            if delete_live_kit_room {
 816                                live_kit.delete_room(live_kit_room).await.trace_err();
 817                            }
 818                        }
 819                    }
 820                }
 821
 822                app_state
 823                    .db
 824                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 825                    .await
 826                    .trace_err();
 827            }
 828            .instrument(span),
 829        );
 830        Ok(())
 831    }
 832
 833    pub fn teardown(&self) {
 834        self.peer.teardown();
 835        self.connection_pool.lock().reset();
 836        let _ = self.teardown.send(true);
 837    }
 838
 839    #[cfg(test)]
 840    pub fn reset(&self, id: ServerId) {
 841        self.teardown();
 842        *self.id.lock() = id;
 843        self.peer.reset(id.0 as u32);
 844        let _ = self.teardown.send(false);
 845    }
 846
 847    #[cfg(test)]
 848    pub fn id(&self) -> ServerId {
 849        *self.id.lock()
 850    }
 851
 852    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 853    where
 854        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 855        Fut: 'static + Send + Future<Output = Result<()>>,
 856        M: EnvelopedMessage,
 857    {
 858        let prev_handler = self.handlers.insert(
 859            TypeId::of::<M>(),
 860            Box::new(move |envelope, session| {
 861                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 862                let received_at = envelope.received_at;
 863                tracing::info!("message received");
 864                let start_time = Instant::now();
 865                let future = (handler)(*envelope, session);
 866                async move {
 867                    let result = future.await;
 868                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 869                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 870                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 871                    let payload_type = M::NAME;
 872
 873                    match result {
 874                        Err(error) => {
 875                            tracing::error!(
 876                                ?error,
 877                                total_duration_ms,
 878                                processing_duration_ms,
 879                                queue_duration_ms,
 880                                payload_type,
 881                                "error handling message"
 882                            )
 883                        }
 884                        Ok(()) => tracing::info!(
 885                            total_duration_ms,
 886                            processing_duration_ms,
 887                            queue_duration_ms,
 888                            "finished handling message"
 889                        ),
 890                    }
 891                }
 892                .boxed()
 893            }),
 894        );
 895        if prev_handler.is_some() {
 896            panic!("registered a handler for the same message twice");
 897        }
 898        self
 899    }
 900
 901    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 902    where
 903        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 904        Fut: 'static + Send + Future<Output = Result<()>>,
 905        M: EnvelopedMessage,
 906    {
 907        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 908        self
 909    }
 910
 911    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 912    where
 913        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 914        Fut: Send + Future<Output = Result<()>>,
 915        M: RequestMessage,
 916    {
 917        let handler = Arc::new(handler);
 918        self.add_handler(move |envelope, session| {
 919            let receipt = envelope.receipt();
 920            let handler = handler.clone();
 921            async move {
 922                let peer = session.peer.clone();
 923                let responded = Arc::new(AtomicBool::default());
 924                let response = Response {
 925                    peer: peer.clone(),
 926                    responded: responded.clone(),
 927                    receipt,
 928                };
 929                match (handler)(envelope.payload, response, session).await {
 930                    Ok(()) => {
 931                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 932                            Ok(())
 933                        } else {
 934                            Err(anyhow!("handler did not send a response"))?
 935                        }
 936                    }
 937                    Err(error) => {
 938                        let proto_err = match &error {
 939                            Error::Internal(err) => err.to_proto(),
 940                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 941                        };
 942                        peer.respond_with_error(receipt, proto_err)?;
 943                        Err(error)
 944                    }
 945                }
 946            }
 947        })
 948    }
 949
 950    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 951    where
 952        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 953        Fut: Send + Future<Output = Result<()>>,
 954        M: RequestMessage,
 955    {
 956        let handler = Arc::new(handler);
 957        self.add_handler(move |envelope, session| {
 958            let receipt = envelope.receipt();
 959            let handler = handler.clone();
 960            async move {
 961                let peer = session.peer.clone();
 962                let response = StreamingResponse {
 963                    peer: peer.clone(),
 964                    receipt,
 965                };
 966                match (handler)(envelope.payload, response, session).await {
 967                    Ok(()) => {
 968                        peer.end_stream(receipt)?;
 969                        Ok(())
 970                    }
 971                    Err(error) => {
 972                        let proto_err = match &error {
 973                            Error::Internal(err) => err.to_proto(),
 974                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 975                        };
 976                        peer.respond_with_error(receipt, proto_err)?;
 977                        Err(error)
 978                    }
 979                }
 980            }
 981        })
 982    }
 983
 984    #[allow(clippy::too_many_arguments)]
 985    pub fn handle_connection(
 986        self: &Arc<Self>,
 987        connection: Connection,
 988        address: String,
 989        principal: Principal,
 990        zed_version: ZedVersion,
 991        geoip_country_code: Option<String>,
 992        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 993        executor: Executor,
 994    ) -> impl Future<Output = ()> {
 995        let this = self.clone();
 996        let span = info_span!("handle connection", %address,
 997            connection_id=field::Empty,
 998            user_id=field::Empty,
 999            login=field::Empty,
1000            impersonator=field::Empty,
1001            dev_server_id=field::Empty
1002        );
1003        principal.update_span(&span);
1004
1005        let mut teardown = self.teardown.subscribe();
1006        async move {
1007            if *teardown.borrow() {
1008                tracing::error!("server is tearing down");
1009                return
1010            }
1011            let (connection_id, handle_io, mut incoming_rx) = this
1012                .peer
1013                .add_connection(connection, {
1014                    let executor = executor.clone();
1015                    move |duration| executor.sleep(duration)
1016                });
1017            tracing::Span::current()
1018                .record("connection_id", format!("{}", connection_id))
1019                .record("geoip_country_code", geoip_country_code.clone());
1020
1021            tracing::info!("connection opened");
1022
1023            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
1024            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1025                Ok(http_client) => Arc::new(http_client),
1026                Err(error) => {
1027                    tracing::error!(?error, "failed to create HTTP client");
1028                    return;
1029                }
1030            };
1031
1032            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1033                Some(Arc::new(SupermavenAdminApi::new(
1034                    supermaven_admin_api_key.to_string(),
1035                    http_client.clone(),
1036                )))
1037            } else {
1038                None
1039            };
1040
1041            let session = Session {
1042                principal: principal.clone(),
1043                connection_id,
1044                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1045                peer: this.peer.clone(),
1046                connection_pool: this.connection_pool.clone(),
1047                live_kit_client: this.app_state.live_kit_client.clone(),
1048                http_client,
1049                rate_limiter: this.app_state.rate_limiter.clone(),
1050                geoip_country_code,
1051                _executor: executor.clone(),
1052                supermaven_client,
1053            };
1054
1055            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1056                tracing::error!(?error, "failed to send initial client update");
1057                return;
1058            }
1059
1060            let handle_io = handle_io.fuse();
1061            futures::pin_mut!(handle_io);
1062
1063            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1064            // This prevents deadlocks when e.g., client A performs a request to client B and
1065            // client B performs a request to client A. If both clients stop processing further
1066            // messages until their respective request completes, they won't have a chance to
1067            // respond to the other client's request and cause a deadlock.
1068            //
1069            // This arrangement ensures we will attempt to process earlier messages first, but fall
1070            // back to processing messages arrived later in the spirit of making progress.
1071            let mut foreground_message_handlers = FuturesUnordered::new();
1072            let concurrent_handlers = Arc::new(Semaphore::new(256));
1073            loop {
1074                let next_message = async {
1075                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1076                    let message = incoming_rx.next().await;
1077                    (permit, message)
1078                }.fuse();
1079                futures::pin_mut!(next_message);
1080                futures::select_biased! {
1081                    _ = teardown.changed().fuse() => return,
1082                    result = handle_io => {
1083                        if let Err(error) = result {
1084                            tracing::error!(?error, "error handling I/O");
1085                        }
1086                        break;
1087                    }
1088                    _ = foreground_message_handlers.next() => {}
1089                    next_message = next_message => {
1090                        let (permit, message) = next_message;
1091                        if let Some(message) = message {
1092                            let type_name = message.payload_type_name();
1093                            // note: we copy all the fields from the parent span so we can query them in the logs.
1094                            // (https://github.com/tokio-rs/tracing/issues/2670).
1095                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1096                                user_id=field::Empty,
1097                                login=field::Empty,
1098                                impersonator=field::Empty,
1099                                dev_server_id=field::Empty
1100                            );
1101                            principal.update_span(&span);
1102                            let span_enter = span.enter();
1103                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1104                                let is_background = message.is_background();
1105                                let handle_message = (handler)(message, session.clone());
1106                                drop(span_enter);
1107
1108                                let handle_message = async move {
1109                                    handle_message.await;
1110                                    drop(permit);
1111                                }.instrument(span);
1112                                if is_background {
1113                                    executor.spawn_detached(handle_message);
1114                                } else {
1115                                    foreground_message_handlers.push(handle_message);
1116                                }
1117                            } else {
1118                                tracing::error!("no message handler");
1119                            }
1120                        } else {
1121                            tracing::info!("connection closed");
1122                            break;
1123                        }
1124                    }
1125                }
1126            }
1127
1128            drop(foreground_message_handlers);
1129            tracing::info!("signing out");
1130            if let Err(error) = connection_lost(session, teardown, executor).await {
1131                tracing::error!(?error, "error signing out");
1132            }
1133
1134        }.instrument(span)
1135    }
1136
1137    async fn send_initial_client_update(
1138        &self,
1139        connection_id: ConnectionId,
1140        principal: &Principal,
1141        zed_version: ZedVersion,
1142        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1143        session: &Session,
1144    ) -> Result<()> {
1145        self.peer.send(
1146            connection_id,
1147            proto::Hello {
1148                peer_id: Some(connection_id.into()),
1149            },
1150        )?;
1151        tracing::info!("sent hello message");
1152        if let Some(send_connection_id) = send_connection_id.take() {
1153            let _ = send_connection_id.send(connection_id);
1154        }
1155
1156        match principal {
1157            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1158                if !user.connected_once {
1159                    self.peer.send(connection_id, proto::ShowContacts {})?;
1160                    self.app_state
1161                        .db
1162                        .set_user_connected_once(user.id, true)
1163                        .await?;
1164                }
1165
1166                update_user_plan(user.id, session).await?;
1167
1168                let (contacts, dev_server_projects) = future::try_join(
1169                    self.app_state.db.get_contacts(user.id),
1170                    self.app_state.db.dev_server_projects_update(user.id),
1171                )
1172                .await?;
1173
1174                {
1175                    let mut pool = self.connection_pool.lock();
1176                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1177                    self.peer.send(
1178                        connection_id,
1179                        build_initial_contacts_update(contacts, &pool),
1180                    )?;
1181                }
1182
1183                if should_auto_subscribe_to_channels(zed_version) {
1184                    subscribe_user_to_channels(user.id, session).await?;
1185                }
1186
1187                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1188
1189                if let Some(incoming_call) =
1190                    self.app_state.db.incoming_call_for_user(user.id).await?
1191                {
1192                    self.peer.send(connection_id, incoming_call)?;
1193                }
1194
1195                update_user_contacts(user.id, &session).await?;
1196            }
1197            Principal::DevServer(dev_server) => {
1198                {
1199                    let mut pool = self.connection_pool.lock();
1200                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1201                    {
1202                        self.peer.send(
1203                            stale_connection_id,
1204                            proto::ShutdownDevServer {
1205                                reason: Some(
1206                                    "another dev server connected with the same token".to_string(),
1207                                ),
1208                            },
1209                        )?;
1210                        pool.remove_connection(stale_connection_id)?;
1211                    };
1212                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1213                }
1214
1215                let projects = self
1216                    .app_state
1217                    .db
1218                    .get_projects_for_dev_server(dev_server.id)
1219                    .await?;
1220                self.peer
1221                    .send(connection_id, proto::DevServerInstructions { projects })?;
1222
1223                let status = self
1224                    .app_state
1225                    .db
1226                    .dev_server_projects_update(dev_server.user_id)
1227                    .await?;
1228                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1229            }
1230        }
1231
1232        Ok(())
1233    }
1234
1235    pub async fn invite_code_redeemed(
1236        self: &Arc<Self>,
1237        inviter_id: UserId,
1238        invitee_id: UserId,
1239    ) -> Result<()> {
1240        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1241            if let Some(code) = &user.invite_code {
1242                let pool = self.connection_pool.lock();
1243                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1244                for connection_id in pool.user_connection_ids(inviter_id) {
1245                    self.peer.send(
1246                        connection_id,
1247                        proto::UpdateContacts {
1248                            contacts: vec![invitee_contact.clone()],
1249                            ..Default::default()
1250                        },
1251                    )?;
1252                    self.peer.send(
1253                        connection_id,
1254                        proto::UpdateInviteInfo {
1255                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1256                            count: user.invite_count as u32,
1257                        },
1258                    )?;
1259                }
1260            }
1261        }
1262        Ok(())
1263    }
1264
1265    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1266        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1267            if let Some(invite_code) = &user.invite_code {
1268                let pool = self.connection_pool.lock();
1269                for connection_id in pool.user_connection_ids(user_id) {
1270                    self.peer.send(
1271                        connection_id,
1272                        proto::UpdateInviteInfo {
1273                            url: format!(
1274                                "{}{}",
1275                                self.app_state.config.invite_link_prefix, invite_code
1276                            ),
1277                            count: user.invite_count as u32,
1278                        },
1279                    )?;
1280                }
1281            }
1282        }
1283        Ok(())
1284    }
1285
1286    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1287        ServerSnapshot {
1288            connection_pool: ConnectionPoolGuard {
1289                guard: self.connection_pool.lock(),
1290                _not_send: PhantomData,
1291            },
1292            peer: &self.peer,
1293        }
1294    }
1295}
1296
1297impl<'a> Deref for ConnectionPoolGuard<'a> {
1298    type Target = ConnectionPool;
1299
1300    fn deref(&self) -> &Self::Target {
1301        &self.guard
1302    }
1303}
1304
1305impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1306    fn deref_mut(&mut self) -> &mut Self::Target {
1307        &mut self.guard
1308    }
1309}
1310
1311impl<'a> Drop for ConnectionPoolGuard<'a> {
1312    fn drop(&mut self) {
1313        #[cfg(test)]
1314        self.check_invariants();
1315    }
1316}
1317
1318fn broadcast<F>(
1319    sender_id: Option<ConnectionId>,
1320    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1321    mut f: F,
1322) where
1323    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1324{
1325    for receiver_id in receiver_ids {
1326        if Some(receiver_id) != sender_id {
1327            if let Err(error) = f(receiver_id) {
1328                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1329            }
1330        }
1331    }
1332}
1333
1334pub struct ProtocolVersion(u32);
1335
1336impl Header for ProtocolVersion {
1337    fn name() -> &'static HeaderName {
1338        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1339        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1340    }
1341
1342    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1343    where
1344        Self: Sized,
1345        I: Iterator<Item = &'i axum::http::HeaderValue>,
1346    {
1347        let version = values
1348            .next()
1349            .ok_or_else(axum::headers::Error::invalid)?
1350            .to_str()
1351            .map_err(|_| axum::headers::Error::invalid())?
1352            .parse()
1353            .map_err(|_| axum::headers::Error::invalid())?;
1354        Ok(Self(version))
1355    }
1356
1357    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1358        values.extend([self.0.to_string().parse().unwrap()]);
1359    }
1360}
1361
1362pub struct AppVersionHeader(SemanticVersion);
1363impl Header for AppVersionHeader {
1364    fn name() -> &'static HeaderName {
1365        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1366        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1367    }
1368
1369    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1370    where
1371        Self: Sized,
1372        I: Iterator<Item = &'i axum::http::HeaderValue>,
1373    {
1374        let version = values
1375            .next()
1376            .ok_or_else(axum::headers::Error::invalid)?
1377            .to_str()
1378            .map_err(|_| axum::headers::Error::invalid())?
1379            .parse()
1380            .map_err(|_| axum::headers::Error::invalid())?;
1381        Ok(Self(version))
1382    }
1383
1384    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1385        values.extend([self.0.to_string().parse().unwrap()]);
1386    }
1387}
1388
1389pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1390    Router::new()
1391        .route("/rpc", get(handle_websocket_request))
1392        .layer(
1393            ServiceBuilder::new()
1394                .layer(Extension(server.app_state.clone()))
1395                .layer(middleware::from_fn(auth::validate_header)),
1396        )
1397        .route("/metrics", get(handle_metrics))
1398        .layer(Extension(server))
1399}
1400
1401pub async fn handle_websocket_request(
1402    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1403    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1404    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1405    Extension(server): Extension<Arc<Server>>,
1406    Extension(principal): Extension<Principal>,
1407    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1408    ws: WebSocketUpgrade,
1409) -> axum::response::Response {
1410    if protocol_version != rpc::PROTOCOL_VERSION {
1411        return (
1412            StatusCode::UPGRADE_REQUIRED,
1413            "client must be upgraded".to_string(),
1414        )
1415            .into_response();
1416    }
1417
1418    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1419        return (
1420            StatusCode::UPGRADE_REQUIRED,
1421            "no version header found".to_string(),
1422        )
1423            .into_response();
1424    };
1425
1426    if !version.can_collaborate() {
1427        return (
1428            StatusCode::UPGRADE_REQUIRED,
1429            "client must be upgraded".to_string(),
1430        )
1431            .into_response();
1432    }
1433
1434    let socket_address = socket_address.to_string();
1435    ws.on_upgrade(move |socket| {
1436        let socket = socket
1437            .map_ok(to_tungstenite_message)
1438            .err_into()
1439            .with(|message| async move { to_axum_message(message) });
1440        let connection = Connection::new(Box::pin(socket));
1441        async move {
1442            server
1443                .handle_connection(
1444                    connection,
1445                    socket_address,
1446                    principal,
1447                    version,
1448                    country_code_header.map(|header| header.to_string()),
1449                    None,
1450                    Executor::Production,
1451                )
1452                .await;
1453        }
1454    })
1455}
1456
1457pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1458    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1459    let connections_metric = CONNECTIONS_METRIC
1460        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1461
1462    let connections = server
1463        .connection_pool
1464        .lock()
1465        .connections()
1466        .filter(|connection| !connection.admin)
1467        .count();
1468    connections_metric.set(connections as _);
1469
1470    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1471    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1472        register_int_gauge!(
1473            "shared_projects",
1474            "number of open projects with one or more guests"
1475        )
1476        .unwrap()
1477    });
1478
1479    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1480    shared_projects_metric.set(shared_projects as _);
1481
1482    let encoder = prometheus::TextEncoder::new();
1483    let metric_families = prometheus::gather();
1484    let encoded_metrics = encoder
1485        .encode_to_string(&metric_families)
1486        .map_err(|err| anyhow!("{}", err))?;
1487    Ok(encoded_metrics)
1488}
1489
1490#[instrument(err, skip(executor))]
1491async fn connection_lost(
1492    session: Session,
1493    mut teardown: watch::Receiver<bool>,
1494    executor: Executor,
1495) -> Result<()> {
1496    session.peer.disconnect(session.connection_id);
1497    session
1498        .connection_pool()
1499        .await
1500        .remove_connection(session.connection_id)?;
1501
1502    session
1503        .db()
1504        .await
1505        .connection_lost(session.connection_id)
1506        .await
1507        .trace_err();
1508
1509    futures::select_biased! {
1510        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1511            match &session.principal {
1512                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1513                    let session = session.for_user().unwrap();
1514
1515                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1516                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1517                    leave_channel_buffers_for_session(&session)
1518                        .await
1519                        .trace_err();
1520
1521                    if !session
1522                        .connection_pool()
1523                        .await
1524                        .is_user_online(session.user_id())
1525                    {
1526                        let db = session.db().await;
1527                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1528                            room_updated(&room, &session.peer);
1529                        }
1530                    }
1531
1532                    update_user_contacts(session.user_id(), &session).await?;
1533                },
1534            Principal::DevServer(_) => {
1535                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1536            },
1537        }
1538        },
1539        _ = teardown.changed().fuse() => {}
1540    }
1541
1542    Ok(())
1543}
1544
1545/// Acknowledges a ping from a client, used to keep the connection alive.
1546async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1547    response.send(proto::Ack {})?;
1548    Ok(())
1549}
1550
1551/// Creates a new room for calling (outside of channels)
1552async fn create_room(
1553    _request: proto::CreateRoom,
1554    response: Response<proto::CreateRoom>,
1555    session: UserSession,
1556) -> Result<()> {
1557    let live_kit_room = nanoid::nanoid!(30);
1558
1559    let live_kit_connection_info = util::maybe!(async {
1560        let live_kit = session.live_kit_client.as_ref();
1561        let live_kit = live_kit?;
1562        let user_id = session.user_id().to_string();
1563
1564        let token = live_kit
1565            .room_token(&live_kit_room, &user_id.to_string())
1566            .trace_err()?;
1567
1568        Some(proto::LiveKitConnectionInfo {
1569            server_url: live_kit.url().into(),
1570            token,
1571            can_publish: true,
1572        })
1573    })
1574    .await;
1575
1576    let room = session
1577        .db()
1578        .await
1579        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1580        .await?;
1581
1582    response.send(proto::CreateRoomResponse {
1583        room: Some(room.clone()),
1584        live_kit_connection_info,
1585    })?;
1586
1587    update_user_contacts(session.user_id(), &session).await?;
1588    Ok(())
1589}
1590
1591/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1592async fn join_room(
1593    request: proto::JoinRoom,
1594    response: Response<proto::JoinRoom>,
1595    session: UserSession,
1596) -> Result<()> {
1597    let room_id = RoomId::from_proto(request.id);
1598
1599    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1600
1601    if let Some(channel_id) = channel_id {
1602        return join_channel_internal(channel_id, Box::new(response), session).await;
1603    }
1604
1605    let joined_room = {
1606        let room = session
1607            .db()
1608            .await
1609            .join_room(room_id, session.user_id(), session.connection_id)
1610            .await?;
1611        room_updated(&room.room, &session.peer);
1612        room.into_inner()
1613    };
1614
1615    for connection_id in session
1616        .connection_pool()
1617        .await
1618        .user_connection_ids(session.user_id())
1619    {
1620        session
1621            .peer
1622            .send(
1623                connection_id,
1624                proto::CallCanceled {
1625                    room_id: room_id.to_proto(),
1626                },
1627            )
1628            .trace_err();
1629    }
1630
1631    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1632        if let Some(token) = live_kit
1633            .room_token(
1634                &joined_room.room.live_kit_room,
1635                &session.user_id().to_string(),
1636            )
1637            .trace_err()
1638        {
1639            Some(proto::LiveKitConnectionInfo {
1640                server_url: live_kit.url().into(),
1641                token,
1642                can_publish: true,
1643            })
1644        } else {
1645            None
1646        }
1647    } else {
1648        None
1649    };
1650
1651    response.send(proto::JoinRoomResponse {
1652        room: Some(joined_room.room),
1653        channel_id: None,
1654        live_kit_connection_info,
1655    })?;
1656
1657    update_user_contacts(session.user_id(), &session).await?;
1658    Ok(())
1659}
1660
1661/// Rejoin room is used to reconnect to a room after connection errors.
1662async fn rejoin_room(
1663    request: proto::RejoinRoom,
1664    response: Response<proto::RejoinRoom>,
1665    session: UserSession,
1666) -> Result<()> {
1667    let room;
1668    let channel;
1669    {
1670        let mut rejoined_room = session
1671            .db()
1672            .await
1673            .rejoin_room(request, session.user_id(), session.connection_id)
1674            .await?;
1675
1676        response.send(proto::RejoinRoomResponse {
1677            room: Some(rejoined_room.room.clone()),
1678            reshared_projects: rejoined_room
1679                .reshared_projects
1680                .iter()
1681                .map(|project| proto::ResharedProject {
1682                    id: project.id.to_proto(),
1683                    collaborators: project
1684                        .collaborators
1685                        .iter()
1686                        .map(|collaborator| collaborator.to_proto())
1687                        .collect(),
1688                })
1689                .collect(),
1690            rejoined_projects: rejoined_room
1691                .rejoined_projects
1692                .iter()
1693                .map(|rejoined_project| rejoined_project.to_proto())
1694                .collect(),
1695        })?;
1696        room_updated(&rejoined_room.room, &session.peer);
1697
1698        for project in &rejoined_room.reshared_projects {
1699            for collaborator in &project.collaborators {
1700                session
1701                    .peer
1702                    .send(
1703                        collaborator.connection_id,
1704                        proto::UpdateProjectCollaborator {
1705                            project_id: project.id.to_proto(),
1706                            old_peer_id: Some(project.old_connection_id.into()),
1707                            new_peer_id: Some(session.connection_id.into()),
1708                        },
1709                    )
1710                    .trace_err();
1711            }
1712
1713            broadcast(
1714                Some(session.connection_id),
1715                project
1716                    .collaborators
1717                    .iter()
1718                    .map(|collaborator| collaborator.connection_id),
1719                |connection_id| {
1720                    session.peer.forward_send(
1721                        session.connection_id,
1722                        connection_id,
1723                        proto::UpdateProject {
1724                            project_id: project.id.to_proto(),
1725                            worktrees: project.worktrees.clone(),
1726                        },
1727                    )
1728                },
1729            );
1730        }
1731
1732        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1733
1734        let rejoined_room = rejoined_room.into_inner();
1735
1736        room = rejoined_room.room;
1737        channel = rejoined_room.channel;
1738    }
1739
1740    if let Some(channel) = channel {
1741        channel_updated(
1742            &channel,
1743            &room,
1744            &session.peer,
1745            &*session.connection_pool().await,
1746        );
1747    }
1748
1749    update_user_contacts(session.user_id(), &session).await?;
1750    Ok(())
1751}
1752
1753fn notify_rejoined_projects(
1754    rejoined_projects: &mut Vec<RejoinedProject>,
1755    session: &UserSession,
1756) -> Result<()> {
1757    for project in rejoined_projects.iter() {
1758        for collaborator in &project.collaborators {
1759            session
1760                .peer
1761                .send(
1762                    collaborator.connection_id,
1763                    proto::UpdateProjectCollaborator {
1764                        project_id: project.id.to_proto(),
1765                        old_peer_id: Some(project.old_connection_id.into()),
1766                        new_peer_id: Some(session.connection_id.into()),
1767                    },
1768                )
1769                .trace_err();
1770        }
1771    }
1772
1773    for project in rejoined_projects {
1774        for worktree in mem::take(&mut project.worktrees) {
1775            #[cfg(any(test, feature = "test-support"))]
1776            const MAX_CHUNK_SIZE: usize = 2;
1777            #[cfg(not(any(test, feature = "test-support")))]
1778            const MAX_CHUNK_SIZE: usize = 256;
1779
1780            // Stream this worktree's entries.
1781            let message = proto::UpdateWorktree {
1782                project_id: project.id.to_proto(),
1783                worktree_id: worktree.id,
1784                abs_path: worktree.abs_path.clone(),
1785                root_name: worktree.root_name,
1786                updated_entries: worktree.updated_entries,
1787                removed_entries: worktree.removed_entries,
1788                scan_id: worktree.scan_id,
1789                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1790                updated_repositories: worktree.updated_repositories,
1791                removed_repositories: worktree.removed_repositories,
1792            };
1793            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1794                session.peer.send(session.connection_id, update.clone())?;
1795            }
1796
1797            // Stream this worktree's diagnostics.
1798            for summary in worktree.diagnostic_summaries {
1799                session.peer.send(
1800                    session.connection_id,
1801                    proto::UpdateDiagnosticSummary {
1802                        project_id: project.id.to_proto(),
1803                        worktree_id: worktree.id,
1804                        summary: Some(summary),
1805                    },
1806                )?;
1807            }
1808
1809            for settings_file in worktree.settings_files {
1810                session.peer.send(
1811                    session.connection_id,
1812                    proto::UpdateWorktreeSettings {
1813                        project_id: project.id.to_proto(),
1814                        worktree_id: worktree.id,
1815                        path: settings_file.path,
1816                        content: Some(settings_file.content),
1817                    },
1818                )?;
1819            }
1820        }
1821
1822        for language_server in &project.language_servers {
1823            session.peer.send(
1824                session.connection_id,
1825                proto::UpdateLanguageServer {
1826                    project_id: project.id.to_proto(),
1827                    language_server_id: language_server.id,
1828                    variant: Some(
1829                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1830                            proto::LspDiskBasedDiagnosticsUpdated {},
1831                        ),
1832                    ),
1833                },
1834            )?;
1835        }
1836    }
1837    Ok(())
1838}
1839
1840/// leave room disconnects from the room.
1841async fn leave_room(
1842    _: proto::LeaveRoom,
1843    response: Response<proto::LeaveRoom>,
1844    session: UserSession,
1845) -> Result<()> {
1846    leave_room_for_session(&session, session.connection_id).await?;
1847    response.send(proto::Ack {})?;
1848    Ok(())
1849}
1850
1851/// Updates the permissions of someone else in the room.
1852async fn set_room_participant_role(
1853    request: proto::SetRoomParticipantRole,
1854    response: Response<proto::SetRoomParticipantRole>,
1855    session: UserSession,
1856) -> Result<()> {
1857    let user_id = UserId::from_proto(request.user_id);
1858    let role = ChannelRole::from(request.role());
1859
1860    let (live_kit_room, can_publish) = {
1861        let room = session
1862            .db()
1863            .await
1864            .set_room_participant_role(
1865                session.user_id(),
1866                RoomId::from_proto(request.room_id),
1867                user_id,
1868                role,
1869            )
1870            .await?;
1871
1872        let live_kit_room = room.live_kit_room.clone();
1873        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1874        room_updated(&room, &session.peer);
1875        (live_kit_room, can_publish)
1876    };
1877
1878    if let Some(live_kit) = session.live_kit_client.as_ref() {
1879        live_kit
1880            .update_participant(
1881                live_kit_room.clone(),
1882                request.user_id.to_string(),
1883                live_kit_server::proto::ParticipantPermission {
1884                    can_subscribe: true,
1885                    can_publish,
1886                    can_publish_data: can_publish,
1887                    hidden: false,
1888                    recorder: false,
1889                },
1890            )
1891            .await
1892            .trace_err();
1893    }
1894
1895    response.send(proto::Ack {})?;
1896    Ok(())
1897}
1898
1899/// Call someone else into the current room
1900async fn call(
1901    request: proto::Call,
1902    response: Response<proto::Call>,
1903    session: UserSession,
1904) -> Result<()> {
1905    let room_id = RoomId::from_proto(request.room_id);
1906    let calling_user_id = session.user_id();
1907    let calling_connection_id = session.connection_id;
1908    let called_user_id = UserId::from_proto(request.called_user_id);
1909    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1910    if !session
1911        .db()
1912        .await
1913        .has_contact(calling_user_id, called_user_id)
1914        .await?
1915    {
1916        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1917    }
1918
1919    let incoming_call = {
1920        let (room, incoming_call) = &mut *session
1921            .db()
1922            .await
1923            .call(
1924                room_id,
1925                calling_user_id,
1926                calling_connection_id,
1927                called_user_id,
1928                initial_project_id,
1929            )
1930            .await?;
1931        room_updated(&room, &session.peer);
1932        mem::take(incoming_call)
1933    };
1934    update_user_contacts(called_user_id, &session).await?;
1935
1936    let mut calls = session
1937        .connection_pool()
1938        .await
1939        .user_connection_ids(called_user_id)
1940        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1941        .collect::<FuturesUnordered<_>>();
1942
1943    while let Some(call_response) = calls.next().await {
1944        match call_response.as_ref() {
1945            Ok(_) => {
1946                response.send(proto::Ack {})?;
1947                return Ok(());
1948            }
1949            Err(_) => {
1950                call_response.trace_err();
1951            }
1952        }
1953    }
1954
1955    {
1956        let room = session
1957            .db()
1958            .await
1959            .call_failed(room_id, called_user_id)
1960            .await?;
1961        room_updated(&room, &session.peer);
1962    }
1963    update_user_contacts(called_user_id, &session).await?;
1964
1965    Err(anyhow!("failed to ring user"))?
1966}
1967
1968/// Cancel an outgoing call.
1969async fn cancel_call(
1970    request: proto::CancelCall,
1971    response: Response<proto::CancelCall>,
1972    session: UserSession,
1973) -> Result<()> {
1974    let called_user_id = UserId::from_proto(request.called_user_id);
1975    let room_id = RoomId::from_proto(request.room_id);
1976    {
1977        let room = session
1978            .db()
1979            .await
1980            .cancel_call(room_id, session.connection_id, called_user_id)
1981            .await?;
1982        room_updated(&room, &session.peer);
1983    }
1984
1985    for connection_id in session
1986        .connection_pool()
1987        .await
1988        .user_connection_ids(called_user_id)
1989    {
1990        session
1991            .peer
1992            .send(
1993                connection_id,
1994                proto::CallCanceled {
1995                    room_id: room_id.to_proto(),
1996                },
1997            )
1998            .trace_err();
1999    }
2000    response.send(proto::Ack {})?;
2001
2002    update_user_contacts(called_user_id, &session).await?;
2003    Ok(())
2004}
2005
2006/// Decline an incoming call.
2007async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
2008    let room_id = RoomId::from_proto(message.room_id);
2009    {
2010        let room = session
2011            .db()
2012            .await
2013            .decline_call(Some(room_id), session.user_id())
2014            .await?
2015            .ok_or_else(|| anyhow!("failed to decline call"))?;
2016        room_updated(&room, &session.peer);
2017    }
2018
2019    for connection_id in session
2020        .connection_pool()
2021        .await
2022        .user_connection_ids(session.user_id())
2023    {
2024        session
2025            .peer
2026            .send(
2027                connection_id,
2028                proto::CallCanceled {
2029                    room_id: room_id.to_proto(),
2030                },
2031            )
2032            .trace_err();
2033    }
2034    update_user_contacts(session.user_id(), &session).await?;
2035    Ok(())
2036}
2037
2038/// Updates other participants in the room with your current location.
2039async fn update_participant_location(
2040    request: proto::UpdateParticipantLocation,
2041    response: Response<proto::UpdateParticipantLocation>,
2042    session: UserSession,
2043) -> Result<()> {
2044    let room_id = RoomId::from_proto(request.room_id);
2045    let location = request
2046        .location
2047        .ok_or_else(|| anyhow!("invalid location"))?;
2048
2049    let db = session.db().await;
2050    let room = db
2051        .update_room_participant_location(room_id, session.connection_id, location)
2052        .await?;
2053
2054    room_updated(&room, &session.peer);
2055    response.send(proto::Ack {})?;
2056    Ok(())
2057}
2058
2059/// Share a project into the room.
2060async fn share_project(
2061    request: proto::ShareProject,
2062    response: Response<proto::ShareProject>,
2063    session: UserSession,
2064) -> Result<()> {
2065    let (project_id, room) = &*session
2066        .db()
2067        .await
2068        .share_project(
2069            RoomId::from_proto(request.room_id),
2070            session.connection_id,
2071            &request.worktrees,
2072            request
2073                .dev_server_project_id
2074                .map(|id| DevServerProjectId::from_proto(id)),
2075        )
2076        .await?;
2077    response.send(proto::ShareProjectResponse {
2078        project_id: project_id.to_proto(),
2079    })?;
2080    room_updated(&room, &session.peer);
2081
2082    Ok(())
2083}
2084
2085/// Unshare a project from the room.
2086async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2087    let project_id = ProjectId::from_proto(message.project_id);
2088    unshare_project_internal(
2089        project_id,
2090        session.connection_id,
2091        session.user_id(),
2092        &session,
2093    )
2094    .await
2095}
2096
2097async fn unshare_project_internal(
2098    project_id: ProjectId,
2099    connection_id: ConnectionId,
2100    user_id: Option<UserId>,
2101    session: &Session,
2102) -> Result<()> {
2103    let delete = {
2104        let room_guard = session
2105            .db()
2106            .await
2107            .unshare_project(project_id, connection_id, user_id)
2108            .await?;
2109
2110        let (delete, room, guest_connection_ids) = &*room_guard;
2111
2112        let message = proto::UnshareProject {
2113            project_id: project_id.to_proto(),
2114        };
2115
2116        broadcast(
2117            Some(connection_id),
2118            guest_connection_ids.iter().copied(),
2119            |conn_id| session.peer.send(conn_id, message.clone()),
2120        );
2121        if let Some(room) = room {
2122            room_updated(room, &session.peer);
2123        }
2124
2125        *delete
2126    };
2127
2128    if delete {
2129        let db = session.db().await;
2130        db.delete_project(project_id).await?;
2131    }
2132
2133    Ok(())
2134}
2135
2136/// DevServer makes a project available online
2137async fn share_dev_server_project(
2138    request: proto::ShareDevServerProject,
2139    response: Response<proto::ShareDevServerProject>,
2140    session: DevServerSession,
2141) -> Result<()> {
2142    let (dev_server_project, user_id, status) = session
2143        .db()
2144        .await
2145        .share_dev_server_project(
2146            DevServerProjectId::from_proto(request.dev_server_project_id),
2147            session.dev_server_id(),
2148            session.connection_id,
2149            &request.worktrees,
2150        )
2151        .await?;
2152    let Some(project_id) = dev_server_project.project_id else {
2153        return Err(anyhow!("failed to share remote project"))?;
2154    };
2155
2156    send_dev_server_projects_update(user_id, status, &session).await;
2157
2158    response.send(proto::ShareProjectResponse { project_id })?;
2159
2160    Ok(())
2161}
2162
2163/// Join someone elses shared project.
2164async fn join_project(
2165    request: proto::JoinProject,
2166    response: Response<proto::JoinProject>,
2167    session: UserSession,
2168) -> Result<()> {
2169    let project_id = ProjectId::from_proto(request.project_id);
2170
2171    tracing::info!(%project_id, "join project");
2172
2173    let db = session.db().await;
2174    let (project, replica_id) = &mut *db
2175        .join_project(project_id, session.connection_id, session.user_id())
2176        .await?;
2177    drop(db);
2178    tracing::info!(%project_id, "join remote project");
2179    join_project_internal(response, session, project, replica_id)
2180}
2181
2182trait JoinProjectInternalResponse {
2183    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2184}
2185impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2186    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2187        Response::<proto::JoinProject>::send(self, result)
2188    }
2189}
2190impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2191    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2192        Response::<proto::JoinHostedProject>::send(self, result)
2193    }
2194}
2195
2196fn join_project_internal(
2197    response: impl JoinProjectInternalResponse,
2198    session: UserSession,
2199    project: &mut Project,
2200    replica_id: &ReplicaId,
2201) -> Result<()> {
2202    let collaborators = project
2203        .collaborators
2204        .iter()
2205        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2206        .map(|collaborator| collaborator.to_proto())
2207        .collect::<Vec<_>>();
2208    let project_id = project.id;
2209    let guest_user_id = session.user_id();
2210
2211    let worktrees = project
2212        .worktrees
2213        .iter()
2214        .map(|(id, worktree)| proto::WorktreeMetadata {
2215            id: *id,
2216            root_name: worktree.root_name.clone(),
2217            visible: worktree.visible,
2218            abs_path: worktree.abs_path.clone(),
2219        })
2220        .collect::<Vec<_>>();
2221
2222    let add_project_collaborator = proto::AddProjectCollaborator {
2223        project_id: project_id.to_proto(),
2224        collaborator: Some(proto::Collaborator {
2225            peer_id: Some(session.connection_id.into()),
2226            replica_id: replica_id.0 as u32,
2227            user_id: guest_user_id.to_proto(),
2228        }),
2229    };
2230
2231    for collaborator in &collaborators {
2232        session
2233            .peer
2234            .send(
2235                collaborator.peer_id.unwrap().into(),
2236                add_project_collaborator.clone(),
2237            )
2238            .trace_err();
2239    }
2240
2241    // First, we send the metadata associated with each worktree.
2242    response.send(proto::JoinProjectResponse {
2243        project_id: project.id.0 as u64,
2244        worktrees: worktrees.clone(),
2245        replica_id: replica_id.0 as u32,
2246        collaborators: collaborators.clone(),
2247        language_servers: project.language_servers.clone(),
2248        role: project.role.into(),
2249        dev_server_project_id: project
2250            .dev_server_project_id
2251            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2252    })?;
2253
2254    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2255        #[cfg(any(test, feature = "test-support"))]
2256        const MAX_CHUNK_SIZE: usize = 2;
2257        #[cfg(not(any(test, feature = "test-support")))]
2258        const MAX_CHUNK_SIZE: usize = 256;
2259
2260        // Stream this worktree's entries.
2261        let message = proto::UpdateWorktree {
2262            project_id: project_id.to_proto(),
2263            worktree_id,
2264            abs_path: worktree.abs_path.clone(),
2265            root_name: worktree.root_name,
2266            updated_entries: worktree.entries,
2267            removed_entries: Default::default(),
2268            scan_id: worktree.scan_id,
2269            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2270            updated_repositories: worktree.repository_entries.into_values().collect(),
2271            removed_repositories: Default::default(),
2272        };
2273        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2274            session.peer.send(session.connection_id, update.clone())?;
2275        }
2276
2277        // Stream this worktree's diagnostics.
2278        for summary in worktree.diagnostic_summaries {
2279            session.peer.send(
2280                session.connection_id,
2281                proto::UpdateDiagnosticSummary {
2282                    project_id: project_id.to_proto(),
2283                    worktree_id: worktree.id,
2284                    summary: Some(summary),
2285                },
2286            )?;
2287        }
2288
2289        for settings_file in worktree.settings_files {
2290            session.peer.send(
2291                session.connection_id,
2292                proto::UpdateWorktreeSettings {
2293                    project_id: project_id.to_proto(),
2294                    worktree_id: worktree.id,
2295                    path: settings_file.path,
2296                    content: Some(settings_file.content),
2297                },
2298            )?;
2299        }
2300    }
2301
2302    for language_server in &project.language_servers {
2303        session.peer.send(
2304            session.connection_id,
2305            proto::UpdateLanguageServer {
2306                project_id: project_id.to_proto(),
2307                language_server_id: language_server.id,
2308                variant: Some(
2309                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2310                        proto::LspDiskBasedDiagnosticsUpdated {},
2311                    ),
2312                ),
2313            },
2314        )?;
2315    }
2316
2317    Ok(())
2318}
2319
2320/// Leave someone elses shared project.
2321async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2322    let sender_id = session.connection_id;
2323    let project_id = ProjectId::from_proto(request.project_id);
2324    let db = session.db().await;
2325    if db.is_hosted_project(project_id).await? {
2326        let project = db.leave_hosted_project(project_id, sender_id).await?;
2327        project_left(&project, &session);
2328        return Ok(());
2329    }
2330
2331    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2332    tracing::info!(
2333        %project_id,
2334        "leave project"
2335    );
2336
2337    project_left(&project, &session);
2338    if let Some(room) = room {
2339        room_updated(&room, &session.peer);
2340    }
2341
2342    Ok(())
2343}
2344
2345async fn join_hosted_project(
2346    request: proto::JoinHostedProject,
2347    response: Response<proto::JoinHostedProject>,
2348    session: UserSession,
2349) -> Result<()> {
2350    let (mut project, replica_id) = session
2351        .db()
2352        .await
2353        .join_hosted_project(
2354            ProjectId(request.project_id as i32),
2355            session.user_id(),
2356            session.connection_id,
2357        )
2358        .await?;
2359
2360    join_project_internal(response, session, &mut project, &replica_id)
2361}
2362
2363async fn list_remote_directory(
2364    request: proto::ListRemoteDirectory,
2365    response: Response<proto::ListRemoteDirectory>,
2366    session: UserSession,
2367) -> Result<()> {
2368    let dev_server_id = DevServerId(request.dev_server_id as i32);
2369    let dev_server_connection_id = session
2370        .connection_pool()
2371        .await
2372        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2373
2374    session
2375        .db()
2376        .await
2377        .get_dev_server_for_user(dev_server_id, session.user_id())
2378        .await?;
2379
2380    response.send(
2381        session
2382            .peer
2383            .forward_request(session.connection_id, dev_server_connection_id, request)
2384            .await?,
2385    )?;
2386    Ok(())
2387}
2388
2389async fn update_dev_server_project(
2390    request: proto::UpdateDevServerProject,
2391    response: Response<proto::UpdateDevServerProject>,
2392    session: UserSession,
2393) -> Result<()> {
2394    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2395
2396    let (dev_server_project, update) = session
2397        .db()
2398        .await
2399        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2400        .await?;
2401
2402    let projects = session
2403        .db()
2404        .await
2405        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2406        .await?;
2407
2408    let dev_server_connection_id = session
2409        .connection_pool()
2410        .await
2411        .dev_server_connection_id_supporting(
2412            dev_server_project.dev_server_id,
2413            ZedVersion::with_list_directory(),
2414        )?;
2415
2416    session.peer.send(
2417        dev_server_connection_id,
2418        proto::DevServerInstructions { projects },
2419    )?;
2420
2421    send_dev_server_projects_update(session.user_id(), update, &session).await;
2422
2423    response.send(proto::Ack {})
2424}
2425
2426async fn create_dev_server_project(
2427    request: proto::CreateDevServerProject,
2428    response: Response<proto::CreateDevServerProject>,
2429    session: UserSession,
2430) -> Result<()> {
2431    let dev_server_id = DevServerId(request.dev_server_id as i32);
2432    let dev_server_connection_id = session
2433        .connection_pool()
2434        .await
2435        .dev_server_connection_id(dev_server_id);
2436    let Some(dev_server_connection_id) = dev_server_connection_id else {
2437        Err(ErrorCode::DevServerOffline
2438            .message("Cannot create a remote project when the dev server is offline".to_string())
2439            .anyhow())?
2440    };
2441
2442    let path = request.path.clone();
2443    //Check that the path exists on the dev server
2444    session
2445        .peer
2446        .forward_request(
2447            session.connection_id,
2448            dev_server_connection_id,
2449            proto::ValidateDevServerProjectRequest { path: path.clone() },
2450        )
2451        .await?;
2452
2453    let (dev_server_project, update) = session
2454        .db()
2455        .await
2456        .create_dev_server_project(
2457            DevServerId(request.dev_server_id as i32),
2458            &request.path,
2459            session.user_id(),
2460        )
2461        .await?;
2462
2463    let projects = session
2464        .db()
2465        .await
2466        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2467        .await?;
2468
2469    session.peer.send(
2470        dev_server_connection_id,
2471        proto::DevServerInstructions { projects },
2472    )?;
2473
2474    send_dev_server_projects_update(session.user_id(), update, &session).await;
2475
2476    response.send(proto::CreateDevServerProjectResponse {
2477        dev_server_project: Some(dev_server_project.to_proto(None)),
2478    })?;
2479    Ok(())
2480}
2481
2482async fn create_dev_server(
2483    request: proto::CreateDevServer,
2484    response: Response<proto::CreateDevServer>,
2485    session: UserSession,
2486) -> Result<()> {
2487    let access_token = auth::random_token();
2488    let hashed_access_token = auth::hash_access_token(&access_token);
2489
2490    if request.name.is_empty() {
2491        return Err(proto::ErrorCode::Forbidden
2492            .message("Dev server name cannot be empty".to_string())
2493            .anyhow())?;
2494    }
2495
2496    let (dev_server, status) = session
2497        .db()
2498        .await
2499        .create_dev_server(
2500            &request.name,
2501            request.ssh_connection_string.as_deref(),
2502            &hashed_access_token,
2503            session.user_id(),
2504        )
2505        .await?;
2506
2507    send_dev_server_projects_update(session.user_id(), status, &session).await;
2508
2509    response.send(proto::CreateDevServerResponse {
2510        dev_server_id: dev_server.id.0 as u64,
2511        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2512        name: request.name,
2513    })?;
2514    Ok(())
2515}
2516
2517async fn regenerate_dev_server_token(
2518    request: proto::RegenerateDevServerToken,
2519    response: Response<proto::RegenerateDevServerToken>,
2520    session: UserSession,
2521) -> Result<()> {
2522    let dev_server_id = DevServerId(request.dev_server_id as i32);
2523    let access_token = auth::random_token();
2524    let hashed_access_token = auth::hash_access_token(&access_token);
2525
2526    let connection_id = session
2527        .connection_pool()
2528        .await
2529        .dev_server_connection_id(dev_server_id);
2530    if let Some(connection_id) = connection_id {
2531        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2532        session.peer.send(
2533            connection_id,
2534            proto::ShutdownDevServer {
2535                reason: Some("dev server token was regenerated".to_string()),
2536            },
2537        )?;
2538        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2539    }
2540
2541    let status = session
2542        .db()
2543        .await
2544        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2545        .await?;
2546
2547    send_dev_server_projects_update(session.user_id(), status, &session).await;
2548
2549    response.send(proto::RegenerateDevServerTokenResponse {
2550        dev_server_id: dev_server_id.to_proto(),
2551        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2552    })?;
2553    Ok(())
2554}
2555
2556async fn rename_dev_server(
2557    request: proto::RenameDevServer,
2558    response: Response<proto::RenameDevServer>,
2559    session: UserSession,
2560) -> Result<()> {
2561    if request.name.trim().is_empty() {
2562        return Err(proto::ErrorCode::Forbidden
2563            .message("Dev server name cannot be empty".to_string())
2564            .anyhow())?;
2565    }
2566
2567    let dev_server_id = DevServerId(request.dev_server_id as i32);
2568    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2569    if dev_server.user_id != session.user_id() {
2570        return Err(anyhow!(ErrorCode::Forbidden))?;
2571    }
2572
2573    let status = session
2574        .db()
2575        .await
2576        .rename_dev_server(
2577            dev_server_id,
2578            &request.name,
2579            request.ssh_connection_string.as_deref(),
2580            session.user_id(),
2581        )
2582        .await?;
2583
2584    send_dev_server_projects_update(session.user_id(), status, &session).await;
2585
2586    response.send(proto::Ack {})?;
2587    Ok(())
2588}
2589
2590async fn delete_dev_server(
2591    request: proto::DeleteDevServer,
2592    response: Response<proto::DeleteDevServer>,
2593    session: UserSession,
2594) -> Result<()> {
2595    let dev_server_id = DevServerId(request.dev_server_id as i32);
2596    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2597    if dev_server.user_id != session.user_id() {
2598        return Err(anyhow!(ErrorCode::Forbidden))?;
2599    }
2600
2601    let connection_id = session
2602        .connection_pool()
2603        .await
2604        .dev_server_connection_id(dev_server_id);
2605    if let Some(connection_id) = connection_id {
2606        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2607        session.peer.send(
2608            connection_id,
2609            proto::ShutdownDevServer {
2610                reason: Some("dev server was deleted".to_string()),
2611            },
2612        )?;
2613        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2614    }
2615
2616    let status = session
2617        .db()
2618        .await
2619        .delete_dev_server(dev_server_id, session.user_id())
2620        .await?;
2621
2622    send_dev_server_projects_update(session.user_id(), status, &session).await;
2623
2624    response.send(proto::Ack {})?;
2625    Ok(())
2626}
2627
2628async fn delete_dev_server_project(
2629    request: proto::DeleteDevServerProject,
2630    response: Response<proto::DeleteDevServerProject>,
2631    session: UserSession,
2632) -> Result<()> {
2633    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2634    let dev_server_project = session
2635        .db()
2636        .await
2637        .get_dev_server_project(dev_server_project_id)
2638        .await?;
2639
2640    let dev_server = session
2641        .db()
2642        .await
2643        .get_dev_server(dev_server_project.dev_server_id)
2644        .await?;
2645    if dev_server.user_id != session.user_id() {
2646        return Err(anyhow!(ErrorCode::Forbidden))?;
2647    }
2648
2649    let dev_server_connection_id = session
2650        .connection_pool()
2651        .await
2652        .dev_server_connection_id(dev_server.id);
2653
2654    if let Some(dev_server_connection_id) = dev_server_connection_id {
2655        let project = session
2656            .db()
2657            .await
2658            .find_dev_server_project(dev_server_project_id)
2659            .await;
2660        if let Ok(project) = project {
2661            unshare_project_internal(
2662                project.id,
2663                dev_server_connection_id,
2664                Some(session.user_id()),
2665                &session,
2666            )
2667            .await?;
2668        }
2669    }
2670
2671    let (projects, status) = session
2672        .db()
2673        .await
2674        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2675        .await?;
2676
2677    if let Some(dev_server_connection_id) = dev_server_connection_id {
2678        session.peer.send(
2679            dev_server_connection_id,
2680            proto::DevServerInstructions { projects },
2681        )?;
2682    }
2683
2684    send_dev_server_projects_update(session.user_id(), status, &session).await;
2685
2686    response.send(proto::Ack {})?;
2687    Ok(())
2688}
2689
2690async fn rejoin_dev_server_projects(
2691    request: proto::RejoinRemoteProjects,
2692    response: Response<proto::RejoinRemoteProjects>,
2693    session: UserSession,
2694) -> Result<()> {
2695    let mut rejoined_projects = {
2696        let db = session.db().await;
2697        db.rejoin_dev_server_projects(
2698            &request.rejoined_projects,
2699            session.user_id(),
2700            session.0.connection_id,
2701        )
2702        .await?
2703    };
2704    response.send(proto::RejoinRemoteProjectsResponse {
2705        rejoined_projects: rejoined_projects
2706            .iter()
2707            .map(|project| project.to_proto())
2708            .collect(),
2709    })?;
2710    notify_rejoined_projects(&mut rejoined_projects, &session)
2711}
2712
2713async fn reconnect_dev_server(
2714    request: proto::ReconnectDevServer,
2715    response: Response<proto::ReconnectDevServer>,
2716    session: DevServerSession,
2717) -> Result<()> {
2718    let reshared_projects = {
2719        let db = session.db().await;
2720        db.reshare_dev_server_projects(
2721            &request.reshared_projects,
2722            session.dev_server_id(),
2723            session.0.connection_id,
2724        )
2725        .await?
2726    };
2727
2728    for project in &reshared_projects {
2729        for collaborator in &project.collaborators {
2730            session
2731                .peer
2732                .send(
2733                    collaborator.connection_id,
2734                    proto::UpdateProjectCollaborator {
2735                        project_id: project.id.to_proto(),
2736                        old_peer_id: Some(project.old_connection_id.into()),
2737                        new_peer_id: Some(session.connection_id.into()),
2738                    },
2739                )
2740                .trace_err();
2741        }
2742
2743        broadcast(
2744            Some(session.connection_id),
2745            project
2746                .collaborators
2747                .iter()
2748                .map(|collaborator| collaborator.connection_id),
2749            |connection_id| {
2750                session.peer.forward_send(
2751                    session.connection_id,
2752                    connection_id,
2753                    proto::UpdateProject {
2754                        project_id: project.id.to_proto(),
2755                        worktrees: project.worktrees.clone(),
2756                    },
2757                )
2758            },
2759        );
2760    }
2761
2762    response.send(proto::ReconnectDevServerResponse {
2763        reshared_projects: reshared_projects
2764            .iter()
2765            .map(|project| proto::ResharedProject {
2766                id: project.id.to_proto(),
2767                collaborators: project
2768                    .collaborators
2769                    .iter()
2770                    .map(|collaborator| collaborator.to_proto())
2771                    .collect(),
2772            })
2773            .collect(),
2774    })?;
2775
2776    Ok(())
2777}
2778
2779async fn shutdown_dev_server(
2780    _: proto::ShutdownDevServer,
2781    response: Response<proto::ShutdownDevServer>,
2782    session: DevServerSession,
2783) -> Result<()> {
2784    response.send(proto::Ack {})?;
2785    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2786    remove_dev_server_connection(session.dev_server_id(), &session).await
2787}
2788
2789async fn shutdown_dev_server_internal(
2790    dev_server_id: DevServerId,
2791    connection_id: ConnectionId,
2792    session: &Session,
2793) -> Result<()> {
2794    let (dev_server_projects, dev_server) = {
2795        let db = session.db().await;
2796        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2797        let dev_server = db.get_dev_server(dev_server_id).await?;
2798        (dev_server_projects, dev_server)
2799    };
2800
2801    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2802        unshare_project_internal(
2803            ProjectId::from_proto(project_id),
2804            connection_id,
2805            None,
2806            session,
2807        )
2808        .await?;
2809    }
2810
2811    session
2812        .connection_pool()
2813        .await
2814        .set_dev_server_offline(dev_server_id);
2815
2816    let status = session
2817        .db()
2818        .await
2819        .dev_server_projects_update(dev_server.user_id)
2820        .await?;
2821    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2822
2823    Ok(())
2824}
2825
2826async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2827    let dev_server_connection = session
2828        .connection_pool()
2829        .await
2830        .dev_server_connection_id(dev_server_id);
2831
2832    if let Some(dev_server_connection) = dev_server_connection {
2833        session
2834            .connection_pool()
2835            .await
2836            .remove_connection(dev_server_connection)?;
2837    }
2838    Ok(())
2839}
2840
2841/// Updates other participants with changes to the project
2842async fn update_project(
2843    request: proto::UpdateProject,
2844    response: Response<proto::UpdateProject>,
2845    session: Session,
2846) -> Result<()> {
2847    let project_id = ProjectId::from_proto(request.project_id);
2848    let (room, guest_connection_ids) = &*session
2849        .db()
2850        .await
2851        .update_project(project_id, session.connection_id, &request.worktrees)
2852        .await?;
2853    broadcast(
2854        Some(session.connection_id),
2855        guest_connection_ids.iter().copied(),
2856        |connection_id| {
2857            session
2858                .peer
2859                .forward_send(session.connection_id, connection_id, request.clone())
2860        },
2861    );
2862    if let Some(room) = room {
2863        room_updated(&room, &session.peer);
2864    }
2865    response.send(proto::Ack {})?;
2866
2867    Ok(())
2868}
2869
2870/// Updates other participants with changes to the worktree
2871async fn update_worktree(
2872    request: proto::UpdateWorktree,
2873    response: Response<proto::UpdateWorktree>,
2874    session: Session,
2875) -> Result<()> {
2876    let guest_connection_ids = session
2877        .db()
2878        .await
2879        .update_worktree(&request, session.connection_id)
2880        .await?;
2881
2882    broadcast(
2883        Some(session.connection_id),
2884        guest_connection_ids.iter().copied(),
2885        |connection_id| {
2886            session
2887                .peer
2888                .forward_send(session.connection_id, connection_id, request.clone())
2889        },
2890    );
2891    response.send(proto::Ack {})?;
2892    Ok(())
2893}
2894
2895/// Updates other participants with changes to the diagnostics
2896async fn update_diagnostic_summary(
2897    message: proto::UpdateDiagnosticSummary,
2898    session: Session,
2899) -> Result<()> {
2900    let guest_connection_ids = session
2901        .db()
2902        .await
2903        .update_diagnostic_summary(&message, session.connection_id)
2904        .await?;
2905
2906    broadcast(
2907        Some(session.connection_id),
2908        guest_connection_ids.iter().copied(),
2909        |connection_id| {
2910            session
2911                .peer
2912                .forward_send(session.connection_id, connection_id, message.clone())
2913        },
2914    );
2915
2916    Ok(())
2917}
2918
2919/// Updates other participants with changes to the worktree settings
2920async fn update_worktree_settings(
2921    message: proto::UpdateWorktreeSettings,
2922    session: Session,
2923) -> Result<()> {
2924    let guest_connection_ids = session
2925        .db()
2926        .await
2927        .update_worktree_settings(&message, session.connection_id)
2928        .await?;
2929
2930    broadcast(
2931        Some(session.connection_id),
2932        guest_connection_ids.iter().copied(),
2933        |connection_id| {
2934            session
2935                .peer
2936                .forward_send(session.connection_id, connection_id, message.clone())
2937        },
2938    );
2939
2940    Ok(())
2941}
2942
2943/// Notify other participants that a  language server has started.
2944async fn start_language_server(
2945    request: proto::StartLanguageServer,
2946    session: Session,
2947) -> Result<()> {
2948    let guest_connection_ids = session
2949        .db()
2950        .await
2951        .start_language_server(&request, session.connection_id)
2952        .await?;
2953
2954    broadcast(
2955        Some(session.connection_id),
2956        guest_connection_ids.iter().copied(),
2957        |connection_id| {
2958            session
2959                .peer
2960                .forward_send(session.connection_id, connection_id, request.clone())
2961        },
2962    );
2963    Ok(())
2964}
2965
2966/// Notify other participants that a language server has changed.
2967async fn update_language_server(
2968    request: proto::UpdateLanguageServer,
2969    session: Session,
2970) -> Result<()> {
2971    let project_id = ProjectId::from_proto(request.project_id);
2972    let project_connection_ids = session
2973        .db()
2974        .await
2975        .project_connection_ids(project_id, session.connection_id, true)
2976        .await?;
2977    broadcast(
2978        Some(session.connection_id),
2979        project_connection_ids.iter().copied(),
2980        |connection_id| {
2981            session
2982                .peer
2983                .forward_send(session.connection_id, connection_id, request.clone())
2984        },
2985    );
2986    Ok(())
2987}
2988
2989/// forward a project request to the host. These requests should be read only
2990/// as guests are allowed to send them.
2991async fn forward_read_only_project_request<T>(
2992    request: T,
2993    response: Response<T>,
2994    session: UserSession,
2995) -> Result<()>
2996where
2997    T: EntityMessage + RequestMessage,
2998{
2999    let project_id = ProjectId::from_proto(request.remote_entity_id());
3000    let host_connection_id = session
3001        .db()
3002        .await
3003        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
3004        .await?;
3005    let payload = session
3006        .peer
3007        .forward_request(session.connection_id, host_connection_id, request)
3008        .await?;
3009    response.send(payload)?;
3010    Ok(())
3011}
3012
3013/// forward a project request to the dev server. Only allowed
3014/// if it's your dev server.
3015async fn forward_project_request_for_owner<T>(
3016    request: T,
3017    response: Response<T>,
3018    session: UserSession,
3019) -> Result<()>
3020where
3021    T: EntityMessage + RequestMessage,
3022{
3023    let project_id = ProjectId::from_proto(request.remote_entity_id());
3024
3025    let host_connection_id = session
3026        .db()
3027        .await
3028        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3029        .await?;
3030    let payload = session
3031        .peer
3032        .forward_request(session.connection_id, host_connection_id, request)
3033        .await?;
3034    response.send(payload)?;
3035    Ok(())
3036}
3037
3038/// forward a project request to the host. These requests are disallowed
3039/// for guests.
3040async fn forward_mutating_project_request<T>(
3041    request: T,
3042    response: Response<T>,
3043    session: UserSession,
3044) -> Result<()>
3045where
3046    T: EntityMessage + RequestMessage,
3047{
3048    let project_id = ProjectId::from_proto(request.remote_entity_id());
3049
3050    let host_connection_id = session
3051        .db()
3052        .await
3053        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3054        .await?;
3055    let payload = session
3056        .peer
3057        .forward_request(session.connection_id, host_connection_id, request)
3058        .await?;
3059    response.send(payload)?;
3060    Ok(())
3061}
3062
3063/// forward a project request to the host. These requests are disallowed
3064/// for guests.
3065async fn forward_versioned_mutating_project_request<T>(
3066    request: T,
3067    response: Response<T>,
3068    session: UserSession,
3069) -> Result<()>
3070where
3071    T: EntityMessage + RequestMessage + VersionedMessage,
3072{
3073    let project_id = ProjectId::from_proto(request.remote_entity_id());
3074
3075    let host_connection_id = session
3076        .db()
3077        .await
3078        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3079        .await?;
3080    if let Some(host_version) = session
3081        .connection_pool()
3082        .await
3083        .connection(host_connection_id)
3084        .map(|c| c.zed_version)
3085    {
3086        if let Some(min_required_version) = request.required_host_version() {
3087            if min_required_version > host_version {
3088                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3089                    .with_tag("required", &min_required_version.to_string())))?;
3090            }
3091        }
3092    }
3093
3094    let payload = session
3095        .peer
3096        .forward_request(session.connection_id, host_connection_id, request)
3097        .await?;
3098    response.send(payload)?;
3099    Ok(())
3100}
3101
3102/// Notify other participants that a new buffer has been created
3103async fn create_buffer_for_peer(
3104    request: proto::CreateBufferForPeer,
3105    session: Session,
3106) -> Result<()> {
3107    session
3108        .db()
3109        .await
3110        .check_user_is_project_host(
3111            ProjectId::from_proto(request.project_id),
3112            session.connection_id,
3113        )
3114        .await?;
3115    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3116    session
3117        .peer
3118        .forward_send(session.connection_id, peer_id.into(), request)?;
3119    Ok(())
3120}
3121
3122/// Notify other participants that a buffer has been updated. This is
3123/// allowed for guests as long as the update is limited to selections.
3124async fn update_buffer(
3125    request: proto::UpdateBuffer,
3126    response: Response<proto::UpdateBuffer>,
3127    session: Session,
3128) -> Result<()> {
3129    let project_id = ProjectId::from_proto(request.project_id);
3130    let mut capability = Capability::ReadOnly;
3131
3132    for op in request.operations.iter() {
3133        match op.variant {
3134            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3135            Some(_) => capability = Capability::ReadWrite,
3136        }
3137    }
3138
3139    let host = {
3140        let guard = session
3141            .db()
3142            .await
3143            .connections_for_buffer_update(
3144                project_id,
3145                session.principal_id(),
3146                session.connection_id,
3147                capability,
3148            )
3149            .await?;
3150
3151        let (host, guests) = &*guard;
3152
3153        broadcast(
3154            Some(session.connection_id),
3155            guests.clone(),
3156            |connection_id| {
3157                session
3158                    .peer
3159                    .forward_send(session.connection_id, connection_id, request.clone())
3160            },
3161        );
3162
3163        *host
3164    };
3165
3166    if host != session.connection_id {
3167        session
3168            .peer
3169            .forward_request(session.connection_id, host, request.clone())
3170            .await?;
3171    }
3172
3173    response.send(proto::Ack {})?;
3174    Ok(())
3175}
3176
3177async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3178    let project_id = ProjectId::from_proto(message.project_id);
3179
3180    let operation = message.operation.as_ref().context("invalid operation")?;
3181    let capability = match operation.variant.as_ref() {
3182        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3183            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3184                match buffer_op.variant {
3185                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3186                        Capability::ReadOnly
3187                    }
3188                    _ => Capability::ReadWrite,
3189                }
3190            } else {
3191                Capability::ReadWrite
3192            }
3193        }
3194        Some(_) => Capability::ReadWrite,
3195        None => Capability::ReadOnly,
3196    };
3197
3198    let guard = session
3199        .db()
3200        .await
3201        .connections_for_buffer_update(
3202            project_id,
3203            session.principal_id(),
3204            session.connection_id,
3205            capability,
3206        )
3207        .await?;
3208
3209    let (host, guests) = &*guard;
3210
3211    broadcast(
3212        Some(session.connection_id),
3213        guests.iter().chain([host]).copied(),
3214        |connection_id| {
3215            session
3216                .peer
3217                .forward_send(session.connection_id, connection_id, message.clone())
3218        },
3219    );
3220
3221    Ok(())
3222}
3223
3224/// Notify other participants that a project has been updated.
3225async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3226    request: T,
3227    session: Session,
3228) -> Result<()> {
3229    let project_id = ProjectId::from_proto(request.remote_entity_id());
3230    let project_connection_ids = session
3231        .db()
3232        .await
3233        .project_connection_ids(project_id, session.connection_id, false)
3234        .await?;
3235
3236    broadcast(
3237        Some(session.connection_id),
3238        project_connection_ids.iter().copied(),
3239        |connection_id| {
3240            session
3241                .peer
3242                .forward_send(session.connection_id, connection_id, request.clone())
3243        },
3244    );
3245    Ok(())
3246}
3247
3248/// Start following another user in a call.
3249async fn follow(
3250    request: proto::Follow,
3251    response: Response<proto::Follow>,
3252    session: UserSession,
3253) -> Result<()> {
3254    let room_id = RoomId::from_proto(request.room_id);
3255    let project_id = request.project_id.map(ProjectId::from_proto);
3256    let leader_id = request
3257        .leader_id
3258        .ok_or_else(|| anyhow!("invalid leader id"))?
3259        .into();
3260    let follower_id = session.connection_id;
3261
3262    session
3263        .db()
3264        .await
3265        .check_room_participants(room_id, leader_id, session.connection_id)
3266        .await?;
3267
3268    let response_payload = session
3269        .peer
3270        .forward_request(session.connection_id, leader_id, request)
3271        .await?;
3272    response.send(response_payload)?;
3273
3274    if let Some(project_id) = project_id {
3275        let room = session
3276            .db()
3277            .await
3278            .follow(room_id, project_id, leader_id, follower_id)
3279            .await?;
3280        room_updated(&room, &session.peer);
3281    }
3282
3283    Ok(())
3284}
3285
3286/// Stop following another user in a call.
3287async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3288    let room_id = RoomId::from_proto(request.room_id);
3289    let project_id = request.project_id.map(ProjectId::from_proto);
3290    let leader_id = request
3291        .leader_id
3292        .ok_or_else(|| anyhow!("invalid leader id"))?
3293        .into();
3294    let follower_id = session.connection_id;
3295
3296    session
3297        .db()
3298        .await
3299        .check_room_participants(room_id, leader_id, session.connection_id)
3300        .await?;
3301
3302    session
3303        .peer
3304        .forward_send(session.connection_id, leader_id, request)?;
3305
3306    if let Some(project_id) = project_id {
3307        let room = session
3308            .db()
3309            .await
3310            .unfollow(room_id, project_id, leader_id, follower_id)
3311            .await?;
3312        room_updated(&room, &session.peer);
3313    }
3314
3315    Ok(())
3316}
3317
3318/// Notify everyone following you of your current location.
3319async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3320    let room_id = RoomId::from_proto(request.room_id);
3321    let database = session.db.lock().await;
3322
3323    let connection_ids = if let Some(project_id) = request.project_id {
3324        let project_id = ProjectId::from_proto(project_id);
3325        database
3326            .project_connection_ids(project_id, session.connection_id, true)
3327            .await?
3328    } else {
3329        database
3330            .room_connection_ids(room_id, session.connection_id)
3331            .await?
3332    };
3333
3334    // For now, don't send view update messages back to that view's current leader.
3335    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3336        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3337        _ => None,
3338    });
3339
3340    for connection_id in connection_ids.iter().cloned() {
3341        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3342            session
3343                .peer
3344                .forward_send(session.connection_id, connection_id, request.clone())?;
3345        }
3346    }
3347    Ok(())
3348}
3349
3350/// Get public data about users.
3351async fn get_users(
3352    request: proto::GetUsers,
3353    response: Response<proto::GetUsers>,
3354    session: Session,
3355) -> Result<()> {
3356    let user_ids = request
3357        .user_ids
3358        .into_iter()
3359        .map(UserId::from_proto)
3360        .collect();
3361    let users = session
3362        .db()
3363        .await
3364        .get_users_by_ids(user_ids)
3365        .await?
3366        .into_iter()
3367        .map(|user| proto::User {
3368            id: user.id.to_proto(),
3369            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3370            github_login: user.github_login,
3371        })
3372        .collect();
3373    response.send(proto::UsersResponse { users })?;
3374    Ok(())
3375}
3376
3377/// Search for users (to invite) buy Github login
3378async fn fuzzy_search_users(
3379    request: proto::FuzzySearchUsers,
3380    response: Response<proto::FuzzySearchUsers>,
3381    session: UserSession,
3382) -> Result<()> {
3383    let query = request.query;
3384    let users = match query.len() {
3385        0 => vec![],
3386        1 | 2 => session
3387            .db()
3388            .await
3389            .get_user_by_github_login(&query)
3390            .await?
3391            .into_iter()
3392            .collect(),
3393        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3394    };
3395    let users = users
3396        .into_iter()
3397        .filter(|user| user.id != session.user_id())
3398        .map(|user| proto::User {
3399            id: user.id.to_proto(),
3400            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3401            github_login: user.github_login,
3402        })
3403        .collect();
3404    response.send(proto::UsersResponse { users })?;
3405    Ok(())
3406}
3407
3408/// Send a contact request to another user.
3409async fn request_contact(
3410    request: proto::RequestContact,
3411    response: Response<proto::RequestContact>,
3412    session: UserSession,
3413) -> Result<()> {
3414    let requester_id = session.user_id();
3415    let responder_id = UserId::from_proto(request.responder_id);
3416    if requester_id == responder_id {
3417        return Err(anyhow!("cannot add yourself as a contact"))?;
3418    }
3419
3420    let notifications = session
3421        .db()
3422        .await
3423        .send_contact_request(requester_id, responder_id)
3424        .await?;
3425
3426    // Update outgoing contact requests of requester
3427    let mut update = proto::UpdateContacts::default();
3428    update.outgoing_requests.push(responder_id.to_proto());
3429    for connection_id in session
3430        .connection_pool()
3431        .await
3432        .user_connection_ids(requester_id)
3433    {
3434        session.peer.send(connection_id, update.clone())?;
3435    }
3436
3437    // Update incoming contact requests of responder
3438    let mut update = proto::UpdateContacts::default();
3439    update
3440        .incoming_requests
3441        .push(proto::IncomingContactRequest {
3442            requester_id: requester_id.to_proto(),
3443        });
3444    let connection_pool = session.connection_pool().await;
3445    for connection_id in connection_pool.user_connection_ids(responder_id) {
3446        session.peer.send(connection_id, update.clone())?;
3447    }
3448
3449    send_notifications(&connection_pool, &session.peer, notifications);
3450
3451    response.send(proto::Ack {})?;
3452    Ok(())
3453}
3454
3455/// Accept or decline a contact request
3456async fn respond_to_contact_request(
3457    request: proto::RespondToContactRequest,
3458    response: Response<proto::RespondToContactRequest>,
3459    session: UserSession,
3460) -> Result<()> {
3461    let responder_id = session.user_id();
3462    let requester_id = UserId::from_proto(request.requester_id);
3463    let db = session.db().await;
3464    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3465        db.dismiss_contact_notification(responder_id, requester_id)
3466            .await?;
3467    } else {
3468        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3469
3470        let notifications = db
3471            .respond_to_contact_request(responder_id, requester_id, accept)
3472            .await?;
3473        let requester_busy = db.is_user_busy(requester_id).await?;
3474        let responder_busy = db.is_user_busy(responder_id).await?;
3475
3476        let pool = session.connection_pool().await;
3477        // Update responder with new contact
3478        let mut update = proto::UpdateContacts::default();
3479        if accept {
3480            update
3481                .contacts
3482                .push(contact_for_user(requester_id, requester_busy, &pool));
3483        }
3484        update
3485            .remove_incoming_requests
3486            .push(requester_id.to_proto());
3487        for connection_id in pool.user_connection_ids(responder_id) {
3488            session.peer.send(connection_id, update.clone())?;
3489        }
3490
3491        // Update requester with new contact
3492        let mut update = proto::UpdateContacts::default();
3493        if accept {
3494            update
3495                .contacts
3496                .push(contact_for_user(responder_id, responder_busy, &pool));
3497        }
3498        update
3499            .remove_outgoing_requests
3500            .push(responder_id.to_proto());
3501
3502        for connection_id in pool.user_connection_ids(requester_id) {
3503            session.peer.send(connection_id, update.clone())?;
3504        }
3505
3506        send_notifications(&pool, &session.peer, notifications);
3507    }
3508
3509    response.send(proto::Ack {})?;
3510    Ok(())
3511}
3512
3513/// Remove a contact.
3514async fn remove_contact(
3515    request: proto::RemoveContact,
3516    response: Response<proto::RemoveContact>,
3517    session: UserSession,
3518) -> Result<()> {
3519    let requester_id = session.user_id();
3520    let responder_id = UserId::from_proto(request.user_id);
3521    let db = session.db().await;
3522    let (contact_accepted, deleted_notification_id) =
3523        db.remove_contact(requester_id, responder_id).await?;
3524
3525    let pool = session.connection_pool().await;
3526    // Update outgoing contact requests of requester
3527    let mut update = proto::UpdateContacts::default();
3528    if contact_accepted {
3529        update.remove_contacts.push(responder_id.to_proto());
3530    } else {
3531        update
3532            .remove_outgoing_requests
3533            .push(responder_id.to_proto());
3534    }
3535    for connection_id in pool.user_connection_ids(requester_id) {
3536        session.peer.send(connection_id, update.clone())?;
3537    }
3538
3539    // Update incoming contact requests of responder
3540    let mut update = proto::UpdateContacts::default();
3541    if contact_accepted {
3542        update.remove_contacts.push(requester_id.to_proto());
3543    } else {
3544        update
3545            .remove_incoming_requests
3546            .push(requester_id.to_proto());
3547    }
3548    for connection_id in pool.user_connection_ids(responder_id) {
3549        session.peer.send(connection_id, update.clone())?;
3550        if let Some(notification_id) = deleted_notification_id {
3551            session.peer.send(
3552                connection_id,
3553                proto::DeleteNotification {
3554                    notification_id: notification_id.to_proto(),
3555                },
3556            )?;
3557        }
3558    }
3559
3560    response.send(proto::Ack {})?;
3561    Ok(())
3562}
3563
3564fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3565    version.0.minor() < 139
3566}
3567
3568async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3569    let plan = session.current_plan().await?;
3570
3571    session
3572        .peer
3573        .send(
3574            session.connection_id,
3575            proto::UpdateUserPlan { plan: plan.into() },
3576        )
3577        .trace_err();
3578
3579    Ok(())
3580}
3581
3582async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3583    subscribe_user_to_channels(
3584        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3585        &session,
3586    )
3587    .await?;
3588    Ok(())
3589}
3590
3591async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3592    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3593    let mut pool = session.connection_pool().await;
3594    for membership in &channels_for_user.channel_memberships {
3595        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3596    }
3597    session.peer.send(
3598        session.connection_id,
3599        build_update_user_channels(&channels_for_user),
3600    )?;
3601    session.peer.send(
3602        session.connection_id,
3603        build_channels_update(channels_for_user),
3604    )?;
3605    Ok(())
3606}
3607
3608/// Creates a new channel.
3609async fn create_channel(
3610    request: proto::CreateChannel,
3611    response: Response<proto::CreateChannel>,
3612    session: UserSession,
3613) -> Result<()> {
3614    let db = session.db().await;
3615
3616    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3617    let (channel, membership) = db
3618        .create_channel(&request.name, parent_id, session.user_id())
3619        .await?;
3620
3621    let root_id = channel.root_id();
3622    let channel = Channel::from_model(channel);
3623
3624    response.send(proto::CreateChannelResponse {
3625        channel: Some(channel.to_proto()),
3626        parent_id: request.parent_id,
3627    })?;
3628
3629    let mut connection_pool = session.connection_pool().await;
3630    if let Some(membership) = membership {
3631        connection_pool.subscribe_to_channel(
3632            membership.user_id,
3633            membership.channel_id,
3634            membership.role,
3635        );
3636        let update = proto::UpdateUserChannels {
3637            channel_memberships: vec![proto::ChannelMembership {
3638                channel_id: membership.channel_id.to_proto(),
3639                role: membership.role.into(),
3640            }],
3641            ..Default::default()
3642        };
3643        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3644            session.peer.send(connection_id, update.clone())?;
3645        }
3646    }
3647
3648    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3649        if !role.can_see_channel(channel.visibility) {
3650            continue;
3651        }
3652
3653        let update = proto::UpdateChannels {
3654            channels: vec![channel.to_proto()],
3655            ..Default::default()
3656        };
3657        session.peer.send(connection_id, update.clone())?;
3658    }
3659
3660    Ok(())
3661}
3662
3663/// Delete a channel
3664async fn delete_channel(
3665    request: proto::DeleteChannel,
3666    response: Response<proto::DeleteChannel>,
3667    session: UserSession,
3668) -> Result<()> {
3669    let db = session.db().await;
3670
3671    let channel_id = request.channel_id;
3672    let (root_channel, removed_channels) = db
3673        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3674        .await?;
3675    response.send(proto::Ack {})?;
3676
3677    // Notify members of removed channels
3678    let mut update = proto::UpdateChannels::default();
3679    update
3680        .delete_channels
3681        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3682
3683    let connection_pool = session.connection_pool().await;
3684    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3685        session.peer.send(connection_id, update.clone())?;
3686    }
3687
3688    Ok(())
3689}
3690
3691/// Invite someone to join a channel.
3692async fn invite_channel_member(
3693    request: proto::InviteChannelMember,
3694    response: Response<proto::InviteChannelMember>,
3695    session: UserSession,
3696) -> Result<()> {
3697    let db = session.db().await;
3698    let channel_id = ChannelId::from_proto(request.channel_id);
3699    let invitee_id = UserId::from_proto(request.user_id);
3700    let InviteMemberResult {
3701        channel,
3702        notifications,
3703    } = db
3704        .invite_channel_member(
3705            channel_id,
3706            invitee_id,
3707            session.user_id(),
3708            request.role().into(),
3709        )
3710        .await?;
3711
3712    let update = proto::UpdateChannels {
3713        channel_invitations: vec![channel.to_proto()],
3714        ..Default::default()
3715    };
3716
3717    let connection_pool = session.connection_pool().await;
3718    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3719        session.peer.send(connection_id, update.clone())?;
3720    }
3721
3722    send_notifications(&connection_pool, &session.peer, notifications);
3723
3724    response.send(proto::Ack {})?;
3725    Ok(())
3726}
3727
3728/// remove someone from a channel
3729async fn remove_channel_member(
3730    request: proto::RemoveChannelMember,
3731    response: Response<proto::RemoveChannelMember>,
3732    session: UserSession,
3733) -> Result<()> {
3734    let db = session.db().await;
3735    let channel_id = ChannelId::from_proto(request.channel_id);
3736    let member_id = UserId::from_proto(request.user_id);
3737
3738    let RemoveChannelMemberResult {
3739        membership_update,
3740        notification_id,
3741    } = db
3742        .remove_channel_member(channel_id, member_id, session.user_id())
3743        .await?;
3744
3745    let mut connection_pool = session.connection_pool().await;
3746    notify_membership_updated(
3747        &mut connection_pool,
3748        membership_update,
3749        member_id,
3750        &session.peer,
3751    );
3752    for connection_id in connection_pool.user_connection_ids(member_id) {
3753        if let Some(notification_id) = notification_id {
3754            session
3755                .peer
3756                .send(
3757                    connection_id,
3758                    proto::DeleteNotification {
3759                        notification_id: notification_id.to_proto(),
3760                    },
3761                )
3762                .trace_err();
3763        }
3764    }
3765
3766    response.send(proto::Ack {})?;
3767    Ok(())
3768}
3769
3770/// Toggle the channel between public and private.
3771/// Care is taken to maintain the invariant that public channels only descend from public channels,
3772/// (though members-only channels can appear at any point in the hierarchy).
3773async fn set_channel_visibility(
3774    request: proto::SetChannelVisibility,
3775    response: Response<proto::SetChannelVisibility>,
3776    session: UserSession,
3777) -> Result<()> {
3778    let db = session.db().await;
3779    let channel_id = ChannelId::from_proto(request.channel_id);
3780    let visibility = request.visibility().into();
3781
3782    let channel_model = db
3783        .set_channel_visibility(channel_id, visibility, session.user_id())
3784        .await?;
3785    let root_id = channel_model.root_id();
3786    let channel = Channel::from_model(channel_model);
3787
3788    let mut connection_pool = session.connection_pool().await;
3789    for (user_id, role) in connection_pool
3790        .channel_user_ids(root_id)
3791        .collect::<Vec<_>>()
3792        .into_iter()
3793    {
3794        let update = if role.can_see_channel(channel.visibility) {
3795            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3796            proto::UpdateChannels {
3797                channels: vec![channel.to_proto()],
3798                ..Default::default()
3799            }
3800        } else {
3801            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3802            proto::UpdateChannels {
3803                delete_channels: vec![channel.id.to_proto()],
3804                ..Default::default()
3805            }
3806        };
3807
3808        for connection_id in connection_pool.user_connection_ids(user_id) {
3809            session.peer.send(connection_id, update.clone())?;
3810        }
3811    }
3812
3813    response.send(proto::Ack {})?;
3814    Ok(())
3815}
3816
3817/// Alter the role for a user in the channel.
3818async fn set_channel_member_role(
3819    request: proto::SetChannelMemberRole,
3820    response: Response<proto::SetChannelMemberRole>,
3821    session: UserSession,
3822) -> Result<()> {
3823    let db = session.db().await;
3824    let channel_id = ChannelId::from_proto(request.channel_id);
3825    let member_id = UserId::from_proto(request.user_id);
3826    let result = db
3827        .set_channel_member_role(
3828            channel_id,
3829            session.user_id(),
3830            member_id,
3831            request.role().into(),
3832        )
3833        .await?;
3834
3835    match result {
3836        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3837            let mut connection_pool = session.connection_pool().await;
3838            notify_membership_updated(
3839                &mut connection_pool,
3840                membership_update,
3841                member_id,
3842                &session.peer,
3843            )
3844        }
3845        db::SetMemberRoleResult::InviteUpdated(channel) => {
3846            let update = proto::UpdateChannels {
3847                channel_invitations: vec![channel.to_proto()],
3848                ..Default::default()
3849            };
3850
3851            for connection_id in session
3852                .connection_pool()
3853                .await
3854                .user_connection_ids(member_id)
3855            {
3856                session.peer.send(connection_id, update.clone())?;
3857            }
3858        }
3859    }
3860
3861    response.send(proto::Ack {})?;
3862    Ok(())
3863}
3864
3865/// Change the name of a channel
3866async fn rename_channel(
3867    request: proto::RenameChannel,
3868    response: Response<proto::RenameChannel>,
3869    session: UserSession,
3870) -> Result<()> {
3871    let db = session.db().await;
3872    let channel_id = ChannelId::from_proto(request.channel_id);
3873    let channel_model = db
3874        .rename_channel(channel_id, session.user_id(), &request.name)
3875        .await?;
3876    let root_id = channel_model.root_id();
3877    let channel = Channel::from_model(channel_model);
3878
3879    response.send(proto::RenameChannelResponse {
3880        channel: Some(channel.to_proto()),
3881    })?;
3882
3883    let connection_pool = session.connection_pool().await;
3884    let update = proto::UpdateChannels {
3885        channels: vec![channel.to_proto()],
3886        ..Default::default()
3887    };
3888    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3889        if role.can_see_channel(channel.visibility) {
3890            session.peer.send(connection_id, update.clone())?;
3891        }
3892    }
3893
3894    Ok(())
3895}
3896
3897/// Move a channel to a new parent.
3898async fn move_channel(
3899    request: proto::MoveChannel,
3900    response: Response<proto::MoveChannel>,
3901    session: UserSession,
3902) -> Result<()> {
3903    let channel_id = ChannelId::from_proto(request.channel_id);
3904    let to = ChannelId::from_proto(request.to);
3905
3906    let (root_id, channels) = session
3907        .db()
3908        .await
3909        .move_channel(channel_id, to, session.user_id())
3910        .await?;
3911
3912    let connection_pool = session.connection_pool().await;
3913    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3914        let channels = channels
3915            .iter()
3916            .filter_map(|channel| {
3917                if role.can_see_channel(channel.visibility) {
3918                    Some(channel.to_proto())
3919                } else {
3920                    None
3921                }
3922            })
3923            .collect::<Vec<_>>();
3924        if channels.is_empty() {
3925            continue;
3926        }
3927
3928        let update = proto::UpdateChannels {
3929            channels,
3930            ..Default::default()
3931        };
3932
3933        session.peer.send(connection_id, update.clone())?;
3934    }
3935
3936    response.send(Ack {})?;
3937    Ok(())
3938}
3939
3940/// Get the list of channel members
3941async fn get_channel_members(
3942    request: proto::GetChannelMembers,
3943    response: Response<proto::GetChannelMembers>,
3944    session: UserSession,
3945) -> Result<()> {
3946    let db = session.db().await;
3947    let channel_id = ChannelId::from_proto(request.channel_id);
3948    let limit = if request.limit == 0 {
3949        u16::MAX as u64
3950    } else {
3951        request.limit
3952    };
3953    let (members, users) = db
3954        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3955        .await?;
3956    response.send(proto::GetChannelMembersResponse { members, users })?;
3957    Ok(())
3958}
3959
3960/// Accept or decline a channel invitation.
3961async fn respond_to_channel_invite(
3962    request: proto::RespondToChannelInvite,
3963    response: Response<proto::RespondToChannelInvite>,
3964    session: UserSession,
3965) -> Result<()> {
3966    let db = session.db().await;
3967    let channel_id = ChannelId::from_proto(request.channel_id);
3968    let RespondToChannelInvite {
3969        membership_update,
3970        notifications,
3971    } = db
3972        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3973        .await?;
3974
3975    let mut connection_pool = session.connection_pool().await;
3976    if let Some(membership_update) = membership_update {
3977        notify_membership_updated(
3978            &mut connection_pool,
3979            membership_update,
3980            session.user_id(),
3981            &session.peer,
3982        );
3983    } else {
3984        let update = proto::UpdateChannels {
3985            remove_channel_invitations: vec![channel_id.to_proto()],
3986            ..Default::default()
3987        };
3988
3989        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3990            session.peer.send(connection_id, update.clone())?;
3991        }
3992    };
3993
3994    send_notifications(&connection_pool, &session.peer, notifications);
3995
3996    response.send(proto::Ack {})?;
3997
3998    Ok(())
3999}
4000
4001/// Join the channels' room
4002async fn join_channel(
4003    request: proto::JoinChannel,
4004    response: Response<proto::JoinChannel>,
4005    session: UserSession,
4006) -> Result<()> {
4007    let channel_id = ChannelId::from_proto(request.channel_id);
4008    join_channel_internal(channel_id, Box::new(response), session).await
4009}
4010
4011trait JoinChannelInternalResponse {
4012    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
4013}
4014impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
4015    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4016        Response::<proto::JoinChannel>::send(self, result)
4017    }
4018}
4019impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
4020    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4021        Response::<proto::JoinRoom>::send(self, result)
4022    }
4023}
4024
4025async fn join_channel_internal(
4026    channel_id: ChannelId,
4027    response: Box<impl JoinChannelInternalResponse>,
4028    session: UserSession,
4029) -> Result<()> {
4030    let joined_room = {
4031        let mut db = session.db().await;
4032        // If zed quits without leaving the room, and the user re-opens zed before the
4033        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4034        // room they were in.
4035        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4036            tracing::info!(
4037                stale_connection_id = %connection,
4038                "cleaning up stale connection",
4039            );
4040            drop(db);
4041            leave_room_for_session(&session, connection).await?;
4042            db = session.db().await;
4043        }
4044
4045        let (joined_room, membership_updated, role) = db
4046            .join_channel(channel_id, session.user_id(), session.connection_id)
4047            .await?;
4048
4049        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4050            let (can_publish, token) = if role == ChannelRole::Guest {
4051                (
4052                    false,
4053                    live_kit
4054                        .guest_token(
4055                            &joined_room.room.live_kit_room,
4056                            &session.user_id().to_string(),
4057                        )
4058                        .trace_err()?,
4059                )
4060            } else {
4061                (
4062                    true,
4063                    live_kit
4064                        .room_token(
4065                            &joined_room.room.live_kit_room,
4066                            &session.user_id().to_string(),
4067                        )
4068                        .trace_err()?,
4069                )
4070            };
4071
4072            Some(LiveKitConnectionInfo {
4073                server_url: live_kit.url().into(),
4074                token,
4075                can_publish,
4076            })
4077        });
4078
4079        response.send(proto::JoinRoomResponse {
4080            room: Some(joined_room.room.clone()),
4081            channel_id: joined_room
4082                .channel
4083                .as_ref()
4084                .map(|channel| channel.id.to_proto()),
4085            live_kit_connection_info,
4086        })?;
4087
4088        let mut connection_pool = session.connection_pool().await;
4089        if let Some(membership_updated) = membership_updated {
4090            notify_membership_updated(
4091                &mut connection_pool,
4092                membership_updated,
4093                session.user_id(),
4094                &session.peer,
4095            );
4096        }
4097
4098        room_updated(&joined_room.room, &session.peer);
4099
4100        joined_room
4101    };
4102
4103    channel_updated(
4104        &joined_room
4105            .channel
4106            .ok_or_else(|| anyhow!("channel not returned"))?,
4107        &joined_room.room,
4108        &session.peer,
4109        &*session.connection_pool().await,
4110    );
4111
4112    update_user_contacts(session.user_id(), &session).await?;
4113    Ok(())
4114}
4115
4116/// Start editing the channel notes
4117async fn join_channel_buffer(
4118    request: proto::JoinChannelBuffer,
4119    response: Response<proto::JoinChannelBuffer>,
4120    session: UserSession,
4121) -> Result<()> {
4122    let db = session.db().await;
4123    let channel_id = ChannelId::from_proto(request.channel_id);
4124
4125    let open_response = db
4126        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4127        .await?;
4128
4129    let collaborators = open_response.collaborators.clone();
4130    response.send(open_response)?;
4131
4132    let update = UpdateChannelBufferCollaborators {
4133        channel_id: channel_id.to_proto(),
4134        collaborators: collaborators.clone(),
4135    };
4136    channel_buffer_updated(
4137        session.connection_id,
4138        collaborators
4139            .iter()
4140            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4141        &update,
4142        &session.peer,
4143    );
4144
4145    Ok(())
4146}
4147
4148/// Edit the channel notes
4149async fn update_channel_buffer(
4150    request: proto::UpdateChannelBuffer,
4151    session: UserSession,
4152) -> Result<()> {
4153    let db = session.db().await;
4154    let channel_id = ChannelId::from_proto(request.channel_id);
4155
4156    let (collaborators, epoch, version) = db
4157        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4158        .await?;
4159
4160    channel_buffer_updated(
4161        session.connection_id,
4162        collaborators.clone(),
4163        &proto::UpdateChannelBuffer {
4164            channel_id: channel_id.to_proto(),
4165            operations: request.operations,
4166        },
4167        &session.peer,
4168    );
4169
4170    let pool = &*session.connection_pool().await;
4171
4172    let non_collaborators =
4173        pool.channel_connection_ids(channel_id)
4174            .filter_map(|(connection_id, _)| {
4175                if collaborators.contains(&connection_id) {
4176                    None
4177                } else {
4178                    Some(connection_id)
4179                }
4180            });
4181
4182    broadcast(None, non_collaborators, |peer_id| {
4183        session.peer.send(
4184            peer_id,
4185            proto::UpdateChannels {
4186                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4187                    channel_id: channel_id.to_proto(),
4188                    epoch: epoch as u64,
4189                    version: version.clone(),
4190                }],
4191                ..Default::default()
4192            },
4193        )
4194    });
4195
4196    Ok(())
4197}
4198
4199/// Rejoin the channel notes after a connection blip
4200async fn rejoin_channel_buffers(
4201    request: proto::RejoinChannelBuffers,
4202    response: Response<proto::RejoinChannelBuffers>,
4203    session: UserSession,
4204) -> Result<()> {
4205    let db = session.db().await;
4206    let buffers = db
4207        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4208        .await?;
4209
4210    for rejoined_buffer in &buffers {
4211        let collaborators_to_notify = rejoined_buffer
4212            .buffer
4213            .collaborators
4214            .iter()
4215            .filter_map(|c| Some(c.peer_id?.into()));
4216        channel_buffer_updated(
4217            session.connection_id,
4218            collaborators_to_notify,
4219            &proto::UpdateChannelBufferCollaborators {
4220                channel_id: rejoined_buffer.buffer.channel_id,
4221                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4222            },
4223            &session.peer,
4224        );
4225    }
4226
4227    response.send(proto::RejoinChannelBuffersResponse {
4228        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4229    })?;
4230
4231    Ok(())
4232}
4233
4234/// Stop editing the channel notes
4235async fn leave_channel_buffer(
4236    request: proto::LeaveChannelBuffer,
4237    response: Response<proto::LeaveChannelBuffer>,
4238    session: UserSession,
4239) -> Result<()> {
4240    let db = session.db().await;
4241    let channel_id = ChannelId::from_proto(request.channel_id);
4242
4243    let left_buffer = db
4244        .leave_channel_buffer(channel_id, session.connection_id)
4245        .await?;
4246
4247    response.send(Ack {})?;
4248
4249    channel_buffer_updated(
4250        session.connection_id,
4251        left_buffer.connections,
4252        &proto::UpdateChannelBufferCollaborators {
4253            channel_id: channel_id.to_proto(),
4254            collaborators: left_buffer.collaborators,
4255        },
4256        &session.peer,
4257    );
4258
4259    Ok(())
4260}
4261
4262fn channel_buffer_updated<T: EnvelopedMessage>(
4263    sender_id: ConnectionId,
4264    collaborators: impl IntoIterator<Item = ConnectionId>,
4265    message: &T,
4266    peer: &Peer,
4267) {
4268    broadcast(Some(sender_id), collaborators, |peer_id| {
4269        peer.send(peer_id, message.clone())
4270    });
4271}
4272
4273fn send_notifications(
4274    connection_pool: &ConnectionPool,
4275    peer: &Peer,
4276    notifications: db::NotificationBatch,
4277) {
4278    for (user_id, notification) in notifications {
4279        for connection_id in connection_pool.user_connection_ids(user_id) {
4280            if let Err(error) = peer.send(
4281                connection_id,
4282                proto::AddNotification {
4283                    notification: Some(notification.clone()),
4284                },
4285            ) {
4286                tracing::error!(
4287                    "failed to send notification to {:?} {}",
4288                    connection_id,
4289                    error
4290                );
4291            }
4292        }
4293    }
4294}
4295
4296/// Send a message to the channel
4297async fn send_channel_message(
4298    request: proto::SendChannelMessage,
4299    response: Response<proto::SendChannelMessage>,
4300    session: UserSession,
4301) -> Result<()> {
4302    // Validate the message body.
4303    let body = request.body.trim().to_string();
4304    if body.len() > MAX_MESSAGE_LEN {
4305        return Err(anyhow!("message is too long"))?;
4306    }
4307    if body.is_empty() {
4308        return Err(anyhow!("message can't be blank"))?;
4309    }
4310
4311    // TODO: adjust mentions if body is trimmed
4312
4313    let timestamp = OffsetDateTime::now_utc();
4314    let nonce = request
4315        .nonce
4316        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4317
4318    let channel_id = ChannelId::from_proto(request.channel_id);
4319    let CreatedChannelMessage {
4320        message_id,
4321        participant_connection_ids,
4322        notifications,
4323    } = session
4324        .db()
4325        .await
4326        .create_channel_message(
4327            channel_id,
4328            session.user_id(),
4329            &body,
4330            &request.mentions,
4331            timestamp,
4332            nonce.clone().into(),
4333            match request.reply_to_message_id {
4334                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4335                None => None,
4336            },
4337        )
4338        .await?;
4339
4340    let message = proto::ChannelMessage {
4341        sender_id: session.user_id().to_proto(),
4342        id: message_id.to_proto(),
4343        body,
4344        mentions: request.mentions,
4345        timestamp: timestamp.unix_timestamp() as u64,
4346        nonce: Some(nonce),
4347        reply_to_message_id: request.reply_to_message_id,
4348        edited_at: None,
4349    };
4350    broadcast(
4351        Some(session.connection_id),
4352        participant_connection_ids.clone(),
4353        |connection| {
4354            session.peer.send(
4355                connection,
4356                proto::ChannelMessageSent {
4357                    channel_id: channel_id.to_proto(),
4358                    message: Some(message.clone()),
4359                },
4360            )
4361        },
4362    );
4363    response.send(proto::SendChannelMessageResponse {
4364        message: Some(message),
4365    })?;
4366
4367    let pool = &*session.connection_pool().await;
4368    let non_participants =
4369        pool.channel_connection_ids(channel_id)
4370            .filter_map(|(connection_id, _)| {
4371                if participant_connection_ids.contains(&connection_id) {
4372                    None
4373                } else {
4374                    Some(connection_id)
4375                }
4376            });
4377    broadcast(None, non_participants, |peer_id| {
4378        session.peer.send(
4379            peer_id,
4380            proto::UpdateChannels {
4381                latest_channel_message_ids: vec![proto::ChannelMessageId {
4382                    channel_id: channel_id.to_proto(),
4383                    message_id: message_id.to_proto(),
4384                }],
4385                ..Default::default()
4386            },
4387        )
4388    });
4389    send_notifications(pool, &session.peer, notifications);
4390
4391    Ok(())
4392}
4393
4394/// Delete a channel message
4395async fn remove_channel_message(
4396    request: proto::RemoveChannelMessage,
4397    response: Response<proto::RemoveChannelMessage>,
4398    session: UserSession,
4399) -> Result<()> {
4400    let channel_id = ChannelId::from_proto(request.channel_id);
4401    let message_id = MessageId::from_proto(request.message_id);
4402    let (connection_ids, existing_notification_ids) = session
4403        .db()
4404        .await
4405        .remove_channel_message(channel_id, message_id, session.user_id())
4406        .await?;
4407
4408    broadcast(
4409        Some(session.connection_id),
4410        connection_ids,
4411        move |connection| {
4412            session.peer.send(connection, request.clone())?;
4413
4414            for notification_id in &existing_notification_ids {
4415                session.peer.send(
4416                    connection,
4417                    proto::DeleteNotification {
4418                        notification_id: (*notification_id).to_proto(),
4419                    },
4420                )?;
4421            }
4422
4423            Ok(())
4424        },
4425    );
4426    response.send(proto::Ack {})?;
4427    Ok(())
4428}
4429
4430async fn update_channel_message(
4431    request: proto::UpdateChannelMessage,
4432    response: Response<proto::UpdateChannelMessage>,
4433    session: UserSession,
4434) -> Result<()> {
4435    let channel_id = ChannelId::from_proto(request.channel_id);
4436    let message_id = MessageId::from_proto(request.message_id);
4437    let updated_at = OffsetDateTime::now_utc();
4438    let UpdatedChannelMessage {
4439        message_id,
4440        participant_connection_ids,
4441        notifications,
4442        reply_to_message_id,
4443        timestamp,
4444        deleted_mention_notification_ids,
4445        updated_mention_notifications,
4446    } = session
4447        .db()
4448        .await
4449        .update_channel_message(
4450            channel_id,
4451            message_id,
4452            session.user_id(),
4453            request.body.as_str(),
4454            &request.mentions,
4455            updated_at,
4456        )
4457        .await?;
4458
4459    let nonce = request
4460        .nonce
4461        .clone()
4462        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4463
4464    let message = proto::ChannelMessage {
4465        sender_id: session.user_id().to_proto(),
4466        id: message_id.to_proto(),
4467        body: request.body.clone(),
4468        mentions: request.mentions.clone(),
4469        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4470        nonce: Some(nonce),
4471        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4472        edited_at: Some(updated_at.unix_timestamp() as u64),
4473    };
4474
4475    response.send(proto::Ack {})?;
4476
4477    let pool = &*session.connection_pool().await;
4478    broadcast(
4479        Some(session.connection_id),
4480        participant_connection_ids,
4481        |connection| {
4482            session.peer.send(
4483                connection,
4484                proto::ChannelMessageUpdate {
4485                    channel_id: channel_id.to_proto(),
4486                    message: Some(message.clone()),
4487                },
4488            )?;
4489
4490            for notification_id in &deleted_mention_notification_ids {
4491                session.peer.send(
4492                    connection,
4493                    proto::DeleteNotification {
4494                        notification_id: (*notification_id).to_proto(),
4495                    },
4496                )?;
4497            }
4498
4499            for notification in &updated_mention_notifications {
4500                session.peer.send(
4501                    connection,
4502                    proto::UpdateNotification {
4503                        notification: Some(notification.clone()),
4504                    },
4505                )?;
4506            }
4507
4508            Ok(())
4509        },
4510    );
4511
4512    send_notifications(pool, &session.peer, notifications);
4513
4514    Ok(())
4515}
4516
4517/// Mark a channel message as read
4518async fn acknowledge_channel_message(
4519    request: proto::AckChannelMessage,
4520    session: UserSession,
4521) -> Result<()> {
4522    let channel_id = ChannelId::from_proto(request.channel_id);
4523    let message_id = MessageId::from_proto(request.message_id);
4524    let notifications = session
4525        .db()
4526        .await
4527        .observe_channel_message(channel_id, session.user_id(), message_id)
4528        .await?;
4529    send_notifications(
4530        &*session.connection_pool().await,
4531        &session.peer,
4532        notifications,
4533    );
4534    Ok(())
4535}
4536
4537/// Mark a buffer version as synced
4538async fn acknowledge_buffer_version(
4539    request: proto::AckBufferOperation,
4540    session: UserSession,
4541) -> Result<()> {
4542    let buffer_id = BufferId::from_proto(request.buffer_id);
4543    session
4544        .db()
4545        .await
4546        .observe_buffer_version(
4547            buffer_id,
4548            session.user_id(),
4549            request.epoch as i32,
4550            &request.version,
4551        )
4552        .await?;
4553    Ok(())
4554}
4555
4556struct ZedProCompleteWithLanguageModelRateLimit;
4557
4558impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
4559    fn capacity(&self) -> usize {
4560        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4561            .ok()
4562            .and_then(|v| v.parse().ok())
4563            .unwrap_or(120) // Picked arbitrarily
4564    }
4565
4566    fn refill_duration(&self) -> chrono::Duration {
4567        chrono::Duration::hours(1)
4568    }
4569
4570    fn db_name(&self) -> &'static str {
4571        "zed-pro:complete-with-language-model"
4572    }
4573}
4574
4575struct FreeCompleteWithLanguageModelRateLimit;
4576
4577impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
4578    fn capacity(&self) -> usize {
4579        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
4580            .ok()
4581            .and_then(|v| v.parse().ok())
4582            .unwrap_or(120 / 10) // Picked arbitrarily
4583    }
4584
4585    fn refill_duration(&self) -> chrono::Duration {
4586        chrono::Duration::hours(1)
4587    }
4588
4589    fn db_name(&self) -> &'static str {
4590        "free:complete-with-language-model"
4591    }
4592}
4593
4594async fn complete_with_language_model(
4595    request: proto::CompleteWithLanguageModel,
4596    response: Response<proto::CompleteWithLanguageModel>,
4597    session: Session,
4598    config: &Config,
4599) -> Result<()> {
4600    let Some(session) = session.for_user() else {
4601        return Err(anyhow!("user not found"))?;
4602    };
4603    authorize_access_to_language_models(&session).await?;
4604
4605    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4606        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4607        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4608    };
4609
4610    session
4611        .rate_limiter
4612        .check(&*rate_limit, session.user_id())
4613        .await?;
4614
4615    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4616        Some(proto::LanguageModelProvider::Anthropic) => {
4617            let api_key = config
4618                .anthropic_api_key
4619                .as_ref()
4620                .context("no Anthropic AI API key configured on the server")?;
4621            anthropic::complete(
4622                session.http_client.as_ref(),
4623                anthropic::ANTHROPIC_API_URL,
4624                api_key,
4625                serde_json::from_str(&request.request)?,
4626            )
4627            .await?
4628        }
4629        _ => return Err(anyhow!("unsupported provider"))?,
4630    };
4631
4632    response.send(proto::CompleteWithLanguageModelResponse {
4633        completion: serde_json::to_string(&result)?,
4634    })?;
4635
4636    Ok(())
4637}
4638
4639async fn stream_complete_with_language_model(
4640    request: proto::StreamCompleteWithLanguageModel,
4641    response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4642    session: Session,
4643    config: &Config,
4644) -> Result<()> {
4645    let Some(session) = session.for_user() else {
4646        return Err(anyhow!("user not found"))?;
4647    };
4648    authorize_access_to_language_models(&session).await?;
4649
4650    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4651        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4652        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4653    };
4654
4655    session
4656        .rate_limiter
4657        .check(&*rate_limit, session.user_id())
4658        .await?;
4659
4660    match proto::LanguageModelProvider::from_i32(request.provider) {
4661        Some(proto::LanguageModelProvider::Anthropic) => {
4662            let api_key = config
4663                .anthropic_api_key
4664                .as_ref()
4665                .context("no Anthropic AI API key configured on the server")?;
4666            let mut chunks = anthropic::stream_completion(
4667                session.http_client.as_ref(),
4668                anthropic::ANTHROPIC_API_URL,
4669                api_key,
4670                serde_json::from_str(&request.request)?,
4671                None,
4672            )
4673            .await?;
4674            while let Some(event) = chunks.next().await {
4675                let chunk = event?;
4676                response.send(proto::StreamCompleteWithLanguageModelResponse {
4677                    event: serde_json::to_string(&chunk)?,
4678                })?;
4679            }
4680        }
4681        Some(proto::LanguageModelProvider::OpenAi) => {
4682            let api_key = config
4683                .openai_api_key
4684                .as_ref()
4685                .context("no OpenAI API key configured on the server")?;
4686            let mut events = open_ai::stream_completion(
4687                session.http_client.as_ref(),
4688                open_ai::OPEN_AI_API_URL,
4689                api_key,
4690                serde_json::from_str(&request.request)?,
4691                None,
4692            )
4693            .await?;
4694            while let Some(event) = events.next().await {
4695                let event = event?;
4696                response.send(proto::StreamCompleteWithLanguageModelResponse {
4697                    event: serde_json::to_string(&event)?,
4698                })?;
4699            }
4700        }
4701        Some(proto::LanguageModelProvider::Google) => {
4702            let api_key = config
4703                .google_ai_api_key
4704                .as_ref()
4705                .context("no Google AI API key configured on the server")?;
4706            let mut events = google_ai::stream_generate_content(
4707                session.http_client.as_ref(),
4708                google_ai::API_URL,
4709                api_key,
4710                serde_json::from_str(&request.request)?,
4711            )
4712            .await?;
4713            while let Some(event) = events.next().await {
4714                let event = event?;
4715                response.send(proto::StreamCompleteWithLanguageModelResponse {
4716                    event: serde_json::to_string(&event)?,
4717                })?;
4718            }
4719        }
4720        Some(proto::LanguageModelProvider::Zed) => {
4721            let api_key = config
4722                .qwen2_7b_api_key
4723                .as_ref()
4724                .context("no Qwen2-7B API key configured on the server")?;
4725            let api_url = config
4726                .qwen2_7b_api_url
4727                .as_ref()
4728                .context("no Qwen2-7B URL configured on the server")?;
4729            let mut events = open_ai::stream_completion(
4730                session.http_client.as_ref(),
4731                &api_url,
4732                api_key,
4733                serde_json::from_str(&request.request)?,
4734                None,
4735            )
4736            .await?;
4737            while let Some(event) = events.next().await {
4738                let event = event?;
4739                response.send(proto::StreamCompleteWithLanguageModelResponse {
4740                    event: serde_json::to_string(&event)?,
4741                })?;
4742            }
4743        }
4744        None => return Err(anyhow!("unknown provider"))?,
4745    }
4746
4747    Ok(())
4748}
4749
4750async fn count_language_model_tokens(
4751    request: proto::CountLanguageModelTokens,
4752    response: Response<proto::CountLanguageModelTokens>,
4753    session: Session,
4754    config: &Config,
4755) -> Result<()> {
4756    let Some(session) = session.for_user() else {
4757        return Err(anyhow!("user not found"))?;
4758    };
4759    authorize_access_to_language_models(&session).await?;
4760
4761    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4762        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4763        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4764    };
4765
4766    session
4767        .rate_limiter
4768        .check(&*rate_limit, session.user_id())
4769        .await?;
4770
4771    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4772        Some(proto::LanguageModelProvider::Google) => {
4773            let api_key = config
4774                .google_ai_api_key
4775                .as_ref()
4776                .context("no Google AI API key configured on the server")?;
4777            google_ai::count_tokens(
4778                session.http_client.as_ref(),
4779                google_ai::API_URL,
4780                api_key,
4781                serde_json::from_str(&request.request)?,
4782            )
4783            .await?
4784        }
4785        _ => return Err(anyhow!("unsupported provider"))?,
4786    };
4787
4788    response.send(proto::CountLanguageModelTokensResponse {
4789        token_count: result.total_tokens as u32,
4790    })?;
4791
4792    Ok(())
4793}
4794
4795struct ZedProCountLanguageModelTokensRateLimit;
4796
4797impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4798    fn capacity(&self) -> usize {
4799        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4800            .ok()
4801            .and_then(|v| v.parse().ok())
4802            .unwrap_or(600) // Picked arbitrarily
4803    }
4804
4805    fn refill_duration(&self) -> chrono::Duration {
4806        chrono::Duration::hours(1)
4807    }
4808
4809    fn db_name(&self) -> &'static str {
4810        "zed-pro:count-language-model-tokens"
4811    }
4812}
4813
4814struct FreeCountLanguageModelTokensRateLimit;
4815
4816impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4817    fn capacity(&self) -> usize {
4818        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4819            .ok()
4820            .and_then(|v| v.parse().ok())
4821            .unwrap_or(600 / 10) // Picked arbitrarily
4822    }
4823
4824    fn refill_duration(&self) -> chrono::Duration {
4825        chrono::Duration::hours(1)
4826    }
4827
4828    fn db_name(&self) -> &'static str {
4829        "free:count-language-model-tokens"
4830    }
4831}
4832
4833struct ZedProComputeEmbeddingsRateLimit;
4834
4835impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4836    fn capacity(&self) -> usize {
4837        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4838            .ok()
4839            .and_then(|v| v.parse().ok())
4840            .unwrap_or(5000) // Picked arbitrarily
4841    }
4842
4843    fn refill_duration(&self) -> chrono::Duration {
4844        chrono::Duration::hours(1)
4845    }
4846
4847    fn db_name(&self) -> &'static str {
4848        "zed-pro:compute-embeddings"
4849    }
4850}
4851
4852struct FreeComputeEmbeddingsRateLimit;
4853
4854impl RateLimit for FreeComputeEmbeddingsRateLimit {
4855    fn capacity(&self) -> usize {
4856        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4857            .ok()
4858            .and_then(|v| v.parse().ok())
4859            .unwrap_or(5000 / 10) // Picked arbitrarily
4860    }
4861
4862    fn refill_duration(&self) -> chrono::Duration {
4863        chrono::Duration::hours(1)
4864    }
4865
4866    fn db_name(&self) -> &'static str {
4867        "free:compute-embeddings"
4868    }
4869}
4870
4871async fn compute_embeddings(
4872    request: proto::ComputeEmbeddings,
4873    response: Response<proto::ComputeEmbeddings>,
4874    session: UserSession,
4875    api_key: Option<Arc<str>>,
4876) -> Result<()> {
4877    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4878    authorize_access_to_language_models(&session).await?;
4879
4880    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4881        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4882        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4883    };
4884
4885    session
4886        .rate_limiter
4887        .check(&*rate_limit, session.user_id())
4888        .await?;
4889
4890    let embeddings = match request.model.as_str() {
4891        "openai/text-embedding-3-small" => {
4892            open_ai::embed(
4893                session.http_client.as_ref(),
4894                OPEN_AI_API_URL,
4895                &api_key,
4896                OpenAiEmbeddingModel::TextEmbedding3Small,
4897                request.texts.iter().map(|text| text.as_str()),
4898            )
4899            .await?
4900        }
4901        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4902    };
4903
4904    let embeddings = request
4905        .texts
4906        .iter()
4907        .map(|text| {
4908            let mut hasher = sha2::Sha256::new();
4909            hasher.update(text.as_bytes());
4910            let result = hasher.finalize();
4911            result.to_vec()
4912        })
4913        .zip(
4914            embeddings
4915                .data
4916                .into_iter()
4917                .map(|embedding| embedding.embedding),
4918        )
4919        .collect::<HashMap<_, _>>();
4920
4921    let db = session.db().await;
4922    db.save_embeddings(&request.model, &embeddings)
4923        .await
4924        .context("failed to save embeddings")
4925        .trace_err();
4926
4927    response.send(proto::ComputeEmbeddingsResponse {
4928        embeddings: embeddings
4929            .into_iter()
4930            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4931            .collect(),
4932    })?;
4933    Ok(())
4934}
4935
4936async fn get_cached_embeddings(
4937    request: proto::GetCachedEmbeddings,
4938    response: Response<proto::GetCachedEmbeddings>,
4939    session: UserSession,
4940) -> Result<()> {
4941    authorize_access_to_language_models(&session).await?;
4942
4943    let db = session.db().await;
4944    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4945
4946    response.send(proto::GetCachedEmbeddingsResponse {
4947        embeddings: embeddings
4948            .into_iter()
4949            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4950            .collect(),
4951    })?;
4952    Ok(())
4953}
4954
4955async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4956    let db = session.db().await;
4957    let flags = db.get_user_flags(session.user_id()).await?;
4958    if flags.iter().any(|flag| flag == "language-models") {
4959        return Ok(());
4960    }
4961
4962    Err(anyhow!("permission denied"))?
4963}
4964
4965/// Get a Supermaven API key for the user
4966async fn get_supermaven_api_key(
4967    _request: proto::GetSupermavenApiKey,
4968    response: Response<proto::GetSupermavenApiKey>,
4969    session: UserSession,
4970) -> Result<()> {
4971    let user_id: String = session.user_id().to_string();
4972    if !session.is_staff() {
4973        return Err(anyhow!("supermaven not enabled for this account"))?;
4974    }
4975
4976    let email = session
4977        .email()
4978        .ok_or_else(|| anyhow!("user must have an email"))?;
4979
4980    let supermaven_admin_api = session
4981        .supermaven_client
4982        .as_ref()
4983        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4984
4985    let result = supermaven_admin_api
4986        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4987        .await?;
4988
4989    response.send(proto::GetSupermavenApiKeyResponse {
4990        api_key: result.api_key,
4991    })?;
4992
4993    Ok(())
4994}
4995
4996/// Start receiving chat updates for a channel
4997async fn join_channel_chat(
4998    request: proto::JoinChannelChat,
4999    response: Response<proto::JoinChannelChat>,
5000    session: UserSession,
5001) -> Result<()> {
5002    let channel_id = ChannelId::from_proto(request.channel_id);
5003
5004    let db = session.db().await;
5005    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
5006        .await?;
5007    let messages = db
5008        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
5009        .await?;
5010    response.send(proto::JoinChannelChatResponse {
5011        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5012        messages,
5013    })?;
5014    Ok(())
5015}
5016
5017/// Stop receiving chat updates for a channel
5018async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5019    let channel_id = ChannelId::from_proto(request.channel_id);
5020    session
5021        .db()
5022        .await
5023        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5024        .await?;
5025    Ok(())
5026}
5027
5028/// Retrieve the chat history for a channel
5029async fn get_channel_messages(
5030    request: proto::GetChannelMessages,
5031    response: Response<proto::GetChannelMessages>,
5032    session: UserSession,
5033) -> Result<()> {
5034    let channel_id = ChannelId::from_proto(request.channel_id);
5035    let messages = session
5036        .db()
5037        .await
5038        .get_channel_messages(
5039            channel_id,
5040            session.user_id(),
5041            MESSAGE_COUNT_PER_PAGE,
5042            Some(MessageId::from_proto(request.before_message_id)),
5043        )
5044        .await?;
5045    response.send(proto::GetChannelMessagesResponse {
5046        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5047        messages,
5048    })?;
5049    Ok(())
5050}
5051
5052/// Retrieve specific chat messages
5053async fn get_channel_messages_by_id(
5054    request: proto::GetChannelMessagesById,
5055    response: Response<proto::GetChannelMessagesById>,
5056    session: UserSession,
5057) -> Result<()> {
5058    let message_ids = request
5059        .message_ids
5060        .iter()
5061        .map(|id| MessageId::from_proto(*id))
5062        .collect::<Vec<_>>();
5063    let messages = session
5064        .db()
5065        .await
5066        .get_channel_messages_by_id(session.user_id(), &message_ids)
5067        .await?;
5068    response.send(proto::GetChannelMessagesResponse {
5069        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5070        messages,
5071    })?;
5072    Ok(())
5073}
5074
5075/// Retrieve the current users notifications
5076async fn get_notifications(
5077    request: proto::GetNotifications,
5078    response: Response<proto::GetNotifications>,
5079    session: UserSession,
5080) -> Result<()> {
5081    let notifications = session
5082        .db()
5083        .await
5084        .get_notifications(
5085            session.user_id(),
5086            NOTIFICATION_COUNT_PER_PAGE,
5087            request
5088                .before_id
5089                .map(|id| db::NotificationId::from_proto(id)),
5090        )
5091        .await?;
5092    response.send(proto::GetNotificationsResponse {
5093        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5094        notifications,
5095    })?;
5096    Ok(())
5097}
5098
5099/// Mark notifications as read
5100async fn mark_notification_as_read(
5101    request: proto::MarkNotificationRead,
5102    response: Response<proto::MarkNotificationRead>,
5103    session: UserSession,
5104) -> Result<()> {
5105    let database = &session.db().await;
5106    let notifications = database
5107        .mark_notification_as_read_by_id(
5108            session.user_id(),
5109            NotificationId::from_proto(request.notification_id),
5110        )
5111        .await?;
5112    send_notifications(
5113        &*session.connection_pool().await,
5114        &session.peer,
5115        notifications,
5116    );
5117    response.send(proto::Ack {})?;
5118    Ok(())
5119}
5120
5121/// Get the current users information
5122async fn get_private_user_info(
5123    _request: proto::GetPrivateUserInfo,
5124    response: Response<proto::GetPrivateUserInfo>,
5125    session: UserSession,
5126) -> Result<()> {
5127    let db = session.db().await;
5128
5129    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5130    let user = db
5131        .get_user_by_id(session.user_id())
5132        .await?
5133        .ok_or_else(|| anyhow!("user not found"))?;
5134    let flags = db.get_user_flags(session.user_id()).await?;
5135
5136    response.send(proto::GetPrivateUserInfoResponse {
5137        metrics_id,
5138        staff: user.admin,
5139        flags,
5140    })?;
5141    Ok(())
5142}
5143
5144fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5145    let message = match message {
5146        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5147        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5148        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5149        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5150        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5151            code: frame.code.into(),
5152            reason: frame.reason,
5153        })),
5154        // We should never receive a frame while reading the message, according
5155        // to the `tungstenite` maintainers:
5156        //
5157        // > It cannot occur when you read messages from the WebSocket, but it
5158        // > can be used when you want to send the raw frames (e.g. you want to
5159        // > send the frames to the WebSocket without composing the full message first).
5160        // >
5161        // > — https://github.com/snapview/tungstenite-rs/issues/268
5162        TungsteniteMessage::Frame(_) => {
5163            bail!("received an unexpected frame while reading the message")
5164        }
5165    };
5166
5167    Ok(message)
5168}
5169
5170fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5171    match message {
5172        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5173        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5174        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5175        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5176        AxumMessage::Close(frame) => {
5177            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5178                code: frame.code.into(),
5179                reason: frame.reason,
5180            }))
5181        }
5182    }
5183}
5184
5185fn notify_membership_updated(
5186    connection_pool: &mut ConnectionPool,
5187    result: MembershipUpdated,
5188    user_id: UserId,
5189    peer: &Peer,
5190) {
5191    for membership in &result.new_channels.channel_memberships {
5192        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5193    }
5194    for channel_id in &result.removed_channels {
5195        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5196    }
5197
5198    let user_channels_update = proto::UpdateUserChannels {
5199        channel_memberships: result
5200            .new_channels
5201            .channel_memberships
5202            .iter()
5203            .map(|cm| proto::ChannelMembership {
5204                channel_id: cm.channel_id.to_proto(),
5205                role: cm.role.into(),
5206            })
5207            .collect(),
5208        ..Default::default()
5209    };
5210
5211    let mut update = build_channels_update(result.new_channels);
5212    update.delete_channels = result
5213        .removed_channels
5214        .into_iter()
5215        .map(|id| id.to_proto())
5216        .collect();
5217    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5218
5219    for connection_id in connection_pool.user_connection_ids(user_id) {
5220        peer.send(connection_id, user_channels_update.clone())
5221            .trace_err();
5222        peer.send(connection_id, update.clone()).trace_err();
5223    }
5224}
5225
5226fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5227    proto::UpdateUserChannels {
5228        channel_memberships: channels
5229            .channel_memberships
5230            .iter()
5231            .map(|m| proto::ChannelMembership {
5232                channel_id: m.channel_id.to_proto(),
5233                role: m.role.into(),
5234            })
5235            .collect(),
5236        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5237        observed_channel_message_id: channels.observed_channel_messages.clone(),
5238    }
5239}
5240
5241fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5242    let mut update = proto::UpdateChannels::default();
5243
5244    for channel in channels.channels {
5245        update.channels.push(channel.to_proto());
5246    }
5247
5248    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5249    update.latest_channel_message_ids = channels.latest_channel_messages;
5250
5251    for (channel_id, participants) in channels.channel_participants {
5252        update
5253            .channel_participants
5254            .push(proto::ChannelParticipants {
5255                channel_id: channel_id.to_proto(),
5256                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5257            });
5258    }
5259
5260    for channel in channels.invited_channels {
5261        update.channel_invitations.push(channel.to_proto());
5262    }
5263
5264    update.hosted_projects = channels.hosted_projects;
5265    update
5266}
5267
5268fn build_initial_contacts_update(
5269    contacts: Vec<db::Contact>,
5270    pool: &ConnectionPool,
5271) -> proto::UpdateContacts {
5272    let mut update = proto::UpdateContacts::default();
5273
5274    for contact in contacts {
5275        match contact {
5276            db::Contact::Accepted { user_id, busy } => {
5277                update.contacts.push(contact_for_user(user_id, busy, &pool));
5278            }
5279            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5280            db::Contact::Incoming { user_id } => {
5281                update
5282                    .incoming_requests
5283                    .push(proto::IncomingContactRequest {
5284                        requester_id: user_id.to_proto(),
5285                    })
5286            }
5287        }
5288    }
5289
5290    update
5291}
5292
5293fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5294    proto::Contact {
5295        user_id: user_id.to_proto(),
5296        online: pool.is_user_online(user_id),
5297        busy,
5298    }
5299}
5300
5301fn room_updated(room: &proto::Room, peer: &Peer) {
5302    broadcast(
5303        None,
5304        room.participants
5305            .iter()
5306            .filter_map(|participant| Some(participant.peer_id?.into())),
5307        |peer_id| {
5308            peer.send(
5309                peer_id,
5310                proto::RoomUpdated {
5311                    room: Some(room.clone()),
5312                },
5313            )
5314        },
5315    );
5316}
5317
5318fn channel_updated(
5319    channel: &db::channel::Model,
5320    room: &proto::Room,
5321    peer: &Peer,
5322    pool: &ConnectionPool,
5323) {
5324    let participants = room
5325        .participants
5326        .iter()
5327        .map(|p| p.user_id)
5328        .collect::<Vec<_>>();
5329
5330    broadcast(
5331        None,
5332        pool.channel_connection_ids(channel.root_id())
5333            .filter_map(|(channel_id, role)| {
5334                role.can_see_channel(channel.visibility).then(|| channel_id)
5335            }),
5336        |peer_id| {
5337            peer.send(
5338                peer_id,
5339                proto::UpdateChannels {
5340                    channel_participants: vec![proto::ChannelParticipants {
5341                        channel_id: channel.id.to_proto(),
5342                        participant_user_ids: participants.clone(),
5343                    }],
5344                    ..Default::default()
5345                },
5346            )
5347        },
5348    );
5349}
5350
5351async fn send_dev_server_projects_update(
5352    user_id: UserId,
5353    mut status: proto::DevServerProjectsUpdate,
5354    session: &Session,
5355) {
5356    let pool = session.connection_pool().await;
5357    for dev_server in &mut status.dev_servers {
5358        dev_server.status =
5359            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5360    }
5361    let connections = pool.user_connection_ids(user_id);
5362    for connection_id in connections {
5363        session.peer.send(connection_id, status.clone()).trace_err();
5364    }
5365}
5366
5367async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5368    let db = session.db().await;
5369
5370    let contacts = db.get_contacts(user_id).await?;
5371    let busy = db.is_user_busy(user_id).await?;
5372
5373    let pool = session.connection_pool().await;
5374    let updated_contact = contact_for_user(user_id, busy, &pool);
5375    for contact in contacts {
5376        if let db::Contact::Accepted {
5377            user_id: contact_user_id,
5378            ..
5379        } = contact
5380        {
5381            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5382                session
5383                    .peer
5384                    .send(
5385                        contact_conn_id,
5386                        proto::UpdateContacts {
5387                            contacts: vec![updated_contact.clone()],
5388                            remove_contacts: Default::default(),
5389                            incoming_requests: Default::default(),
5390                            remove_incoming_requests: Default::default(),
5391                            outgoing_requests: Default::default(),
5392                            remove_outgoing_requests: Default::default(),
5393                        },
5394                    )
5395                    .trace_err();
5396            }
5397        }
5398    }
5399    Ok(())
5400}
5401
5402async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5403    log::info!("lost dev server connection, unsharing projects");
5404    let project_ids = session
5405        .db()
5406        .await
5407        .get_stale_dev_server_projects(session.connection_id)
5408        .await?;
5409
5410    for project_id in project_ids {
5411        // not unshare re-checks the connection ids match, so we get away with no transaction
5412        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5413    }
5414
5415    let user_id = session.dev_server().user_id;
5416    let update = session
5417        .db()
5418        .await
5419        .dev_server_projects_update(user_id)
5420        .await?;
5421
5422    send_dev_server_projects_update(user_id, update, session).await;
5423
5424    Ok(())
5425}
5426
5427async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5428    let mut contacts_to_update = HashSet::default();
5429
5430    let room_id;
5431    let canceled_calls_to_user_ids;
5432    let live_kit_room;
5433    let delete_live_kit_room;
5434    let room;
5435    let channel;
5436
5437    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5438        contacts_to_update.insert(session.user_id());
5439
5440        for project in left_room.left_projects.values() {
5441            project_left(project, session);
5442        }
5443
5444        room_id = RoomId::from_proto(left_room.room.id);
5445        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5446        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5447        delete_live_kit_room = left_room.deleted;
5448        room = mem::take(&mut left_room.room);
5449        channel = mem::take(&mut left_room.channel);
5450
5451        room_updated(&room, &session.peer);
5452    } else {
5453        return Ok(());
5454    }
5455
5456    if let Some(channel) = channel {
5457        channel_updated(
5458            &channel,
5459            &room,
5460            &session.peer,
5461            &*session.connection_pool().await,
5462        );
5463    }
5464
5465    {
5466        let pool = session.connection_pool().await;
5467        for canceled_user_id in canceled_calls_to_user_ids {
5468            for connection_id in pool.user_connection_ids(canceled_user_id) {
5469                session
5470                    .peer
5471                    .send(
5472                        connection_id,
5473                        proto::CallCanceled {
5474                            room_id: room_id.to_proto(),
5475                        },
5476                    )
5477                    .trace_err();
5478            }
5479            contacts_to_update.insert(canceled_user_id);
5480        }
5481    }
5482
5483    for contact_user_id in contacts_to_update {
5484        update_user_contacts(contact_user_id, &session).await?;
5485    }
5486
5487    if let Some(live_kit) = session.live_kit_client.as_ref() {
5488        live_kit
5489            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5490            .await
5491            .trace_err();
5492
5493        if delete_live_kit_room {
5494            live_kit.delete_room(live_kit_room).await.trace_err();
5495        }
5496    }
5497
5498    Ok(())
5499}
5500
5501async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5502    let left_channel_buffers = session
5503        .db()
5504        .await
5505        .leave_channel_buffers(session.connection_id)
5506        .await?;
5507
5508    for left_buffer in left_channel_buffers {
5509        channel_buffer_updated(
5510            session.connection_id,
5511            left_buffer.connections,
5512            &proto::UpdateChannelBufferCollaborators {
5513                channel_id: left_buffer.channel_id.to_proto(),
5514                collaborators: left_buffer.collaborators,
5515            },
5516            &session.peer,
5517        );
5518    }
5519
5520    Ok(())
5521}
5522
5523fn project_left(project: &db::LeftProject, session: &UserSession) {
5524    for connection_id in &project.connection_ids {
5525        if project.should_unshare {
5526            session
5527                .peer
5528                .send(
5529                    *connection_id,
5530                    proto::UnshareProject {
5531                        project_id: project.id.to_proto(),
5532                    },
5533                )
5534                .trace_err();
5535        } else {
5536            session
5537                .peer
5538                .send(
5539                    *connection_id,
5540                    proto::RemoveProjectCollaborator {
5541                        project_id: project.id.to_proto(),
5542                        peer_id: Some(session.connection_id.into()),
5543                    },
5544                )
5545                .trace_err();
5546        }
5547    }
5548}
5549
5550pub trait ResultExt {
5551    type Ok;
5552
5553    fn trace_err(self) -> Option<Self::Ok>;
5554}
5555
5556impl<T, E> ResultExt for Result<T, E>
5557where
5558    E: std::fmt::Debug,
5559{
5560    type Ok = T;
5561
5562    #[track_caller]
5563    fn trace_err(self) -> Option<T> {
5564        match self {
5565            Ok(value) => Some(value),
5566            Err(error) => {
5567                tracing::error!("{:?}", error);
5568                None
5569            }
5570        }
5571    }
5572}