rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::{
   5    auth,
   6    db::{
   7        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   8        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   9        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  10        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  11        ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, RateLimiter, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use collections::{HashMap, HashSet};
  34pub use connection_pool::{ConnectionPool, ZedVersion};
  35use core::fmt::{self, Debug, Formatter};
  36use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  37use sha2::Digest;
  38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  39
  40use futures::{
  41    channel::oneshot,
  42    future::{self, BoxFuture},
  43    stream::FuturesUnordered,
  44    FutureExt, SinkExt, StreamExt, TryStreamExt,
  45};
  46use http_client::IsahcHttpClient;
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79use self::connection_pool::VersionedMessage;
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107struct StreamingResponse<R: RequestMessage> {
 108    peer: Arc<Peer>,
 109    receipt: Receipt<R>,
 110}
 111
 112impl<R: RequestMessage> StreamingResponse<R> {
 113    fn send(&self, payload: R::Response) -> Result<()> {
 114        self.peer.respond(self.receipt, payload)?;
 115        Ok(())
 116    }
 117}
 118
 119#[derive(Clone, Debug)]
 120pub enum Principal {
 121    User(User),
 122    Impersonated { user: User, admin: User },
 123    DevServer(dev_server::Model),
 124}
 125
 126impl Principal {
 127    fn update_span(&self, span: &tracing::Span) {
 128        match &self {
 129            Principal::User(user) => {
 130                span.record("user_id", &user.id.0);
 131                span.record("login", &user.github_login);
 132            }
 133            Principal::Impersonated { user, admin } => {
 134                span.record("user_id", &user.id.0);
 135                span.record("login", &user.github_login);
 136                span.record("impersonator", &admin.github_login);
 137            }
 138            Principal::DevServer(dev_server) => {
 139                span.record("dev_server_id", &dev_server.id.0);
 140            }
 141        }
 142    }
 143}
 144
 145#[derive(Clone)]
 146struct Session {
 147    principal: Principal,
 148    connection_id: ConnectionId,
 149    db: Arc<tokio::sync::Mutex<DbHandle>>,
 150    peer: Arc<Peer>,
 151    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 152    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 153    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 154    http_client: Arc<IsahcHttpClient>,
 155    rate_limiter: Arc<RateLimiter>,
 156    /// The GeoIP country code for the user.
 157    #[allow(unused)]
 158    geoip_country_code: Option<String>,
 159    _executor: Executor,
 160}
 161
 162impl Session {
 163    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 164        #[cfg(test)]
 165        tokio::task::yield_now().await;
 166        let guard = self.db.lock().await;
 167        #[cfg(test)]
 168        tokio::task::yield_now().await;
 169        guard
 170    }
 171
 172    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 173        #[cfg(test)]
 174        tokio::task::yield_now().await;
 175        let guard = self.connection_pool.lock();
 176        ConnectionPoolGuard {
 177            guard,
 178            _not_send: PhantomData,
 179        }
 180    }
 181
 182    fn for_user(self) -> Option<UserSession> {
 183        UserSession::new(self)
 184    }
 185
 186    fn for_dev_server(self) -> Option<DevServerSession> {
 187        DevServerSession::new(self)
 188    }
 189
 190    fn user_id(&self) -> Option<UserId> {
 191        match &self.principal {
 192            Principal::User(user) => Some(user.id),
 193            Principal::Impersonated { user, .. } => Some(user.id),
 194            Principal::DevServer(_) => None,
 195        }
 196    }
 197
 198    fn is_staff(&self) -> bool {
 199        match &self.principal {
 200            Principal::User(user) => user.admin,
 201            Principal::Impersonated { .. } => true,
 202            Principal::DevServer(_) => false,
 203        }
 204    }
 205
 206    pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
 207        if self.is_staff() {
 208            return Ok(proto::Plan::ZedPro);
 209        }
 210
 211        let Some(user_id) = self.user_id() else {
 212            return Ok(proto::Plan::Free);
 213        };
 214
 215        let db = self.db().await;
 216        if db.has_active_billing_subscription(user_id).await? {
 217            Ok(proto::Plan::ZedPro)
 218        } else {
 219            Ok(proto::Plan::Free)
 220        }
 221    }
 222
 223    fn dev_server_id(&self) -> Option<DevServerId> {
 224        match &self.principal {
 225            Principal::User(_) | Principal::Impersonated { .. } => None,
 226            Principal::DevServer(dev_server) => Some(dev_server.id),
 227        }
 228    }
 229
 230    fn principal_id(&self) -> PrincipalId {
 231        match &self.principal {
 232            Principal::User(user) => PrincipalId::UserId(user.id),
 233            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 234            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 235        }
 236    }
 237}
 238
 239impl Debug for Session {
 240    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 241        let mut result = f.debug_struct("Session");
 242        match &self.principal {
 243            Principal::User(user) => {
 244                result.field("user", &user.github_login);
 245            }
 246            Principal::Impersonated { user, admin } => {
 247                result.field("user", &user.github_login);
 248                result.field("impersonator", &admin.github_login);
 249            }
 250            Principal::DevServer(dev_server) => {
 251                result.field("dev_server", &dev_server.id);
 252            }
 253        }
 254        result.field("connection_id", &self.connection_id).finish()
 255    }
 256}
 257
 258struct UserSession(Session);
 259
 260impl UserSession {
 261    pub fn new(s: Session) -> Option<Self> {
 262        s.user_id().map(|_| UserSession(s))
 263    }
 264    pub fn user_id(&self) -> UserId {
 265        self.0.user_id().unwrap()
 266    }
 267
 268    pub fn email(&self) -> Option<String> {
 269        match &self.0.principal {
 270            Principal::User(user) => user.email_address.clone(),
 271            Principal::Impersonated { user, .. } => user.email_address.clone(),
 272            Principal::DevServer(..) => None,
 273        }
 274    }
 275}
 276
 277impl Deref for UserSession {
 278    type Target = Session;
 279
 280    fn deref(&self) -> &Self::Target {
 281        &self.0
 282    }
 283}
 284impl DerefMut for UserSession {
 285    fn deref_mut(&mut self) -> &mut Self::Target {
 286        &mut self.0
 287    }
 288}
 289
 290struct DevServerSession(Session);
 291
 292impl DevServerSession {
 293    pub fn new(s: Session) -> Option<Self> {
 294        s.dev_server_id().map(|_| DevServerSession(s))
 295    }
 296    pub fn dev_server_id(&self) -> DevServerId {
 297        self.0.dev_server_id().unwrap()
 298    }
 299
 300    fn dev_server(&self) -> &dev_server::Model {
 301        match &self.0.principal {
 302            Principal::DevServer(dev_server) => dev_server,
 303            _ => unreachable!(),
 304        }
 305    }
 306}
 307
 308impl Deref for DevServerSession {
 309    type Target = Session;
 310
 311    fn deref(&self) -> &Self::Target {
 312        &self.0
 313    }
 314}
 315impl DerefMut for DevServerSession {
 316    fn deref_mut(&mut self) -> &mut Self::Target {
 317        &mut self.0
 318    }
 319}
 320
 321fn user_handler<M: RequestMessage, Fut>(
 322    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 323) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 324where
 325    Fut: Send + Future<Output = Result<()>>,
 326{
 327    let handler = Arc::new(handler);
 328    move |message, response, session| {
 329        let handler = handler.clone();
 330        Box::pin(async move {
 331            if let Some(user_session) = session.for_user() {
 332                Ok(handler(message, response, user_session).await?)
 333            } else {
 334                Err(Error::Internal(anyhow!(
 335                    "must be a user to call {}",
 336                    M::NAME
 337                )))
 338            }
 339        })
 340    }
 341}
 342
 343fn dev_server_handler<M: RequestMessage, Fut>(
 344    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 345) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 346where
 347    Fut: Send + Future<Output = Result<()>>,
 348{
 349    let handler = Arc::new(handler);
 350    move |message, response, session| {
 351        let handler = handler.clone();
 352        Box::pin(async move {
 353            if let Some(dev_server_session) = session.for_dev_server() {
 354                Ok(handler(message, response, dev_server_session).await?)
 355            } else {
 356                Err(Error::Internal(anyhow!(
 357                    "must be a dev server to call {}",
 358                    M::NAME
 359                )))
 360            }
 361        })
 362    }
 363}
 364
 365fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 366    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 367) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 368where
 369    InnertRetFut: Send + Future<Output = Result<()>>,
 370{
 371    let handler = Arc::new(handler);
 372    move |message, session| {
 373        let handler = handler.clone();
 374        Box::pin(async move {
 375            if let Some(user_session) = session.for_user() {
 376                Ok(handler(message, user_session).await?)
 377            } else {
 378                Err(Error::Internal(anyhow!(
 379                    "must be a user to call {}",
 380                    M::NAME
 381                )))
 382            }
 383        })
 384    }
 385}
 386
 387struct DbHandle(Arc<Database>);
 388
 389impl Deref for DbHandle {
 390    type Target = Database;
 391
 392    fn deref(&self) -> &Self::Target {
 393        self.0.as_ref()
 394    }
 395}
 396
 397pub struct Server {
 398    id: parking_lot::Mutex<ServerId>,
 399    peer: Arc<Peer>,
 400    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 401    app_state: Arc<AppState>,
 402    handlers: HashMap<TypeId, MessageHandler>,
 403    teardown: watch::Sender<bool>,
 404}
 405
 406pub(crate) struct ConnectionPoolGuard<'a> {
 407    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 408    _not_send: PhantomData<Rc<()>>,
 409}
 410
 411#[derive(Serialize)]
 412pub struct ServerSnapshot<'a> {
 413    peer: &'a Peer,
 414    #[serde(serialize_with = "serialize_deref")]
 415    connection_pool: ConnectionPoolGuard<'a>,
 416}
 417
 418pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 419where
 420    S: Serializer,
 421    T: Deref<Target = U>,
 422    U: Serialize,
 423{
 424    Serialize::serialize(value.deref(), serializer)
 425}
 426
 427impl Server {
 428    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 429        let mut server = Self {
 430            id: parking_lot::Mutex::new(id),
 431            peer: Peer::new(id.0 as u32),
 432            app_state: app_state.clone(),
 433            connection_pool: Default::default(),
 434            handlers: Default::default(),
 435            teardown: watch::channel(false).0,
 436        };
 437
 438        server
 439            .add_request_handler(ping)
 440            .add_request_handler(user_handler(create_room))
 441            .add_request_handler(user_handler(join_room))
 442            .add_request_handler(user_handler(rejoin_room))
 443            .add_request_handler(user_handler(leave_room))
 444            .add_request_handler(user_handler(set_room_participant_role))
 445            .add_request_handler(user_handler(call))
 446            .add_request_handler(user_handler(cancel_call))
 447            .add_message_handler(user_message_handler(decline_call))
 448            .add_request_handler(user_handler(update_participant_location))
 449            .add_request_handler(user_handler(share_project))
 450            .add_message_handler(unshare_project)
 451            .add_request_handler(user_handler(join_project))
 452            .add_request_handler(user_handler(join_hosted_project))
 453            .add_request_handler(user_handler(rejoin_dev_server_projects))
 454            .add_request_handler(user_handler(create_dev_server_project))
 455            .add_request_handler(user_handler(update_dev_server_project))
 456            .add_request_handler(user_handler(delete_dev_server_project))
 457            .add_request_handler(user_handler(create_dev_server))
 458            .add_request_handler(user_handler(regenerate_dev_server_token))
 459            .add_request_handler(user_handler(rename_dev_server))
 460            .add_request_handler(user_handler(delete_dev_server))
 461            .add_request_handler(user_handler(list_remote_directory))
 462            .add_request_handler(dev_server_handler(share_dev_server_project))
 463            .add_request_handler(dev_server_handler(shutdown_dev_server))
 464            .add_request_handler(dev_server_handler(reconnect_dev_server))
 465            .add_message_handler(user_message_handler(leave_project))
 466            .add_request_handler(update_project)
 467            .add_request_handler(update_worktree)
 468            .add_message_handler(start_language_server)
 469            .add_message_handler(update_language_server)
 470            .add_message_handler(update_diagnostic_summary)
 471            .add_message_handler(update_worktree_settings)
 472            .add_request_handler(user_handler(
 473                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_project_request_for_owner::<proto::TaskTemplates>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::GetHover>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::GetDefinition>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::GetTypeDefinition>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::GetReferences>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::SearchProject>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_read_only_project_request::<proto::GetProjectSymbols>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_read_only_project_request::<proto::OpenBufferById>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_read_only_project_request::<proto::InlayHints>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_read_only_project_request::<proto::OpenBufferByPath>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::GetCompletions>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::GetCodeActions>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::ApplyCodeAction>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::PrepareRename>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::PerformRename>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::ReloadBuffers>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_mutating_project_request::<proto::FormatBuffers>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::CreateProjectEntry>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::RenameProjectEntry>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::CopyProjectEntry>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 555            ))
 556            .add_request_handler(user_handler(
 557                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 558            ))
 559            .add_request_handler(user_handler(
 560                forward_mutating_project_request::<proto::OnTypeFormatting>,
 561            ))
 562            .add_request_handler(user_handler(
 563                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 564            ))
 565            .add_request_handler(user_handler(
 566                forward_mutating_project_request::<proto::BlameBuffer>,
 567            ))
 568            .add_request_handler(user_handler(
 569                forward_mutating_project_request::<proto::MultiLspQuery>,
 570            ))
 571            .add_request_handler(user_handler(
 572                forward_mutating_project_request::<proto::RestartLanguageServers>,
 573            ))
 574            .add_request_handler(user_handler(
 575                forward_mutating_project_request::<proto::LinkedEditingRange>,
 576            ))
 577            .add_message_handler(create_buffer_for_peer)
 578            .add_request_handler(update_buffer)
 579            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 580            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 581            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 582            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 583            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 584            .add_request_handler(get_users)
 585            .add_request_handler(user_handler(fuzzy_search_users))
 586            .add_request_handler(user_handler(request_contact))
 587            .add_request_handler(user_handler(remove_contact))
 588            .add_request_handler(user_handler(respond_to_contact_request))
 589            .add_message_handler(subscribe_to_channels)
 590            .add_request_handler(user_handler(create_channel))
 591            .add_request_handler(user_handler(delete_channel))
 592            .add_request_handler(user_handler(invite_channel_member))
 593            .add_request_handler(user_handler(remove_channel_member))
 594            .add_request_handler(user_handler(set_channel_member_role))
 595            .add_request_handler(user_handler(set_channel_visibility))
 596            .add_request_handler(user_handler(rename_channel))
 597            .add_request_handler(user_handler(join_channel_buffer))
 598            .add_request_handler(user_handler(leave_channel_buffer))
 599            .add_message_handler(user_message_handler(update_channel_buffer))
 600            .add_request_handler(user_handler(rejoin_channel_buffers))
 601            .add_request_handler(user_handler(get_channel_members))
 602            .add_request_handler(user_handler(respond_to_channel_invite))
 603            .add_request_handler(user_handler(join_channel))
 604            .add_request_handler(user_handler(join_channel_chat))
 605            .add_message_handler(user_message_handler(leave_channel_chat))
 606            .add_request_handler(user_handler(send_channel_message))
 607            .add_request_handler(user_handler(remove_channel_message))
 608            .add_request_handler(user_handler(update_channel_message))
 609            .add_request_handler(user_handler(get_channel_messages))
 610            .add_request_handler(user_handler(get_channel_messages_by_id))
 611            .add_request_handler(user_handler(get_notifications))
 612            .add_request_handler(user_handler(mark_notification_as_read))
 613            .add_request_handler(user_handler(move_channel))
 614            .add_request_handler(user_handler(follow))
 615            .add_message_handler(user_message_handler(unfollow))
 616            .add_message_handler(user_message_handler(update_followers))
 617            .add_request_handler(user_handler(get_private_user_info))
 618            .add_message_handler(user_message_handler(acknowledge_channel_message))
 619            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 620            .add_request_handler(user_handler(get_supermaven_api_key))
 621            .add_request_handler(user_handler(
 622                forward_mutating_project_request::<proto::OpenContext>,
 623            ))
 624            .add_request_handler(user_handler(
 625                forward_mutating_project_request::<proto::CreateContext>,
 626            ))
 627            .add_request_handler(user_handler(
 628                forward_mutating_project_request::<proto::SynchronizeContexts>,
 629            ))
 630            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 631            .add_message_handler(update_context)
 632            .add_request_handler({
 633                let app_state = app_state.clone();
 634                move |request, response, session| {
 635                    let app_state = app_state.clone();
 636                    async move {
 637                        complete_with_language_model(request, response, session, &app_state.config)
 638                            .await
 639                    }
 640                }
 641            })
 642            .add_streaming_request_handler({
 643                let app_state = app_state.clone();
 644                move |request, response, session| {
 645                    let app_state = app_state.clone();
 646                    async move {
 647                        stream_complete_with_language_model(
 648                            request,
 649                            response,
 650                            session,
 651                            &app_state.config,
 652                        )
 653                        .await
 654                    }
 655                }
 656            })
 657            .add_request_handler({
 658                let app_state = app_state.clone();
 659                move |request, response, session| {
 660                    let app_state = app_state.clone();
 661                    async move {
 662                        count_language_model_tokens(request, response, session, &app_state.config)
 663                            .await
 664                    }
 665                }
 666            })
 667            .add_request_handler({
 668                user_handler(move |request, response, session| {
 669                    get_cached_embeddings(request, response, session)
 670                })
 671            })
 672            .add_request_handler({
 673                let app_state = app_state.clone();
 674                user_handler(move |request, response, session| {
 675                    compute_embeddings(
 676                        request,
 677                        response,
 678                        session,
 679                        app_state.config.openai_api_key.clone(),
 680                    )
 681                })
 682            });
 683
 684        Arc::new(server)
 685    }
 686
 687    pub async fn start(&self) -> Result<()> {
 688        let server_id = *self.id.lock();
 689        let app_state = self.app_state.clone();
 690        let peer = self.peer.clone();
 691        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 692        let pool = self.connection_pool.clone();
 693        let live_kit_client = self.app_state.live_kit_client.clone();
 694
 695        let span = info_span!("start server");
 696        self.app_state.executor.spawn_detached(
 697            async move {
 698                tracing::info!("waiting for cleanup timeout");
 699                timeout.await;
 700                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 701                if let Some((room_ids, channel_ids)) = app_state
 702                    .db
 703                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 704                    .await
 705                    .trace_err()
 706                {
 707                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 708                    tracing::info!(
 709                        stale_channel_buffer_count = channel_ids.len(),
 710                        "retrieved stale channel buffers"
 711                    );
 712
 713                    for channel_id in channel_ids {
 714                        if let Some(refreshed_channel_buffer) = app_state
 715                            .db
 716                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 717                            .await
 718                            .trace_err()
 719                        {
 720                            for connection_id in refreshed_channel_buffer.connection_ids {
 721                                peer.send(
 722                                    connection_id,
 723                                    proto::UpdateChannelBufferCollaborators {
 724                                        channel_id: channel_id.to_proto(),
 725                                        collaborators: refreshed_channel_buffer
 726                                            .collaborators
 727                                            .clone(),
 728                                    },
 729                                )
 730                                .trace_err();
 731                            }
 732                        }
 733                    }
 734
 735                    for room_id in room_ids {
 736                        let mut contacts_to_update = HashSet::default();
 737                        let mut canceled_calls_to_user_ids = Vec::new();
 738                        let mut live_kit_room = String::new();
 739                        let mut delete_live_kit_room = false;
 740
 741                        if let Some(mut refreshed_room) = app_state
 742                            .db
 743                            .clear_stale_room_participants(room_id, server_id)
 744                            .await
 745                            .trace_err()
 746                        {
 747                            tracing::info!(
 748                                room_id = room_id.0,
 749                                new_participant_count = refreshed_room.room.participants.len(),
 750                                "refreshed room"
 751                            );
 752                            room_updated(&refreshed_room.room, &peer);
 753                            if let Some(channel) = refreshed_room.channel.as_ref() {
 754                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 755                            }
 756                            contacts_to_update
 757                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 758                            contacts_to_update
 759                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 760                            canceled_calls_to_user_ids =
 761                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 762                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 763                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 764                        }
 765
 766                        {
 767                            let pool = pool.lock();
 768                            for canceled_user_id in canceled_calls_to_user_ids {
 769                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 770                                    peer.send(
 771                                        connection_id,
 772                                        proto::CallCanceled {
 773                                            room_id: room_id.to_proto(),
 774                                        },
 775                                    )
 776                                    .trace_err();
 777                                }
 778                            }
 779                        }
 780
 781                        for user_id in contacts_to_update {
 782                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 783                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 784                            if let Some((busy, contacts)) = busy.zip(contacts) {
 785                                let pool = pool.lock();
 786                                let updated_contact = contact_for_user(user_id, busy, &pool);
 787                                for contact in contacts {
 788                                    if let db::Contact::Accepted {
 789                                        user_id: contact_user_id,
 790                                        ..
 791                                    } = contact
 792                                    {
 793                                        for contact_conn_id in
 794                                            pool.user_connection_ids(contact_user_id)
 795                                        {
 796                                            peer.send(
 797                                                contact_conn_id,
 798                                                proto::UpdateContacts {
 799                                                    contacts: vec![updated_contact.clone()],
 800                                                    remove_contacts: Default::default(),
 801                                                    incoming_requests: Default::default(),
 802                                                    remove_incoming_requests: Default::default(),
 803                                                    outgoing_requests: Default::default(),
 804                                                    remove_outgoing_requests: Default::default(),
 805                                                },
 806                                            )
 807                                            .trace_err();
 808                                        }
 809                                    }
 810                                }
 811                            }
 812                        }
 813
 814                        if let Some(live_kit) = live_kit_client.as_ref() {
 815                            if delete_live_kit_room {
 816                                live_kit.delete_room(live_kit_room).await.trace_err();
 817                            }
 818                        }
 819                    }
 820                }
 821
 822                app_state
 823                    .db
 824                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 825                    .await
 826                    .trace_err();
 827            }
 828            .instrument(span),
 829        );
 830        Ok(())
 831    }
 832
 833    pub fn teardown(&self) {
 834        self.peer.teardown();
 835        self.connection_pool.lock().reset();
 836        let _ = self.teardown.send(true);
 837    }
 838
 839    #[cfg(test)]
 840    pub fn reset(&self, id: ServerId) {
 841        self.teardown();
 842        *self.id.lock() = id;
 843        self.peer.reset(id.0 as u32);
 844        let _ = self.teardown.send(false);
 845    }
 846
 847    #[cfg(test)]
 848    pub fn id(&self) -> ServerId {
 849        *self.id.lock()
 850    }
 851
 852    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 853    where
 854        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 855        Fut: 'static + Send + Future<Output = Result<()>>,
 856        M: EnvelopedMessage,
 857    {
 858        let prev_handler = self.handlers.insert(
 859            TypeId::of::<M>(),
 860            Box::new(move |envelope, session| {
 861                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 862                let received_at = envelope.received_at;
 863                tracing::info!("message received");
 864                let start_time = Instant::now();
 865                let future = (handler)(*envelope, session);
 866                async move {
 867                    let result = future.await;
 868                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 869                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 870                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 871                    let payload_type = M::NAME;
 872
 873                    match result {
 874                        Err(error) => {
 875                            tracing::error!(
 876                                ?error,
 877                                total_duration_ms,
 878                                processing_duration_ms,
 879                                queue_duration_ms,
 880                                payload_type,
 881                                "error handling message"
 882                            )
 883                        }
 884                        Ok(()) => tracing::info!(
 885                            total_duration_ms,
 886                            processing_duration_ms,
 887                            queue_duration_ms,
 888                            "finished handling message"
 889                        ),
 890                    }
 891                }
 892                .boxed()
 893            }),
 894        );
 895        if prev_handler.is_some() {
 896            panic!("registered a handler for the same message twice");
 897        }
 898        self
 899    }
 900
 901    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 902    where
 903        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 904        Fut: 'static + Send + Future<Output = Result<()>>,
 905        M: EnvelopedMessage,
 906    {
 907        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 908        self
 909    }
 910
 911    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 912    where
 913        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 914        Fut: Send + Future<Output = Result<()>>,
 915        M: RequestMessage,
 916    {
 917        let handler = Arc::new(handler);
 918        self.add_handler(move |envelope, session| {
 919            let receipt = envelope.receipt();
 920            let handler = handler.clone();
 921            async move {
 922                let peer = session.peer.clone();
 923                let responded = Arc::new(AtomicBool::default());
 924                let response = Response {
 925                    peer: peer.clone(),
 926                    responded: responded.clone(),
 927                    receipt,
 928                };
 929                match (handler)(envelope.payload, response, session).await {
 930                    Ok(()) => {
 931                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 932                            Ok(())
 933                        } else {
 934                            Err(anyhow!("handler did not send a response"))?
 935                        }
 936                    }
 937                    Err(error) => {
 938                        let proto_err = match &error {
 939                            Error::Internal(err) => err.to_proto(),
 940                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 941                        };
 942                        peer.respond_with_error(receipt, proto_err)?;
 943                        Err(error)
 944                    }
 945                }
 946            }
 947        })
 948    }
 949
 950    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 951    where
 952        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 953        Fut: Send + Future<Output = Result<()>>,
 954        M: RequestMessage,
 955    {
 956        let handler = Arc::new(handler);
 957        self.add_handler(move |envelope, session| {
 958            let receipt = envelope.receipt();
 959            let handler = handler.clone();
 960            async move {
 961                let peer = session.peer.clone();
 962                let response = StreamingResponse {
 963                    peer: peer.clone(),
 964                    receipt,
 965                };
 966                match (handler)(envelope.payload, response, session).await {
 967                    Ok(()) => {
 968                        peer.end_stream(receipt)?;
 969                        Ok(())
 970                    }
 971                    Err(error) => {
 972                        let proto_err = match &error {
 973                            Error::Internal(err) => err.to_proto(),
 974                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 975                        };
 976                        peer.respond_with_error(receipt, proto_err)?;
 977                        Err(error)
 978                    }
 979                }
 980            }
 981        })
 982    }
 983
 984    #[allow(clippy::too_many_arguments)]
 985    pub fn handle_connection(
 986        self: &Arc<Self>,
 987        connection: Connection,
 988        address: String,
 989        principal: Principal,
 990        zed_version: ZedVersion,
 991        geoip_country_code: Option<String>,
 992        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 993        executor: Executor,
 994    ) -> impl Future<Output = ()> {
 995        let this = self.clone();
 996        let span = info_span!("handle connection", %address,
 997            connection_id=field::Empty,
 998            user_id=field::Empty,
 999            login=field::Empty,
1000            impersonator=field::Empty,
1001            dev_server_id=field::Empty,
1002            geoip_country_code=field::Empty
1003        );
1004        principal.update_span(&span);
1005        if let Some(country_code) = geoip_country_code.as_ref() {
1006            span.record("geoip_country_code", country_code);
1007        }
1008
1009        let mut teardown = self.teardown.subscribe();
1010        async move {
1011            if *teardown.borrow() {
1012                tracing::error!("server is tearing down");
1013                return
1014            }
1015            let (connection_id, handle_io, mut incoming_rx) = this
1016                .peer
1017                .add_connection(connection, {
1018                    let executor = executor.clone();
1019                    move |duration| executor.sleep(duration)
1020                });
1021            tracing::Span::current().record("connection_id", format!("{}", connection_id));
1022
1023            tracing::info!("connection opened");
1024
1025            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
1026            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1027                Ok(http_client) => Arc::new(http_client),
1028                Err(error) => {
1029                    tracing::error!(?error, "failed to create HTTP client");
1030                    return;
1031                }
1032            };
1033
1034            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1035                Some(Arc::new(SupermavenAdminApi::new(
1036                    supermaven_admin_api_key.to_string(),
1037                    http_client.clone(),
1038                )))
1039            } else {
1040                None
1041            };
1042
1043            let session = Session {
1044                principal: principal.clone(),
1045                connection_id,
1046                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1047                peer: this.peer.clone(),
1048                connection_pool: this.connection_pool.clone(),
1049                live_kit_client: this.app_state.live_kit_client.clone(),
1050                http_client,
1051                rate_limiter: this.app_state.rate_limiter.clone(),
1052                geoip_country_code,
1053                _executor: executor.clone(),
1054                supermaven_client,
1055            };
1056
1057            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1058                tracing::error!(?error, "failed to send initial client update");
1059                return;
1060            }
1061
1062            let handle_io = handle_io.fuse();
1063            futures::pin_mut!(handle_io);
1064
1065            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1066            // This prevents deadlocks when e.g., client A performs a request to client B and
1067            // client B performs a request to client A. If both clients stop processing further
1068            // messages until their respective request completes, they won't have a chance to
1069            // respond to the other client's request and cause a deadlock.
1070            //
1071            // This arrangement ensures we will attempt to process earlier messages first, but fall
1072            // back to processing messages arrived later in the spirit of making progress.
1073            let mut foreground_message_handlers = FuturesUnordered::new();
1074            let concurrent_handlers = Arc::new(Semaphore::new(256));
1075            loop {
1076                let next_message = async {
1077                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1078                    let message = incoming_rx.next().await;
1079                    (permit, message)
1080                }.fuse();
1081                futures::pin_mut!(next_message);
1082                futures::select_biased! {
1083                    _ = teardown.changed().fuse() => return,
1084                    result = handle_io => {
1085                        if let Err(error) = result {
1086                            tracing::error!(?error, "error handling I/O");
1087                        }
1088                        break;
1089                    }
1090                    _ = foreground_message_handlers.next() => {}
1091                    next_message = next_message => {
1092                        let (permit, message) = next_message;
1093                        if let Some(message) = message {
1094                            let type_name = message.payload_type_name();
1095                            // note: we copy all the fields from the parent span so we can query them in the logs.
1096                            // (https://github.com/tokio-rs/tracing/issues/2670).
1097                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1098                                user_id=field::Empty,
1099                                login=field::Empty,
1100                                impersonator=field::Empty,
1101                                dev_server_id=field::Empty
1102                            );
1103                            principal.update_span(&span);
1104                            let span_enter = span.enter();
1105                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1106                                let is_background = message.is_background();
1107                                let handle_message = (handler)(message, session.clone());
1108                                drop(span_enter);
1109
1110                                let handle_message = async move {
1111                                    handle_message.await;
1112                                    drop(permit);
1113                                }.instrument(span);
1114                                if is_background {
1115                                    executor.spawn_detached(handle_message);
1116                                } else {
1117                                    foreground_message_handlers.push(handle_message);
1118                                }
1119                            } else {
1120                                tracing::error!("no message handler");
1121                            }
1122                        } else {
1123                            tracing::info!("connection closed");
1124                            break;
1125                        }
1126                    }
1127                }
1128            }
1129
1130            drop(foreground_message_handlers);
1131            tracing::info!("signing out");
1132            if let Err(error) = connection_lost(session, teardown, executor).await {
1133                tracing::error!(?error, "error signing out");
1134            }
1135
1136        }.instrument(span)
1137    }
1138
1139    async fn send_initial_client_update(
1140        &self,
1141        connection_id: ConnectionId,
1142        principal: &Principal,
1143        zed_version: ZedVersion,
1144        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1145        session: &Session,
1146    ) -> Result<()> {
1147        self.peer.send(
1148            connection_id,
1149            proto::Hello {
1150                peer_id: Some(connection_id.into()),
1151            },
1152        )?;
1153        tracing::info!("sent hello message");
1154        if let Some(send_connection_id) = send_connection_id.take() {
1155            let _ = send_connection_id.send(connection_id);
1156        }
1157
1158        match principal {
1159            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1160                if !user.connected_once {
1161                    self.peer.send(connection_id, proto::ShowContacts {})?;
1162                    self.app_state
1163                        .db
1164                        .set_user_connected_once(user.id, true)
1165                        .await?;
1166                }
1167
1168                update_user_plan(user.id, session).await?;
1169
1170                let (contacts, dev_server_projects) = future::try_join(
1171                    self.app_state.db.get_contacts(user.id),
1172                    self.app_state.db.dev_server_projects_update(user.id),
1173                )
1174                .await?;
1175
1176                {
1177                    let mut pool = self.connection_pool.lock();
1178                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1179                    self.peer.send(
1180                        connection_id,
1181                        build_initial_contacts_update(contacts, &pool),
1182                    )?;
1183                }
1184
1185                if should_auto_subscribe_to_channels(zed_version) {
1186                    subscribe_user_to_channels(user.id, session).await?;
1187                }
1188
1189                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1190
1191                if let Some(incoming_call) =
1192                    self.app_state.db.incoming_call_for_user(user.id).await?
1193                {
1194                    self.peer.send(connection_id, incoming_call)?;
1195                }
1196
1197                update_user_contacts(user.id, &session).await?;
1198            }
1199            Principal::DevServer(dev_server) => {
1200                {
1201                    let mut pool = self.connection_pool.lock();
1202                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1203                    {
1204                        self.peer.send(
1205                            stale_connection_id,
1206                            proto::ShutdownDevServer {
1207                                reason: Some(
1208                                    "another dev server connected with the same token".to_string(),
1209                                ),
1210                            },
1211                        )?;
1212                        pool.remove_connection(stale_connection_id)?;
1213                    };
1214                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1215                }
1216
1217                let projects = self
1218                    .app_state
1219                    .db
1220                    .get_projects_for_dev_server(dev_server.id)
1221                    .await?;
1222                self.peer
1223                    .send(connection_id, proto::DevServerInstructions { projects })?;
1224
1225                let status = self
1226                    .app_state
1227                    .db
1228                    .dev_server_projects_update(dev_server.user_id)
1229                    .await?;
1230                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1231            }
1232        }
1233
1234        Ok(())
1235    }
1236
1237    pub async fn invite_code_redeemed(
1238        self: &Arc<Self>,
1239        inviter_id: UserId,
1240        invitee_id: UserId,
1241    ) -> Result<()> {
1242        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1243            if let Some(code) = &user.invite_code {
1244                let pool = self.connection_pool.lock();
1245                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1246                for connection_id in pool.user_connection_ids(inviter_id) {
1247                    self.peer.send(
1248                        connection_id,
1249                        proto::UpdateContacts {
1250                            contacts: vec![invitee_contact.clone()],
1251                            ..Default::default()
1252                        },
1253                    )?;
1254                    self.peer.send(
1255                        connection_id,
1256                        proto::UpdateInviteInfo {
1257                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1258                            count: user.invite_count as u32,
1259                        },
1260                    )?;
1261                }
1262            }
1263        }
1264        Ok(())
1265    }
1266
1267    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1268        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1269            if let Some(invite_code) = &user.invite_code {
1270                let pool = self.connection_pool.lock();
1271                for connection_id in pool.user_connection_ids(user_id) {
1272                    self.peer.send(
1273                        connection_id,
1274                        proto::UpdateInviteInfo {
1275                            url: format!(
1276                                "{}{}",
1277                                self.app_state.config.invite_link_prefix, invite_code
1278                            ),
1279                            count: user.invite_count as u32,
1280                        },
1281                    )?;
1282                }
1283            }
1284        }
1285        Ok(())
1286    }
1287
1288    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1289        ServerSnapshot {
1290            connection_pool: ConnectionPoolGuard {
1291                guard: self.connection_pool.lock(),
1292                _not_send: PhantomData,
1293            },
1294            peer: &self.peer,
1295        }
1296    }
1297}
1298
1299impl<'a> Deref for ConnectionPoolGuard<'a> {
1300    type Target = ConnectionPool;
1301
1302    fn deref(&self) -> &Self::Target {
1303        &self.guard
1304    }
1305}
1306
1307impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1308    fn deref_mut(&mut self) -> &mut Self::Target {
1309        &mut self.guard
1310    }
1311}
1312
1313impl<'a> Drop for ConnectionPoolGuard<'a> {
1314    fn drop(&mut self) {
1315        #[cfg(test)]
1316        self.check_invariants();
1317    }
1318}
1319
1320fn broadcast<F>(
1321    sender_id: Option<ConnectionId>,
1322    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1323    mut f: F,
1324) where
1325    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1326{
1327    for receiver_id in receiver_ids {
1328        if Some(receiver_id) != sender_id {
1329            if let Err(error) = f(receiver_id) {
1330                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1331            }
1332        }
1333    }
1334}
1335
1336pub struct ProtocolVersion(u32);
1337
1338impl Header for ProtocolVersion {
1339    fn name() -> &'static HeaderName {
1340        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1341        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1342    }
1343
1344    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1345    where
1346        Self: Sized,
1347        I: Iterator<Item = &'i axum::http::HeaderValue>,
1348    {
1349        let version = values
1350            .next()
1351            .ok_or_else(axum::headers::Error::invalid)?
1352            .to_str()
1353            .map_err(|_| axum::headers::Error::invalid())?
1354            .parse()
1355            .map_err(|_| axum::headers::Error::invalid())?;
1356        Ok(Self(version))
1357    }
1358
1359    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1360        values.extend([self.0.to_string().parse().unwrap()]);
1361    }
1362}
1363
1364pub struct AppVersionHeader(SemanticVersion);
1365impl Header for AppVersionHeader {
1366    fn name() -> &'static HeaderName {
1367        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1368        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1369    }
1370
1371    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1372    where
1373        Self: Sized,
1374        I: Iterator<Item = &'i axum::http::HeaderValue>,
1375    {
1376        let version = values
1377            .next()
1378            .ok_or_else(axum::headers::Error::invalid)?
1379            .to_str()
1380            .map_err(|_| axum::headers::Error::invalid())?
1381            .parse()
1382            .map_err(|_| axum::headers::Error::invalid())?;
1383        Ok(Self(version))
1384    }
1385
1386    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1387        values.extend([self.0.to_string().parse().unwrap()]);
1388    }
1389}
1390
1391pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1392    Router::new()
1393        .route("/rpc", get(handle_websocket_request))
1394        .layer(
1395            ServiceBuilder::new()
1396                .layer(Extension(server.app_state.clone()))
1397                .layer(middleware::from_fn(auth::validate_header)),
1398        )
1399        .route("/metrics", get(handle_metrics))
1400        .layer(Extension(server))
1401}
1402
1403pub async fn handle_websocket_request(
1404    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1405    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1406    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1407    Extension(server): Extension<Arc<Server>>,
1408    Extension(principal): Extension<Principal>,
1409    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1410    ws: WebSocketUpgrade,
1411) -> axum::response::Response {
1412    if protocol_version != rpc::PROTOCOL_VERSION {
1413        return (
1414            StatusCode::UPGRADE_REQUIRED,
1415            "client must be upgraded".to_string(),
1416        )
1417            .into_response();
1418    }
1419
1420    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1421        return (
1422            StatusCode::UPGRADE_REQUIRED,
1423            "no version header found".to_string(),
1424        )
1425            .into_response();
1426    };
1427
1428    if !version.can_collaborate() {
1429        return (
1430            StatusCode::UPGRADE_REQUIRED,
1431            "client must be upgraded".to_string(),
1432        )
1433            .into_response();
1434    }
1435
1436    let socket_address = socket_address.to_string();
1437    ws.on_upgrade(move |socket| {
1438        let socket = socket
1439            .map_ok(to_tungstenite_message)
1440            .err_into()
1441            .with(|message| async move { to_axum_message(message) });
1442        let connection = Connection::new(Box::pin(socket));
1443        async move {
1444            server
1445                .handle_connection(
1446                    connection,
1447                    socket_address,
1448                    principal,
1449                    version,
1450                    country_code_header.map(|header| header.to_string()),
1451                    None,
1452                    Executor::Production,
1453                )
1454                .await;
1455        }
1456    })
1457}
1458
1459pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1460    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1461    let connections_metric = CONNECTIONS_METRIC
1462        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1463
1464    let connections = server
1465        .connection_pool
1466        .lock()
1467        .connections()
1468        .filter(|connection| !connection.admin)
1469        .count();
1470    connections_metric.set(connections as _);
1471
1472    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1473    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1474        register_int_gauge!(
1475            "shared_projects",
1476            "number of open projects with one or more guests"
1477        )
1478        .unwrap()
1479    });
1480
1481    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1482    shared_projects_metric.set(shared_projects as _);
1483
1484    let encoder = prometheus::TextEncoder::new();
1485    let metric_families = prometheus::gather();
1486    let encoded_metrics = encoder
1487        .encode_to_string(&metric_families)
1488        .map_err(|err| anyhow!("{}", err))?;
1489    Ok(encoded_metrics)
1490}
1491
1492#[instrument(err, skip(executor))]
1493async fn connection_lost(
1494    session: Session,
1495    mut teardown: watch::Receiver<bool>,
1496    executor: Executor,
1497) -> Result<()> {
1498    session.peer.disconnect(session.connection_id);
1499    session
1500        .connection_pool()
1501        .await
1502        .remove_connection(session.connection_id)?;
1503
1504    session
1505        .db()
1506        .await
1507        .connection_lost(session.connection_id)
1508        .await
1509        .trace_err();
1510
1511    futures::select_biased! {
1512        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1513            match &session.principal {
1514                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1515                    let session = session.for_user().unwrap();
1516
1517                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1518                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1519                    leave_channel_buffers_for_session(&session)
1520                        .await
1521                        .trace_err();
1522
1523                    if !session
1524                        .connection_pool()
1525                        .await
1526                        .is_user_online(session.user_id())
1527                    {
1528                        let db = session.db().await;
1529                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1530                            room_updated(&room, &session.peer);
1531                        }
1532                    }
1533
1534                    update_user_contacts(session.user_id(), &session).await?;
1535                },
1536            Principal::DevServer(_) => {
1537                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1538            },
1539        }
1540        },
1541        _ = teardown.changed().fuse() => {}
1542    }
1543
1544    Ok(())
1545}
1546
1547/// Acknowledges a ping from a client, used to keep the connection alive.
1548async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1549    response.send(proto::Ack {})?;
1550    Ok(())
1551}
1552
1553/// Creates a new room for calling (outside of channels)
1554async fn create_room(
1555    _request: proto::CreateRoom,
1556    response: Response<proto::CreateRoom>,
1557    session: UserSession,
1558) -> Result<()> {
1559    let live_kit_room = nanoid::nanoid!(30);
1560
1561    let live_kit_connection_info = util::maybe!(async {
1562        let live_kit = session.live_kit_client.as_ref();
1563        let live_kit = live_kit?;
1564        let user_id = session.user_id().to_string();
1565
1566        let token = live_kit
1567            .room_token(&live_kit_room, &user_id.to_string())
1568            .trace_err()?;
1569
1570        Some(proto::LiveKitConnectionInfo {
1571            server_url: live_kit.url().into(),
1572            token,
1573            can_publish: true,
1574        })
1575    })
1576    .await;
1577
1578    let room = session
1579        .db()
1580        .await
1581        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1582        .await?;
1583
1584    response.send(proto::CreateRoomResponse {
1585        room: Some(room.clone()),
1586        live_kit_connection_info,
1587    })?;
1588
1589    update_user_contacts(session.user_id(), &session).await?;
1590    Ok(())
1591}
1592
1593/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1594async fn join_room(
1595    request: proto::JoinRoom,
1596    response: Response<proto::JoinRoom>,
1597    session: UserSession,
1598) -> Result<()> {
1599    let room_id = RoomId::from_proto(request.id);
1600
1601    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1602
1603    if let Some(channel_id) = channel_id {
1604        return join_channel_internal(channel_id, Box::new(response), session).await;
1605    }
1606
1607    let joined_room = {
1608        let room = session
1609            .db()
1610            .await
1611            .join_room(room_id, session.user_id(), session.connection_id)
1612            .await?;
1613        room_updated(&room.room, &session.peer);
1614        room.into_inner()
1615    };
1616
1617    for connection_id in session
1618        .connection_pool()
1619        .await
1620        .user_connection_ids(session.user_id())
1621    {
1622        session
1623            .peer
1624            .send(
1625                connection_id,
1626                proto::CallCanceled {
1627                    room_id: room_id.to_proto(),
1628                },
1629            )
1630            .trace_err();
1631    }
1632
1633    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1634        if let Some(token) = live_kit
1635            .room_token(
1636                &joined_room.room.live_kit_room,
1637                &session.user_id().to_string(),
1638            )
1639            .trace_err()
1640        {
1641            Some(proto::LiveKitConnectionInfo {
1642                server_url: live_kit.url().into(),
1643                token,
1644                can_publish: true,
1645            })
1646        } else {
1647            None
1648        }
1649    } else {
1650        None
1651    };
1652
1653    response.send(proto::JoinRoomResponse {
1654        room: Some(joined_room.room),
1655        channel_id: None,
1656        live_kit_connection_info,
1657    })?;
1658
1659    update_user_contacts(session.user_id(), &session).await?;
1660    Ok(())
1661}
1662
1663/// Rejoin room is used to reconnect to a room after connection errors.
1664async fn rejoin_room(
1665    request: proto::RejoinRoom,
1666    response: Response<proto::RejoinRoom>,
1667    session: UserSession,
1668) -> Result<()> {
1669    let room;
1670    let channel;
1671    {
1672        let mut rejoined_room = session
1673            .db()
1674            .await
1675            .rejoin_room(request, session.user_id(), session.connection_id)
1676            .await?;
1677
1678        response.send(proto::RejoinRoomResponse {
1679            room: Some(rejoined_room.room.clone()),
1680            reshared_projects: rejoined_room
1681                .reshared_projects
1682                .iter()
1683                .map(|project| proto::ResharedProject {
1684                    id: project.id.to_proto(),
1685                    collaborators: project
1686                        .collaborators
1687                        .iter()
1688                        .map(|collaborator| collaborator.to_proto())
1689                        .collect(),
1690                })
1691                .collect(),
1692            rejoined_projects: rejoined_room
1693                .rejoined_projects
1694                .iter()
1695                .map(|rejoined_project| rejoined_project.to_proto())
1696                .collect(),
1697        })?;
1698        room_updated(&rejoined_room.room, &session.peer);
1699
1700        for project in &rejoined_room.reshared_projects {
1701            for collaborator in &project.collaborators {
1702                session
1703                    .peer
1704                    .send(
1705                        collaborator.connection_id,
1706                        proto::UpdateProjectCollaborator {
1707                            project_id: project.id.to_proto(),
1708                            old_peer_id: Some(project.old_connection_id.into()),
1709                            new_peer_id: Some(session.connection_id.into()),
1710                        },
1711                    )
1712                    .trace_err();
1713            }
1714
1715            broadcast(
1716                Some(session.connection_id),
1717                project
1718                    .collaborators
1719                    .iter()
1720                    .map(|collaborator| collaborator.connection_id),
1721                |connection_id| {
1722                    session.peer.forward_send(
1723                        session.connection_id,
1724                        connection_id,
1725                        proto::UpdateProject {
1726                            project_id: project.id.to_proto(),
1727                            worktrees: project.worktrees.clone(),
1728                        },
1729                    )
1730                },
1731            );
1732        }
1733
1734        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1735
1736        let rejoined_room = rejoined_room.into_inner();
1737
1738        room = rejoined_room.room;
1739        channel = rejoined_room.channel;
1740    }
1741
1742    if let Some(channel) = channel {
1743        channel_updated(
1744            &channel,
1745            &room,
1746            &session.peer,
1747            &*session.connection_pool().await,
1748        );
1749    }
1750
1751    update_user_contacts(session.user_id(), &session).await?;
1752    Ok(())
1753}
1754
1755fn notify_rejoined_projects(
1756    rejoined_projects: &mut Vec<RejoinedProject>,
1757    session: &UserSession,
1758) -> Result<()> {
1759    for project in rejoined_projects.iter() {
1760        for collaborator in &project.collaborators {
1761            session
1762                .peer
1763                .send(
1764                    collaborator.connection_id,
1765                    proto::UpdateProjectCollaborator {
1766                        project_id: project.id.to_proto(),
1767                        old_peer_id: Some(project.old_connection_id.into()),
1768                        new_peer_id: Some(session.connection_id.into()),
1769                    },
1770                )
1771                .trace_err();
1772        }
1773    }
1774
1775    for project in rejoined_projects {
1776        for worktree in mem::take(&mut project.worktrees) {
1777            #[cfg(any(test, feature = "test-support"))]
1778            const MAX_CHUNK_SIZE: usize = 2;
1779            #[cfg(not(any(test, feature = "test-support")))]
1780            const MAX_CHUNK_SIZE: usize = 256;
1781
1782            // Stream this worktree's entries.
1783            let message = proto::UpdateWorktree {
1784                project_id: project.id.to_proto(),
1785                worktree_id: worktree.id,
1786                abs_path: worktree.abs_path.clone(),
1787                root_name: worktree.root_name,
1788                updated_entries: worktree.updated_entries,
1789                removed_entries: worktree.removed_entries,
1790                scan_id: worktree.scan_id,
1791                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1792                updated_repositories: worktree.updated_repositories,
1793                removed_repositories: worktree.removed_repositories,
1794            };
1795            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1796                session.peer.send(session.connection_id, update.clone())?;
1797            }
1798
1799            // Stream this worktree's diagnostics.
1800            for summary in worktree.diagnostic_summaries {
1801                session.peer.send(
1802                    session.connection_id,
1803                    proto::UpdateDiagnosticSummary {
1804                        project_id: project.id.to_proto(),
1805                        worktree_id: worktree.id,
1806                        summary: Some(summary),
1807                    },
1808                )?;
1809            }
1810
1811            for settings_file in worktree.settings_files {
1812                session.peer.send(
1813                    session.connection_id,
1814                    proto::UpdateWorktreeSettings {
1815                        project_id: project.id.to_proto(),
1816                        worktree_id: worktree.id,
1817                        path: settings_file.path,
1818                        content: Some(settings_file.content),
1819                    },
1820                )?;
1821            }
1822        }
1823
1824        for language_server in &project.language_servers {
1825            session.peer.send(
1826                session.connection_id,
1827                proto::UpdateLanguageServer {
1828                    project_id: project.id.to_proto(),
1829                    language_server_id: language_server.id,
1830                    variant: Some(
1831                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1832                            proto::LspDiskBasedDiagnosticsUpdated {},
1833                        ),
1834                    ),
1835                },
1836            )?;
1837        }
1838    }
1839    Ok(())
1840}
1841
1842/// leave room disconnects from the room.
1843async fn leave_room(
1844    _: proto::LeaveRoom,
1845    response: Response<proto::LeaveRoom>,
1846    session: UserSession,
1847) -> Result<()> {
1848    leave_room_for_session(&session, session.connection_id).await?;
1849    response.send(proto::Ack {})?;
1850    Ok(())
1851}
1852
1853/// Updates the permissions of someone else in the room.
1854async fn set_room_participant_role(
1855    request: proto::SetRoomParticipantRole,
1856    response: Response<proto::SetRoomParticipantRole>,
1857    session: UserSession,
1858) -> Result<()> {
1859    let user_id = UserId::from_proto(request.user_id);
1860    let role = ChannelRole::from(request.role());
1861
1862    let (live_kit_room, can_publish) = {
1863        let room = session
1864            .db()
1865            .await
1866            .set_room_participant_role(
1867                session.user_id(),
1868                RoomId::from_proto(request.room_id),
1869                user_id,
1870                role,
1871            )
1872            .await?;
1873
1874        let live_kit_room = room.live_kit_room.clone();
1875        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1876        room_updated(&room, &session.peer);
1877        (live_kit_room, can_publish)
1878    };
1879
1880    if let Some(live_kit) = session.live_kit_client.as_ref() {
1881        live_kit
1882            .update_participant(
1883                live_kit_room.clone(),
1884                request.user_id.to_string(),
1885                live_kit_server::proto::ParticipantPermission {
1886                    can_subscribe: true,
1887                    can_publish,
1888                    can_publish_data: can_publish,
1889                    hidden: false,
1890                    recorder: false,
1891                },
1892            )
1893            .await
1894            .trace_err();
1895    }
1896
1897    response.send(proto::Ack {})?;
1898    Ok(())
1899}
1900
1901/// Call someone else into the current room
1902async fn call(
1903    request: proto::Call,
1904    response: Response<proto::Call>,
1905    session: UserSession,
1906) -> Result<()> {
1907    let room_id = RoomId::from_proto(request.room_id);
1908    let calling_user_id = session.user_id();
1909    let calling_connection_id = session.connection_id;
1910    let called_user_id = UserId::from_proto(request.called_user_id);
1911    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1912    if !session
1913        .db()
1914        .await
1915        .has_contact(calling_user_id, called_user_id)
1916        .await?
1917    {
1918        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1919    }
1920
1921    let incoming_call = {
1922        let (room, incoming_call) = &mut *session
1923            .db()
1924            .await
1925            .call(
1926                room_id,
1927                calling_user_id,
1928                calling_connection_id,
1929                called_user_id,
1930                initial_project_id,
1931            )
1932            .await?;
1933        room_updated(&room, &session.peer);
1934        mem::take(incoming_call)
1935    };
1936    update_user_contacts(called_user_id, &session).await?;
1937
1938    let mut calls = session
1939        .connection_pool()
1940        .await
1941        .user_connection_ids(called_user_id)
1942        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1943        .collect::<FuturesUnordered<_>>();
1944
1945    while let Some(call_response) = calls.next().await {
1946        match call_response.as_ref() {
1947            Ok(_) => {
1948                response.send(proto::Ack {})?;
1949                return Ok(());
1950            }
1951            Err(_) => {
1952                call_response.trace_err();
1953            }
1954        }
1955    }
1956
1957    {
1958        let room = session
1959            .db()
1960            .await
1961            .call_failed(room_id, called_user_id)
1962            .await?;
1963        room_updated(&room, &session.peer);
1964    }
1965    update_user_contacts(called_user_id, &session).await?;
1966
1967    Err(anyhow!("failed to ring user"))?
1968}
1969
1970/// Cancel an outgoing call.
1971async fn cancel_call(
1972    request: proto::CancelCall,
1973    response: Response<proto::CancelCall>,
1974    session: UserSession,
1975) -> Result<()> {
1976    let called_user_id = UserId::from_proto(request.called_user_id);
1977    let room_id = RoomId::from_proto(request.room_id);
1978    {
1979        let room = session
1980            .db()
1981            .await
1982            .cancel_call(room_id, session.connection_id, called_user_id)
1983            .await?;
1984        room_updated(&room, &session.peer);
1985    }
1986
1987    for connection_id in session
1988        .connection_pool()
1989        .await
1990        .user_connection_ids(called_user_id)
1991    {
1992        session
1993            .peer
1994            .send(
1995                connection_id,
1996                proto::CallCanceled {
1997                    room_id: room_id.to_proto(),
1998                },
1999            )
2000            .trace_err();
2001    }
2002    response.send(proto::Ack {})?;
2003
2004    update_user_contacts(called_user_id, &session).await?;
2005    Ok(())
2006}
2007
2008/// Decline an incoming call.
2009async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
2010    let room_id = RoomId::from_proto(message.room_id);
2011    {
2012        let room = session
2013            .db()
2014            .await
2015            .decline_call(Some(room_id), session.user_id())
2016            .await?
2017            .ok_or_else(|| anyhow!("failed to decline call"))?;
2018        room_updated(&room, &session.peer);
2019    }
2020
2021    for connection_id in session
2022        .connection_pool()
2023        .await
2024        .user_connection_ids(session.user_id())
2025    {
2026        session
2027            .peer
2028            .send(
2029                connection_id,
2030                proto::CallCanceled {
2031                    room_id: room_id.to_proto(),
2032                },
2033            )
2034            .trace_err();
2035    }
2036    update_user_contacts(session.user_id(), &session).await?;
2037    Ok(())
2038}
2039
2040/// Updates other participants in the room with your current location.
2041async fn update_participant_location(
2042    request: proto::UpdateParticipantLocation,
2043    response: Response<proto::UpdateParticipantLocation>,
2044    session: UserSession,
2045) -> Result<()> {
2046    let room_id = RoomId::from_proto(request.room_id);
2047    let location = request
2048        .location
2049        .ok_or_else(|| anyhow!("invalid location"))?;
2050
2051    let db = session.db().await;
2052    let room = db
2053        .update_room_participant_location(room_id, session.connection_id, location)
2054        .await?;
2055
2056    room_updated(&room, &session.peer);
2057    response.send(proto::Ack {})?;
2058    Ok(())
2059}
2060
2061/// Share a project into the room.
2062async fn share_project(
2063    request: proto::ShareProject,
2064    response: Response<proto::ShareProject>,
2065    session: UserSession,
2066) -> Result<()> {
2067    let (project_id, room) = &*session
2068        .db()
2069        .await
2070        .share_project(
2071            RoomId::from_proto(request.room_id),
2072            session.connection_id,
2073            &request.worktrees,
2074            request
2075                .dev_server_project_id
2076                .map(|id| DevServerProjectId::from_proto(id)),
2077        )
2078        .await?;
2079    response.send(proto::ShareProjectResponse {
2080        project_id: project_id.to_proto(),
2081    })?;
2082    room_updated(&room, &session.peer);
2083
2084    Ok(())
2085}
2086
2087/// Unshare a project from the room.
2088async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2089    let project_id = ProjectId::from_proto(message.project_id);
2090    unshare_project_internal(
2091        project_id,
2092        session.connection_id,
2093        session.user_id(),
2094        &session,
2095    )
2096    .await
2097}
2098
2099async fn unshare_project_internal(
2100    project_id: ProjectId,
2101    connection_id: ConnectionId,
2102    user_id: Option<UserId>,
2103    session: &Session,
2104) -> Result<()> {
2105    let delete = {
2106        let room_guard = session
2107            .db()
2108            .await
2109            .unshare_project(project_id, connection_id, user_id)
2110            .await?;
2111
2112        let (delete, room, guest_connection_ids) = &*room_guard;
2113
2114        let message = proto::UnshareProject {
2115            project_id: project_id.to_proto(),
2116        };
2117
2118        broadcast(
2119            Some(connection_id),
2120            guest_connection_ids.iter().copied(),
2121            |conn_id| session.peer.send(conn_id, message.clone()),
2122        );
2123        if let Some(room) = room {
2124            room_updated(room, &session.peer);
2125        }
2126
2127        *delete
2128    };
2129
2130    if delete {
2131        let db = session.db().await;
2132        db.delete_project(project_id).await?;
2133    }
2134
2135    Ok(())
2136}
2137
2138/// DevServer makes a project available online
2139async fn share_dev_server_project(
2140    request: proto::ShareDevServerProject,
2141    response: Response<proto::ShareDevServerProject>,
2142    session: DevServerSession,
2143) -> Result<()> {
2144    let (dev_server_project, user_id, status) = session
2145        .db()
2146        .await
2147        .share_dev_server_project(
2148            DevServerProjectId::from_proto(request.dev_server_project_id),
2149            session.dev_server_id(),
2150            session.connection_id,
2151            &request.worktrees,
2152        )
2153        .await?;
2154    let Some(project_id) = dev_server_project.project_id else {
2155        return Err(anyhow!("failed to share remote project"))?;
2156    };
2157
2158    send_dev_server_projects_update(user_id, status, &session).await;
2159
2160    response.send(proto::ShareProjectResponse { project_id })?;
2161
2162    Ok(())
2163}
2164
2165/// Join someone elses shared project.
2166async fn join_project(
2167    request: proto::JoinProject,
2168    response: Response<proto::JoinProject>,
2169    session: UserSession,
2170) -> Result<()> {
2171    let project_id = ProjectId::from_proto(request.project_id);
2172
2173    tracing::info!(%project_id, "join project");
2174
2175    let db = session.db().await;
2176    let (project, replica_id) = &mut *db
2177        .join_project(project_id, session.connection_id, session.user_id())
2178        .await?;
2179    drop(db);
2180    tracing::info!(%project_id, "join remote project");
2181    join_project_internal(response, session, project, replica_id)
2182}
2183
2184trait JoinProjectInternalResponse {
2185    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2186}
2187impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2188    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2189        Response::<proto::JoinProject>::send(self, result)
2190    }
2191}
2192impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2193    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2194        Response::<proto::JoinHostedProject>::send(self, result)
2195    }
2196}
2197
2198fn join_project_internal(
2199    response: impl JoinProjectInternalResponse,
2200    session: UserSession,
2201    project: &mut Project,
2202    replica_id: &ReplicaId,
2203) -> Result<()> {
2204    let collaborators = project
2205        .collaborators
2206        .iter()
2207        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2208        .map(|collaborator| collaborator.to_proto())
2209        .collect::<Vec<_>>();
2210    let project_id = project.id;
2211    let guest_user_id = session.user_id();
2212
2213    let worktrees = project
2214        .worktrees
2215        .iter()
2216        .map(|(id, worktree)| proto::WorktreeMetadata {
2217            id: *id,
2218            root_name: worktree.root_name.clone(),
2219            visible: worktree.visible,
2220            abs_path: worktree.abs_path.clone(),
2221        })
2222        .collect::<Vec<_>>();
2223
2224    let add_project_collaborator = proto::AddProjectCollaborator {
2225        project_id: project_id.to_proto(),
2226        collaborator: Some(proto::Collaborator {
2227            peer_id: Some(session.connection_id.into()),
2228            replica_id: replica_id.0 as u32,
2229            user_id: guest_user_id.to_proto(),
2230        }),
2231    };
2232
2233    for collaborator in &collaborators {
2234        session
2235            .peer
2236            .send(
2237                collaborator.peer_id.unwrap().into(),
2238                add_project_collaborator.clone(),
2239            )
2240            .trace_err();
2241    }
2242
2243    // First, we send the metadata associated with each worktree.
2244    response.send(proto::JoinProjectResponse {
2245        project_id: project.id.0 as u64,
2246        worktrees: worktrees.clone(),
2247        replica_id: replica_id.0 as u32,
2248        collaborators: collaborators.clone(),
2249        language_servers: project.language_servers.clone(),
2250        role: project.role.into(),
2251        dev_server_project_id: project
2252            .dev_server_project_id
2253            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2254    })?;
2255
2256    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2257        #[cfg(any(test, feature = "test-support"))]
2258        const MAX_CHUNK_SIZE: usize = 2;
2259        #[cfg(not(any(test, feature = "test-support")))]
2260        const MAX_CHUNK_SIZE: usize = 256;
2261
2262        // Stream this worktree's entries.
2263        let message = proto::UpdateWorktree {
2264            project_id: project_id.to_proto(),
2265            worktree_id,
2266            abs_path: worktree.abs_path.clone(),
2267            root_name: worktree.root_name,
2268            updated_entries: worktree.entries,
2269            removed_entries: Default::default(),
2270            scan_id: worktree.scan_id,
2271            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2272            updated_repositories: worktree.repository_entries.into_values().collect(),
2273            removed_repositories: Default::default(),
2274        };
2275        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2276            session.peer.send(session.connection_id, update.clone())?;
2277        }
2278
2279        // Stream this worktree's diagnostics.
2280        for summary in worktree.diagnostic_summaries {
2281            session.peer.send(
2282                session.connection_id,
2283                proto::UpdateDiagnosticSummary {
2284                    project_id: project_id.to_proto(),
2285                    worktree_id: worktree.id,
2286                    summary: Some(summary),
2287                },
2288            )?;
2289        }
2290
2291        for settings_file in worktree.settings_files {
2292            session.peer.send(
2293                session.connection_id,
2294                proto::UpdateWorktreeSettings {
2295                    project_id: project_id.to_proto(),
2296                    worktree_id: worktree.id,
2297                    path: settings_file.path,
2298                    content: Some(settings_file.content),
2299                },
2300            )?;
2301        }
2302    }
2303
2304    for language_server in &project.language_servers {
2305        session.peer.send(
2306            session.connection_id,
2307            proto::UpdateLanguageServer {
2308                project_id: project_id.to_proto(),
2309                language_server_id: language_server.id,
2310                variant: Some(
2311                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2312                        proto::LspDiskBasedDiagnosticsUpdated {},
2313                    ),
2314                ),
2315            },
2316        )?;
2317    }
2318
2319    Ok(())
2320}
2321
2322/// Leave someone elses shared project.
2323async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2324    let sender_id = session.connection_id;
2325    let project_id = ProjectId::from_proto(request.project_id);
2326    let db = session.db().await;
2327    if db.is_hosted_project(project_id).await? {
2328        let project = db.leave_hosted_project(project_id, sender_id).await?;
2329        project_left(&project, &session);
2330        return Ok(());
2331    }
2332
2333    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2334    tracing::info!(
2335        %project_id,
2336        "leave project"
2337    );
2338
2339    project_left(&project, &session);
2340    if let Some(room) = room {
2341        room_updated(&room, &session.peer);
2342    }
2343
2344    Ok(())
2345}
2346
2347async fn join_hosted_project(
2348    request: proto::JoinHostedProject,
2349    response: Response<proto::JoinHostedProject>,
2350    session: UserSession,
2351) -> Result<()> {
2352    let (mut project, replica_id) = session
2353        .db()
2354        .await
2355        .join_hosted_project(
2356            ProjectId(request.project_id as i32),
2357            session.user_id(),
2358            session.connection_id,
2359        )
2360        .await?;
2361
2362    join_project_internal(response, session, &mut project, &replica_id)
2363}
2364
2365async fn list_remote_directory(
2366    request: proto::ListRemoteDirectory,
2367    response: Response<proto::ListRemoteDirectory>,
2368    session: UserSession,
2369) -> Result<()> {
2370    let dev_server_id = DevServerId(request.dev_server_id as i32);
2371    let dev_server_connection_id = session
2372        .connection_pool()
2373        .await
2374        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2375
2376    session
2377        .db()
2378        .await
2379        .get_dev_server_for_user(dev_server_id, session.user_id())
2380        .await?;
2381
2382    response.send(
2383        session
2384            .peer
2385            .forward_request(session.connection_id, dev_server_connection_id, request)
2386            .await?,
2387    )?;
2388    Ok(())
2389}
2390
2391async fn update_dev_server_project(
2392    request: proto::UpdateDevServerProject,
2393    response: Response<proto::UpdateDevServerProject>,
2394    session: UserSession,
2395) -> Result<()> {
2396    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2397
2398    let (dev_server_project, update) = session
2399        .db()
2400        .await
2401        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2402        .await?;
2403
2404    let projects = session
2405        .db()
2406        .await
2407        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2408        .await?;
2409
2410    let dev_server_connection_id = session
2411        .connection_pool()
2412        .await
2413        .dev_server_connection_id_supporting(
2414            dev_server_project.dev_server_id,
2415            ZedVersion::with_list_directory(),
2416        )?;
2417
2418    session.peer.send(
2419        dev_server_connection_id,
2420        proto::DevServerInstructions { projects },
2421    )?;
2422
2423    send_dev_server_projects_update(session.user_id(), update, &session).await;
2424
2425    response.send(proto::Ack {})
2426}
2427
2428async fn create_dev_server_project(
2429    request: proto::CreateDevServerProject,
2430    response: Response<proto::CreateDevServerProject>,
2431    session: UserSession,
2432) -> Result<()> {
2433    let dev_server_id = DevServerId(request.dev_server_id as i32);
2434    let dev_server_connection_id = session
2435        .connection_pool()
2436        .await
2437        .dev_server_connection_id(dev_server_id);
2438    let Some(dev_server_connection_id) = dev_server_connection_id else {
2439        Err(ErrorCode::DevServerOffline
2440            .message("Cannot create a remote project when the dev server is offline".to_string())
2441            .anyhow())?
2442    };
2443
2444    let path = request.path.clone();
2445    //Check that the path exists on the dev server
2446    session
2447        .peer
2448        .forward_request(
2449            session.connection_id,
2450            dev_server_connection_id,
2451            proto::ValidateDevServerProjectRequest { path: path.clone() },
2452        )
2453        .await?;
2454
2455    let (dev_server_project, update) = session
2456        .db()
2457        .await
2458        .create_dev_server_project(
2459            DevServerId(request.dev_server_id as i32),
2460            &request.path,
2461            session.user_id(),
2462        )
2463        .await?;
2464
2465    let projects = session
2466        .db()
2467        .await
2468        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2469        .await?;
2470
2471    session.peer.send(
2472        dev_server_connection_id,
2473        proto::DevServerInstructions { projects },
2474    )?;
2475
2476    send_dev_server_projects_update(session.user_id(), update, &session).await;
2477
2478    response.send(proto::CreateDevServerProjectResponse {
2479        dev_server_project: Some(dev_server_project.to_proto(None)),
2480    })?;
2481    Ok(())
2482}
2483
2484async fn create_dev_server(
2485    request: proto::CreateDevServer,
2486    response: Response<proto::CreateDevServer>,
2487    session: UserSession,
2488) -> Result<()> {
2489    let access_token = auth::random_token();
2490    let hashed_access_token = auth::hash_access_token(&access_token);
2491
2492    if request.name.is_empty() {
2493        return Err(proto::ErrorCode::Forbidden
2494            .message("Dev server name cannot be empty".to_string())
2495            .anyhow())?;
2496    }
2497
2498    let (dev_server, status) = session
2499        .db()
2500        .await
2501        .create_dev_server(
2502            &request.name,
2503            request.ssh_connection_string.as_deref(),
2504            &hashed_access_token,
2505            session.user_id(),
2506        )
2507        .await?;
2508
2509    send_dev_server_projects_update(session.user_id(), status, &session).await;
2510
2511    response.send(proto::CreateDevServerResponse {
2512        dev_server_id: dev_server.id.0 as u64,
2513        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2514        name: request.name,
2515    })?;
2516    Ok(())
2517}
2518
2519async fn regenerate_dev_server_token(
2520    request: proto::RegenerateDevServerToken,
2521    response: Response<proto::RegenerateDevServerToken>,
2522    session: UserSession,
2523) -> Result<()> {
2524    let dev_server_id = DevServerId(request.dev_server_id as i32);
2525    let access_token = auth::random_token();
2526    let hashed_access_token = auth::hash_access_token(&access_token);
2527
2528    let connection_id = session
2529        .connection_pool()
2530        .await
2531        .dev_server_connection_id(dev_server_id);
2532    if let Some(connection_id) = connection_id {
2533        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2534        session.peer.send(
2535            connection_id,
2536            proto::ShutdownDevServer {
2537                reason: Some("dev server token was regenerated".to_string()),
2538            },
2539        )?;
2540        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2541    }
2542
2543    let status = session
2544        .db()
2545        .await
2546        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2547        .await?;
2548
2549    send_dev_server_projects_update(session.user_id(), status, &session).await;
2550
2551    response.send(proto::RegenerateDevServerTokenResponse {
2552        dev_server_id: dev_server_id.to_proto(),
2553        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2554    })?;
2555    Ok(())
2556}
2557
2558async fn rename_dev_server(
2559    request: proto::RenameDevServer,
2560    response: Response<proto::RenameDevServer>,
2561    session: UserSession,
2562) -> Result<()> {
2563    if request.name.trim().is_empty() {
2564        return Err(proto::ErrorCode::Forbidden
2565            .message("Dev server name cannot be empty".to_string())
2566            .anyhow())?;
2567    }
2568
2569    let dev_server_id = DevServerId(request.dev_server_id as i32);
2570    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2571    if dev_server.user_id != session.user_id() {
2572        return Err(anyhow!(ErrorCode::Forbidden))?;
2573    }
2574
2575    let status = session
2576        .db()
2577        .await
2578        .rename_dev_server(
2579            dev_server_id,
2580            &request.name,
2581            request.ssh_connection_string.as_deref(),
2582            session.user_id(),
2583        )
2584        .await?;
2585
2586    send_dev_server_projects_update(session.user_id(), status, &session).await;
2587
2588    response.send(proto::Ack {})?;
2589    Ok(())
2590}
2591
2592async fn delete_dev_server(
2593    request: proto::DeleteDevServer,
2594    response: Response<proto::DeleteDevServer>,
2595    session: UserSession,
2596) -> Result<()> {
2597    let dev_server_id = DevServerId(request.dev_server_id as i32);
2598    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2599    if dev_server.user_id != session.user_id() {
2600        return Err(anyhow!(ErrorCode::Forbidden))?;
2601    }
2602
2603    let connection_id = session
2604        .connection_pool()
2605        .await
2606        .dev_server_connection_id(dev_server_id);
2607    if let Some(connection_id) = connection_id {
2608        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2609        session.peer.send(
2610            connection_id,
2611            proto::ShutdownDevServer {
2612                reason: Some("dev server was deleted".to_string()),
2613            },
2614        )?;
2615        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2616    }
2617
2618    let status = session
2619        .db()
2620        .await
2621        .delete_dev_server(dev_server_id, session.user_id())
2622        .await?;
2623
2624    send_dev_server_projects_update(session.user_id(), status, &session).await;
2625
2626    response.send(proto::Ack {})?;
2627    Ok(())
2628}
2629
2630async fn delete_dev_server_project(
2631    request: proto::DeleteDevServerProject,
2632    response: Response<proto::DeleteDevServerProject>,
2633    session: UserSession,
2634) -> Result<()> {
2635    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2636    let dev_server_project = session
2637        .db()
2638        .await
2639        .get_dev_server_project(dev_server_project_id)
2640        .await?;
2641
2642    let dev_server = session
2643        .db()
2644        .await
2645        .get_dev_server(dev_server_project.dev_server_id)
2646        .await?;
2647    if dev_server.user_id != session.user_id() {
2648        return Err(anyhow!(ErrorCode::Forbidden))?;
2649    }
2650
2651    let dev_server_connection_id = session
2652        .connection_pool()
2653        .await
2654        .dev_server_connection_id(dev_server.id);
2655
2656    if let Some(dev_server_connection_id) = dev_server_connection_id {
2657        let project = session
2658            .db()
2659            .await
2660            .find_dev_server_project(dev_server_project_id)
2661            .await;
2662        if let Ok(project) = project {
2663            unshare_project_internal(
2664                project.id,
2665                dev_server_connection_id,
2666                Some(session.user_id()),
2667                &session,
2668            )
2669            .await?;
2670        }
2671    }
2672
2673    let (projects, status) = session
2674        .db()
2675        .await
2676        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2677        .await?;
2678
2679    if let Some(dev_server_connection_id) = dev_server_connection_id {
2680        session.peer.send(
2681            dev_server_connection_id,
2682            proto::DevServerInstructions { projects },
2683        )?;
2684    }
2685
2686    send_dev_server_projects_update(session.user_id(), status, &session).await;
2687
2688    response.send(proto::Ack {})?;
2689    Ok(())
2690}
2691
2692async fn rejoin_dev_server_projects(
2693    request: proto::RejoinRemoteProjects,
2694    response: Response<proto::RejoinRemoteProjects>,
2695    session: UserSession,
2696) -> Result<()> {
2697    let mut rejoined_projects = {
2698        let db = session.db().await;
2699        db.rejoin_dev_server_projects(
2700            &request.rejoined_projects,
2701            session.user_id(),
2702            session.0.connection_id,
2703        )
2704        .await?
2705    };
2706    response.send(proto::RejoinRemoteProjectsResponse {
2707        rejoined_projects: rejoined_projects
2708            .iter()
2709            .map(|project| project.to_proto())
2710            .collect(),
2711    })?;
2712    notify_rejoined_projects(&mut rejoined_projects, &session)
2713}
2714
2715async fn reconnect_dev_server(
2716    request: proto::ReconnectDevServer,
2717    response: Response<proto::ReconnectDevServer>,
2718    session: DevServerSession,
2719) -> Result<()> {
2720    let reshared_projects = {
2721        let db = session.db().await;
2722        db.reshare_dev_server_projects(
2723            &request.reshared_projects,
2724            session.dev_server_id(),
2725            session.0.connection_id,
2726        )
2727        .await?
2728    };
2729
2730    for project in &reshared_projects {
2731        for collaborator in &project.collaborators {
2732            session
2733                .peer
2734                .send(
2735                    collaborator.connection_id,
2736                    proto::UpdateProjectCollaborator {
2737                        project_id: project.id.to_proto(),
2738                        old_peer_id: Some(project.old_connection_id.into()),
2739                        new_peer_id: Some(session.connection_id.into()),
2740                    },
2741                )
2742                .trace_err();
2743        }
2744
2745        broadcast(
2746            Some(session.connection_id),
2747            project
2748                .collaborators
2749                .iter()
2750                .map(|collaborator| collaborator.connection_id),
2751            |connection_id| {
2752                session.peer.forward_send(
2753                    session.connection_id,
2754                    connection_id,
2755                    proto::UpdateProject {
2756                        project_id: project.id.to_proto(),
2757                        worktrees: project.worktrees.clone(),
2758                    },
2759                )
2760            },
2761        );
2762    }
2763
2764    response.send(proto::ReconnectDevServerResponse {
2765        reshared_projects: reshared_projects
2766            .iter()
2767            .map(|project| proto::ResharedProject {
2768                id: project.id.to_proto(),
2769                collaborators: project
2770                    .collaborators
2771                    .iter()
2772                    .map(|collaborator| collaborator.to_proto())
2773                    .collect(),
2774            })
2775            .collect(),
2776    })?;
2777
2778    Ok(())
2779}
2780
2781async fn shutdown_dev_server(
2782    _: proto::ShutdownDevServer,
2783    response: Response<proto::ShutdownDevServer>,
2784    session: DevServerSession,
2785) -> Result<()> {
2786    response.send(proto::Ack {})?;
2787    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2788    remove_dev_server_connection(session.dev_server_id(), &session).await
2789}
2790
2791async fn shutdown_dev_server_internal(
2792    dev_server_id: DevServerId,
2793    connection_id: ConnectionId,
2794    session: &Session,
2795) -> Result<()> {
2796    let (dev_server_projects, dev_server) = {
2797        let db = session.db().await;
2798        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2799        let dev_server = db.get_dev_server(dev_server_id).await?;
2800        (dev_server_projects, dev_server)
2801    };
2802
2803    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2804        unshare_project_internal(
2805            ProjectId::from_proto(project_id),
2806            connection_id,
2807            None,
2808            session,
2809        )
2810        .await?;
2811    }
2812
2813    session
2814        .connection_pool()
2815        .await
2816        .set_dev_server_offline(dev_server_id);
2817
2818    let status = session
2819        .db()
2820        .await
2821        .dev_server_projects_update(dev_server.user_id)
2822        .await?;
2823    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2824
2825    Ok(())
2826}
2827
2828async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2829    let dev_server_connection = session
2830        .connection_pool()
2831        .await
2832        .dev_server_connection_id(dev_server_id);
2833
2834    if let Some(dev_server_connection) = dev_server_connection {
2835        session
2836            .connection_pool()
2837            .await
2838            .remove_connection(dev_server_connection)?;
2839    }
2840    Ok(())
2841}
2842
2843/// Updates other participants with changes to the project
2844async fn update_project(
2845    request: proto::UpdateProject,
2846    response: Response<proto::UpdateProject>,
2847    session: Session,
2848) -> Result<()> {
2849    let project_id = ProjectId::from_proto(request.project_id);
2850    let (room, guest_connection_ids) = &*session
2851        .db()
2852        .await
2853        .update_project(project_id, session.connection_id, &request.worktrees)
2854        .await?;
2855    broadcast(
2856        Some(session.connection_id),
2857        guest_connection_ids.iter().copied(),
2858        |connection_id| {
2859            session
2860                .peer
2861                .forward_send(session.connection_id, connection_id, request.clone())
2862        },
2863    );
2864    if let Some(room) = room {
2865        room_updated(&room, &session.peer);
2866    }
2867    response.send(proto::Ack {})?;
2868
2869    Ok(())
2870}
2871
2872/// Updates other participants with changes to the worktree
2873async fn update_worktree(
2874    request: proto::UpdateWorktree,
2875    response: Response<proto::UpdateWorktree>,
2876    session: Session,
2877) -> Result<()> {
2878    let guest_connection_ids = session
2879        .db()
2880        .await
2881        .update_worktree(&request, session.connection_id)
2882        .await?;
2883
2884    broadcast(
2885        Some(session.connection_id),
2886        guest_connection_ids.iter().copied(),
2887        |connection_id| {
2888            session
2889                .peer
2890                .forward_send(session.connection_id, connection_id, request.clone())
2891        },
2892    );
2893    response.send(proto::Ack {})?;
2894    Ok(())
2895}
2896
2897/// Updates other participants with changes to the diagnostics
2898async fn update_diagnostic_summary(
2899    message: proto::UpdateDiagnosticSummary,
2900    session: Session,
2901) -> Result<()> {
2902    let guest_connection_ids = session
2903        .db()
2904        .await
2905        .update_diagnostic_summary(&message, session.connection_id)
2906        .await?;
2907
2908    broadcast(
2909        Some(session.connection_id),
2910        guest_connection_ids.iter().copied(),
2911        |connection_id| {
2912            session
2913                .peer
2914                .forward_send(session.connection_id, connection_id, message.clone())
2915        },
2916    );
2917
2918    Ok(())
2919}
2920
2921/// Updates other participants with changes to the worktree settings
2922async fn update_worktree_settings(
2923    message: proto::UpdateWorktreeSettings,
2924    session: Session,
2925) -> Result<()> {
2926    let guest_connection_ids = session
2927        .db()
2928        .await
2929        .update_worktree_settings(&message, session.connection_id)
2930        .await?;
2931
2932    broadcast(
2933        Some(session.connection_id),
2934        guest_connection_ids.iter().copied(),
2935        |connection_id| {
2936            session
2937                .peer
2938                .forward_send(session.connection_id, connection_id, message.clone())
2939        },
2940    );
2941
2942    Ok(())
2943}
2944
2945/// Notify other participants that a  language server has started.
2946async fn start_language_server(
2947    request: proto::StartLanguageServer,
2948    session: Session,
2949) -> Result<()> {
2950    let guest_connection_ids = session
2951        .db()
2952        .await
2953        .start_language_server(&request, session.connection_id)
2954        .await?;
2955
2956    broadcast(
2957        Some(session.connection_id),
2958        guest_connection_ids.iter().copied(),
2959        |connection_id| {
2960            session
2961                .peer
2962                .forward_send(session.connection_id, connection_id, request.clone())
2963        },
2964    );
2965    Ok(())
2966}
2967
2968/// Notify other participants that a language server has changed.
2969async fn update_language_server(
2970    request: proto::UpdateLanguageServer,
2971    session: Session,
2972) -> Result<()> {
2973    let project_id = ProjectId::from_proto(request.project_id);
2974    let project_connection_ids = session
2975        .db()
2976        .await
2977        .project_connection_ids(project_id, session.connection_id, true)
2978        .await?;
2979    broadcast(
2980        Some(session.connection_id),
2981        project_connection_ids.iter().copied(),
2982        |connection_id| {
2983            session
2984                .peer
2985                .forward_send(session.connection_id, connection_id, request.clone())
2986        },
2987    );
2988    Ok(())
2989}
2990
2991/// forward a project request to the host. These requests should be read only
2992/// as guests are allowed to send them.
2993async fn forward_read_only_project_request<T>(
2994    request: T,
2995    response: Response<T>,
2996    session: UserSession,
2997) -> Result<()>
2998where
2999    T: EntityMessage + RequestMessage,
3000{
3001    let project_id = ProjectId::from_proto(request.remote_entity_id());
3002    let host_connection_id = session
3003        .db()
3004        .await
3005        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
3006        .await?;
3007    let payload = session
3008        .peer
3009        .forward_request(session.connection_id, host_connection_id, request)
3010        .await?;
3011    response.send(payload)?;
3012    Ok(())
3013}
3014
3015/// forward a project request to the dev server. Only allowed
3016/// if it's your dev server.
3017async fn forward_project_request_for_owner<T>(
3018    request: T,
3019    response: Response<T>,
3020    session: UserSession,
3021) -> Result<()>
3022where
3023    T: EntityMessage + RequestMessage,
3024{
3025    let project_id = ProjectId::from_proto(request.remote_entity_id());
3026
3027    let host_connection_id = session
3028        .db()
3029        .await
3030        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3031        .await?;
3032    let payload = session
3033        .peer
3034        .forward_request(session.connection_id, host_connection_id, request)
3035        .await?;
3036    response.send(payload)?;
3037    Ok(())
3038}
3039
3040/// forward a project request to the host. These requests are disallowed
3041/// for guests.
3042async fn forward_mutating_project_request<T>(
3043    request: T,
3044    response: Response<T>,
3045    session: UserSession,
3046) -> Result<()>
3047where
3048    T: EntityMessage + RequestMessage,
3049{
3050    let project_id = ProjectId::from_proto(request.remote_entity_id());
3051
3052    let host_connection_id = session
3053        .db()
3054        .await
3055        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3056        .await?;
3057    let payload = session
3058        .peer
3059        .forward_request(session.connection_id, host_connection_id, request)
3060        .await?;
3061    response.send(payload)?;
3062    Ok(())
3063}
3064
3065/// forward a project request to the host. These requests are disallowed
3066/// for guests.
3067async fn forward_versioned_mutating_project_request<T>(
3068    request: T,
3069    response: Response<T>,
3070    session: UserSession,
3071) -> Result<()>
3072where
3073    T: EntityMessage + RequestMessage + VersionedMessage,
3074{
3075    let project_id = ProjectId::from_proto(request.remote_entity_id());
3076
3077    let host_connection_id = session
3078        .db()
3079        .await
3080        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3081        .await?;
3082    if let Some(host_version) = session
3083        .connection_pool()
3084        .await
3085        .connection(host_connection_id)
3086        .map(|c| c.zed_version)
3087    {
3088        if let Some(min_required_version) = request.required_host_version() {
3089            if min_required_version > host_version {
3090                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3091                    .with_tag("required", &min_required_version.to_string())))?;
3092            }
3093        }
3094    }
3095
3096    let payload = session
3097        .peer
3098        .forward_request(session.connection_id, host_connection_id, request)
3099        .await?;
3100    response.send(payload)?;
3101    Ok(())
3102}
3103
3104/// Notify other participants that a new buffer has been created
3105async fn create_buffer_for_peer(
3106    request: proto::CreateBufferForPeer,
3107    session: Session,
3108) -> Result<()> {
3109    session
3110        .db()
3111        .await
3112        .check_user_is_project_host(
3113            ProjectId::from_proto(request.project_id),
3114            session.connection_id,
3115        )
3116        .await?;
3117    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3118    session
3119        .peer
3120        .forward_send(session.connection_id, peer_id.into(), request)?;
3121    Ok(())
3122}
3123
3124/// Notify other participants that a buffer has been updated. This is
3125/// allowed for guests as long as the update is limited to selections.
3126async fn update_buffer(
3127    request: proto::UpdateBuffer,
3128    response: Response<proto::UpdateBuffer>,
3129    session: Session,
3130) -> Result<()> {
3131    let project_id = ProjectId::from_proto(request.project_id);
3132    let mut capability = Capability::ReadOnly;
3133
3134    for op in request.operations.iter() {
3135        match op.variant {
3136            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3137            Some(_) => capability = Capability::ReadWrite,
3138        }
3139    }
3140
3141    let host = {
3142        let guard = session
3143            .db()
3144            .await
3145            .connections_for_buffer_update(
3146                project_id,
3147                session.principal_id(),
3148                session.connection_id,
3149                capability,
3150            )
3151            .await?;
3152
3153        let (host, guests) = &*guard;
3154
3155        broadcast(
3156            Some(session.connection_id),
3157            guests.clone(),
3158            |connection_id| {
3159                session
3160                    .peer
3161                    .forward_send(session.connection_id, connection_id, request.clone())
3162            },
3163        );
3164
3165        *host
3166    };
3167
3168    if host != session.connection_id {
3169        session
3170            .peer
3171            .forward_request(session.connection_id, host, request.clone())
3172            .await?;
3173    }
3174
3175    response.send(proto::Ack {})?;
3176    Ok(())
3177}
3178
3179async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3180    let project_id = ProjectId::from_proto(message.project_id);
3181
3182    let operation = message.operation.as_ref().context("invalid operation")?;
3183    let capability = match operation.variant.as_ref() {
3184        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3185            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3186                match buffer_op.variant {
3187                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3188                        Capability::ReadOnly
3189                    }
3190                    _ => Capability::ReadWrite,
3191                }
3192            } else {
3193                Capability::ReadWrite
3194            }
3195        }
3196        Some(_) => Capability::ReadWrite,
3197        None => Capability::ReadOnly,
3198    };
3199
3200    let guard = session
3201        .db()
3202        .await
3203        .connections_for_buffer_update(
3204            project_id,
3205            session.principal_id(),
3206            session.connection_id,
3207            capability,
3208        )
3209        .await?;
3210
3211    let (host, guests) = &*guard;
3212
3213    broadcast(
3214        Some(session.connection_id),
3215        guests.iter().chain([host]).copied(),
3216        |connection_id| {
3217            session
3218                .peer
3219                .forward_send(session.connection_id, connection_id, message.clone())
3220        },
3221    );
3222
3223    Ok(())
3224}
3225
3226/// Notify other participants that a project has been updated.
3227async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3228    request: T,
3229    session: Session,
3230) -> Result<()> {
3231    let project_id = ProjectId::from_proto(request.remote_entity_id());
3232    let project_connection_ids = session
3233        .db()
3234        .await
3235        .project_connection_ids(project_id, session.connection_id, false)
3236        .await?;
3237
3238    broadcast(
3239        Some(session.connection_id),
3240        project_connection_ids.iter().copied(),
3241        |connection_id| {
3242            session
3243                .peer
3244                .forward_send(session.connection_id, connection_id, request.clone())
3245        },
3246    );
3247    Ok(())
3248}
3249
3250/// Start following another user in a call.
3251async fn follow(
3252    request: proto::Follow,
3253    response: Response<proto::Follow>,
3254    session: UserSession,
3255) -> Result<()> {
3256    let room_id = RoomId::from_proto(request.room_id);
3257    let project_id = request.project_id.map(ProjectId::from_proto);
3258    let leader_id = request
3259        .leader_id
3260        .ok_or_else(|| anyhow!("invalid leader id"))?
3261        .into();
3262    let follower_id = session.connection_id;
3263
3264    session
3265        .db()
3266        .await
3267        .check_room_participants(room_id, leader_id, session.connection_id)
3268        .await?;
3269
3270    let response_payload = session
3271        .peer
3272        .forward_request(session.connection_id, leader_id, request)
3273        .await?;
3274    response.send(response_payload)?;
3275
3276    if let Some(project_id) = project_id {
3277        let room = session
3278            .db()
3279            .await
3280            .follow(room_id, project_id, leader_id, follower_id)
3281            .await?;
3282        room_updated(&room, &session.peer);
3283    }
3284
3285    Ok(())
3286}
3287
3288/// Stop following another user in a call.
3289async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3290    let room_id = RoomId::from_proto(request.room_id);
3291    let project_id = request.project_id.map(ProjectId::from_proto);
3292    let leader_id = request
3293        .leader_id
3294        .ok_or_else(|| anyhow!("invalid leader id"))?
3295        .into();
3296    let follower_id = session.connection_id;
3297
3298    session
3299        .db()
3300        .await
3301        .check_room_participants(room_id, leader_id, session.connection_id)
3302        .await?;
3303
3304    session
3305        .peer
3306        .forward_send(session.connection_id, leader_id, request)?;
3307
3308    if let Some(project_id) = project_id {
3309        let room = session
3310            .db()
3311            .await
3312            .unfollow(room_id, project_id, leader_id, follower_id)
3313            .await?;
3314        room_updated(&room, &session.peer);
3315    }
3316
3317    Ok(())
3318}
3319
3320/// Notify everyone following you of your current location.
3321async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3322    let room_id = RoomId::from_proto(request.room_id);
3323    let database = session.db.lock().await;
3324
3325    let connection_ids = if let Some(project_id) = request.project_id {
3326        let project_id = ProjectId::from_proto(project_id);
3327        database
3328            .project_connection_ids(project_id, session.connection_id, true)
3329            .await?
3330    } else {
3331        database
3332            .room_connection_ids(room_id, session.connection_id)
3333            .await?
3334    };
3335
3336    // For now, don't send view update messages back to that view's current leader.
3337    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3338        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3339        _ => None,
3340    });
3341
3342    for connection_id in connection_ids.iter().cloned() {
3343        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3344            session
3345                .peer
3346                .forward_send(session.connection_id, connection_id, request.clone())?;
3347        }
3348    }
3349    Ok(())
3350}
3351
3352/// Get public data about users.
3353async fn get_users(
3354    request: proto::GetUsers,
3355    response: Response<proto::GetUsers>,
3356    session: Session,
3357) -> Result<()> {
3358    let user_ids = request
3359        .user_ids
3360        .into_iter()
3361        .map(UserId::from_proto)
3362        .collect();
3363    let users = session
3364        .db()
3365        .await
3366        .get_users_by_ids(user_ids)
3367        .await?
3368        .into_iter()
3369        .map(|user| proto::User {
3370            id: user.id.to_proto(),
3371            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3372            github_login: user.github_login,
3373        })
3374        .collect();
3375    response.send(proto::UsersResponse { users })?;
3376    Ok(())
3377}
3378
3379/// Search for users (to invite) buy Github login
3380async fn fuzzy_search_users(
3381    request: proto::FuzzySearchUsers,
3382    response: Response<proto::FuzzySearchUsers>,
3383    session: UserSession,
3384) -> Result<()> {
3385    let query = request.query;
3386    let users = match query.len() {
3387        0 => vec![],
3388        1 | 2 => session
3389            .db()
3390            .await
3391            .get_user_by_github_login(&query)
3392            .await?
3393            .into_iter()
3394            .collect(),
3395        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3396    };
3397    let users = users
3398        .into_iter()
3399        .filter(|user| user.id != session.user_id())
3400        .map(|user| proto::User {
3401            id: user.id.to_proto(),
3402            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3403            github_login: user.github_login,
3404        })
3405        .collect();
3406    response.send(proto::UsersResponse { users })?;
3407    Ok(())
3408}
3409
3410/// Send a contact request to another user.
3411async fn request_contact(
3412    request: proto::RequestContact,
3413    response: Response<proto::RequestContact>,
3414    session: UserSession,
3415) -> Result<()> {
3416    let requester_id = session.user_id();
3417    let responder_id = UserId::from_proto(request.responder_id);
3418    if requester_id == responder_id {
3419        return Err(anyhow!("cannot add yourself as a contact"))?;
3420    }
3421
3422    let notifications = session
3423        .db()
3424        .await
3425        .send_contact_request(requester_id, responder_id)
3426        .await?;
3427
3428    // Update outgoing contact requests of requester
3429    let mut update = proto::UpdateContacts::default();
3430    update.outgoing_requests.push(responder_id.to_proto());
3431    for connection_id in session
3432        .connection_pool()
3433        .await
3434        .user_connection_ids(requester_id)
3435    {
3436        session.peer.send(connection_id, update.clone())?;
3437    }
3438
3439    // Update incoming contact requests of responder
3440    let mut update = proto::UpdateContacts::default();
3441    update
3442        .incoming_requests
3443        .push(proto::IncomingContactRequest {
3444            requester_id: requester_id.to_proto(),
3445        });
3446    let connection_pool = session.connection_pool().await;
3447    for connection_id in connection_pool.user_connection_ids(responder_id) {
3448        session.peer.send(connection_id, update.clone())?;
3449    }
3450
3451    send_notifications(&connection_pool, &session.peer, notifications);
3452
3453    response.send(proto::Ack {})?;
3454    Ok(())
3455}
3456
3457/// Accept or decline a contact request
3458async fn respond_to_contact_request(
3459    request: proto::RespondToContactRequest,
3460    response: Response<proto::RespondToContactRequest>,
3461    session: UserSession,
3462) -> Result<()> {
3463    let responder_id = session.user_id();
3464    let requester_id = UserId::from_proto(request.requester_id);
3465    let db = session.db().await;
3466    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3467        db.dismiss_contact_notification(responder_id, requester_id)
3468            .await?;
3469    } else {
3470        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3471
3472        let notifications = db
3473            .respond_to_contact_request(responder_id, requester_id, accept)
3474            .await?;
3475        let requester_busy = db.is_user_busy(requester_id).await?;
3476        let responder_busy = db.is_user_busy(responder_id).await?;
3477
3478        let pool = session.connection_pool().await;
3479        // Update responder with new contact
3480        let mut update = proto::UpdateContacts::default();
3481        if accept {
3482            update
3483                .contacts
3484                .push(contact_for_user(requester_id, requester_busy, &pool));
3485        }
3486        update
3487            .remove_incoming_requests
3488            .push(requester_id.to_proto());
3489        for connection_id in pool.user_connection_ids(responder_id) {
3490            session.peer.send(connection_id, update.clone())?;
3491        }
3492
3493        // Update requester with new contact
3494        let mut update = proto::UpdateContacts::default();
3495        if accept {
3496            update
3497                .contacts
3498                .push(contact_for_user(responder_id, responder_busy, &pool));
3499        }
3500        update
3501            .remove_outgoing_requests
3502            .push(responder_id.to_proto());
3503
3504        for connection_id in pool.user_connection_ids(requester_id) {
3505            session.peer.send(connection_id, update.clone())?;
3506        }
3507
3508        send_notifications(&pool, &session.peer, notifications);
3509    }
3510
3511    response.send(proto::Ack {})?;
3512    Ok(())
3513}
3514
3515/// Remove a contact.
3516async fn remove_contact(
3517    request: proto::RemoveContact,
3518    response: Response<proto::RemoveContact>,
3519    session: UserSession,
3520) -> Result<()> {
3521    let requester_id = session.user_id();
3522    let responder_id = UserId::from_proto(request.user_id);
3523    let db = session.db().await;
3524    let (contact_accepted, deleted_notification_id) =
3525        db.remove_contact(requester_id, responder_id).await?;
3526
3527    let pool = session.connection_pool().await;
3528    // Update outgoing contact requests of requester
3529    let mut update = proto::UpdateContacts::default();
3530    if contact_accepted {
3531        update.remove_contacts.push(responder_id.to_proto());
3532    } else {
3533        update
3534            .remove_outgoing_requests
3535            .push(responder_id.to_proto());
3536    }
3537    for connection_id in pool.user_connection_ids(requester_id) {
3538        session.peer.send(connection_id, update.clone())?;
3539    }
3540
3541    // Update incoming contact requests of responder
3542    let mut update = proto::UpdateContacts::default();
3543    if contact_accepted {
3544        update.remove_contacts.push(requester_id.to_proto());
3545    } else {
3546        update
3547            .remove_incoming_requests
3548            .push(requester_id.to_proto());
3549    }
3550    for connection_id in pool.user_connection_ids(responder_id) {
3551        session.peer.send(connection_id, update.clone())?;
3552        if let Some(notification_id) = deleted_notification_id {
3553            session.peer.send(
3554                connection_id,
3555                proto::DeleteNotification {
3556                    notification_id: notification_id.to_proto(),
3557                },
3558            )?;
3559        }
3560    }
3561
3562    response.send(proto::Ack {})?;
3563    Ok(())
3564}
3565
3566fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3567    version.0.minor() < 139
3568}
3569
3570async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3571    let plan = session.current_plan().await?;
3572
3573    session
3574        .peer
3575        .send(
3576            session.connection_id,
3577            proto::UpdateUserPlan { plan: plan.into() },
3578        )
3579        .trace_err();
3580
3581    Ok(())
3582}
3583
3584async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3585    subscribe_user_to_channels(
3586        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3587        &session,
3588    )
3589    .await?;
3590    Ok(())
3591}
3592
3593async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3594    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3595    let mut pool = session.connection_pool().await;
3596    for membership in &channels_for_user.channel_memberships {
3597        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3598    }
3599    session.peer.send(
3600        session.connection_id,
3601        build_update_user_channels(&channels_for_user),
3602    )?;
3603    session.peer.send(
3604        session.connection_id,
3605        build_channels_update(channels_for_user),
3606    )?;
3607    Ok(())
3608}
3609
3610/// Creates a new channel.
3611async fn create_channel(
3612    request: proto::CreateChannel,
3613    response: Response<proto::CreateChannel>,
3614    session: UserSession,
3615) -> Result<()> {
3616    let db = session.db().await;
3617
3618    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3619    let (channel, membership) = db
3620        .create_channel(&request.name, parent_id, session.user_id())
3621        .await?;
3622
3623    let root_id = channel.root_id();
3624    let channel = Channel::from_model(channel);
3625
3626    response.send(proto::CreateChannelResponse {
3627        channel: Some(channel.to_proto()),
3628        parent_id: request.parent_id,
3629    })?;
3630
3631    let mut connection_pool = session.connection_pool().await;
3632    if let Some(membership) = membership {
3633        connection_pool.subscribe_to_channel(
3634            membership.user_id,
3635            membership.channel_id,
3636            membership.role,
3637        );
3638        let update = proto::UpdateUserChannels {
3639            channel_memberships: vec![proto::ChannelMembership {
3640                channel_id: membership.channel_id.to_proto(),
3641                role: membership.role.into(),
3642            }],
3643            ..Default::default()
3644        };
3645        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3646            session.peer.send(connection_id, update.clone())?;
3647        }
3648    }
3649
3650    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3651        if !role.can_see_channel(channel.visibility) {
3652            continue;
3653        }
3654
3655        let update = proto::UpdateChannels {
3656            channels: vec![channel.to_proto()],
3657            ..Default::default()
3658        };
3659        session.peer.send(connection_id, update.clone())?;
3660    }
3661
3662    Ok(())
3663}
3664
3665/// Delete a channel
3666async fn delete_channel(
3667    request: proto::DeleteChannel,
3668    response: Response<proto::DeleteChannel>,
3669    session: UserSession,
3670) -> Result<()> {
3671    let db = session.db().await;
3672
3673    let channel_id = request.channel_id;
3674    let (root_channel, removed_channels) = db
3675        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3676        .await?;
3677    response.send(proto::Ack {})?;
3678
3679    // Notify members of removed channels
3680    let mut update = proto::UpdateChannels::default();
3681    update
3682        .delete_channels
3683        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3684
3685    let connection_pool = session.connection_pool().await;
3686    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3687        session.peer.send(connection_id, update.clone())?;
3688    }
3689
3690    Ok(())
3691}
3692
3693/// Invite someone to join a channel.
3694async fn invite_channel_member(
3695    request: proto::InviteChannelMember,
3696    response: Response<proto::InviteChannelMember>,
3697    session: UserSession,
3698) -> Result<()> {
3699    let db = session.db().await;
3700    let channel_id = ChannelId::from_proto(request.channel_id);
3701    let invitee_id = UserId::from_proto(request.user_id);
3702    let InviteMemberResult {
3703        channel,
3704        notifications,
3705    } = db
3706        .invite_channel_member(
3707            channel_id,
3708            invitee_id,
3709            session.user_id(),
3710            request.role().into(),
3711        )
3712        .await?;
3713
3714    let update = proto::UpdateChannels {
3715        channel_invitations: vec![channel.to_proto()],
3716        ..Default::default()
3717    };
3718
3719    let connection_pool = session.connection_pool().await;
3720    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3721        session.peer.send(connection_id, update.clone())?;
3722    }
3723
3724    send_notifications(&connection_pool, &session.peer, notifications);
3725
3726    response.send(proto::Ack {})?;
3727    Ok(())
3728}
3729
3730/// remove someone from a channel
3731async fn remove_channel_member(
3732    request: proto::RemoveChannelMember,
3733    response: Response<proto::RemoveChannelMember>,
3734    session: UserSession,
3735) -> Result<()> {
3736    let db = session.db().await;
3737    let channel_id = ChannelId::from_proto(request.channel_id);
3738    let member_id = UserId::from_proto(request.user_id);
3739
3740    let RemoveChannelMemberResult {
3741        membership_update,
3742        notification_id,
3743    } = db
3744        .remove_channel_member(channel_id, member_id, session.user_id())
3745        .await?;
3746
3747    let mut connection_pool = session.connection_pool().await;
3748    notify_membership_updated(
3749        &mut connection_pool,
3750        membership_update,
3751        member_id,
3752        &session.peer,
3753    );
3754    for connection_id in connection_pool.user_connection_ids(member_id) {
3755        if let Some(notification_id) = notification_id {
3756            session
3757                .peer
3758                .send(
3759                    connection_id,
3760                    proto::DeleteNotification {
3761                        notification_id: notification_id.to_proto(),
3762                    },
3763                )
3764                .trace_err();
3765        }
3766    }
3767
3768    response.send(proto::Ack {})?;
3769    Ok(())
3770}
3771
3772/// Toggle the channel between public and private.
3773/// Care is taken to maintain the invariant that public channels only descend from public channels,
3774/// (though members-only channels can appear at any point in the hierarchy).
3775async fn set_channel_visibility(
3776    request: proto::SetChannelVisibility,
3777    response: Response<proto::SetChannelVisibility>,
3778    session: UserSession,
3779) -> Result<()> {
3780    let db = session.db().await;
3781    let channel_id = ChannelId::from_proto(request.channel_id);
3782    let visibility = request.visibility().into();
3783
3784    let channel_model = db
3785        .set_channel_visibility(channel_id, visibility, session.user_id())
3786        .await?;
3787    let root_id = channel_model.root_id();
3788    let channel = Channel::from_model(channel_model);
3789
3790    let mut connection_pool = session.connection_pool().await;
3791    for (user_id, role) in connection_pool
3792        .channel_user_ids(root_id)
3793        .collect::<Vec<_>>()
3794        .into_iter()
3795    {
3796        let update = if role.can_see_channel(channel.visibility) {
3797            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3798            proto::UpdateChannels {
3799                channels: vec![channel.to_proto()],
3800                ..Default::default()
3801            }
3802        } else {
3803            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3804            proto::UpdateChannels {
3805                delete_channels: vec![channel.id.to_proto()],
3806                ..Default::default()
3807            }
3808        };
3809
3810        for connection_id in connection_pool.user_connection_ids(user_id) {
3811            session.peer.send(connection_id, update.clone())?;
3812        }
3813    }
3814
3815    response.send(proto::Ack {})?;
3816    Ok(())
3817}
3818
3819/// Alter the role for a user in the channel.
3820async fn set_channel_member_role(
3821    request: proto::SetChannelMemberRole,
3822    response: Response<proto::SetChannelMemberRole>,
3823    session: UserSession,
3824) -> Result<()> {
3825    let db = session.db().await;
3826    let channel_id = ChannelId::from_proto(request.channel_id);
3827    let member_id = UserId::from_proto(request.user_id);
3828    let result = db
3829        .set_channel_member_role(
3830            channel_id,
3831            session.user_id(),
3832            member_id,
3833            request.role().into(),
3834        )
3835        .await?;
3836
3837    match result {
3838        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3839            let mut connection_pool = session.connection_pool().await;
3840            notify_membership_updated(
3841                &mut connection_pool,
3842                membership_update,
3843                member_id,
3844                &session.peer,
3845            )
3846        }
3847        db::SetMemberRoleResult::InviteUpdated(channel) => {
3848            let update = proto::UpdateChannels {
3849                channel_invitations: vec![channel.to_proto()],
3850                ..Default::default()
3851            };
3852
3853            for connection_id in session
3854                .connection_pool()
3855                .await
3856                .user_connection_ids(member_id)
3857            {
3858                session.peer.send(connection_id, update.clone())?;
3859            }
3860        }
3861    }
3862
3863    response.send(proto::Ack {})?;
3864    Ok(())
3865}
3866
3867/// Change the name of a channel
3868async fn rename_channel(
3869    request: proto::RenameChannel,
3870    response: Response<proto::RenameChannel>,
3871    session: UserSession,
3872) -> Result<()> {
3873    let db = session.db().await;
3874    let channel_id = ChannelId::from_proto(request.channel_id);
3875    let channel_model = db
3876        .rename_channel(channel_id, session.user_id(), &request.name)
3877        .await?;
3878    let root_id = channel_model.root_id();
3879    let channel = Channel::from_model(channel_model);
3880
3881    response.send(proto::RenameChannelResponse {
3882        channel: Some(channel.to_proto()),
3883    })?;
3884
3885    let connection_pool = session.connection_pool().await;
3886    let update = proto::UpdateChannels {
3887        channels: vec![channel.to_proto()],
3888        ..Default::default()
3889    };
3890    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3891        if role.can_see_channel(channel.visibility) {
3892            session.peer.send(connection_id, update.clone())?;
3893        }
3894    }
3895
3896    Ok(())
3897}
3898
3899/// Move a channel to a new parent.
3900async fn move_channel(
3901    request: proto::MoveChannel,
3902    response: Response<proto::MoveChannel>,
3903    session: UserSession,
3904) -> Result<()> {
3905    let channel_id = ChannelId::from_proto(request.channel_id);
3906    let to = ChannelId::from_proto(request.to);
3907
3908    let (root_id, channels) = session
3909        .db()
3910        .await
3911        .move_channel(channel_id, to, session.user_id())
3912        .await?;
3913
3914    let connection_pool = session.connection_pool().await;
3915    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3916        let channels = channels
3917            .iter()
3918            .filter_map(|channel| {
3919                if role.can_see_channel(channel.visibility) {
3920                    Some(channel.to_proto())
3921                } else {
3922                    None
3923                }
3924            })
3925            .collect::<Vec<_>>();
3926        if channels.is_empty() {
3927            continue;
3928        }
3929
3930        let update = proto::UpdateChannels {
3931            channels,
3932            ..Default::default()
3933        };
3934
3935        session.peer.send(connection_id, update.clone())?;
3936    }
3937
3938    response.send(Ack {})?;
3939    Ok(())
3940}
3941
3942/// Get the list of channel members
3943async fn get_channel_members(
3944    request: proto::GetChannelMembers,
3945    response: Response<proto::GetChannelMembers>,
3946    session: UserSession,
3947) -> Result<()> {
3948    let db = session.db().await;
3949    let channel_id = ChannelId::from_proto(request.channel_id);
3950    let limit = if request.limit == 0 {
3951        u16::MAX as u64
3952    } else {
3953        request.limit
3954    };
3955    let (members, users) = db
3956        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3957        .await?;
3958    response.send(proto::GetChannelMembersResponse { members, users })?;
3959    Ok(())
3960}
3961
3962/// Accept or decline a channel invitation.
3963async fn respond_to_channel_invite(
3964    request: proto::RespondToChannelInvite,
3965    response: Response<proto::RespondToChannelInvite>,
3966    session: UserSession,
3967) -> Result<()> {
3968    let db = session.db().await;
3969    let channel_id = ChannelId::from_proto(request.channel_id);
3970    let RespondToChannelInvite {
3971        membership_update,
3972        notifications,
3973    } = db
3974        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3975        .await?;
3976
3977    let mut connection_pool = session.connection_pool().await;
3978    if let Some(membership_update) = membership_update {
3979        notify_membership_updated(
3980            &mut connection_pool,
3981            membership_update,
3982            session.user_id(),
3983            &session.peer,
3984        );
3985    } else {
3986        let update = proto::UpdateChannels {
3987            remove_channel_invitations: vec![channel_id.to_proto()],
3988            ..Default::default()
3989        };
3990
3991        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3992            session.peer.send(connection_id, update.clone())?;
3993        }
3994    };
3995
3996    send_notifications(&connection_pool, &session.peer, notifications);
3997
3998    response.send(proto::Ack {})?;
3999
4000    Ok(())
4001}
4002
4003/// Join the channels' room
4004async fn join_channel(
4005    request: proto::JoinChannel,
4006    response: Response<proto::JoinChannel>,
4007    session: UserSession,
4008) -> Result<()> {
4009    let channel_id = ChannelId::from_proto(request.channel_id);
4010    join_channel_internal(channel_id, Box::new(response), session).await
4011}
4012
4013trait JoinChannelInternalResponse {
4014    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
4015}
4016impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
4017    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4018        Response::<proto::JoinChannel>::send(self, result)
4019    }
4020}
4021impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
4022    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4023        Response::<proto::JoinRoom>::send(self, result)
4024    }
4025}
4026
4027async fn join_channel_internal(
4028    channel_id: ChannelId,
4029    response: Box<impl JoinChannelInternalResponse>,
4030    session: UserSession,
4031) -> Result<()> {
4032    let joined_room = {
4033        let mut db = session.db().await;
4034        // If zed quits without leaving the room, and the user re-opens zed before the
4035        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4036        // room they were in.
4037        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4038            tracing::info!(
4039                stale_connection_id = %connection,
4040                "cleaning up stale connection",
4041            );
4042            drop(db);
4043            leave_room_for_session(&session, connection).await?;
4044            db = session.db().await;
4045        }
4046
4047        let (joined_room, membership_updated, role) = db
4048            .join_channel(channel_id, session.user_id(), session.connection_id)
4049            .await?;
4050
4051        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4052            let (can_publish, token) = if role == ChannelRole::Guest {
4053                (
4054                    false,
4055                    live_kit
4056                        .guest_token(
4057                            &joined_room.room.live_kit_room,
4058                            &session.user_id().to_string(),
4059                        )
4060                        .trace_err()?,
4061                )
4062            } else {
4063                (
4064                    true,
4065                    live_kit
4066                        .room_token(
4067                            &joined_room.room.live_kit_room,
4068                            &session.user_id().to_string(),
4069                        )
4070                        .trace_err()?,
4071                )
4072            };
4073
4074            Some(LiveKitConnectionInfo {
4075                server_url: live_kit.url().into(),
4076                token,
4077                can_publish,
4078            })
4079        });
4080
4081        response.send(proto::JoinRoomResponse {
4082            room: Some(joined_room.room.clone()),
4083            channel_id: joined_room
4084                .channel
4085                .as_ref()
4086                .map(|channel| channel.id.to_proto()),
4087            live_kit_connection_info,
4088        })?;
4089
4090        let mut connection_pool = session.connection_pool().await;
4091        if let Some(membership_updated) = membership_updated {
4092            notify_membership_updated(
4093                &mut connection_pool,
4094                membership_updated,
4095                session.user_id(),
4096                &session.peer,
4097            );
4098        }
4099
4100        room_updated(&joined_room.room, &session.peer);
4101
4102        joined_room
4103    };
4104
4105    channel_updated(
4106        &joined_room
4107            .channel
4108            .ok_or_else(|| anyhow!("channel not returned"))?,
4109        &joined_room.room,
4110        &session.peer,
4111        &*session.connection_pool().await,
4112    );
4113
4114    update_user_contacts(session.user_id(), &session).await?;
4115    Ok(())
4116}
4117
4118/// Start editing the channel notes
4119async fn join_channel_buffer(
4120    request: proto::JoinChannelBuffer,
4121    response: Response<proto::JoinChannelBuffer>,
4122    session: UserSession,
4123) -> Result<()> {
4124    let db = session.db().await;
4125    let channel_id = ChannelId::from_proto(request.channel_id);
4126
4127    let open_response = db
4128        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4129        .await?;
4130
4131    let collaborators = open_response.collaborators.clone();
4132    response.send(open_response)?;
4133
4134    let update = UpdateChannelBufferCollaborators {
4135        channel_id: channel_id.to_proto(),
4136        collaborators: collaborators.clone(),
4137    };
4138    channel_buffer_updated(
4139        session.connection_id,
4140        collaborators
4141            .iter()
4142            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4143        &update,
4144        &session.peer,
4145    );
4146
4147    Ok(())
4148}
4149
4150/// Edit the channel notes
4151async fn update_channel_buffer(
4152    request: proto::UpdateChannelBuffer,
4153    session: UserSession,
4154) -> Result<()> {
4155    let db = session.db().await;
4156    let channel_id = ChannelId::from_proto(request.channel_id);
4157
4158    let (collaborators, epoch, version) = db
4159        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4160        .await?;
4161
4162    channel_buffer_updated(
4163        session.connection_id,
4164        collaborators.clone(),
4165        &proto::UpdateChannelBuffer {
4166            channel_id: channel_id.to_proto(),
4167            operations: request.operations,
4168        },
4169        &session.peer,
4170    );
4171
4172    let pool = &*session.connection_pool().await;
4173
4174    let non_collaborators =
4175        pool.channel_connection_ids(channel_id)
4176            .filter_map(|(connection_id, _)| {
4177                if collaborators.contains(&connection_id) {
4178                    None
4179                } else {
4180                    Some(connection_id)
4181                }
4182            });
4183
4184    broadcast(None, non_collaborators, |peer_id| {
4185        session.peer.send(
4186            peer_id,
4187            proto::UpdateChannels {
4188                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4189                    channel_id: channel_id.to_proto(),
4190                    epoch: epoch as u64,
4191                    version: version.clone(),
4192                }],
4193                ..Default::default()
4194            },
4195        )
4196    });
4197
4198    Ok(())
4199}
4200
4201/// Rejoin the channel notes after a connection blip
4202async fn rejoin_channel_buffers(
4203    request: proto::RejoinChannelBuffers,
4204    response: Response<proto::RejoinChannelBuffers>,
4205    session: UserSession,
4206) -> Result<()> {
4207    let db = session.db().await;
4208    let buffers = db
4209        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4210        .await?;
4211
4212    for rejoined_buffer in &buffers {
4213        let collaborators_to_notify = rejoined_buffer
4214            .buffer
4215            .collaborators
4216            .iter()
4217            .filter_map(|c| Some(c.peer_id?.into()));
4218        channel_buffer_updated(
4219            session.connection_id,
4220            collaborators_to_notify,
4221            &proto::UpdateChannelBufferCollaborators {
4222                channel_id: rejoined_buffer.buffer.channel_id,
4223                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4224            },
4225            &session.peer,
4226        );
4227    }
4228
4229    response.send(proto::RejoinChannelBuffersResponse {
4230        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4231    })?;
4232
4233    Ok(())
4234}
4235
4236/// Stop editing the channel notes
4237async fn leave_channel_buffer(
4238    request: proto::LeaveChannelBuffer,
4239    response: Response<proto::LeaveChannelBuffer>,
4240    session: UserSession,
4241) -> Result<()> {
4242    let db = session.db().await;
4243    let channel_id = ChannelId::from_proto(request.channel_id);
4244
4245    let left_buffer = db
4246        .leave_channel_buffer(channel_id, session.connection_id)
4247        .await?;
4248
4249    response.send(Ack {})?;
4250
4251    channel_buffer_updated(
4252        session.connection_id,
4253        left_buffer.connections,
4254        &proto::UpdateChannelBufferCollaborators {
4255            channel_id: channel_id.to_proto(),
4256            collaborators: left_buffer.collaborators,
4257        },
4258        &session.peer,
4259    );
4260
4261    Ok(())
4262}
4263
4264fn channel_buffer_updated<T: EnvelopedMessage>(
4265    sender_id: ConnectionId,
4266    collaborators: impl IntoIterator<Item = ConnectionId>,
4267    message: &T,
4268    peer: &Peer,
4269) {
4270    broadcast(Some(sender_id), collaborators, |peer_id| {
4271        peer.send(peer_id, message.clone())
4272    });
4273}
4274
4275fn send_notifications(
4276    connection_pool: &ConnectionPool,
4277    peer: &Peer,
4278    notifications: db::NotificationBatch,
4279) {
4280    for (user_id, notification) in notifications {
4281        for connection_id in connection_pool.user_connection_ids(user_id) {
4282            if let Err(error) = peer.send(
4283                connection_id,
4284                proto::AddNotification {
4285                    notification: Some(notification.clone()),
4286                },
4287            ) {
4288                tracing::error!(
4289                    "failed to send notification to {:?} {}",
4290                    connection_id,
4291                    error
4292                );
4293            }
4294        }
4295    }
4296}
4297
4298/// Send a message to the channel
4299async fn send_channel_message(
4300    request: proto::SendChannelMessage,
4301    response: Response<proto::SendChannelMessage>,
4302    session: UserSession,
4303) -> Result<()> {
4304    // Validate the message body.
4305    let body = request.body.trim().to_string();
4306    if body.len() > MAX_MESSAGE_LEN {
4307        return Err(anyhow!("message is too long"))?;
4308    }
4309    if body.is_empty() {
4310        return Err(anyhow!("message can't be blank"))?;
4311    }
4312
4313    // TODO: adjust mentions if body is trimmed
4314
4315    let timestamp = OffsetDateTime::now_utc();
4316    let nonce = request
4317        .nonce
4318        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4319
4320    let channel_id = ChannelId::from_proto(request.channel_id);
4321    let CreatedChannelMessage {
4322        message_id,
4323        participant_connection_ids,
4324        notifications,
4325    } = session
4326        .db()
4327        .await
4328        .create_channel_message(
4329            channel_id,
4330            session.user_id(),
4331            &body,
4332            &request.mentions,
4333            timestamp,
4334            nonce.clone().into(),
4335            match request.reply_to_message_id {
4336                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4337                None => None,
4338            },
4339        )
4340        .await?;
4341
4342    let message = proto::ChannelMessage {
4343        sender_id: session.user_id().to_proto(),
4344        id: message_id.to_proto(),
4345        body,
4346        mentions: request.mentions,
4347        timestamp: timestamp.unix_timestamp() as u64,
4348        nonce: Some(nonce),
4349        reply_to_message_id: request.reply_to_message_id,
4350        edited_at: None,
4351    };
4352    broadcast(
4353        Some(session.connection_id),
4354        participant_connection_ids.clone(),
4355        |connection| {
4356            session.peer.send(
4357                connection,
4358                proto::ChannelMessageSent {
4359                    channel_id: channel_id.to_proto(),
4360                    message: Some(message.clone()),
4361                },
4362            )
4363        },
4364    );
4365    response.send(proto::SendChannelMessageResponse {
4366        message: Some(message),
4367    })?;
4368
4369    let pool = &*session.connection_pool().await;
4370    let non_participants =
4371        pool.channel_connection_ids(channel_id)
4372            .filter_map(|(connection_id, _)| {
4373                if participant_connection_ids.contains(&connection_id) {
4374                    None
4375                } else {
4376                    Some(connection_id)
4377                }
4378            });
4379    broadcast(None, non_participants, |peer_id| {
4380        session.peer.send(
4381            peer_id,
4382            proto::UpdateChannels {
4383                latest_channel_message_ids: vec![proto::ChannelMessageId {
4384                    channel_id: channel_id.to_proto(),
4385                    message_id: message_id.to_proto(),
4386                }],
4387                ..Default::default()
4388            },
4389        )
4390    });
4391    send_notifications(pool, &session.peer, notifications);
4392
4393    Ok(())
4394}
4395
4396/// Delete a channel message
4397async fn remove_channel_message(
4398    request: proto::RemoveChannelMessage,
4399    response: Response<proto::RemoveChannelMessage>,
4400    session: UserSession,
4401) -> Result<()> {
4402    let channel_id = ChannelId::from_proto(request.channel_id);
4403    let message_id = MessageId::from_proto(request.message_id);
4404    let (connection_ids, existing_notification_ids) = session
4405        .db()
4406        .await
4407        .remove_channel_message(channel_id, message_id, session.user_id())
4408        .await?;
4409
4410    broadcast(
4411        Some(session.connection_id),
4412        connection_ids,
4413        move |connection| {
4414            session.peer.send(connection, request.clone())?;
4415
4416            for notification_id in &existing_notification_ids {
4417                session.peer.send(
4418                    connection,
4419                    proto::DeleteNotification {
4420                        notification_id: (*notification_id).to_proto(),
4421                    },
4422                )?;
4423            }
4424
4425            Ok(())
4426        },
4427    );
4428    response.send(proto::Ack {})?;
4429    Ok(())
4430}
4431
4432async fn update_channel_message(
4433    request: proto::UpdateChannelMessage,
4434    response: Response<proto::UpdateChannelMessage>,
4435    session: UserSession,
4436) -> Result<()> {
4437    let channel_id = ChannelId::from_proto(request.channel_id);
4438    let message_id = MessageId::from_proto(request.message_id);
4439    let updated_at = OffsetDateTime::now_utc();
4440    let UpdatedChannelMessage {
4441        message_id,
4442        participant_connection_ids,
4443        notifications,
4444        reply_to_message_id,
4445        timestamp,
4446        deleted_mention_notification_ids,
4447        updated_mention_notifications,
4448    } = session
4449        .db()
4450        .await
4451        .update_channel_message(
4452            channel_id,
4453            message_id,
4454            session.user_id(),
4455            request.body.as_str(),
4456            &request.mentions,
4457            updated_at,
4458        )
4459        .await?;
4460
4461    let nonce = request
4462        .nonce
4463        .clone()
4464        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4465
4466    let message = proto::ChannelMessage {
4467        sender_id: session.user_id().to_proto(),
4468        id: message_id.to_proto(),
4469        body: request.body.clone(),
4470        mentions: request.mentions.clone(),
4471        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4472        nonce: Some(nonce),
4473        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4474        edited_at: Some(updated_at.unix_timestamp() as u64),
4475    };
4476
4477    response.send(proto::Ack {})?;
4478
4479    let pool = &*session.connection_pool().await;
4480    broadcast(
4481        Some(session.connection_id),
4482        participant_connection_ids,
4483        |connection| {
4484            session.peer.send(
4485                connection,
4486                proto::ChannelMessageUpdate {
4487                    channel_id: channel_id.to_proto(),
4488                    message: Some(message.clone()),
4489                },
4490            )?;
4491
4492            for notification_id in &deleted_mention_notification_ids {
4493                session.peer.send(
4494                    connection,
4495                    proto::DeleteNotification {
4496                        notification_id: (*notification_id).to_proto(),
4497                    },
4498                )?;
4499            }
4500
4501            for notification in &updated_mention_notifications {
4502                session.peer.send(
4503                    connection,
4504                    proto::UpdateNotification {
4505                        notification: Some(notification.clone()),
4506                    },
4507                )?;
4508            }
4509
4510            Ok(())
4511        },
4512    );
4513
4514    send_notifications(pool, &session.peer, notifications);
4515
4516    Ok(())
4517}
4518
4519/// Mark a channel message as read
4520async fn acknowledge_channel_message(
4521    request: proto::AckChannelMessage,
4522    session: UserSession,
4523) -> Result<()> {
4524    let channel_id = ChannelId::from_proto(request.channel_id);
4525    let message_id = MessageId::from_proto(request.message_id);
4526    let notifications = session
4527        .db()
4528        .await
4529        .observe_channel_message(channel_id, session.user_id(), message_id)
4530        .await?;
4531    send_notifications(
4532        &*session.connection_pool().await,
4533        &session.peer,
4534        notifications,
4535    );
4536    Ok(())
4537}
4538
4539/// Mark a buffer version as synced
4540async fn acknowledge_buffer_version(
4541    request: proto::AckBufferOperation,
4542    session: UserSession,
4543) -> Result<()> {
4544    let buffer_id = BufferId::from_proto(request.buffer_id);
4545    session
4546        .db()
4547        .await
4548        .observe_buffer_version(
4549            buffer_id,
4550            session.user_id(),
4551            request.epoch as i32,
4552            &request.version,
4553        )
4554        .await?;
4555    Ok(())
4556}
4557
4558struct ZedProCompleteWithLanguageModelRateLimit;
4559
4560impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
4561    fn capacity(&self) -> usize {
4562        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4563            .ok()
4564            .and_then(|v| v.parse().ok())
4565            .unwrap_or(120) // Picked arbitrarily
4566    }
4567
4568    fn refill_duration(&self) -> chrono::Duration {
4569        chrono::Duration::hours(1)
4570    }
4571
4572    fn db_name(&self) -> &'static str {
4573        "zed-pro:complete-with-language-model"
4574    }
4575}
4576
4577struct FreeCompleteWithLanguageModelRateLimit;
4578
4579impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
4580    fn capacity(&self) -> usize {
4581        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
4582            .ok()
4583            .and_then(|v| v.parse().ok())
4584            .unwrap_or(120 / 10) // Picked arbitrarily
4585    }
4586
4587    fn refill_duration(&self) -> chrono::Duration {
4588        chrono::Duration::hours(1)
4589    }
4590
4591    fn db_name(&self) -> &'static str {
4592        "free:complete-with-language-model"
4593    }
4594}
4595
4596async fn complete_with_language_model(
4597    request: proto::CompleteWithLanguageModel,
4598    response: Response<proto::CompleteWithLanguageModel>,
4599    session: Session,
4600    config: &Config,
4601) -> Result<()> {
4602    let Some(session) = session.for_user() else {
4603        return Err(anyhow!("user not found"))?;
4604    };
4605    authorize_access_to_language_models(&session).await?;
4606
4607    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4608        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4609        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4610    };
4611
4612    session
4613        .rate_limiter
4614        .check(&*rate_limit, session.user_id())
4615        .await?;
4616
4617    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4618        Some(proto::LanguageModelProvider::Anthropic) => {
4619            let api_key = config
4620                .anthropic_api_key
4621                .as_ref()
4622                .context("no Anthropic AI API key configured on the server")?;
4623            anthropic::complete(
4624                session.http_client.as_ref(),
4625                anthropic::ANTHROPIC_API_URL,
4626                api_key,
4627                serde_json::from_str(&request.request)?,
4628            )
4629            .await?
4630        }
4631        _ => return Err(anyhow!("unsupported provider"))?,
4632    };
4633
4634    response.send(proto::CompleteWithLanguageModelResponse {
4635        completion: serde_json::to_string(&result)?,
4636    })?;
4637
4638    Ok(())
4639}
4640
4641async fn stream_complete_with_language_model(
4642    request: proto::StreamCompleteWithLanguageModel,
4643    response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4644    session: Session,
4645    config: &Config,
4646) -> Result<()> {
4647    let Some(session) = session.for_user() else {
4648        return Err(anyhow!("user not found"))?;
4649    };
4650    authorize_access_to_language_models(&session).await?;
4651
4652    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4653        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4654        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4655    };
4656
4657    session
4658        .rate_limiter
4659        .check(&*rate_limit, session.user_id())
4660        .await?;
4661
4662    match proto::LanguageModelProvider::from_i32(request.provider) {
4663        Some(proto::LanguageModelProvider::Anthropic) => {
4664            let api_key = config
4665                .anthropic_api_key
4666                .as_ref()
4667                .context("no Anthropic AI API key configured on the server")?;
4668            let mut chunks = anthropic::stream_completion(
4669                session.http_client.as_ref(),
4670                anthropic::ANTHROPIC_API_URL,
4671                api_key,
4672                serde_json::from_str(&request.request)?,
4673                None,
4674            )
4675            .await?;
4676            while let Some(event) = chunks.next().await {
4677                let chunk = event?;
4678                response.send(proto::StreamCompleteWithLanguageModelResponse {
4679                    event: serde_json::to_string(&chunk)?,
4680                })?;
4681            }
4682        }
4683        Some(proto::LanguageModelProvider::OpenAi) => {
4684            let api_key = config
4685                .openai_api_key
4686                .as_ref()
4687                .context("no OpenAI API key configured on the server")?;
4688            let mut events = open_ai::stream_completion(
4689                session.http_client.as_ref(),
4690                open_ai::OPEN_AI_API_URL,
4691                api_key,
4692                serde_json::from_str(&request.request)?,
4693                None,
4694            )
4695            .await?;
4696            while let Some(event) = events.next().await {
4697                let event = event?;
4698                response.send(proto::StreamCompleteWithLanguageModelResponse {
4699                    event: serde_json::to_string(&event)?,
4700                })?;
4701            }
4702        }
4703        Some(proto::LanguageModelProvider::Google) => {
4704            let api_key = config
4705                .google_ai_api_key
4706                .as_ref()
4707                .context("no Google AI API key configured on the server")?;
4708            let mut events = google_ai::stream_generate_content(
4709                session.http_client.as_ref(),
4710                google_ai::API_URL,
4711                api_key,
4712                serde_json::from_str(&request.request)?,
4713            )
4714            .await?;
4715            while let Some(event) = events.next().await {
4716                let event = event?;
4717                response.send(proto::StreamCompleteWithLanguageModelResponse {
4718                    event: serde_json::to_string(&event)?,
4719                })?;
4720            }
4721        }
4722        Some(proto::LanguageModelProvider::Zed) => {
4723            let api_key = config
4724                .qwen2_7b_api_key
4725                .as_ref()
4726                .context("no Qwen2-7B API key configured on the server")?;
4727            let api_url = config
4728                .qwen2_7b_api_url
4729                .as_ref()
4730                .context("no Qwen2-7B URL configured on the server")?;
4731            let mut events = open_ai::stream_completion(
4732                session.http_client.as_ref(),
4733                &api_url,
4734                api_key,
4735                serde_json::from_str(&request.request)?,
4736                None,
4737            )
4738            .await?;
4739            while let Some(event) = events.next().await {
4740                let event = event?;
4741                response.send(proto::StreamCompleteWithLanguageModelResponse {
4742                    event: serde_json::to_string(&event)?,
4743                })?;
4744            }
4745        }
4746        None => return Err(anyhow!("unknown provider"))?,
4747    }
4748
4749    Ok(())
4750}
4751
4752async fn count_language_model_tokens(
4753    request: proto::CountLanguageModelTokens,
4754    response: Response<proto::CountLanguageModelTokens>,
4755    session: Session,
4756    config: &Config,
4757) -> Result<()> {
4758    let Some(session) = session.for_user() else {
4759        return Err(anyhow!("user not found"))?;
4760    };
4761    authorize_access_to_language_models(&session).await?;
4762
4763    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4764        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4765        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4766    };
4767
4768    session
4769        .rate_limiter
4770        .check(&*rate_limit, session.user_id())
4771        .await?;
4772
4773    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4774        Some(proto::LanguageModelProvider::Google) => {
4775            let api_key = config
4776                .google_ai_api_key
4777                .as_ref()
4778                .context("no Google AI API key configured on the server")?;
4779            google_ai::count_tokens(
4780                session.http_client.as_ref(),
4781                google_ai::API_URL,
4782                api_key,
4783                serde_json::from_str(&request.request)?,
4784            )
4785            .await?
4786        }
4787        _ => return Err(anyhow!("unsupported provider"))?,
4788    };
4789
4790    response.send(proto::CountLanguageModelTokensResponse {
4791        token_count: result.total_tokens as u32,
4792    })?;
4793
4794    Ok(())
4795}
4796
4797struct ZedProCountLanguageModelTokensRateLimit;
4798
4799impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4800    fn capacity(&self) -> usize {
4801        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4802            .ok()
4803            .and_then(|v| v.parse().ok())
4804            .unwrap_or(600) // Picked arbitrarily
4805    }
4806
4807    fn refill_duration(&self) -> chrono::Duration {
4808        chrono::Duration::hours(1)
4809    }
4810
4811    fn db_name(&self) -> &'static str {
4812        "zed-pro:count-language-model-tokens"
4813    }
4814}
4815
4816struct FreeCountLanguageModelTokensRateLimit;
4817
4818impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4819    fn capacity(&self) -> usize {
4820        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4821            .ok()
4822            .and_then(|v| v.parse().ok())
4823            .unwrap_or(600 / 10) // Picked arbitrarily
4824    }
4825
4826    fn refill_duration(&self) -> chrono::Duration {
4827        chrono::Duration::hours(1)
4828    }
4829
4830    fn db_name(&self) -> &'static str {
4831        "free:count-language-model-tokens"
4832    }
4833}
4834
4835struct ZedProComputeEmbeddingsRateLimit;
4836
4837impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4838    fn capacity(&self) -> usize {
4839        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4840            .ok()
4841            .and_then(|v| v.parse().ok())
4842            .unwrap_or(5000) // Picked arbitrarily
4843    }
4844
4845    fn refill_duration(&self) -> chrono::Duration {
4846        chrono::Duration::hours(1)
4847    }
4848
4849    fn db_name(&self) -> &'static str {
4850        "zed-pro:compute-embeddings"
4851    }
4852}
4853
4854struct FreeComputeEmbeddingsRateLimit;
4855
4856impl RateLimit for FreeComputeEmbeddingsRateLimit {
4857    fn capacity(&self) -> usize {
4858        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4859            .ok()
4860            .and_then(|v| v.parse().ok())
4861            .unwrap_or(5000 / 10) // Picked arbitrarily
4862    }
4863
4864    fn refill_duration(&self) -> chrono::Duration {
4865        chrono::Duration::hours(1)
4866    }
4867
4868    fn db_name(&self) -> &'static str {
4869        "free:compute-embeddings"
4870    }
4871}
4872
4873async fn compute_embeddings(
4874    request: proto::ComputeEmbeddings,
4875    response: Response<proto::ComputeEmbeddings>,
4876    session: UserSession,
4877    api_key: Option<Arc<str>>,
4878) -> Result<()> {
4879    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4880    authorize_access_to_language_models(&session).await?;
4881
4882    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4883        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4884        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4885    };
4886
4887    session
4888        .rate_limiter
4889        .check(&*rate_limit, session.user_id())
4890        .await?;
4891
4892    let embeddings = match request.model.as_str() {
4893        "openai/text-embedding-3-small" => {
4894            open_ai::embed(
4895                session.http_client.as_ref(),
4896                OPEN_AI_API_URL,
4897                &api_key,
4898                OpenAiEmbeddingModel::TextEmbedding3Small,
4899                request.texts.iter().map(|text| text.as_str()),
4900            )
4901            .await?
4902        }
4903        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4904    };
4905
4906    let embeddings = request
4907        .texts
4908        .iter()
4909        .map(|text| {
4910            let mut hasher = sha2::Sha256::new();
4911            hasher.update(text.as_bytes());
4912            let result = hasher.finalize();
4913            result.to_vec()
4914        })
4915        .zip(
4916            embeddings
4917                .data
4918                .into_iter()
4919                .map(|embedding| embedding.embedding),
4920        )
4921        .collect::<HashMap<_, _>>();
4922
4923    let db = session.db().await;
4924    db.save_embeddings(&request.model, &embeddings)
4925        .await
4926        .context("failed to save embeddings")
4927        .trace_err();
4928
4929    response.send(proto::ComputeEmbeddingsResponse {
4930        embeddings: embeddings
4931            .into_iter()
4932            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4933            .collect(),
4934    })?;
4935    Ok(())
4936}
4937
4938async fn get_cached_embeddings(
4939    request: proto::GetCachedEmbeddings,
4940    response: Response<proto::GetCachedEmbeddings>,
4941    session: UserSession,
4942) -> Result<()> {
4943    authorize_access_to_language_models(&session).await?;
4944
4945    let db = session.db().await;
4946    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4947
4948    response.send(proto::GetCachedEmbeddingsResponse {
4949        embeddings: embeddings
4950            .into_iter()
4951            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4952            .collect(),
4953    })?;
4954    Ok(())
4955}
4956
4957async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4958    let db = session.db().await;
4959    let flags = db.get_user_flags(session.user_id()).await?;
4960    if flags.iter().any(|flag| flag == "language-models") {
4961        return Ok(());
4962    }
4963
4964    Err(anyhow!("permission denied"))?
4965}
4966
4967/// Get a Supermaven API key for the user
4968async fn get_supermaven_api_key(
4969    _request: proto::GetSupermavenApiKey,
4970    response: Response<proto::GetSupermavenApiKey>,
4971    session: UserSession,
4972) -> Result<()> {
4973    let user_id: String = session.user_id().to_string();
4974    if !session.is_staff() {
4975        return Err(anyhow!("supermaven not enabled for this account"))?;
4976    }
4977
4978    let email = session
4979        .email()
4980        .ok_or_else(|| anyhow!("user must have an email"))?;
4981
4982    let supermaven_admin_api = session
4983        .supermaven_client
4984        .as_ref()
4985        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4986
4987    let result = supermaven_admin_api
4988        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4989        .await?;
4990
4991    response.send(proto::GetSupermavenApiKeyResponse {
4992        api_key: result.api_key,
4993    })?;
4994
4995    Ok(())
4996}
4997
4998/// Start receiving chat updates for a channel
4999async fn join_channel_chat(
5000    request: proto::JoinChannelChat,
5001    response: Response<proto::JoinChannelChat>,
5002    session: UserSession,
5003) -> Result<()> {
5004    let channel_id = ChannelId::from_proto(request.channel_id);
5005
5006    let db = session.db().await;
5007    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
5008        .await?;
5009    let messages = db
5010        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
5011        .await?;
5012    response.send(proto::JoinChannelChatResponse {
5013        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5014        messages,
5015    })?;
5016    Ok(())
5017}
5018
5019/// Stop receiving chat updates for a channel
5020async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5021    let channel_id = ChannelId::from_proto(request.channel_id);
5022    session
5023        .db()
5024        .await
5025        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5026        .await?;
5027    Ok(())
5028}
5029
5030/// Retrieve the chat history for a channel
5031async fn get_channel_messages(
5032    request: proto::GetChannelMessages,
5033    response: Response<proto::GetChannelMessages>,
5034    session: UserSession,
5035) -> Result<()> {
5036    let channel_id = ChannelId::from_proto(request.channel_id);
5037    let messages = session
5038        .db()
5039        .await
5040        .get_channel_messages(
5041            channel_id,
5042            session.user_id(),
5043            MESSAGE_COUNT_PER_PAGE,
5044            Some(MessageId::from_proto(request.before_message_id)),
5045        )
5046        .await?;
5047    response.send(proto::GetChannelMessagesResponse {
5048        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5049        messages,
5050    })?;
5051    Ok(())
5052}
5053
5054/// Retrieve specific chat messages
5055async fn get_channel_messages_by_id(
5056    request: proto::GetChannelMessagesById,
5057    response: Response<proto::GetChannelMessagesById>,
5058    session: UserSession,
5059) -> Result<()> {
5060    let message_ids = request
5061        .message_ids
5062        .iter()
5063        .map(|id| MessageId::from_proto(*id))
5064        .collect::<Vec<_>>();
5065    let messages = session
5066        .db()
5067        .await
5068        .get_channel_messages_by_id(session.user_id(), &message_ids)
5069        .await?;
5070    response.send(proto::GetChannelMessagesResponse {
5071        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5072        messages,
5073    })?;
5074    Ok(())
5075}
5076
5077/// Retrieve the current users notifications
5078async fn get_notifications(
5079    request: proto::GetNotifications,
5080    response: Response<proto::GetNotifications>,
5081    session: UserSession,
5082) -> Result<()> {
5083    let notifications = session
5084        .db()
5085        .await
5086        .get_notifications(
5087            session.user_id(),
5088            NOTIFICATION_COUNT_PER_PAGE,
5089            request
5090                .before_id
5091                .map(|id| db::NotificationId::from_proto(id)),
5092        )
5093        .await?;
5094    response.send(proto::GetNotificationsResponse {
5095        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5096        notifications,
5097    })?;
5098    Ok(())
5099}
5100
5101/// Mark notifications as read
5102async fn mark_notification_as_read(
5103    request: proto::MarkNotificationRead,
5104    response: Response<proto::MarkNotificationRead>,
5105    session: UserSession,
5106) -> Result<()> {
5107    let database = &session.db().await;
5108    let notifications = database
5109        .mark_notification_as_read_by_id(
5110            session.user_id(),
5111            NotificationId::from_proto(request.notification_id),
5112        )
5113        .await?;
5114    send_notifications(
5115        &*session.connection_pool().await,
5116        &session.peer,
5117        notifications,
5118    );
5119    response.send(proto::Ack {})?;
5120    Ok(())
5121}
5122
5123/// Get the current users information
5124async fn get_private_user_info(
5125    _request: proto::GetPrivateUserInfo,
5126    response: Response<proto::GetPrivateUserInfo>,
5127    session: UserSession,
5128) -> Result<()> {
5129    let db = session.db().await;
5130
5131    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5132    let user = db
5133        .get_user_by_id(session.user_id())
5134        .await?
5135        .ok_or_else(|| anyhow!("user not found"))?;
5136    let flags = db.get_user_flags(session.user_id()).await?;
5137
5138    response.send(proto::GetPrivateUserInfoResponse {
5139        metrics_id,
5140        staff: user.admin,
5141        flags,
5142    })?;
5143    Ok(())
5144}
5145
5146fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5147    let message = match message {
5148        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5149        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5150        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5151        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5152        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5153            code: frame.code.into(),
5154            reason: frame.reason,
5155        })),
5156        // We should never receive a frame while reading the message, according
5157        // to the `tungstenite` maintainers:
5158        //
5159        // > It cannot occur when you read messages from the WebSocket, but it
5160        // > can be used when you want to send the raw frames (e.g. you want to
5161        // > send the frames to the WebSocket without composing the full message first).
5162        // >
5163        // > — https://github.com/snapview/tungstenite-rs/issues/268
5164        TungsteniteMessage::Frame(_) => {
5165            bail!("received an unexpected frame while reading the message")
5166        }
5167    };
5168
5169    Ok(message)
5170}
5171
5172fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5173    match message {
5174        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5175        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5176        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5177        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5178        AxumMessage::Close(frame) => {
5179            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5180                code: frame.code.into(),
5181                reason: frame.reason,
5182            }))
5183        }
5184    }
5185}
5186
5187fn notify_membership_updated(
5188    connection_pool: &mut ConnectionPool,
5189    result: MembershipUpdated,
5190    user_id: UserId,
5191    peer: &Peer,
5192) {
5193    for membership in &result.new_channels.channel_memberships {
5194        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5195    }
5196    for channel_id in &result.removed_channels {
5197        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5198    }
5199
5200    let user_channels_update = proto::UpdateUserChannels {
5201        channel_memberships: result
5202            .new_channels
5203            .channel_memberships
5204            .iter()
5205            .map(|cm| proto::ChannelMembership {
5206                channel_id: cm.channel_id.to_proto(),
5207                role: cm.role.into(),
5208            })
5209            .collect(),
5210        ..Default::default()
5211    };
5212
5213    let mut update = build_channels_update(result.new_channels);
5214    update.delete_channels = result
5215        .removed_channels
5216        .into_iter()
5217        .map(|id| id.to_proto())
5218        .collect();
5219    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5220
5221    for connection_id in connection_pool.user_connection_ids(user_id) {
5222        peer.send(connection_id, user_channels_update.clone())
5223            .trace_err();
5224        peer.send(connection_id, update.clone()).trace_err();
5225    }
5226}
5227
5228fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5229    proto::UpdateUserChannels {
5230        channel_memberships: channels
5231            .channel_memberships
5232            .iter()
5233            .map(|m| proto::ChannelMembership {
5234                channel_id: m.channel_id.to_proto(),
5235                role: m.role.into(),
5236            })
5237            .collect(),
5238        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5239        observed_channel_message_id: channels.observed_channel_messages.clone(),
5240    }
5241}
5242
5243fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5244    let mut update = proto::UpdateChannels::default();
5245
5246    for channel in channels.channels {
5247        update.channels.push(channel.to_proto());
5248    }
5249
5250    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5251    update.latest_channel_message_ids = channels.latest_channel_messages;
5252
5253    for (channel_id, participants) in channels.channel_participants {
5254        update
5255            .channel_participants
5256            .push(proto::ChannelParticipants {
5257                channel_id: channel_id.to_proto(),
5258                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5259            });
5260    }
5261
5262    for channel in channels.invited_channels {
5263        update.channel_invitations.push(channel.to_proto());
5264    }
5265
5266    update.hosted_projects = channels.hosted_projects;
5267    update
5268}
5269
5270fn build_initial_contacts_update(
5271    contacts: Vec<db::Contact>,
5272    pool: &ConnectionPool,
5273) -> proto::UpdateContacts {
5274    let mut update = proto::UpdateContacts::default();
5275
5276    for contact in contacts {
5277        match contact {
5278            db::Contact::Accepted { user_id, busy } => {
5279                update.contacts.push(contact_for_user(user_id, busy, &pool));
5280            }
5281            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5282            db::Contact::Incoming { user_id } => {
5283                update
5284                    .incoming_requests
5285                    .push(proto::IncomingContactRequest {
5286                        requester_id: user_id.to_proto(),
5287                    })
5288            }
5289        }
5290    }
5291
5292    update
5293}
5294
5295fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5296    proto::Contact {
5297        user_id: user_id.to_proto(),
5298        online: pool.is_user_online(user_id),
5299        busy,
5300    }
5301}
5302
5303fn room_updated(room: &proto::Room, peer: &Peer) {
5304    broadcast(
5305        None,
5306        room.participants
5307            .iter()
5308            .filter_map(|participant| Some(participant.peer_id?.into())),
5309        |peer_id| {
5310            peer.send(
5311                peer_id,
5312                proto::RoomUpdated {
5313                    room: Some(room.clone()),
5314                },
5315            )
5316        },
5317    );
5318}
5319
5320fn channel_updated(
5321    channel: &db::channel::Model,
5322    room: &proto::Room,
5323    peer: &Peer,
5324    pool: &ConnectionPool,
5325) {
5326    let participants = room
5327        .participants
5328        .iter()
5329        .map(|p| p.user_id)
5330        .collect::<Vec<_>>();
5331
5332    broadcast(
5333        None,
5334        pool.channel_connection_ids(channel.root_id())
5335            .filter_map(|(channel_id, role)| {
5336                role.can_see_channel(channel.visibility).then(|| channel_id)
5337            }),
5338        |peer_id| {
5339            peer.send(
5340                peer_id,
5341                proto::UpdateChannels {
5342                    channel_participants: vec![proto::ChannelParticipants {
5343                        channel_id: channel.id.to_proto(),
5344                        participant_user_ids: participants.clone(),
5345                    }],
5346                    ..Default::default()
5347                },
5348            )
5349        },
5350    );
5351}
5352
5353async fn send_dev_server_projects_update(
5354    user_id: UserId,
5355    mut status: proto::DevServerProjectsUpdate,
5356    session: &Session,
5357) {
5358    let pool = session.connection_pool().await;
5359    for dev_server in &mut status.dev_servers {
5360        dev_server.status =
5361            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5362    }
5363    let connections = pool.user_connection_ids(user_id);
5364    for connection_id in connections {
5365        session.peer.send(connection_id, status.clone()).trace_err();
5366    }
5367}
5368
5369async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5370    let db = session.db().await;
5371
5372    let contacts = db.get_contacts(user_id).await?;
5373    let busy = db.is_user_busy(user_id).await?;
5374
5375    let pool = session.connection_pool().await;
5376    let updated_contact = contact_for_user(user_id, busy, &pool);
5377    for contact in contacts {
5378        if let db::Contact::Accepted {
5379            user_id: contact_user_id,
5380            ..
5381        } = contact
5382        {
5383            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5384                session
5385                    .peer
5386                    .send(
5387                        contact_conn_id,
5388                        proto::UpdateContacts {
5389                            contacts: vec![updated_contact.clone()],
5390                            remove_contacts: Default::default(),
5391                            incoming_requests: Default::default(),
5392                            remove_incoming_requests: Default::default(),
5393                            outgoing_requests: Default::default(),
5394                            remove_outgoing_requests: Default::default(),
5395                        },
5396                    )
5397                    .trace_err();
5398            }
5399        }
5400    }
5401    Ok(())
5402}
5403
5404async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5405    log::info!("lost dev server connection, unsharing projects");
5406    let project_ids = session
5407        .db()
5408        .await
5409        .get_stale_dev_server_projects(session.connection_id)
5410        .await?;
5411
5412    for project_id in project_ids {
5413        // not unshare re-checks the connection ids match, so we get away with no transaction
5414        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5415    }
5416
5417    let user_id = session.dev_server().user_id;
5418    let update = session
5419        .db()
5420        .await
5421        .dev_server_projects_update(user_id)
5422        .await?;
5423
5424    send_dev_server_projects_update(user_id, update, session).await;
5425
5426    Ok(())
5427}
5428
5429async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5430    let mut contacts_to_update = HashSet::default();
5431
5432    let room_id;
5433    let canceled_calls_to_user_ids;
5434    let live_kit_room;
5435    let delete_live_kit_room;
5436    let room;
5437    let channel;
5438
5439    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5440        contacts_to_update.insert(session.user_id());
5441
5442        for project in left_room.left_projects.values() {
5443            project_left(project, session);
5444        }
5445
5446        room_id = RoomId::from_proto(left_room.room.id);
5447        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5448        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5449        delete_live_kit_room = left_room.deleted;
5450        room = mem::take(&mut left_room.room);
5451        channel = mem::take(&mut left_room.channel);
5452
5453        room_updated(&room, &session.peer);
5454    } else {
5455        return Ok(());
5456    }
5457
5458    if let Some(channel) = channel {
5459        channel_updated(
5460            &channel,
5461            &room,
5462            &session.peer,
5463            &*session.connection_pool().await,
5464        );
5465    }
5466
5467    {
5468        let pool = session.connection_pool().await;
5469        for canceled_user_id in canceled_calls_to_user_ids {
5470            for connection_id in pool.user_connection_ids(canceled_user_id) {
5471                session
5472                    .peer
5473                    .send(
5474                        connection_id,
5475                        proto::CallCanceled {
5476                            room_id: room_id.to_proto(),
5477                        },
5478                    )
5479                    .trace_err();
5480            }
5481            contacts_to_update.insert(canceled_user_id);
5482        }
5483    }
5484
5485    for contact_user_id in contacts_to_update {
5486        update_user_contacts(contact_user_id, &session).await?;
5487    }
5488
5489    if let Some(live_kit) = session.live_kit_client.as_ref() {
5490        live_kit
5491            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5492            .await
5493            .trace_err();
5494
5495        if delete_live_kit_room {
5496            live_kit.delete_room(live_kit_room).await.trace_err();
5497        }
5498    }
5499
5500    Ok(())
5501}
5502
5503async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5504    let left_channel_buffers = session
5505        .db()
5506        .await
5507        .leave_channel_buffers(session.connection_id)
5508        .await?;
5509
5510    for left_buffer in left_channel_buffers {
5511        channel_buffer_updated(
5512            session.connection_id,
5513            left_buffer.connections,
5514            &proto::UpdateChannelBufferCollaborators {
5515                channel_id: left_buffer.channel_id.to_proto(),
5516                collaborators: left_buffer.collaborators,
5517            },
5518            &session.peer,
5519        );
5520    }
5521
5522    Ok(())
5523}
5524
5525fn project_left(project: &db::LeftProject, session: &UserSession) {
5526    for connection_id in &project.connection_ids {
5527        if project.should_unshare {
5528            session
5529                .peer
5530                .send(
5531                    *connection_id,
5532                    proto::UnshareProject {
5533                        project_id: project.id.to_proto(),
5534                    },
5535                )
5536                .trace_err();
5537        } else {
5538            session
5539                .peer
5540                .send(
5541                    *connection_id,
5542                    proto::RemoveProjectCollaborator {
5543                        project_id: project.id.to_proto(),
5544                        peer_id: Some(session.connection_id.into()),
5545                    },
5546                )
5547                .trace_err();
5548        }
5549    }
5550}
5551
5552pub trait ResultExt {
5553    type Ok;
5554
5555    fn trace_err(self) -> Option<Self::Ok>;
5556}
5557
5558impl<T, E> ResultExt for Result<T, E>
5559where
5560    E: std::fmt::Debug,
5561{
5562    type Ok = T;
5563
5564    #[track_caller]
5565    fn trace_err(self) -> Option<T> {
5566        match self {
5567            Ok(value) => Some(value),
5568            Err(error) => {
5569                tracing::error!("{:?}", error);
5570                None
5571            }
5572        }
5573    }
5574}