rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Config, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, bail, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http_client::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
 203        if self.is_staff() {
 204            return Ok(proto::Plan::ZedPro);
 205        }
 206
 207        let Some(user_id) = self.user_id() else {
 208            return Ok(proto::Plan::Free);
 209        };
 210
 211        let db = self.db().await;
 212        if db.has_active_billing_subscription(user_id).await? {
 213            Ok(proto::Plan::ZedPro)
 214        } else {
 215            Ok(proto::Plan::Free)
 216        }
 217    }
 218
 219    fn dev_server_id(&self) -> Option<DevServerId> {
 220        match &self.principal {
 221            Principal::User(_) | Principal::Impersonated { .. } => None,
 222            Principal::DevServer(dev_server) => Some(dev_server.id),
 223        }
 224    }
 225
 226    fn principal_id(&self) -> PrincipalId {
 227        match &self.principal {
 228            Principal::User(user) => PrincipalId::UserId(user.id),
 229            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 230            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 231        }
 232    }
 233}
 234
 235impl Debug for Session {
 236    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 237        let mut result = f.debug_struct("Session");
 238        match &self.principal {
 239            Principal::User(user) => {
 240                result.field("user", &user.github_login);
 241            }
 242            Principal::Impersonated { user, admin } => {
 243                result.field("user", &user.github_login);
 244                result.field("impersonator", &admin.github_login);
 245            }
 246            Principal::DevServer(dev_server) => {
 247                result.field("dev_server", &dev_server.id);
 248            }
 249        }
 250        result.field("connection_id", &self.connection_id).finish()
 251    }
 252}
 253
 254struct UserSession(Session);
 255
 256impl UserSession {
 257    pub fn new(s: Session) -> Option<Self> {
 258        s.user_id().map(|_| UserSession(s))
 259    }
 260    pub fn user_id(&self) -> UserId {
 261        self.0.user_id().unwrap()
 262    }
 263
 264    pub fn email(&self) -> Option<String> {
 265        match &self.0.principal {
 266            Principal::User(user) => user.email_address.clone(),
 267            Principal::Impersonated { user, .. } => user.email_address.clone(),
 268            Principal::DevServer(..) => None,
 269        }
 270    }
 271}
 272
 273impl Deref for UserSession {
 274    type Target = Session;
 275
 276    fn deref(&self) -> &Self::Target {
 277        &self.0
 278    }
 279}
 280impl DerefMut for UserSession {
 281    fn deref_mut(&mut self) -> &mut Self::Target {
 282        &mut self.0
 283    }
 284}
 285
 286struct DevServerSession(Session);
 287
 288impl DevServerSession {
 289    pub fn new(s: Session) -> Option<Self> {
 290        s.dev_server_id().map(|_| DevServerSession(s))
 291    }
 292    pub fn dev_server_id(&self) -> DevServerId {
 293        self.0.dev_server_id().unwrap()
 294    }
 295
 296    fn dev_server(&self) -> &dev_server::Model {
 297        match &self.0.principal {
 298            Principal::DevServer(dev_server) => dev_server,
 299            _ => unreachable!(),
 300        }
 301    }
 302}
 303
 304impl Deref for DevServerSession {
 305    type Target = Session;
 306
 307    fn deref(&self) -> &Self::Target {
 308        &self.0
 309    }
 310}
 311impl DerefMut for DevServerSession {
 312    fn deref_mut(&mut self) -> &mut Self::Target {
 313        &mut self.0
 314    }
 315}
 316
 317fn user_handler<M: RequestMessage, Fut>(
 318    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 319) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 320where
 321    Fut: Send + Future<Output = Result<()>>,
 322{
 323    let handler = Arc::new(handler);
 324    move |message, response, session| {
 325        let handler = handler.clone();
 326        Box::pin(async move {
 327            if let Some(user_session) = session.for_user() {
 328                Ok(handler(message, response, user_session).await?)
 329            } else {
 330                Err(Error::Internal(anyhow!(
 331                    "must be a user to call {}",
 332                    M::NAME
 333                )))
 334            }
 335        })
 336    }
 337}
 338
 339fn dev_server_handler<M: RequestMessage, Fut>(
 340    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 341) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 342where
 343    Fut: Send + Future<Output = Result<()>>,
 344{
 345    let handler = Arc::new(handler);
 346    move |message, response, session| {
 347        let handler = handler.clone();
 348        Box::pin(async move {
 349            if let Some(dev_server_session) = session.for_dev_server() {
 350                Ok(handler(message, response, dev_server_session).await?)
 351            } else {
 352                Err(Error::Internal(anyhow!(
 353                    "must be a dev server to call {}",
 354                    M::NAME
 355                )))
 356            }
 357        })
 358    }
 359}
 360
 361fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 362    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 363) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 364where
 365    InnertRetFut: Send + Future<Output = Result<()>>,
 366{
 367    let handler = Arc::new(handler);
 368    move |message, session| {
 369        let handler = handler.clone();
 370        Box::pin(async move {
 371            if let Some(user_session) = session.for_user() {
 372                Ok(handler(message, user_session).await?)
 373            } else {
 374                Err(Error::Internal(anyhow!(
 375                    "must be a user to call {}",
 376                    M::NAME
 377                )))
 378            }
 379        })
 380    }
 381}
 382
 383struct DbHandle(Arc<Database>);
 384
 385impl Deref for DbHandle {
 386    type Target = Database;
 387
 388    fn deref(&self) -> &Self::Target {
 389        self.0.as_ref()
 390    }
 391}
 392
 393pub struct Server {
 394    id: parking_lot::Mutex<ServerId>,
 395    peer: Arc<Peer>,
 396    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 397    app_state: Arc<AppState>,
 398    handlers: HashMap<TypeId, MessageHandler>,
 399    teardown: watch::Sender<bool>,
 400}
 401
 402pub(crate) struct ConnectionPoolGuard<'a> {
 403    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 404    _not_send: PhantomData<Rc<()>>,
 405}
 406
 407#[derive(Serialize)]
 408pub struct ServerSnapshot<'a> {
 409    peer: &'a Peer,
 410    #[serde(serialize_with = "serialize_deref")]
 411    connection_pool: ConnectionPoolGuard<'a>,
 412}
 413
 414pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 415where
 416    S: Serializer,
 417    T: Deref<Target = U>,
 418    U: Serialize,
 419{
 420    Serialize::serialize(value.deref(), serializer)
 421}
 422
 423impl Server {
 424    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 425        let mut server = Self {
 426            id: parking_lot::Mutex::new(id),
 427            peer: Peer::new(id.0 as u32),
 428            app_state: app_state.clone(),
 429            connection_pool: Default::default(),
 430            handlers: Default::default(),
 431            teardown: watch::channel(false).0,
 432        };
 433
 434        server
 435            .add_request_handler(ping)
 436            .add_request_handler(user_handler(create_room))
 437            .add_request_handler(user_handler(join_room))
 438            .add_request_handler(user_handler(rejoin_room))
 439            .add_request_handler(user_handler(leave_room))
 440            .add_request_handler(user_handler(set_room_participant_role))
 441            .add_request_handler(user_handler(call))
 442            .add_request_handler(user_handler(cancel_call))
 443            .add_message_handler(user_message_handler(decline_call))
 444            .add_request_handler(user_handler(update_participant_location))
 445            .add_request_handler(user_handler(share_project))
 446            .add_message_handler(unshare_project)
 447            .add_request_handler(user_handler(join_project))
 448            .add_request_handler(user_handler(join_hosted_project))
 449            .add_request_handler(user_handler(rejoin_dev_server_projects))
 450            .add_request_handler(user_handler(create_dev_server_project))
 451            .add_request_handler(user_handler(update_dev_server_project))
 452            .add_request_handler(user_handler(delete_dev_server_project))
 453            .add_request_handler(user_handler(create_dev_server))
 454            .add_request_handler(user_handler(regenerate_dev_server_token))
 455            .add_request_handler(user_handler(rename_dev_server))
 456            .add_request_handler(user_handler(delete_dev_server))
 457            .add_request_handler(user_handler(list_remote_directory))
 458            .add_request_handler(dev_server_handler(share_dev_server_project))
 459            .add_request_handler(dev_server_handler(shutdown_dev_server))
 460            .add_request_handler(dev_server_handler(reconnect_dev_server))
 461            .add_message_handler(user_message_handler(leave_project))
 462            .add_request_handler(update_project)
 463            .add_request_handler(update_worktree)
 464            .add_message_handler(start_language_server)
 465            .add_message_handler(update_language_server)
 466            .add_message_handler(update_diagnostic_summary)
 467            .add_message_handler(update_worktree_settings)
 468            .add_request_handler(user_handler(
 469                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 470            ))
 471            .add_request_handler(user_handler(
 472                forward_project_request_for_owner::<proto::TaskTemplates>,
 473            ))
 474            .add_request_handler(user_handler(
 475                forward_read_only_project_request::<proto::GetHover>,
 476            ))
 477            .add_request_handler(user_handler(
 478                forward_read_only_project_request::<proto::GetDefinition>,
 479            ))
 480            .add_request_handler(user_handler(
 481                forward_read_only_project_request::<proto::GetTypeDefinition>,
 482            ))
 483            .add_request_handler(user_handler(
 484                forward_read_only_project_request::<proto::GetReferences>,
 485            ))
 486            .add_request_handler(user_handler(
 487                forward_read_only_project_request::<proto::SearchProject>,
 488            ))
 489            .add_request_handler(user_handler(
 490                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 491            ))
 492            .add_request_handler(user_handler(
 493                forward_read_only_project_request::<proto::GetProjectSymbols>,
 494            ))
 495            .add_request_handler(user_handler(
 496                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 497            ))
 498            .add_request_handler(user_handler(
 499                forward_read_only_project_request::<proto::OpenBufferById>,
 500            ))
 501            .add_request_handler(user_handler(
 502                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 503            ))
 504            .add_request_handler(user_handler(
 505                forward_read_only_project_request::<proto::InlayHints>,
 506            ))
 507            .add_request_handler(user_handler(
 508                forward_read_only_project_request::<proto::OpenBufferByPath>,
 509            ))
 510            .add_request_handler(user_handler(
 511                forward_mutating_project_request::<proto::GetCompletions>,
 512            ))
 513            .add_request_handler(user_handler(
 514                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 515            ))
 516            .add_request_handler(user_handler(
 517                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 518            ))
 519            .add_request_handler(user_handler(
 520                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 521            ))
 522            .add_request_handler(user_handler(
 523                forward_mutating_project_request::<proto::GetCodeActions>,
 524            ))
 525            .add_request_handler(user_handler(
 526                forward_mutating_project_request::<proto::ApplyCodeAction>,
 527            ))
 528            .add_request_handler(user_handler(
 529                forward_mutating_project_request::<proto::PrepareRename>,
 530            ))
 531            .add_request_handler(user_handler(
 532                forward_mutating_project_request::<proto::PerformRename>,
 533            ))
 534            .add_request_handler(user_handler(
 535                forward_mutating_project_request::<proto::ReloadBuffers>,
 536            ))
 537            .add_request_handler(user_handler(
 538                forward_mutating_project_request::<proto::FormatBuffers>,
 539            ))
 540            .add_request_handler(user_handler(
 541                forward_mutating_project_request::<proto::CreateProjectEntry>,
 542            ))
 543            .add_request_handler(user_handler(
 544                forward_mutating_project_request::<proto::RenameProjectEntry>,
 545            ))
 546            .add_request_handler(user_handler(
 547                forward_mutating_project_request::<proto::CopyProjectEntry>,
 548            ))
 549            .add_request_handler(user_handler(
 550                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 551            ))
 552            .add_request_handler(user_handler(
 553                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 554            ))
 555            .add_request_handler(user_handler(
 556                forward_mutating_project_request::<proto::OnTypeFormatting>,
 557            ))
 558            .add_request_handler(user_handler(
 559                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 560            ))
 561            .add_request_handler(user_handler(
 562                forward_mutating_project_request::<proto::BlameBuffer>,
 563            ))
 564            .add_request_handler(user_handler(
 565                forward_mutating_project_request::<proto::MultiLspQuery>,
 566            ))
 567            .add_request_handler(user_handler(
 568                forward_mutating_project_request::<proto::RestartLanguageServers>,
 569            ))
 570            .add_request_handler(user_handler(
 571                forward_mutating_project_request::<proto::LinkedEditingRange>,
 572            ))
 573            .add_message_handler(create_buffer_for_peer)
 574            .add_request_handler(update_buffer)
 575            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 576            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 577            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 578            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 579            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 580            .add_request_handler(get_users)
 581            .add_request_handler(user_handler(fuzzy_search_users))
 582            .add_request_handler(user_handler(request_contact))
 583            .add_request_handler(user_handler(remove_contact))
 584            .add_request_handler(user_handler(respond_to_contact_request))
 585            .add_message_handler(subscribe_to_channels)
 586            .add_request_handler(user_handler(create_channel))
 587            .add_request_handler(user_handler(delete_channel))
 588            .add_request_handler(user_handler(invite_channel_member))
 589            .add_request_handler(user_handler(remove_channel_member))
 590            .add_request_handler(user_handler(set_channel_member_role))
 591            .add_request_handler(user_handler(set_channel_visibility))
 592            .add_request_handler(user_handler(rename_channel))
 593            .add_request_handler(user_handler(join_channel_buffer))
 594            .add_request_handler(user_handler(leave_channel_buffer))
 595            .add_message_handler(user_message_handler(update_channel_buffer))
 596            .add_request_handler(user_handler(rejoin_channel_buffers))
 597            .add_request_handler(user_handler(get_channel_members))
 598            .add_request_handler(user_handler(respond_to_channel_invite))
 599            .add_request_handler(user_handler(join_channel))
 600            .add_request_handler(user_handler(join_channel_chat))
 601            .add_message_handler(user_message_handler(leave_channel_chat))
 602            .add_request_handler(user_handler(send_channel_message))
 603            .add_request_handler(user_handler(remove_channel_message))
 604            .add_request_handler(user_handler(update_channel_message))
 605            .add_request_handler(user_handler(get_channel_messages))
 606            .add_request_handler(user_handler(get_channel_messages_by_id))
 607            .add_request_handler(user_handler(get_notifications))
 608            .add_request_handler(user_handler(mark_notification_as_read))
 609            .add_request_handler(user_handler(move_channel))
 610            .add_request_handler(user_handler(follow))
 611            .add_message_handler(user_message_handler(unfollow))
 612            .add_message_handler(user_message_handler(update_followers))
 613            .add_request_handler(user_handler(get_private_user_info))
 614            .add_message_handler(user_message_handler(acknowledge_channel_message))
 615            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 616            .add_request_handler(user_handler(get_supermaven_api_key))
 617            .add_request_handler(user_handler(
 618                forward_mutating_project_request::<proto::OpenContext>,
 619            ))
 620            .add_request_handler(user_handler(
 621                forward_mutating_project_request::<proto::CreateContext>,
 622            ))
 623            .add_request_handler(user_handler(
 624                forward_mutating_project_request::<proto::SynchronizeContexts>,
 625            ))
 626            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 627            .add_message_handler(update_context)
 628            .add_request_handler({
 629                let app_state = app_state.clone();
 630                move |request, response, session| {
 631                    let app_state = app_state.clone();
 632                    async move {
 633                        complete_with_language_model(request, response, session, &app_state.config)
 634                            .await
 635                    }
 636                }
 637            })
 638            .add_streaming_request_handler({
 639                let app_state = app_state.clone();
 640                move |request, response, session| {
 641                    let app_state = app_state.clone();
 642                    async move {
 643                        stream_complete_with_language_model(
 644                            request,
 645                            response,
 646                            session,
 647                            &app_state.config,
 648                        )
 649                        .await
 650                    }
 651                }
 652            })
 653            .add_request_handler({
 654                let app_state = app_state.clone();
 655                move |request, response, session| {
 656                    let app_state = app_state.clone();
 657                    async move {
 658                        count_language_model_tokens(request, response, session, &app_state.config)
 659                            .await
 660                    }
 661                }
 662            })
 663            .add_request_handler({
 664                user_handler(move |request, response, session| {
 665                    get_cached_embeddings(request, response, session)
 666                })
 667            })
 668            .add_request_handler({
 669                let app_state = app_state.clone();
 670                user_handler(move |request, response, session| {
 671                    compute_embeddings(
 672                        request,
 673                        response,
 674                        session,
 675                        app_state.config.openai_api_key.clone(),
 676                    )
 677                })
 678            });
 679
 680        Arc::new(server)
 681    }
 682
 683    pub async fn start(&self) -> Result<()> {
 684        let server_id = *self.id.lock();
 685        let app_state = self.app_state.clone();
 686        let peer = self.peer.clone();
 687        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 688        let pool = self.connection_pool.clone();
 689        let live_kit_client = self.app_state.live_kit_client.clone();
 690
 691        let span = info_span!("start server");
 692        self.app_state.executor.spawn_detached(
 693            async move {
 694                tracing::info!("waiting for cleanup timeout");
 695                timeout.await;
 696                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 697                if let Some((room_ids, channel_ids)) = app_state
 698                    .db
 699                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 700                    .await
 701                    .trace_err()
 702                {
 703                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 704                    tracing::info!(
 705                        stale_channel_buffer_count = channel_ids.len(),
 706                        "retrieved stale channel buffers"
 707                    );
 708
 709                    for channel_id in channel_ids {
 710                        if let Some(refreshed_channel_buffer) = app_state
 711                            .db
 712                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 713                            .await
 714                            .trace_err()
 715                        {
 716                            for connection_id in refreshed_channel_buffer.connection_ids {
 717                                peer.send(
 718                                    connection_id,
 719                                    proto::UpdateChannelBufferCollaborators {
 720                                        channel_id: channel_id.to_proto(),
 721                                        collaborators: refreshed_channel_buffer
 722                                            .collaborators
 723                                            .clone(),
 724                                    },
 725                                )
 726                                .trace_err();
 727                            }
 728                        }
 729                    }
 730
 731                    for room_id in room_ids {
 732                        let mut contacts_to_update = HashSet::default();
 733                        let mut canceled_calls_to_user_ids = Vec::new();
 734                        let mut live_kit_room = String::new();
 735                        let mut delete_live_kit_room = false;
 736
 737                        if let Some(mut refreshed_room) = app_state
 738                            .db
 739                            .clear_stale_room_participants(room_id, server_id)
 740                            .await
 741                            .trace_err()
 742                        {
 743                            tracing::info!(
 744                                room_id = room_id.0,
 745                                new_participant_count = refreshed_room.room.participants.len(),
 746                                "refreshed room"
 747                            );
 748                            room_updated(&refreshed_room.room, &peer);
 749                            if let Some(channel) = refreshed_room.channel.as_ref() {
 750                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 751                            }
 752                            contacts_to_update
 753                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 754                            contacts_to_update
 755                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 756                            canceled_calls_to_user_ids =
 757                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 758                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 759                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 760                        }
 761
 762                        {
 763                            let pool = pool.lock();
 764                            for canceled_user_id in canceled_calls_to_user_ids {
 765                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 766                                    peer.send(
 767                                        connection_id,
 768                                        proto::CallCanceled {
 769                                            room_id: room_id.to_proto(),
 770                                        },
 771                                    )
 772                                    .trace_err();
 773                                }
 774                            }
 775                        }
 776
 777                        for user_id in contacts_to_update {
 778                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 779                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 780                            if let Some((busy, contacts)) = busy.zip(contacts) {
 781                                let pool = pool.lock();
 782                                let updated_contact = contact_for_user(user_id, busy, &pool);
 783                                for contact in contacts {
 784                                    if let db::Contact::Accepted {
 785                                        user_id: contact_user_id,
 786                                        ..
 787                                    } = contact
 788                                    {
 789                                        for contact_conn_id in
 790                                            pool.user_connection_ids(contact_user_id)
 791                                        {
 792                                            peer.send(
 793                                                contact_conn_id,
 794                                                proto::UpdateContacts {
 795                                                    contacts: vec![updated_contact.clone()],
 796                                                    remove_contacts: Default::default(),
 797                                                    incoming_requests: Default::default(),
 798                                                    remove_incoming_requests: Default::default(),
 799                                                    outgoing_requests: Default::default(),
 800                                                    remove_outgoing_requests: Default::default(),
 801                                                },
 802                                            )
 803                                            .trace_err();
 804                                        }
 805                                    }
 806                                }
 807                            }
 808                        }
 809
 810                        if let Some(live_kit) = live_kit_client.as_ref() {
 811                            if delete_live_kit_room {
 812                                live_kit.delete_room(live_kit_room).await.trace_err();
 813                            }
 814                        }
 815                    }
 816                }
 817
 818                app_state
 819                    .db
 820                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 821                    .await
 822                    .trace_err();
 823            }
 824            .instrument(span),
 825        );
 826        Ok(())
 827    }
 828
 829    pub fn teardown(&self) {
 830        self.peer.teardown();
 831        self.connection_pool.lock().reset();
 832        let _ = self.teardown.send(true);
 833    }
 834
 835    #[cfg(test)]
 836    pub fn reset(&self, id: ServerId) {
 837        self.teardown();
 838        *self.id.lock() = id;
 839        self.peer.reset(id.0 as u32);
 840        let _ = self.teardown.send(false);
 841    }
 842
 843    #[cfg(test)]
 844    pub fn id(&self) -> ServerId {
 845        *self.id.lock()
 846    }
 847
 848    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 849    where
 850        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 851        Fut: 'static + Send + Future<Output = Result<()>>,
 852        M: EnvelopedMessage,
 853    {
 854        let prev_handler = self.handlers.insert(
 855            TypeId::of::<M>(),
 856            Box::new(move |envelope, session| {
 857                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 858                let received_at = envelope.received_at;
 859                tracing::info!("message received");
 860                let start_time = Instant::now();
 861                let future = (handler)(*envelope, session);
 862                async move {
 863                    let result = future.await;
 864                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 865                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 866                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 867                    let payload_type = M::NAME;
 868
 869                    match result {
 870                        Err(error) => {
 871                            tracing::error!(
 872                                ?error,
 873                                total_duration_ms,
 874                                processing_duration_ms,
 875                                queue_duration_ms,
 876                                payload_type,
 877                                "error handling message"
 878                            )
 879                        }
 880                        Ok(()) => tracing::info!(
 881                            total_duration_ms,
 882                            processing_duration_ms,
 883                            queue_duration_ms,
 884                            "finished handling message"
 885                        ),
 886                    }
 887                }
 888                .boxed()
 889            }),
 890        );
 891        if prev_handler.is_some() {
 892            panic!("registered a handler for the same message twice");
 893        }
 894        self
 895    }
 896
 897    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 898    where
 899        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 900        Fut: 'static + Send + Future<Output = Result<()>>,
 901        M: EnvelopedMessage,
 902    {
 903        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 904        self
 905    }
 906
 907    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 908    where
 909        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 910        Fut: Send + Future<Output = Result<()>>,
 911        M: RequestMessage,
 912    {
 913        let handler = Arc::new(handler);
 914        self.add_handler(move |envelope, session| {
 915            let receipt = envelope.receipt();
 916            let handler = handler.clone();
 917            async move {
 918                let peer = session.peer.clone();
 919                let responded = Arc::new(AtomicBool::default());
 920                let response = Response {
 921                    peer: peer.clone(),
 922                    responded: responded.clone(),
 923                    receipt,
 924                };
 925                match (handler)(envelope.payload, response, session).await {
 926                    Ok(()) => {
 927                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 928                            Ok(())
 929                        } else {
 930                            Err(anyhow!("handler did not send a response"))?
 931                        }
 932                    }
 933                    Err(error) => {
 934                        let proto_err = match &error {
 935                            Error::Internal(err) => err.to_proto(),
 936                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 937                        };
 938                        peer.respond_with_error(receipt, proto_err)?;
 939                        Err(error)
 940                    }
 941                }
 942            }
 943        })
 944    }
 945
 946    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 947    where
 948        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 949        Fut: Send + Future<Output = Result<()>>,
 950        M: RequestMessage,
 951    {
 952        let handler = Arc::new(handler);
 953        self.add_handler(move |envelope, session| {
 954            let receipt = envelope.receipt();
 955            let handler = handler.clone();
 956            async move {
 957                let peer = session.peer.clone();
 958                let response = StreamingResponse {
 959                    peer: peer.clone(),
 960                    receipt,
 961                };
 962                match (handler)(envelope.payload, response, session).await {
 963                    Ok(()) => {
 964                        peer.end_stream(receipt)?;
 965                        Ok(())
 966                    }
 967                    Err(error) => {
 968                        let proto_err = match &error {
 969                            Error::Internal(err) => err.to_proto(),
 970                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 971                        };
 972                        peer.respond_with_error(receipt, proto_err)?;
 973                        Err(error)
 974                    }
 975                }
 976            }
 977        })
 978    }
 979
 980    #[allow(clippy::too_many_arguments)]
 981    pub fn handle_connection(
 982        self: &Arc<Self>,
 983        connection: Connection,
 984        address: String,
 985        principal: Principal,
 986        zed_version: ZedVersion,
 987        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 988        executor: Executor,
 989    ) -> impl Future<Output = ()> {
 990        let this = self.clone();
 991        let span = info_span!("handle connection", %address,
 992            connection_id=field::Empty,
 993            user_id=field::Empty,
 994            login=field::Empty,
 995            impersonator=field::Empty,
 996            dev_server_id=field::Empty
 997        );
 998        principal.update_span(&span);
 999
1000        let mut teardown = self.teardown.subscribe();
1001        async move {
1002            if *teardown.borrow() {
1003                tracing::error!("server is tearing down");
1004                return
1005            }
1006            let (connection_id, handle_io, mut incoming_rx) = this
1007                .peer
1008                .add_connection(connection, {
1009                    let executor = executor.clone();
1010                    move |duration| executor.sleep(duration)
1011                });
1012            tracing::Span::current().record("connection_id", format!("{}", connection_id));
1013            tracing::info!("connection opened");
1014
1015            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
1016            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1017                Ok(http_client) => Arc::new(http_client),
1018                Err(error) => {
1019                    tracing::error!(?error, "failed to create HTTP client");
1020                    return;
1021                }
1022            };
1023
1024            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1025                Some(Arc::new(SupermavenAdminApi::new(
1026                    supermaven_admin_api_key.to_string(),
1027                    http_client.clone(),
1028                )))
1029            } else {
1030                None
1031            };
1032
1033            let session = Session {
1034                principal: principal.clone(),
1035                connection_id,
1036                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1037                peer: this.peer.clone(),
1038                connection_pool: this.connection_pool.clone(),
1039                live_kit_client: this.app_state.live_kit_client.clone(),
1040                http_client,
1041                rate_limiter: this.app_state.rate_limiter.clone(),
1042                _executor: executor.clone(),
1043                supermaven_client,
1044            };
1045
1046            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1047                tracing::error!(?error, "failed to send initial client update");
1048                return;
1049            }
1050
1051            let handle_io = handle_io.fuse();
1052            futures::pin_mut!(handle_io);
1053
1054            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1055            // This prevents deadlocks when e.g., client A performs a request to client B and
1056            // client B performs a request to client A. If both clients stop processing further
1057            // messages until their respective request completes, they won't have a chance to
1058            // respond to the other client's request and cause a deadlock.
1059            //
1060            // This arrangement ensures we will attempt to process earlier messages first, but fall
1061            // back to processing messages arrived later in the spirit of making progress.
1062            let mut foreground_message_handlers = FuturesUnordered::new();
1063            let concurrent_handlers = Arc::new(Semaphore::new(256));
1064            loop {
1065                let next_message = async {
1066                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1067                    let message = incoming_rx.next().await;
1068                    (permit, message)
1069                }.fuse();
1070                futures::pin_mut!(next_message);
1071                futures::select_biased! {
1072                    _ = teardown.changed().fuse() => return,
1073                    result = handle_io => {
1074                        if let Err(error) = result {
1075                            tracing::error!(?error, "error handling I/O");
1076                        }
1077                        break;
1078                    }
1079                    _ = foreground_message_handlers.next() => {}
1080                    next_message = next_message => {
1081                        let (permit, message) = next_message;
1082                        if let Some(message) = message {
1083                            let type_name = message.payload_type_name();
1084                            // note: we copy all the fields from the parent span so we can query them in the logs.
1085                            // (https://github.com/tokio-rs/tracing/issues/2670).
1086                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1087                                user_id=field::Empty,
1088                                login=field::Empty,
1089                                impersonator=field::Empty,
1090                                dev_server_id=field::Empty
1091                            );
1092                            principal.update_span(&span);
1093                            let span_enter = span.enter();
1094                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1095                                let is_background = message.is_background();
1096                                let handle_message = (handler)(message, session.clone());
1097                                drop(span_enter);
1098
1099                                let handle_message = async move {
1100                                    handle_message.await;
1101                                    drop(permit);
1102                                }.instrument(span);
1103                                if is_background {
1104                                    executor.spawn_detached(handle_message);
1105                                } else {
1106                                    foreground_message_handlers.push(handle_message);
1107                                }
1108                            } else {
1109                                tracing::error!("no message handler");
1110                            }
1111                        } else {
1112                            tracing::info!("connection closed");
1113                            break;
1114                        }
1115                    }
1116                }
1117            }
1118
1119            drop(foreground_message_handlers);
1120            tracing::info!("signing out");
1121            if let Err(error) = connection_lost(session, teardown, executor).await {
1122                tracing::error!(?error, "error signing out");
1123            }
1124
1125        }.instrument(span)
1126    }
1127
1128    async fn send_initial_client_update(
1129        &self,
1130        connection_id: ConnectionId,
1131        principal: &Principal,
1132        zed_version: ZedVersion,
1133        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1134        session: &Session,
1135    ) -> Result<()> {
1136        self.peer.send(
1137            connection_id,
1138            proto::Hello {
1139                peer_id: Some(connection_id.into()),
1140            },
1141        )?;
1142        tracing::info!("sent hello message");
1143        if let Some(send_connection_id) = send_connection_id.take() {
1144            let _ = send_connection_id.send(connection_id);
1145        }
1146
1147        match principal {
1148            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1149                if !user.connected_once {
1150                    self.peer.send(connection_id, proto::ShowContacts {})?;
1151                    self.app_state
1152                        .db
1153                        .set_user_connected_once(user.id, true)
1154                        .await?;
1155                }
1156
1157                update_user_plan(user.id, session).await?;
1158
1159                let (contacts, dev_server_projects) = future::try_join(
1160                    self.app_state.db.get_contacts(user.id),
1161                    self.app_state.db.dev_server_projects_update(user.id),
1162                )
1163                .await?;
1164
1165                {
1166                    let mut pool = self.connection_pool.lock();
1167                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1168                    self.peer.send(
1169                        connection_id,
1170                        build_initial_contacts_update(contacts, &pool),
1171                    )?;
1172                }
1173
1174                if should_auto_subscribe_to_channels(zed_version) {
1175                    subscribe_user_to_channels(user.id, session).await?;
1176                }
1177
1178                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1179
1180                if let Some(incoming_call) =
1181                    self.app_state.db.incoming_call_for_user(user.id).await?
1182                {
1183                    self.peer.send(connection_id, incoming_call)?;
1184                }
1185
1186                update_user_contacts(user.id, &session).await?;
1187            }
1188            Principal::DevServer(dev_server) => {
1189                {
1190                    let mut pool = self.connection_pool.lock();
1191                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1192                    {
1193                        self.peer.send(
1194                            stale_connection_id,
1195                            proto::ShutdownDevServer {
1196                                reason: Some(
1197                                    "another dev server connected with the same token".to_string(),
1198                                ),
1199                            },
1200                        )?;
1201                        pool.remove_connection(stale_connection_id)?;
1202                    };
1203                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1204                }
1205
1206                let projects = self
1207                    .app_state
1208                    .db
1209                    .get_projects_for_dev_server(dev_server.id)
1210                    .await?;
1211                self.peer
1212                    .send(connection_id, proto::DevServerInstructions { projects })?;
1213
1214                let status = self
1215                    .app_state
1216                    .db
1217                    .dev_server_projects_update(dev_server.user_id)
1218                    .await?;
1219                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1220            }
1221        }
1222
1223        Ok(())
1224    }
1225
1226    pub async fn invite_code_redeemed(
1227        self: &Arc<Self>,
1228        inviter_id: UserId,
1229        invitee_id: UserId,
1230    ) -> Result<()> {
1231        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1232            if let Some(code) = &user.invite_code {
1233                let pool = self.connection_pool.lock();
1234                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1235                for connection_id in pool.user_connection_ids(inviter_id) {
1236                    self.peer.send(
1237                        connection_id,
1238                        proto::UpdateContacts {
1239                            contacts: vec![invitee_contact.clone()],
1240                            ..Default::default()
1241                        },
1242                    )?;
1243                    self.peer.send(
1244                        connection_id,
1245                        proto::UpdateInviteInfo {
1246                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1247                            count: user.invite_count as u32,
1248                        },
1249                    )?;
1250                }
1251            }
1252        }
1253        Ok(())
1254    }
1255
1256    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1257        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1258            if let Some(invite_code) = &user.invite_code {
1259                let pool = self.connection_pool.lock();
1260                for connection_id in pool.user_connection_ids(user_id) {
1261                    self.peer.send(
1262                        connection_id,
1263                        proto::UpdateInviteInfo {
1264                            url: format!(
1265                                "{}{}",
1266                                self.app_state.config.invite_link_prefix, invite_code
1267                            ),
1268                            count: user.invite_count as u32,
1269                        },
1270                    )?;
1271                }
1272            }
1273        }
1274        Ok(())
1275    }
1276
1277    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1278        ServerSnapshot {
1279            connection_pool: ConnectionPoolGuard {
1280                guard: self.connection_pool.lock(),
1281                _not_send: PhantomData,
1282            },
1283            peer: &self.peer,
1284        }
1285    }
1286}
1287
1288impl<'a> Deref for ConnectionPoolGuard<'a> {
1289    type Target = ConnectionPool;
1290
1291    fn deref(&self) -> &Self::Target {
1292        &self.guard
1293    }
1294}
1295
1296impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1297    fn deref_mut(&mut self) -> &mut Self::Target {
1298        &mut self.guard
1299    }
1300}
1301
1302impl<'a> Drop for ConnectionPoolGuard<'a> {
1303    fn drop(&mut self) {
1304        #[cfg(test)]
1305        self.check_invariants();
1306    }
1307}
1308
1309fn broadcast<F>(
1310    sender_id: Option<ConnectionId>,
1311    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1312    mut f: F,
1313) where
1314    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1315{
1316    for receiver_id in receiver_ids {
1317        if Some(receiver_id) != sender_id {
1318            if let Err(error) = f(receiver_id) {
1319                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1320            }
1321        }
1322    }
1323}
1324
1325pub struct ProtocolVersion(u32);
1326
1327impl Header for ProtocolVersion {
1328    fn name() -> &'static HeaderName {
1329        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1330        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1331    }
1332
1333    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1334    where
1335        Self: Sized,
1336        I: Iterator<Item = &'i axum::http::HeaderValue>,
1337    {
1338        let version = values
1339            .next()
1340            .ok_or_else(axum::headers::Error::invalid)?
1341            .to_str()
1342            .map_err(|_| axum::headers::Error::invalid())?
1343            .parse()
1344            .map_err(|_| axum::headers::Error::invalid())?;
1345        Ok(Self(version))
1346    }
1347
1348    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1349        values.extend([self.0.to_string().parse().unwrap()]);
1350    }
1351}
1352
1353pub struct AppVersionHeader(SemanticVersion);
1354impl Header for AppVersionHeader {
1355    fn name() -> &'static HeaderName {
1356        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1357        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1358    }
1359
1360    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1361    where
1362        Self: Sized,
1363        I: Iterator<Item = &'i axum::http::HeaderValue>,
1364    {
1365        let version = values
1366            .next()
1367            .ok_or_else(axum::headers::Error::invalid)?
1368            .to_str()
1369            .map_err(|_| axum::headers::Error::invalid())?
1370            .parse()
1371            .map_err(|_| axum::headers::Error::invalid())?;
1372        Ok(Self(version))
1373    }
1374
1375    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1376        values.extend([self.0.to_string().parse().unwrap()]);
1377    }
1378}
1379
1380pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1381    Router::new()
1382        .route("/rpc", get(handle_websocket_request))
1383        .layer(
1384            ServiceBuilder::new()
1385                .layer(Extension(server.app_state.clone()))
1386                .layer(middleware::from_fn(auth::validate_header)),
1387        )
1388        .route("/metrics", get(handle_metrics))
1389        .layer(Extension(server))
1390}
1391
1392pub async fn handle_websocket_request(
1393    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1394    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1395    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1396    Extension(server): Extension<Arc<Server>>,
1397    Extension(principal): Extension<Principal>,
1398    ws: WebSocketUpgrade,
1399) -> axum::response::Response {
1400    if protocol_version != rpc::PROTOCOL_VERSION {
1401        return (
1402            StatusCode::UPGRADE_REQUIRED,
1403            "client must be upgraded".to_string(),
1404        )
1405            .into_response();
1406    }
1407
1408    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1409        return (
1410            StatusCode::UPGRADE_REQUIRED,
1411            "no version header found".to_string(),
1412        )
1413            .into_response();
1414    };
1415
1416    if !version.can_collaborate() {
1417        return (
1418            StatusCode::UPGRADE_REQUIRED,
1419            "client must be upgraded".to_string(),
1420        )
1421            .into_response();
1422    }
1423
1424    let socket_address = socket_address.to_string();
1425    ws.on_upgrade(move |socket| {
1426        let socket = socket
1427            .map_ok(to_tungstenite_message)
1428            .err_into()
1429            .with(|message| async move { to_axum_message(message) });
1430        let connection = Connection::new(Box::pin(socket));
1431        async move {
1432            server
1433                .handle_connection(
1434                    connection,
1435                    socket_address,
1436                    principal,
1437                    version,
1438                    None,
1439                    Executor::Production,
1440                )
1441                .await;
1442        }
1443    })
1444}
1445
1446pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1447    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1448    let connections_metric = CONNECTIONS_METRIC
1449        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1450
1451    let connections = server
1452        .connection_pool
1453        .lock()
1454        .connections()
1455        .filter(|connection| !connection.admin)
1456        .count();
1457    connections_metric.set(connections as _);
1458
1459    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1460    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1461        register_int_gauge!(
1462            "shared_projects",
1463            "number of open projects with one or more guests"
1464        )
1465        .unwrap()
1466    });
1467
1468    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1469    shared_projects_metric.set(shared_projects as _);
1470
1471    let encoder = prometheus::TextEncoder::new();
1472    let metric_families = prometheus::gather();
1473    let encoded_metrics = encoder
1474        .encode_to_string(&metric_families)
1475        .map_err(|err| anyhow!("{}", err))?;
1476    Ok(encoded_metrics)
1477}
1478
1479#[instrument(err, skip(executor))]
1480async fn connection_lost(
1481    session: Session,
1482    mut teardown: watch::Receiver<bool>,
1483    executor: Executor,
1484) -> Result<()> {
1485    session.peer.disconnect(session.connection_id);
1486    session
1487        .connection_pool()
1488        .await
1489        .remove_connection(session.connection_id)?;
1490
1491    session
1492        .db()
1493        .await
1494        .connection_lost(session.connection_id)
1495        .await
1496        .trace_err();
1497
1498    futures::select_biased! {
1499        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1500            match &session.principal {
1501                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1502                    let session = session.for_user().unwrap();
1503
1504                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1505                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1506                    leave_channel_buffers_for_session(&session)
1507                        .await
1508                        .trace_err();
1509
1510                    if !session
1511                        .connection_pool()
1512                        .await
1513                        .is_user_online(session.user_id())
1514                    {
1515                        let db = session.db().await;
1516                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1517                            room_updated(&room, &session.peer);
1518                        }
1519                    }
1520
1521                    update_user_contacts(session.user_id(), &session).await?;
1522                },
1523            Principal::DevServer(_) => {
1524                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1525            },
1526        }
1527        },
1528        _ = teardown.changed().fuse() => {}
1529    }
1530
1531    Ok(())
1532}
1533
1534/// Acknowledges a ping from a client, used to keep the connection alive.
1535async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1536    response.send(proto::Ack {})?;
1537    Ok(())
1538}
1539
1540/// Creates a new room for calling (outside of channels)
1541async fn create_room(
1542    _request: proto::CreateRoom,
1543    response: Response<proto::CreateRoom>,
1544    session: UserSession,
1545) -> Result<()> {
1546    let live_kit_room = nanoid::nanoid!(30);
1547
1548    let live_kit_connection_info = util::maybe!(async {
1549        let live_kit = session.live_kit_client.as_ref();
1550        let live_kit = live_kit?;
1551        let user_id = session.user_id().to_string();
1552
1553        let token = live_kit
1554            .room_token(&live_kit_room, &user_id.to_string())
1555            .trace_err()?;
1556
1557        Some(proto::LiveKitConnectionInfo {
1558            server_url: live_kit.url().into(),
1559            token,
1560            can_publish: true,
1561        })
1562    })
1563    .await;
1564
1565    let room = session
1566        .db()
1567        .await
1568        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1569        .await?;
1570
1571    response.send(proto::CreateRoomResponse {
1572        room: Some(room.clone()),
1573        live_kit_connection_info,
1574    })?;
1575
1576    update_user_contacts(session.user_id(), &session).await?;
1577    Ok(())
1578}
1579
1580/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1581async fn join_room(
1582    request: proto::JoinRoom,
1583    response: Response<proto::JoinRoom>,
1584    session: UserSession,
1585) -> Result<()> {
1586    let room_id = RoomId::from_proto(request.id);
1587
1588    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1589
1590    if let Some(channel_id) = channel_id {
1591        return join_channel_internal(channel_id, Box::new(response), session).await;
1592    }
1593
1594    let joined_room = {
1595        let room = session
1596            .db()
1597            .await
1598            .join_room(room_id, session.user_id(), session.connection_id)
1599            .await?;
1600        room_updated(&room.room, &session.peer);
1601        room.into_inner()
1602    };
1603
1604    for connection_id in session
1605        .connection_pool()
1606        .await
1607        .user_connection_ids(session.user_id())
1608    {
1609        session
1610            .peer
1611            .send(
1612                connection_id,
1613                proto::CallCanceled {
1614                    room_id: room_id.to_proto(),
1615                },
1616            )
1617            .trace_err();
1618    }
1619
1620    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1621        if let Some(token) = live_kit
1622            .room_token(
1623                &joined_room.room.live_kit_room,
1624                &session.user_id().to_string(),
1625            )
1626            .trace_err()
1627        {
1628            Some(proto::LiveKitConnectionInfo {
1629                server_url: live_kit.url().into(),
1630                token,
1631                can_publish: true,
1632            })
1633        } else {
1634            None
1635        }
1636    } else {
1637        None
1638    };
1639
1640    response.send(proto::JoinRoomResponse {
1641        room: Some(joined_room.room),
1642        channel_id: None,
1643        live_kit_connection_info,
1644    })?;
1645
1646    update_user_contacts(session.user_id(), &session).await?;
1647    Ok(())
1648}
1649
1650/// Rejoin room is used to reconnect to a room after connection errors.
1651async fn rejoin_room(
1652    request: proto::RejoinRoom,
1653    response: Response<proto::RejoinRoom>,
1654    session: UserSession,
1655) -> Result<()> {
1656    let room;
1657    let channel;
1658    {
1659        let mut rejoined_room = session
1660            .db()
1661            .await
1662            .rejoin_room(request, session.user_id(), session.connection_id)
1663            .await?;
1664
1665        response.send(proto::RejoinRoomResponse {
1666            room: Some(rejoined_room.room.clone()),
1667            reshared_projects: rejoined_room
1668                .reshared_projects
1669                .iter()
1670                .map(|project| proto::ResharedProject {
1671                    id: project.id.to_proto(),
1672                    collaborators: project
1673                        .collaborators
1674                        .iter()
1675                        .map(|collaborator| collaborator.to_proto())
1676                        .collect(),
1677                })
1678                .collect(),
1679            rejoined_projects: rejoined_room
1680                .rejoined_projects
1681                .iter()
1682                .map(|rejoined_project| rejoined_project.to_proto())
1683                .collect(),
1684        })?;
1685        room_updated(&rejoined_room.room, &session.peer);
1686
1687        for project in &rejoined_room.reshared_projects {
1688            for collaborator in &project.collaborators {
1689                session
1690                    .peer
1691                    .send(
1692                        collaborator.connection_id,
1693                        proto::UpdateProjectCollaborator {
1694                            project_id: project.id.to_proto(),
1695                            old_peer_id: Some(project.old_connection_id.into()),
1696                            new_peer_id: Some(session.connection_id.into()),
1697                        },
1698                    )
1699                    .trace_err();
1700            }
1701
1702            broadcast(
1703                Some(session.connection_id),
1704                project
1705                    .collaborators
1706                    .iter()
1707                    .map(|collaborator| collaborator.connection_id),
1708                |connection_id| {
1709                    session.peer.forward_send(
1710                        session.connection_id,
1711                        connection_id,
1712                        proto::UpdateProject {
1713                            project_id: project.id.to_proto(),
1714                            worktrees: project.worktrees.clone(),
1715                        },
1716                    )
1717                },
1718            );
1719        }
1720
1721        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1722
1723        let rejoined_room = rejoined_room.into_inner();
1724
1725        room = rejoined_room.room;
1726        channel = rejoined_room.channel;
1727    }
1728
1729    if let Some(channel) = channel {
1730        channel_updated(
1731            &channel,
1732            &room,
1733            &session.peer,
1734            &*session.connection_pool().await,
1735        );
1736    }
1737
1738    update_user_contacts(session.user_id(), &session).await?;
1739    Ok(())
1740}
1741
1742fn notify_rejoined_projects(
1743    rejoined_projects: &mut Vec<RejoinedProject>,
1744    session: &UserSession,
1745) -> Result<()> {
1746    for project in rejoined_projects.iter() {
1747        for collaborator in &project.collaborators {
1748            session
1749                .peer
1750                .send(
1751                    collaborator.connection_id,
1752                    proto::UpdateProjectCollaborator {
1753                        project_id: project.id.to_proto(),
1754                        old_peer_id: Some(project.old_connection_id.into()),
1755                        new_peer_id: Some(session.connection_id.into()),
1756                    },
1757                )
1758                .trace_err();
1759        }
1760    }
1761
1762    for project in rejoined_projects {
1763        for worktree in mem::take(&mut project.worktrees) {
1764            #[cfg(any(test, feature = "test-support"))]
1765            const MAX_CHUNK_SIZE: usize = 2;
1766            #[cfg(not(any(test, feature = "test-support")))]
1767            const MAX_CHUNK_SIZE: usize = 256;
1768
1769            // Stream this worktree's entries.
1770            let message = proto::UpdateWorktree {
1771                project_id: project.id.to_proto(),
1772                worktree_id: worktree.id,
1773                abs_path: worktree.abs_path.clone(),
1774                root_name: worktree.root_name,
1775                updated_entries: worktree.updated_entries,
1776                removed_entries: worktree.removed_entries,
1777                scan_id: worktree.scan_id,
1778                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1779                updated_repositories: worktree.updated_repositories,
1780                removed_repositories: worktree.removed_repositories,
1781            };
1782            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1783                session.peer.send(session.connection_id, update.clone())?;
1784            }
1785
1786            // Stream this worktree's diagnostics.
1787            for summary in worktree.diagnostic_summaries {
1788                session.peer.send(
1789                    session.connection_id,
1790                    proto::UpdateDiagnosticSummary {
1791                        project_id: project.id.to_proto(),
1792                        worktree_id: worktree.id,
1793                        summary: Some(summary),
1794                    },
1795                )?;
1796            }
1797
1798            for settings_file in worktree.settings_files {
1799                session.peer.send(
1800                    session.connection_id,
1801                    proto::UpdateWorktreeSettings {
1802                        project_id: project.id.to_proto(),
1803                        worktree_id: worktree.id,
1804                        path: settings_file.path,
1805                        content: Some(settings_file.content),
1806                    },
1807                )?;
1808            }
1809        }
1810
1811        for language_server in &project.language_servers {
1812            session.peer.send(
1813                session.connection_id,
1814                proto::UpdateLanguageServer {
1815                    project_id: project.id.to_proto(),
1816                    language_server_id: language_server.id,
1817                    variant: Some(
1818                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1819                            proto::LspDiskBasedDiagnosticsUpdated {},
1820                        ),
1821                    ),
1822                },
1823            )?;
1824        }
1825    }
1826    Ok(())
1827}
1828
1829/// leave room disconnects from the room.
1830async fn leave_room(
1831    _: proto::LeaveRoom,
1832    response: Response<proto::LeaveRoom>,
1833    session: UserSession,
1834) -> Result<()> {
1835    leave_room_for_session(&session, session.connection_id).await?;
1836    response.send(proto::Ack {})?;
1837    Ok(())
1838}
1839
1840/// Updates the permissions of someone else in the room.
1841async fn set_room_participant_role(
1842    request: proto::SetRoomParticipantRole,
1843    response: Response<proto::SetRoomParticipantRole>,
1844    session: UserSession,
1845) -> Result<()> {
1846    let user_id = UserId::from_proto(request.user_id);
1847    let role = ChannelRole::from(request.role());
1848
1849    let (live_kit_room, can_publish) = {
1850        let room = session
1851            .db()
1852            .await
1853            .set_room_participant_role(
1854                session.user_id(),
1855                RoomId::from_proto(request.room_id),
1856                user_id,
1857                role,
1858            )
1859            .await?;
1860
1861        let live_kit_room = room.live_kit_room.clone();
1862        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1863        room_updated(&room, &session.peer);
1864        (live_kit_room, can_publish)
1865    };
1866
1867    if let Some(live_kit) = session.live_kit_client.as_ref() {
1868        live_kit
1869            .update_participant(
1870                live_kit_room.clone(),
1871                request.user_id.to_string(),
1872                live_kit_server::proto::ParticipantPermission {
1873                    can_subscribe: true,
1874                    can_publish,
1875                    can_publish_data: can_publish,
1876                    hidden: false,
1877                    recorder: false,
1878                },
1879            )
1880            .await
1881            .trace_err();
1882    }
1883
1884    response.send(proto::Ack {})?;
1885    Ok(())
1886}
1887
1888/// Call someone else into the current room
1889async fn call(
1890    request: proto::Call,
1891    response: Response<proto::Call>,
1892    session: UserSession,
1893) -> Result<()> {
1894    let room_id = RoomId::from_proto(request.room_id);
1895    let calling_user_id = session.user_id();
1896    let calling_connection_id = session.connection_id;
1897    let called_user_id = UserId::from_proto(request.called_user_id);
1898    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1899    if !session
1900        .db()
1901        .await
1902        .has_contact(calling_user_id, called_user_id)
1903        .await?
1904    {
1905        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1906    }
1907
1908    let incoming_call = {
1909        let (room, incoming_call) = &mut *session
1910            .db()
1911            .await
1912            .call(
1913                room_id,
1914                calling_user_id,
1915                calling_connection_id,
1916                called_user_id,
1917                initial_project_id,
1918            )
1919            .await?;
1920        room_updated(&room, &session.peer);
1921        mem::take(incoming_call)
1922    };
1923    update_user_contacts(called_user_id, &session).await?;
1924
1925    let mut calls = session
1926        .connection_pool()
1927        .await
1928        .user_connection_ids(called_user_id)
1929        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1930        .collect::<FuturesUnordered<_>>();
1931
1932    while let Some(call_response) = calls.next().await {
1933        match call_response.as_ref() {
1934            Ok(_) => {
1935                response.send(proto::Ack {})?;
1936                return Ok(());
1937            }
1938            Err(_) => {
1939                call_response.trace_err();
1940            }
1941        }
1942    }
1943
1944    {
1945        let room = session
1946            .db()
1947            .await
1948            .call_failed(room_id, called_user_id)
1949            .await?;
1950        room_updated(&room, &session.peer);
1951    }
1952    update_user_contacts(called_user_id, &session).await?;
1953
1954    Err(anyhow!("failed to ring user"))?
1955}
1956
1957/// Cancel an outgoing call.
1958async fn cancel_call(
1959    request: proto::CancelCall,
1960    response: Response<proto::CancelCall>,
1961    session: UserSession,
1962) -> Result<()> {
1963    let called_user_id = UserId::from_proto(request.called_user_id);
1964    let room_id = RoomId::from_proto(request.room_id);
1965    {
1966        let room = session
1967            .db()
1968            .await
1969            .cancel_call(room_id, session.connection_id, called_user_id)
1970            .await?;
1971        room_updated(&room, &session.peer);
1972    }
1973
1974    for connection_id in session
1975        .connection_pool()
1976        .await
1977        .user_connection_ids(called_user_id)
1978    {
1979        session
1980            .peer
1981            .send(
1982                connection_id,
1983                proto::CallCanceled {
1984                    room_id: room_id.to_proto(),
1985                },
1986            )
1987            .trace_err();
1988    }
1989    response.send(proto::Ack {})?;
1990
1991    update_user_contacts(called_user_id, &session).await?;
1992    Ok(())
1993}
1994
1995/// Decline an incoming call.
1996async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1997    let room_id = RoomId::from_proto(message.room_id);
1998    {
1999        let room = session
2000            .db()
2001            .await
2002            .decline_call(Some(room_id), session.user_id())
2003            .await?
2004            .ok_or_else(|| anyhow!("failed to decline call"))?;
2005        room_updated(&room, &session.peer);
2006    }
2007
2008    for connection_id in session
2009        .connection_pool()
2010        .await
2011        .user_connection_ids(session.user_id())
2012    {
2013        session
2014            .peer
2015            .send(
2016                connection_id,
2017                proto::CallCanceled {
2018                    room_id: room_id.to_proto(),
2019                },
2020            )
2021            .trace_err();
2022    }
2023    update_user_contacts(session.user_id(), &session).await?;
2024    Ok(())
2025}
2026
2027/// Updates other participants in the room with your current location.
2028async fn update_participant_location(
2029    request: proto::UpdateParticipantLocation,
2030    response: Response<proto::UpdateParticipantLocation>,
2031    session: UserSession,
2032) -> Result<()> {
2033    let room_id = RoomId::from_proto(request.room_id);
2034    let location = request
2035        .location
2036        .ok_or_else(|| anyhow!("invalid location"))?;
2037
2038    let db = session.db().await;
2039    let room = db
2040        .update_room_participant_location(room_id, session.connection_id, location)
2041        .await?;
2042
2043    room_updated(&room, &session.peer);
2044    response.send(proto::Ack {})?;
2045    Ok(())
2046}
2047
2048/// Share a project into the room.
2049async fn share_project(
2050    request: proto::ShareProject,
2051    response: Response<proto::ShareProject>,
2052    session: UserSession,
2053) -> Result<()> {
2054    let (project_id, room) = &*session
2055        .db()
2056        .await
2057        .share_project(
2058            RoomId::from_proto(request.room_id),
2059            session.connection_id,
2060            &request.worktrees,
2061            request
2062                .dev_server_project_id
2063                .map(|id| DevServerProjectId::from_proto(id)),
2064        )
2065        .await?;
2066    response.send(proto::ShareProjectResponse {
2067        project_id: project_id.to_proto(),
2068    })?;
2069    room_updated(&room, &session.peer);
2070
2071    Ok(())
2072}
2073
2074/// Unshare a project from the room.
2075async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2076    let project_id = ProjectId::from_proto(message.project_id);
2077    unshare_project_internal(
2078        project_id,
2079        session.connection_id,
2080        session.user_id(),
2081        &session,
2082    )
2083    .await
2084}
2085
2086async fn unshare_project_internal(
2087    project_id: ProjectId,
2088    connection_id: ConnectionId,
2089    user_id: Option<UserId>,
2090    session: &Session,
2091) -> Result<()> {
2092    let delete = {
2093        let room_guard = session
2094            .db()
2095            .await
2096            .unshare_project(project_id, connection_id, user_id)
2097            .await?;
2098
2099        let (delete, room, guest_connection_ids) = &*room_guard;
2100
2101        let message = proto::UnshareProject {
2102            project_id: project_id.to_proto(),
2103        };
2104
2105        broadcast(
2106            Some(connection_id),
2107            guest_connection_ids.iter().copied(),
2108            |conn_id| session.peer.send(conn_id, message.clone()),
2109        );
2110        if let Some(room) = room {
2111            room_updated(room, &session.peer);
2112        }
2113
2114        *delete
2115    };
2116
2117    if delete {
2118        let db = session.db().await;
2119        db.delete_project(project_id).await?;
2120    }
2121
2122    Ok(())
2123}
2124
2125/// DevServer makes a project available online
2126async fn share_dev_server_project(
2127    request: proto::ShareDevServerProject,
2128    response: Response<proto::ShareDevServerProject>,
2129    session: DevServerSession,
2130) -> Result<()> {
2131    let (dev_server_project, user_id, status) = session
2132        .db()
2133        .await
2134        .share_dev_server_project(
2135            DevServerProjectId::from_proto(request.dev_server_project_id),
2136            session.dev_server_id(),
2137            session.connection_id,
2138            &request.worktrees,
2139        )
2140        .await?;
2141    let Some(project_id) = dev_server_project.project_id else {
2142        return Err(anyhow!("failed to share remote project"))?;
2143    };
2144
2145    send_dev_server_projects_update(user_id, status, &session).await;
2146
2147    response.send(proto::ShareProjectResponse { project_id })?;
2148
2149    Ok(())
2150}
2151
2152/// Join someone elses shared project.
2153async fn join_project(
2154    request: proto::JoinProject,
2155    response: Response<proto::JoinProject>,
2156    session: UserSession,
2157) -> Result<()> {
2158    let project_id = ProjectId::from_proto(request.project_id);
2159
2160    tracing::info!(%project_id, "join project");
2161
2162    let db = session.db().await;
2163    let (project, replica_id) = &mut *db
2164        .join_project(project_id, session.connection_id, session.user_id())
2165        .await?;
2166    drop(db);
2167    tracing::info!(%project_id, "join remote project");
2168    join_project_internal(response, session, project, replica_id)
2169}
2170
2171trait JoinProjectInternalResponse {
2172    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2173}
2174impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2175    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2176        Response::<proto::JoinProject>::send(self, result)
2177    }
2178}
2179impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2180    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2181        Response::<proto::JoinHostedProject>::send(self, result)
2182    }
2183}
2184
2185fn join_project_internal(
2186    response: impl JoinProjectInternalResponse,
2187    session: UserSession,
2188    project: &mut Project,
2189    replica_id: &ReplicaId,
2190) -> Result<()> {
2191    let collaborators = project
2192        .collaborators
2193        .iter()
2194        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2195        .map(|collaborator| collaborator.to_proto())
2196        .collect::<Vec<_>>();
2197    let project_id = project.id;
2198    let guest_user_id = session.user_id();
2199
2200    let worktrees = project
2201        .worktrees
2202        .iter()
2203        .map(|(id, worktree)| proto::WorktreeMetadata {
2204            id: *id,
2205            root_name: worktree.root_name.clone(),
2206            visible: worktree.visible,
2207            abs_path: worktree.abs_path.clone(),
2208        })
2209        .collect::<Vec<_>>();
2210
2211    let add_project_collaborator = proto::AddProjectCollaborator {
2212        project_id: project_id.to_proto(),
2213        collaborator: Some(proto::Collaborator {
2214            peer_id: Some(session.connection_id.into()),
2215            replica_id: replica_id.0 as u32,
2216            user_id: guest_user_id.to_proto(),
2217        }),
2218    };
2219
2220    for collaborator in &collaborators {
2221        session
2222            .peer
2223            .send(
2224                collaborator.peer_id.unwrap().into(),
2225                add_project_collaborator.clone(),
2226            )
2227            .trace_err();
2228    }
2229
2230    // First, we send the metadata associated with each worktree.
2231    response.send(proto::JoinProjectResponse {
2232        project_id: project.id.0 as u64,
2233        worktrees: worktrees.clone(),
2234        replica_id: replica_id.0 as u32,
2235        collaborators: collaborators.clone(),
2236        language_servers: project.language_servers.clone(),
2237        role: project.role.into(),
2238        dev_server_project_id: project
2239            .dev_server_project_id
2240            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2241    })?;
2242
2243    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2244        #[cfg(any(test, feature = "test-support"))]
2245        const MAX_CHUNK_SIZE: usize = 2;
2246        #[cfg(not(any(test, feature = "test-support")))]
2247        const MAX_CHUNK_SIZE: usize = 256;
2248
2249        // Stream this worktree's entries.
2250        let message = proto::UpdateWorktree {
2251            project_id: project_id.to_proto(),
2252            worktree_id,
2253            abs_path: worktree.abs_path.clone(),
2254            root_name: worktree.root_name,
2255            updated_entries: worktree.entries,
2256            removed_entries: Default::default(),
2257            scan_id: worktree.scan_id,
2258            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2259            updated_repositories: worktree.repository_entries.into_values().collect(),
2260            removed_repositories: Default::default(),
2261        };
2262        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2263            session.peer.send(session.connection_id, update.clone())?;
2264        }
2265
2266        // Stream this worktree's diagnostics.
2267        for summary in worktree.diagnostic_summaries {
2268            session.peer.send(
2269                session.connection_id,
2270                proto::UpdateDiagnosticSummary {
2271                    project_id: project_id.to_proto(),
2272                    worktree_id: worktree.id,
2273                    summary: Some(summary),
2274                },
2275            )?;
2276        }
2277
2278        for settings_file in worktree.settings_files {
2279            session.peer.send(
2280                session.connection_id,
2281                proto::UpdateWorktreeSettings {
2282                    project_id: project_id.to_proto(),
2283                    worktree_id: worktree.id,
2284                    path: settings_file.path,
2285                    content: Some(settings_file.content),
2286                },
2287            )?;
2288        }
2289    }
2290
2291    for language_server in &project.language_servers {
2292        session.peer.send(
2293            session.connection_id,
2294            proto::UpdateLanguageServer {
2295                project_id: project_id.to_proto(),
2296                language_server_id: language_server.id,
2297                variant: Some(
2298                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2299                        proto::LspDiskBasedDiagnosticsUpdated {},
2300                    ),
2301                ),
2302            },
2303        )?;
2304    }
2305
2306    Ok(())
2307}
2308
2309/// Leave someone elses shared project.
2310async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2311    let sender_id = session.connection_id;
2312    let project_id = ProjectId::from_proto(request.project_id);
2313    let db = session.db().await;
2314    if db.is_hosted_project(project_id).await? {
2315        let project = db.leave_hosted_project(project_id, sender_id).await?;
2316        project_left(&project, &session);
2317        return Ok(());
2318    }
2319
2320    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2321    tracing::info!(
2322        %project_id,
2323        "leave project"
2324    );
2325
2326    project_left(&project, &session);
2327    if let Some(room) = room {
2328        room_updated(&room, &session.peer);
2329    }
2330
2331    Ok(())
2332}
2333
2334async fn join_hosted_project(
2335    request: proto::JoinHostedProject,
2336    response: Response<proto::JoinHostedProject>,
2337    session: UserSession,
2338) -> Result<()> {
2339    let (mut project, replica_id) = session
2340        .db()
2341        .await
2342        .join_hosted_project(
2343            ProjectId(request.project_id as i32),
2344            session.user_id(),
2345            session.connection_id,
2346        )
2347        .await?;
2348
2349    join_project_internal(response, session, &mut project, &replica_id)
2350}
2351
2352async fn list_remote_directory(
2353    request: proto::ListRemoteDirectory,
2354    response: Response<proto::ListRemoteDirectory>,
2355    session: UserSession,
2356) -> Result<()> {
2357    let dev_server_id = DevServerId(request.dev_server_id as i32);
2358    let dev_server_connection_id = session
2359        .connection_pool()
2360        .await
2361        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2362
2363    session
2364        .db()
2365        .await
2366        .get_dev_server_for_user(dev_server_id, session.user_id())
2367        .await?;
2368
2369    response.send(
2370        session
2371            .peer
2372            .forward_request(session.connection_id, dev_server_connection_id, request)
2373            .await?,
2374    )?;
2375    Ok(())
2376}
2377
2378async fn update_dev_server_project(
2379    request: proto::UpdateDevServerProject,
2380    response: Response<proto::UpdateDevServerProject>,
2381    session: UserSession,
2382) -> Result<()> {
2383    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2384
2385    let (dev_server_project, update) = session
2386        .db()
2387        .await
2388        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2389        .await?;
2390
2391    let projects = session
2392        .db()
2393        .await
2394        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2395        .await?;
2396
2397    let dev_server_connection_id = session
2398        .connection_pool()
2399        .await
2400        .dev_server_connection_id_supporting(
2401            dev_server_project.dev_server_id,
2402            ZedVersion::with_list_directory(),
2403        )?;
2404
2405    session.peer.send(
2406        dev_server_connection_id,
2407        proto::DevServerInstructions { projects },
2408    )?;
2409
2410    send_dev_server_projects_update(session.user_id(), update, &session).await;
2411
2412    response.send(proto::Ack {})
2413}
2414
2415async fn create_dev_server_project(
2416    request: proto::CreateDevServerProject,
2417    response: Response<proto::CreateDevServerProject>,
2418    session: UserSession,
2419) -> Result<()> {
2420    let dev_server_id = DevServerId(request.dev_server_id as i32);
2421    let dev_server_connection_id = session
2422        .connection_pool()
2423        .await
2424        .dev_server_connection_id(dev_server_id);
2425    let Some(dev_server_connection_id) = dev_server_connection_id else {
2426        Err(ErrorCode::DevServerOffline
2427            .message("Cannot create a remote project when the dev server is offline".to_string())
2428            .anyhow())?
2429    };
2430
2431    let path = request.path.clone();
2432    //Check that the path exists on the dev server
2433    session
2434        .peer
2435        .forward_request(
2436            session.connection_id,
2437            dev_server_connection_id,
2438            proto::ValidateDevServerProjectRequest { path: path.clone() },
2439        )
2440        .await?;
2441
2442    let (dev_server_project, update) = session
2443        .db()
2444        .await
2445        .create_dev_server_project(
2446            DevServerId(request.dev_server_id as i32),
2447            &request.path,
2448            session.user_id(),
2449        )
2450        .await?;
2451
2452    let projects = session
2453        .db()
2454        .await
2455        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2456        .await?;
2457
2458    session.peer.send(
2459        dev_server_connection_id,
2460        proto::DevServerInstructions { projects },
2461    )?;
2462
2463    send_dev_server_projects_update(session.user_id(), update, &session).await;
2464
2465    response.send(proto::CreateDevServerProjectResponse {
2466        dev_server_project: Some(dev_server_project.to_proto(None)),
2467    })?;
2468    Ok(())
2469}
2470
2471async fn create_dev_server(
2472    request: proto::CreateDevServer,
2473    response: Response<proto::CreateDevServer>,
2474    session: UserSession,
2475) -> Result<()> {
2476    let access_token = auth::random_token();
2477    let hashed_access_token = auth::hash_access_token(&access_token);
2478
2479    if request.name.is_empty() {
2480        return Err(proto::ErrorCode::Forbidden
2481            .message("Dev server name cannot be empty".to_string())
2482            .anyhow())?;
2483    }
2484
2485    let (dev_server, status) = session
2486        .db()
2487        .await
2488        .create_dev_server(
2489            &request.name,
2490            request.ssh_connection_string.as_deref(),
2491            &hashed_access_token,
2492            session.user_id(),
2493        )
2494        .await?;
2495
2496    send_dev_server_projects_update(session.user_id(), status, &session).await;
2497
2498    response.send(proto::CreateDevServerResponse {
2499        dev_server_id: dev_server.id.0 as u64,
2500        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2501        name: request.name,
2502    })?;
2503    Ok(())
2504}
2505
2506async fn regenerate_dev_server_token(
2507    request: proto::RegenerateDevServerToken,
2508    response: Response<proto::RegenerateDevServerToken>,
2509    session: UserSession,
2510) -> Result<()> {
2511    let dev_server_id = DevServerId(request.dev_server_id as i32);
2512    let access_token = auth::random_token();
2513    let hashed_access_token = auth::hash_access_token(&access_token);
2514
2515    let connection_id = session
2516        .connection_pool()
2517        .await
2518        .dev_server_connection_id(dev_server_id);
2519    if let Some(connection_id) = connection_id {
2520        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2521        session.peer.send(
2522            connection_id,
2523            proto::ShutdownDevServer {
2524                reason: Some("dev server token was regenerated".to_string()),
2525            },
2526        )?;
2527        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2528    }
2529
2530    let status = session
2531        .db()
2532        .await
2533        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2534        .await?;
2535
2536    send_dev_server_projects_update(session.user_id(), status, &session).await;
2537
2538    response.send(proto::RegenerateDevServerTokenResponse {
2539        dev_server_id: dev_server_id.to_proto(),
2540        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2541    })?;
2542    Ok(())
2543}
2544
2545async fn rename_dev_server(
2546    request: proto::RenameDevServer,
2547    response: Response<proto::RenameDevServer>,
2548    session: UserSession,
2549) -> Result<()> {
2550    if request.name.trim().is_empty() {
2551        return Err(proto::ErrorCode::Forbidden
2552            .message("Dev server name cannot be empty".to_string())
2553            .anyhow())?;
2554    }
2555
2556    let dev_server_id = DevServerId(request.dev_server_id as i32);
2557    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2558    if dev_server.user_id != session.user_id() {
2559        return Err(anyhow!(ErrorCode::Forbidden))?;
2560    }
2561
2562    let status = session
2563        .db()
2564        .await
2565        .rename_dev_server(
2566            dev_server_id,
2567            &request.name,
2568            request.ssh_connection_string.as_deref(),
2569            session.user_id(),
2570        )
2571        .await?;
2572
2573    send_dev_server_projects_update(session.user_id(), status, &session).await;
2574
2575    response.send(proto::Ack {})?;
2576    Ok(())
2577}
2578
2579async fn delete_dev_server(
2580    request: proto::DeleteDevServer,
2581    response: Response<proto::DeleteDevServer>,
2582    session: UserSession,
2583) -> Result<()> {
2584    let dev_server_id = DevServerId(request.dev_server_id as i32);
2585    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2586    if dev_server.user_id != session.user_id() {
2587        return Err(anyhow!(ErrorCode::Forbidden))?;
2588    }
2589
2590    let connection_id = session
2591        .connection_pool()
2592        .await
2593        .dev_server_connection_id(dev_server_id);
2594    if let Some(connection_id) = connection_id {
2595        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2596        session.peer.send(
2597            connection_id,
2598            proto::ShutdownDevServer {
2599                reason: Some("dev server was deleted".to_string()),
2600            },
2601        )?;
2602        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2603    }
2604
2605    let status = session
2606        .db()
2607        .await
2608        .delete_dev_server(dev_server_id, session.user_id())
2609        .await?;
2610
2611    send_dev_server_projects_update(session.user_id(), status, &session).await;
2612
2613    response.send(proto::Ack {})?;
2614    Ok(())
2615}
2616
2617async fn delete_dev_server_project(
2618    request: proto::DeleteDevServerProject,
2619    response: Response<proto::DeleteDevServerProject>,
2620    session: UserSession,
2621) -> Result<()> {
2622    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2623    let dev_server_project = session
2624        .db()
2625        .await
2626        .get_dev_server_project(dev_server_project_id)
2627        .await?;
2628
2629    let dev_server = session
2630        .db()
2631        .await
2632        .get_dev_server(dev_server_project.dev_server_id)
2633        .await?;
2634    if dev_server.user_id != session.user_id() {
2635        return Err(anyhow!(ErrorCode::Forbidden))?;
2636    }
2637
2638    let dev_server_connection_id = session
2639        .connection_pool()
2640        .await
2641        .dev_server_connection_id(dev_server.id);
2642
2643    if let Some(dev_server_connection_id) = dev_server_connection_id {
2644        let project = session
2645            .db()
2646            .await
2647            .find_dev_server_project(dev_server_project_id)
2648            .await;
2649        if let Ok(project) = project {
2650            unshare_project_internal(
2651                project.id,
2652                dev_server_connection_id,
2653                Some(session.user_id()),
2654                &session,
2655            )
2656            .await?;
2657        }
2658    }
2659
2660    let (projects, status) = session
2661        .db()
2662        .await
2663        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2664        .await?;
2665
2666    if let Some(dev_server_connection_id) = dev_server_connection_id {
2667        session.peer.send(
2668            dev_server_connection_id,
2669            proto::DevServerInstructions { projects },
2670        )?;
2671    }
2672
2673    send_dev_server_projects_update(session.user_id(), status, &session).await;
2674
2675    response.send(proto::Ack {})?;
2676    Ok(())
2677}
2678
2679async fn rejoin_dev_server_projects(
2680    request: proto::RejoinRemoteProjects,
2681    response: Response<proto::RejoinRemoteProjects>,
2682    session: UserSession,
2683) -> Result<()> {
2684    let mut rejoined_projects = {
2685        let db = session.db().await;
2686        db.rejoin_dev_server_projects(
2687            &request.rejoined_projects,
2688            session.user_id(),
2689            session.0.connection_id,
2690        )
2691        .await?
2692    };
2693    response.send(proto::RejoinRemoteProjectsResponse {
2694        rejoined_projects: rejoined_projects
2695            .iter()
2696            .map(|project| project.to_proto())
2697            .collect(),
2698    })?;
2699    notify_rejoined_projects(&mut rejoined_projects, &session)
2700}
2701
2702async fn reconnect_dev_server(
2703    request: proto::ReconnectDevServer,
2704    response: Response<proto::ReconnectDevServer>,
2705    session: DevServerSession,
2706) -> Result<()> {
2707    let reshared_projects = {
2708        let db = session.db().await;
2709        db.reshare_dev_server_projects(
2710            &request.reshared_projects,
2711            session.dev_server_id(),
2712            session.0.connection_id,
2713        )
2714        .await?
2715    };
2716
2717    for project in &reshared_projects {
2718        for collaborator in &project.collaborators {
2719            session
2720                .peer
2721                .send(
2722                    collaborator.connection_id,
2723                    proto::UpdateProjectCollaborator {
2724                        project_id: project.id.to_proto(),
2725                        old_peer_id: Some(project.old_connection_id.into()),
2726                        new_peer_id: Some(session.connection_id.into()),
2727                    },
2728                )
2729                .trace_err();
2730        }
2731
2732        broadcast(
2733            Some(session.connection_id),
2734            project
2735                .collaborators
2736                .iter()
2737                .map(|collaborator| collaborator.connection_id),
2738            |connection_id| {
2739                session.peer.forward_send(
2740                    session.connection_id,
2741                    connection_id,
2742                    proto::UpdateProject {
2743                        project_id: project.id.to_proto(),
2744                        worktrees: project.worktrees.clone(),
2745                    },
2746                )
2747            },
2748        );
2749    }
2750
2751    response.send(proto::ReconnectDevServerResponse {
2752        reshared_projects: reshared_projects
2753            .iter()
2754            .map(|project| proto::ResharedProject {
2755                id: project.id.to_proto(),
2756                collaborators: project
2757                    .collaborators
2758                    .iter()
2759                    .map(|collaborator| collaborator.to_proto())
2760                    .collect(),
2761            })
2762            .collect(),
2763    })?;
2764
2765    Ok(())
2766}
2767
2768async fn shutdown_dev_server(
2769    _: proto::ShutdownDevServer,
2770    response: Response<proto::ShutdownDevServer>,
2771    session: DevServerSession,
2772) -> Result<()> {
2773    response.send(proto::Ack {})?;
2774    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2775    remove_dev_server_connection(session.dev_server_id(), &session).await
2776}
2777
2778async fn shutdown_dev_server_internal(
2779    dev_server_id: DevServerId,
2780    connection_id: ConnectionId,
2781    session: &Session,
2782) -> Result<()> {
2783    let (dev_server_projects, dev_server) = {
2784        let db = session.db().await;
2785        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2786        let dev_server = db.get_dev_server(dev_server_id).await?;
2787        (dev_server_projects, dev_server)
2788    };
2789
2790    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2791        unshare_project_internal(
2792            ProjectId::from_proto(project_id),
2793            connection_id,
2794            None,
2795            session,
2796        )
2797        .await?;
2798    }
2799
2800    session
2801        .connection_pool()
2802        .await
2803        .set_dev_server_offline(dev_server_id);
2804
2805    let status = session
2806        .db()
2807        .await
2808        .dev_server_projects_update(dev_server.user_id)
2809        .await?;
2810    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2811
2812    Ok(())
2813}
2814
2815async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2816    let dev_server_connection = session
2817        .connection_pool()
2818        .await
2819        .dev_server_connection_id(dev_server_id);
2820
2821    if let Some(dev_server_connection) = dev_server_connection {
2822        session
2823            .connection_pool()
2824            .await
2825            .remove_connection(dev_server_connection)?;
2826    }
2827    Ok(())
2828}
2829
2830/// Updates other participants with changes to the project
2831async fn update_project(
2832    request: proto::UpdateProject,
2833    response: Response<proto::UpdateProject>,
2834    session: Session,
2835) -> Result<()> {
2836    let project_id = ProjectId::from_proto(request.project_id);
2837    let (room, guest_connection_ids) = &*session
2838        .db()
2839        .await
2840        .update_project(project_id, session.connection_id, &request.worktrees)
2841        .await?;
2842    broadcast(
2843        Some(session.connection_id),
2844        guest_connection_ids.iter().copied(),
2845        |connection_id| {
2846            session
2847                .peer
2848                .forward_send(session.connection_id, connection_id, request.clone())
2849        },
2850    );
2851    if let Some(room) = room {
2852        room_updated(&room, &session.peer);
2853    }
2854    response.send(proto::Ack {})?;
2855
2856    Ok(())
2857}
2858
2859/// Updates other participants with changes to the worktree
2860async fn update_worktree(
2861    request: proto::UpdateWorktree,
2862    response: Response<proto::UpdateWorktree>,
2863    session: Session,
2864) -> Result<()> {
2865    let guest_connection_ids = session
2866        .db()
2867        .await
2868        .update_worktree(&request, session.connection_id)
2869        .await?;
2870
2871    broadcast(
2872        Some(session.connection_id),
2873        guest_connection_ids.iter().copied(),
2874        |connection_id| {
2875            session
2876                .peer
2877                .forward_send(session.connection_id, connection_id, request.clone())
2878        },
2879    );
2880    response.send(proto::Ack {})?;
2881    Ok(())
2882}
2883
2884/// Updates other participants with changes to the diagnostics
2885async fn update_diagnostic_summary(
2886    message: proto::UpdateDiagnosticSummary,
2887    session: Session,
2888) -> Result<()> {
2889    let guest_connection_ids = session
2890        .db()
2891        .await
2892        .update_diagnostic_summary(&message, session.connection_id)
2893        .await?;
2894
2895    broadcast(
2896        Some(session.connection_id),
2897        guest_connection_ids.iter().copied(),
2898        |connection_id| {
2899            session
2900                .peer
2901                .forward_send(session.connection_id, connection_id, message.clone())
2902        },
2903    );
2904
2905    Ok(())
2906}
2907
2908/// Updates other participants with changes to the worktree settings
2909async fn update_worktree_settings(
2910    message: proto::UpdateWorktreeSettings,
2911    session: Session,
2912) -> Result<()> {
2913    let guest_connection_ids = session
2914        .db()
2915        .await
2916        .update_worktree_settings(&message, session.connection_id)
2917        .await?;
2918
2919    broadcast(
2920        Some(session.connection_id),
2921        guest_connection_ids.iter().copied(),
2922        |connection_id| {
2923            session
2924                .peer
2925                .forward_send(session.connection_id, connection_id, message.clone())
2926        },
2927    );
2928
2929    Ok(())
2930}
2931
2932/// Notify other participants that a  language server has started.
2933async fn start_language_server(
2934    request: proto::StartLanguageServer,
2935    session: Session,
2936) -> Result<()> {
2937    let guest_connection_ids = session
2938        .db()
2939        .await
2940        .start_language_server(&request, session.connection_id)
2941        .await?;
2942
2943    broadcast(
2944        Some(session.connection_id),
2945        guest_connection_ids.iter().copied(),
2946        |connection_id| {
2947            session
2948                .peer
2949                .forward_send(session.connection_id, connection_id, request.clone())
2950        },
2951    );
2952    Ok(())
2953}
2954
2955/// Notify other participants that a language server has changed.
2956async fn update_language_server(
2957    request: proto::UpdateLanguageServer,
2958    session: Session,
2959) -> Result<()> {
2960    let project_id = ProjectId::from_proto(request.project_id);
2961    let project_connection_ids = session
2962        .db()
2963        .await
2964        .project_connection_ids(project_id, session.connection_id, true)
2965        .await?;
2966    broadcast(
2967        Some(session.connection_id),
2968        project_connection_ids.iter().copied(),
2969        |connection_id| {
2970            session
2971                .peer
2972                .forward_send(session.connection_id, connection_id, request.clone())
2973        },
2974    );
2975    Ok(())
2976}
2977
2978/// forward a project request to the host. These requests should be read only
2979/// as guests are allowed to send them.
2980async fn forward_read_only_project_request<T>(
2981    request: T,
2982    response: Response<T>,
2983    session: UserSession,
2984) -> Result<()>
2985where
2986    T: EntityMessage + RequestMessage,
2987{
2988    let project_id = ProjectId::from_proto(request.remote_entity_id());
2989    let host_connection_id = session
2990        .db()
2991        .await
2992        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2993        .await?;
2994    let payload = session
2995        .peer
2996        .forward_request(session.connection_id, host_connection_id, request)
2997        .await?;
2998    response.send(payload)?;
2999    Ok(())
3000}
3001
3002/// forward a project request to the dev server. Only allowed
3003/// if it's your dev server.
3004async fn forward_project_request_for_owner<T>(
3005    request: T,
3006    response: Response<T>,
3007    session: UserSession,
3008) -> Result<()>
3009where
3010    T: EntityMessage + RequestMessage,
3011{
3012    let project_id = ProjectId::from_proto(request.remote_entity_id());
3013
3014    let host_connection_id = session
3015        .db()
3016        .await
3017        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3018        .await?;
3019    let payload = session
3020        .peer
3021        .forward_request(session.connection_id, host_connection_id, request)
3022        .await?;
3023    response.send(payload)?;
3024    Ok(())
3025}
3026
3027/// forward a project request to the host. These requests are disallowed
3028/// for guests.
3029async fn forward_mutating_project_request<T>(
3030    request: T,
3031    response: Response<T>,
3032    session: UserSession,
3033) -> Result<()>
3034where
3035    T: EntityMessage + RequestMessage,
3036{
3037    let project_id = ProjectId::from_proto(request.remote_entity_id());
3038
3039    let host_connection_id = session
3040        .db()
3041        .await
3042        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3043        .await?;
3044    let payload = session
3045        .peer
3046        .forward_request(session.connection_id, host_connection_id, request)
3047        .await?;
3048    response.send(payload)?;
3049    Ok(())
3050}
3051
3052/// forward a project request to the host. These requests are disallowed
3053/// for guests.
3054async fn forward_versioned_mutating_project_request<T>(
3055    request: T,
3056    response: Response<T>,
3057    session: UserSession,
3058) -> Result<()>
3059where
3060    T: EntityMessage + RequestMessage + VersionedMessage,
3061{
3062    let project_id = ProjectId::from_proto(request.remote_entity_id());
3063
3064    let host_connection_id = session
3065        .db()
3066        .await
3067        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3068        .await?;
3069    if let Some(host_version) = session
3070        .connection_pool()
3071        .await
3072        .connection(host_connection_id)
3073        .map(|c| c.zed_version)
3074    {
3075        if let Some(min_required_version) = request.required_host_version() {
3076            if min_required_version > host_version {
3077                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3078                    .with_tag("required", &min_required_version.to_string())))?;
3079            }
3080        }
3081    }
3082
3083    let payload = session
3084        .peer
3085        .forward_request(session.connection_id, host_connection_id, request)
3086        .await?;
3087    response.send(payload)?;
3088    Ok(())
3089}
3090
3091/// Notify other participants that a new buffer has been created
3092async fn create_buffer_for_peer(
3093    request: proto::CreateBufferForPeer,
3094    session: Session,
3095) -> Result<()> {
3096    session
3097        .db()
3098        .await
3099        .check_user_is_project_host(
3100            ProjectId::from_proto(request.project_id),
3101            session.connection_id,
3102        )
3103        .await?;
3104    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3105    session
3106        .peer
3107        .forward_send(session.connection_id, peer_id.into(), request)?;
3108    Ok(())
3109}
3110
3111/// Notify other participants that a buffer has been updated. This is
3112/// allowed for guests as long as the update is limited to selections.
3113async fn update_buffer(
3114    request: proto::UpdateBuffer,
3115    response: Response<proto::UpdateBuffer>,
3116    session: Session,
3117) -> Result<()> {
3118    let project_id = ProjectId::from_proto(request.project_id);
3119    let mut capability = Capability::ReadOnly;
3120
3121    for op in request.operations.iter() {
3122        match op.variant {
3123            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3124            Some(_) => capability = Capability::ReadWrite,
3125        }
3126    }
3127
3128    let host = {
3129        let guard = session
3130            .db()
3131            .await
3132            .connections_for_buffer_update(
3133                project_id,
3134                session.principal_id(),
3135                session.connection_id,
3136                capability,
3137            )
3138            .await?;
3139
3140        let (host, guests) = &*guard;
3141
3142        broadcast(
3143            Some(session.connection_id),
3144            guests.clone(),
3145            |connection_id| {
3146                session
3147                    .peer
3148                    .forward_send(session.connection_id, connection_id, request.clone())
3149            },
3150        );
3151
3152        *host
3153    };
3154
3155    if host != session.connection_id {
3156        session
3157            .peer
3158            .forward_request(session.connection_id, host, request.clone())
3159            .await?;
3160    }
3161
3162    response.send(proto::Ack {})?;
3163    Ok(())
3164}
3165
3166async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3167    let project_id = ProjectId::from_proto(message.project_id);
3168
3169    let operation = message.operation.as_ref().context("invalid operation")?;
3170    let capability = match operation.variant.as_ref() {
3171        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3172            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3173                match buffer_op.variant {
3174                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3175                        Capability::ReadOnly
3176                    }
3177                    _ => Capability::ReadWrite,
3178                }
3179            } else {
3180                Capability::ReadWrite
3181            }
3182        }
3183        Some(_) => Capability::ReadWrite,
3184        None => Capability::ReadOnly,
3185    };
3186
3187    let guard = session
3188        .db()
3189        .await
3190        .connections_for_buffer_update(
3191            project_id,
3192            session.principal_id(),
3193            session.connection_id,
3194            capability,
3195        )
3196        .await?;
3197
3198    let (host, guests) = &*guard;
3199
3200    broadcast(
3201        Some(session.connection_id),
3202        guests.iter().chain([host]).copied(),
3203        |connection_id| {
3204            session
3205                .peer
3206                .forward_send(session.connection_id, connection_id, message.clone())
3207        },
3208    );
3209
3210    Ok(())
3211}
3212
3213/// Notify other participants that a project has been updated.
3214async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3215    request: T,
3216    session: Session,
3217) -> Result<()> {
3218    let project_id = ProjectId::from_proto(request.remote_entity_id());
3219    let project_connection_ids = session
3220        .db()
3221        .await
3222        .project_connection_ids(project_id, session.connection_id, false)
3223        .await?;
3224
3225    broadcast(
3226        Some(session.connection_id),
3227        project_connection_ids.iter().copied(),
3228        |connection_id| {
3229            session
3230                .peer
3231                .forward_send(session.connection_id, connection_id, request.clone())
3232        },
3233    );
3234    Ok(())
3235}
3236
3237/// Start following another user in a call.
3238async fn follow(
3239    request: proto::Follow,
3240    response: Response<proto::Follow>,
3241    session: UserSession,
3242) -> Result<()> {
3243    let room_id = RoomId::from_proto(request.room_id);
3244    let project_id = request.project_id.map(ProjectId::from_proto);
3245    let leader_id = request
3246        .leader_id
3247        .ok_or_else(|| anyhow!("invalid leader id"))?
3248        .into();
3249    let follower_id = session.connection_id;
3250
3251    session
3252        .db()
3253        .await
3254        .check_room_participants(room_id, leader_id, session.connection_id)
3255        .await?;
3256
3257    let response_payload = session
3258        .peer
3259        .forward_request(session.connection_id, leader_id, request)
3260        .await?;
3261    response.send(response_payload)?;
3262
3263    if let Some(project_id) = project_id {
3264        let room = session
3265            .db()
3266            .await
3267            .follow(room_id, project_id, leader_id, follower_id)
3268            .await?;
3269        room_updated(&room, &session.peer);
3270    }
3271
3272    Ok(())
3273}
3274
3275/// Stop following another user in a call.
3276async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3277    let room_id = RoomId::from_proto(request.room_id);
3278    let project_id = request.project_id.map(ProjectId::from_proto);
3279    let leader_id = request
3280        .leader_id
3281        .ok_or_else(|| anyhow!("invalid leader id"))?
3282        .into();
3283    let follower_id = session.connection_id;
3284
3285    session
3286        .db()
3287        .await
3288        .check_room_participants(room_id, leader_id, session.connection_id)
3289        .await?;
3290
3291    session
3292        .peer
3293        .forward_send(session.connection_id, leader_id, request)?;
3294
3295    if let Some(project_id) = project_id {
3296        let room = session
3297            .db()
3298            .await
3299            .unfollow(room_id, project_id, leader_id, follower_id)
3300            .await?;
3301        room_updated(&room, &session.peer);
3302    }
3303
3304    Ok(())
3305}
3306
3307/// Notify everyone following you of your current location.
3308async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3309    let room_id = RoomId::from_proto(request.room_id);
3310    let database = session.db.lock().await;
3311
3312    let connection_ids = if let Some(project_id) = request.project_id {
3313        let project_id = ProjectId::from_proto(project_id);
3314        database
3315            .project_connection_ids(project_id, session.connection_id, true)
3316            .await?
3317    } else {
3318        database
3319            .room_connection_ids(room_id, session.connection_id)
3320            .await?
3321    };
3322
3323    // For now, don't send view update messages back to that view's current leader.
3324    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3325        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3326        _ => None,
3327    });
3328
3329    for connection_id in connection_ids.iter().cloned() {
3330        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3331            session
3332                .peer
3333                .forward_send(session.connection_id, connection_id, request.clone())?;
3334        }
3335    }
3336    Ok(())
3337}
3338
3339/// Get public data about users.
3340async fn get_users(
3341    request: proto::GetUsers,
3342    response: Response<proto::GetUsers>,
3343    session: Session,
3344) -> Result<()> {
3345    let user_ids = request
3346        .user_ids
3347        .into_iter()
3348        .map(UserId::from_proto)
3349        .collect();
3350    let users = session
3351        .db()
3352        .await
3353        .get_users_by_ids(user_ids)
3354        .await?
3355        .into_iter()
3356        .map(|user| proto::User {
3357            id: user.id.to_proto(),
3358            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3359            github_login: user.github_login,
3360        })
3361        .collect();
3362    response.send(proto::UsersResponse { users })?;
3363    Ok(())
3364}
3365
3366/// Search for users (to invite) buy Github login
3367async fn fuzzy_search_users(
3368    request: proto::FuzzySearchUsers,
3369    response: Response<proto::FuzzySearchUsers>,
3370    session: UserSession,
3371) -> Result<()> {
3372    let query = request.query;
3373    let users = match query.len() {
3374        0 => vec![],
3375        1 | 2 => session
3376            .db()
3377            .await
3378            .get_user_by_github_login(&query)
3379            .await?
3380            .into_iter()
3381            .collect(),
3382        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3383    };
3384    let users = users
3385        .into_iter()
3386        .filter(|user| user.id != session.user_id())
3387        .map(|user| proto::User {
3388            id: user.id.to_proto(),
3389            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3390            github_login: user.github_login,
3391        })
3392        .collect();
3393    response.send(proto::UsersResponse { users })?;
3394    Ok(())
3395}
3396
3397/// Send a contact request to another user.
3398async fn request_contact(
3399    request: proto::RequestContact,
3400    response: Response<proto::RequestContact>,
3401    session: UserSession,
3402) -> Result<()> {
3403    let requester_id = session.user_id();
3404    let responder_id = UserId::from_proto(request.responder_id);
3405    if requester_id == responder_id {
3406        return Err(anyhow!("cannot add yourself as a contact"))?;
3407    }
3408
3409    let notifications = session
3410        .db()
3411        .await
3412        .send_contact_request(requester_id, responder_id)
3413        .await?;
3414
3415    // Update outgoing contact requests of requester
3416    let mut update = proto::UpdateContacts::default();
3417    update.outgoing_requests.push(responder_id.to_proto());
3418    for connection_id in session
3419        .connection_pool()
3420        .await
3421        .user_connection_ids(requester_id)
3422    {
3423        session.peer.send(connection_id, update.clone())?;
3424    }
3425
3426    // Update incoming contact requests of responder
3427    let mut update = proto::UpdateContacts::default();
3428    update
3429        .incoming_requests
3430        .push(proto::IncomingContactRequest {
3431            requester_id: requester_id.to_proto(),
3432        });
3433    let connection_pool = session.connection_pool().await;
3434    for connection_id in connection_pool.user_connection_ids(responder_id) {
3435        session.peer.send(connection_id, update.clone())?;
3436    }
3437
3438    send_notifications(&connection_pool, &session.peer, notifications);
3439
3440    response.send(proto::Ack {})?;
3441    Ok(())
3442}
3443
3444/// Accept or decline a contact request
3445async fn respond_to_contact_request(
3446    request: proto::RespondToContactRequest,
3447    response: Response<proto::RespondToContactRequest>,
3448    session: UserSession,
3449) -> Result<()> {
3450    let responder_id = session.user_id();
3451    let requester_id = UserId::from_proto(request.requester_id);
3452    let db = session.db().await;
3453    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3454        db.dismiss_contact_notification(responder_id, requester_id)
3455            .await?;
3456    } else {
3457        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3458
3459        let notifications = db
3460            .respond_to_contact_request(responder_id, requester_id, accept)
3461            .await?;
3462        let requester_busy = db.is_user_busy(requester_id).await?;
3463        let responder_busy = db.is_user_busy(responder_id).await?;
3464
3465        let pool = session.connection_pool().await;
3466        // Update responder with new contact
3467        let mut update = proto::UpdateContacts::default();
3468        if accept {
3469            update
3470                .contacts
3471                .push(contact_for_user(requester_id, requester_busy, &pool));
3472        }
3473        update
3474            .remove_incoming_requests
3475            .push(requester_id.to_proto());
3476        for connection_id in pool.user_connection_ids(responder_id) {
3477            session.peer.send(connection_id, update.clone())?;
3478        }
3479
3480        // Update requester with new contact
3481        let mut update = proto::UpdateContacts::default();
3482        if accept {
3483            update
3484                .contacts
3485                .push(contact_for_user(responder_id, responder_busy, &pool));
3486        }
3487        update
3488            .remove_outgoing_requests
3489            .push(responder_id.to_proto());
3490
3491        for connection_id in pool.user_connection_ids(requester_id) {
3492            session.peer.send(connection_id, update.clone())?;
3493        }
3494
3495        send_notifications(&pool, &session.peer, notifications);
3496    }
3497
3498    response.send(proto::Ack {})?;
3499    Ok(())
3500}
3501
3502/// Remove a contact.
3503async fn remove_contact(
3504    request: proto::RemoveContact,
3505    response: Response<proto::RemoveContact>,
3506    session: UserSession,
3507) -> Result<()> {
3508    let requester_id = session.user_id();
3509    let responder_id = UserId::from_proto(request.user_id);
3510    let db = session.db().await;
3511    let (contact_accepted, deleted_notification_id) =
3512        db.remove_contact(requester_id, responder_id).await?;
3513
3514    let pool = session.connection_pool().await;
3515    // Update outgoing contact requests of requester
3516    let mut update = proto::UpdateContacts::default();
3517    if contact_accepted {
3518        update.remove_contacts.push(responder_id.to_proto());
3519    } else {
3520        update
3521            .remove_outgoing_requests
3522            .push(responder_id.to_proto());
3523    }
3524    for connection_id in pool.user_connection_ids(requester_id) {
3525        session.peer.send(connection_id, update.clone())?;
3526    }
3527
3528    // Update incoming contact requests of responder
3529    let mut update = proto::UpdateContacts::default();
3530    if contact_accepted {
3531        update.remove_contacts.push(requester_id.to_proto());
3532    } else {
3533        update
3534            .remove_incoming_requests
3535            .push(requester_id.to_proto());
3536    }
3537    for connection_id in pool.user_connection_ids(responder_id) {
3538        session.peer.send(connection_id, update.clone())?;
3539        if let Some(notification_id) = deleted_notification_id {
3540            session.peer.send(
3541                connection_id,
3542                proto::DeleteNotification {
3543                    notification_id: notification_id.to_proto(),
3544                },
3545            )?;
3546        }
3547    }
3548
3549    response.send(proto::Ack {})?;
3550    Ok(())
3551}
3552
3553fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3554    version.0.minor() < 139
3555}
3556
3557async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3558    let plan = session.current_plan().await?;
3559
3560    session
3561        .peer
3562        .send(
3563            session.connection_id,
3564            proto::UpdateUserPlan { plan: plan.into() },
3565        )
3566        .trace_err();
3567
3568    Ok(())
3569}
3570
3571async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3572    subscribe_user_to_channels(
3573        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3574        &session,
3575    )
3576    .await?;
3577    Ok(())
3578}
3579
3580async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3581    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3582    let mut pool = session.connection_pool().await;
3583    for membership in &channels_for_user.channel_memberships {
3584        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3585    }
3586    session.peer.send(
3587        session.connection_id,
3588        build_update_user_channels(&channels_for_user),
3589    )?;
3590    session.peer.send(
3591        session.connection_id,
3592        build_channels_update(channels_for_user),
3593    )?;
3594    Ok(())
3595}
3596
3597/// Creates a new channel.
3598async fn create_channel(
3599    request: proto::CreateChannel,
3600    response: Response<proto::CreateChannel>,
3601    session: UserSession,
3602) -> Result<()> {
3603    let db = session.db().await;
3604
3605    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3606    let (channel, membership) = db
3607        .create_channel(&request.name, parent_id, session.user_id())
3608        .await?;
3609
3610    let root_id = channel.root_id();
3611    let channel = Channel::from_model(channel);
3612
3613    response.send(proto::CreateChannelResponse {
3614        channel: Some(channel.to_proto()),
3615        parent_id: request.parent_id,
3616    })?;
3617
3618    let mut connection_pool = session.connection_pool().await;
3619    if let Some(membership) = membership {
3620        connection_pool.subscribe_to_channel(
3621            membership.user_id,
3622            membership.channel_id,
3623            membership.role,
3624        );
3625        let update = proto::UpdateUserChannels {
3626            channel_memberships: vec![proto::ChannelMembership {
3627                channel_id: membership.channel_id.to_proto(),
3628                role: membership.role.into(),
3629            }],
3630            ..Default::default()
3631        };
3632        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3633            session.peer.send(connection_id, update.clone())?;
3634        }
3635    }
3636
3637    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3638        if !role.can_see_channel(channel.visibility) {
3639            continue;
3640        }
3641
3642        let update = proto::UpdateChannels {
3643            channels: vec![channel.to_proto()],
3644            ..Default::default()
3645        };
3646        session.peer.send(connection_id, update.clone())?;
3647    }
3648
3649    Ok(())
3650}
3651
3652/// Delete a channel
3653async fn delete_channel(
3654    request: proto::DeleteChannel,
3655    response: Response<proto::DeleteChannel>,
3656    session: UserSession,
3657) -> Result<()> {
3658    let db = session.db().await;
3659
3660    let channel_id = request.channel_id;
3661    let (root_channel, removed_channels) = db
3662        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3663        .await?;
3664    response.send(proto::Ack {})?;
3665
3666    // Notify members of removed channels
3667    let mut update = proto::UpdateChannels::default();
3668    update
3669        .delete_channels
3670        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3671
3672    let connection_pool = session.connection_pool().await;
3673    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3674        session.peer.send(connection_id, update.clone())?;
3675    }
3676
3677    Ok(())
3678}
3679
3680/// Invite someone to join a channel.
3681async fn invite_channel_member(
3682    request: proto::InviteChannelMember,
3683    response: Response<proto::InviteChannelMember>,
3684    session: UserSession,
3685) -> Result<()> {
3686    let db = session.db().await;
3687    let channel_id = ChannelId::from_proto(request.channel_id);
3688    let invitee_id = UserId::from_proto(request.user_id);
3689    let InviteMemberResult {
3690        channel,
3691        notifications,
3692    } = db
3693        .invite_channel_member(
3694            channel_id,
3695            invitee_id,
3696            session.user_id(),
3697            request.role().into(),
3698        )
3699        .await?;
3700
3701    let update = proto::UpdateChannels {
3702        channel_invitations: vec![channel.to_proto()],
3703        ..Default::default()
3704    };
3705
3706    let connection_pool = session.connection_pool().await;
3707    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3708        session.peer.send(connection_id, update.clone())?;
3709    }
3710
3711    send_notifications(&connection_pool, &session.peer, notifications);
3712
3713    response.send(proto::Ack {})?;
3714    Ok(())
3715}
3716
3717/// remove someone from a channel
3718async fn remove_channel_member(
3719    request: proto::RemoveChannelMember,
3720    response: Response<proto::RemoveChannelMember>,
3721    session: UserSession,
3722) -> Result<()> {
3723    let db = session.db().await;
3724    let channel_id = ChannelId::from_proto(request.channel_id);
3725    let member_id = UserId::from_proto(request.user_id);
3726
3727    let RemoveChannelMemberResult {
3728        membership_update,
3729        notification_id,
3730    } = db
3731        .remove_channel_member(channel_id, member_id, session.user_id())
3732        .await?;
3733
3734    let mut connection_pool = session.connection_pool().await;
3735    notify_membership_updated(
3736        &mut connection_pool,
3737        membership_update,
3738        member_id,
3739        &session.peer,
3740    );
3741    for connection_id in connection_pool.user_connection_ids(member_id) {
3742        if let Some(notification_id) = notification_id {
3743            session
3744                .peer
3745                .send(
3746                    connection_id,
3747                    proto::DeleteNotification {
3748                        notification_id: notification_id.to_proto(),
3749                    },
3750                )
3751                .trace_err();
3752        }
3753    }
3754
3755    response.send(proto::Ack {})?;
3756    Ok(())
3757}
3758
3759/// Toggle the channel between public and private.
3760/// Care is taken to maintain the invariant that public channels only descend from public channels,
3761/// (though members-only channels can appear at any point in the hierarchy).
3762async fn set_channel_visibility(
3763    request: proto::SetChannelVisibility,
3764    response: Response<proto::SetChannelVisibility>,
3765    session: UserSession,
3766) -> Result<()> {
3767    let db = session.db().await;
3768    let channel_id = ChannelId::from_proto(request.channel_id);
3769    let visibility = request.visibility().into();
3770
3771    let channel_model = db
3772        .set_channel_visibility(channel_id, visibility, session.user_id())
3773        .await?;
3774    let root_id = channel_model.root_id();
3775    let channel = Channel::from_model(channel_model);
3776
3777    let mut connection_pool = session.connection_pool().await;
3778    for (user_id, role) in connection_pool
3779        .channel_user_ids(root_id)
3780        .collect::<Vec<_>>()
3781        .into_iter()
3782    {
3783        let update = if role.can_see_channel(channel.visibility) {
3784            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3785            proto::UpdateChannels {
3786                channels: vec![channel.to_proto()],
3787                ..Default::default()
3788            }
3789        } else {
3790            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3791            proto::UpdateChannels {
3792                delete_channels: vec![channel.id.to_proto()],
3793                ..Default::default()
3794            }
3795        };
3796
3797        for connection_id in connection_pool.user_connection_ids(user_id) {
3798            session.peer.send(connection_id, update.clone())?;
3799        }
3800    }
3801
3802    response.send(proto::Ack {})?;
3803    Ok(())
3804}
3805
3806/// Alter the role for a user in the channel.
3807async fn set_channel_member_role(
3808    request: proto::SetChannelMemberRole,
3809    response: Response<proto::SetChannelMemberRole>,
3810    session: UserSession,
3811) -> Result<()> {
3812    let db = session.db().await;
3813    let channel_id = ChannelId::from_proto(request.channel_id);
3814    let member_id = UserId::from_proto(request.user_id);
3815    let result = db
3816        .set_channel_member_role(
3817            channel_id,
3818            session.user_id(),
3819            member_id,
3820            request.role().into(),
3821        )
3822        .await?;
3823
3824    match result {
3825        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3826            let mut connection_pool = session.connection_pool().await;
3827            notify_membership_updated(
3828                &mut connection_pool,
3829                membership_update,
3830                member_id,
3831                &session.peer,
3832            )
3833        }
3834        db::SetMemberRoleResult::InviteUpdated(channel) => {
3835            let update = proto::UpdateChannels {
3836                channel_invitations: vec![channel.to_proto()],
3837                ..Default::default()
3838            };
3839
3840            for connection_id in session
3841                .connection_pool()
3842                .await
3843                .user_connection_ids(member_id)
3844            {
3845                session.peer.send(connection_id, update.clone())?;
3846            }
3847        }
3848    }
3849
3850    response.send(proto::Ack {})?;
3851    Ok(())
3852}
3853
3854/// Change the name of a channel
3855async fn rename_channel(
3856    request: proto::RenameChannel,
3857    response: Response<proto::RenameChannel>,
3858    session: UserSession,
3859) -> Result<()> {
3860    let db = session.db().await;
3861    let channel_id = ChannelId::from_proto(request.channel_id);
3862    let channel_model = db
3863        .rename_channel(channel_id, session.user_id(), &request.name)
3864        .await?;
3865    let root_id = channel_model.root_id();
3866    let channel = Channel::from_model(channel_model);
3867
3868    response.send(proto::RenameChannelResponse {
3869        channel: Some(channel.to_proto()),
3870    })?;
3871
3872    let connection_pool = session.connection_pool().await;
3873    let update = proto::UpdateChannels {
3874        channels: vec![channel.to_proto()],
3875        ..Default::default()
3876    };
3877    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3878        if role.can_see_channel(channel.visibility) {
3879            session.peer.send(connection_id, update.clone())?;
3880        }
3881    }
3882
3883    Ok(())
3884}
3885
3886/// Move a channel to a new parent.
3887async fn move_channel(
3888    request: proto::MoveChannel,
3889    response: Response<proto::MoveChannel>,
3890    session: UserSession,
3891) -> Result<()> {
3892    let channel_id = ChannelId::from_proto(request.channel_id);
3893    let to = ChannelId::from_proto(request.to);
3894
3895    let (root_id, channels) = session
3896        .db()
3897        .await
3898        .move_channel(channel_id, to, session.user_id())
3899        .await?;
3900
3901    let connection_pool = session.connection_pool().await;
3902    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3903        let channels = channels
3904            .iter()
3905            .filter_map(|channel| {
3906                if role.can_see_channel(channel.visibility) {
3907                    Some(channel.to_proto())
3908                } else {
3909                    None
3910                }
3911            })
3912            .collect::<Vec<_>>();
3913        if channels.is_empty() {
3914            continue;
3915        }
3916
3917        let update = proto::UpdateChannels {
3918            channels,
3919            ..Default::default()
3920        };
3921
3922        session.peer.send(connection_id, update.clone())?;
3923    }
3924
3925    response.send(Ack {})?;
3926    Ok(())
3927}
3928
3929/// Get the list of channel members
3930async fn get_channel_members(
3931    request: proto::GetChannelMembers,
3932    response: Response<proto::GetChannelMembers>,
3933    session: UserSession,
3934) -> Result<()> {
3935    let db = session.db().await;
3936    let channel_id = ChannelId::from_proto(request.channel_id);
3937    let limit = if request.limit == 0 {
3938        u16::MAX as u64
3939    } else {
3940        request.limit
3941    };
3942    let (members, users) = db
3943        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3944        .await?;
3945    response.send(proto::GetChannelMembersResponse { members, users })?;
3946    Ok(())
3947}
3948
3949/// Accept or decline a channel invitation.
3950async fn respond_to_channel_invite(
3951    request: proto::RespondToChannelInvite,
3952    response: Response<proto::RespondToChannelInvite>,
3953    session: UserSession,
3954) -> Result<()> {
3955    let db = session.db().await;
3956    let channel_id = ChannelId::from_proto(request.channel_id);
3957    let RespondToChannelInvite {
3958        membership_update,
3959        notifications,
3960    } = db
3961        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3962        .await?;
3963
3964    let mut connection_pool = session.connection_pool().await;
3965    if let Some(membership_update) = membership_update {
3966        notify_membership_updated(
3967            &mut connection_pool,
3968            membership_update,
3969            session.user_id(),
3970            &session.peer,
3971        );
3972    } else {
3973        let update = proto::UpdateChannels {
3974            remove_channel_invitations: vec![channel_id.to_proto()],
3975            ..Default::default()
3976        };
3977
3978        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3979            session.peer.send(connection_id, update.clone())?;
3980        }
3981    };
3982
3983    send_notifications(&connection_pool, &session.peer, notifications);
3984
3985    response.send(proto::Ack {})?;
3986
3987    Ok(())
3988}
3989
3990/// Join the channels' room
3991async fn join_channel(
3992    request: proto::JoinChannel,
3993    response: Response<proto::JoinChannel>,
3994    session: UserSession,
3995) -> Result<()> {
3996    let channel_id = ChannelId::from_proto(request.channel_id);
3997    join_channel_internal(channel_id, Box::new(response), session).await
3998}
3999
4000trait JoinChannelInternalResponse {
4001    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
4002}
4003impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
4004    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4005        Response::<proto::JoinChannel>::send(self, result)
4006    }
4007}
4008impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
4009    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4010        Response::<proto::JoinRoom>::send(self, result)
4011    }
4012}
4013
4014async fn join_channel_internal(
4015    channel_id: ChannelId,
4016    response: Box<impl JoinChannelInternalResponse>,
4017    session: UserSession,
4018) -> Result<()> {
4019    let joined_room = {
4020        let mut db = session.db().await;
4021        // If zed quits without leaving the room, and the user re-opens zed before the
4022        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4023        // room they were in.
4024        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4025            tracing::info!(
4026                stale_connection_id = %connection,
4027                "cleaning up stale connection",
4028            );
4029            drop(db);
4030            leave_room_for_session(&session, connection).await?;
4031            db = session.db().await;
4032        }
4033
4034        let (joined_room, membership_updated, role) = db
4035            .join_channel(channel_id, session.user_id(), session.connection_id)
4036            .await?;
4037
4038        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4039            let (can_publish, token) = if role == ChannelRole::Guest {
4040                (
4041                    false,
4042                    live_kit
4043                        .guest_token(
4044                            &joined_room.room.live_kit_room,
4045                            &session.user_id().to_string(),
4046                        )
4047                        .trace_err()?,
4048                )
4049            } else {
4050                (
4051                    true,
4052                    live_kit
4053                        .room_token(
4054                            &joined_room.room.live_kit_room,
4055                            &session.user_id().to_string(),
4056                        )
4057                        .trace_err()?,
4058                )
4059            };
4060
4061            Some(LiveKitConnectionInfo {
4062                server_url: live_kit.url().into(),
4063                token,
4064                can_publish,
4065            })
4066        });
4067
4068        response.send(proto::JoinRoomResponse {
4069            room: Some(joined_room.room.clone()),
4070            channel_id: joined_room
4071                .channel
4072                .as_ref()
4073                .map(|channel| channel.id.to_proto()),
4074            live_kit_connection_info,
4075        })?;
4076
4077        let mut connection_pool = session.connection_pool().await;
4078        if let Some(membership_updated) = membership_updated {
4079            notify_membership_updated(
4080                &mut connection_pool,
4081                membership_updated,
4082                session.user_id(),
4083                &session.peer,
4084            );
4085        }
4086
4087        room_updated(&joined_room.room, &session.peer);
4088
4089        joined_room
4090    };
4091
4092    channel_updated(
4093        &joined_room
4094            .channel
4095            .ok_or_else(|| anyhow!("channel not returned"))?,
4096        &joined_room.room,
4097        &session.peer,
4098        &*session.connection_pool().await,
4099    );
4100
4101    update_user_contacts(session.user_id(), &session).await?;
4102    Ok(())
4103}
4104
4105/// Start editing the channel notes
4106async fn join_channel_buffer(
4107    request: proto::JoinChannelBuffer,
4108    response: Response<proto::JoinChannelBuffer>,
4109    session: UserSession,
4110) -> Result<()> {
4111    let db = session.db().await;
4112    let channel_id = ChannelId::from_proto(request.channel_id);
4113
4114    let open_response = db
4115        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4116        .await?;
4117
4118    let collaborators = open_response.collaborators.clone();
4119    response.send(open_response)?;
4120
4121    let update = UpdateChannelBufferCollaborators {
4122        channel_id: channel_id.to_proto(),
4123        collaborators: collaborators.clone(),
4124    };
4125    channel_buffer_updated(
4126        session.connection_id,
4127        collaborators
4128            .iter()
4129            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4130        &update,
4131        &session.peer,
4132    );
4133
4134    Ok(())
4135}
4136
4137/// Edit the channel notes
4138async fn update_channel_buffer(
4139    request: proto::UpdateChannelBuffer,
4140    session: UserSession,
4141) -> Result<()> {
4142    let db = session.db().await;
4143    let channel_id = ChannelId::from_proto(request.channel_id);
4144
4145    let (collaborators, epoch, version) = db
4146        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4147        .await?;
4148
4149    channel_buffer_updated(
4150        session.connection_id,
4151        collaborators.clone(),
4152        &proto::UpdateChannelBuffer {
4153            channel_id: channel_id.to_proto(),
4154            operations: request.operations,
4155        },
4156        &session.peer,
4157    );
4158
4159    let pool = &*session.connection_pool().await;
4160
4161    let non_collaborators =
4162        pool.channel_connection_ids(channel_id)
4163            .filter_map(|(connection_id, _)| {
4164                if collaborators.contains(&connection_id) {
4165                    None
4166                } else {
4167                    Some(connection_id)
4168                }
4169            });
4170
4171    broadcast(None, non_collaborators, |peer_id| {
4172        session.peer.send(
4173            peer_id,
4174            proto::UpdateChannels {
4175                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4176                    channel_id: channel_id.to_proto(),
4177                    epoch: epoch as u64,
4178                    version: version.clone(),
4179                }],
4180                ..Default::default()
4181            },
4182        )
4183    });
4184
4185    Ok(())
4186}
4187
4188/// Rejoin the channel notes after a connection blip
4189async fn rejoin_channel_buffers(
4190    request: proto::RejoinChannelBuffers,
4191    response: Response<proto::RejoinChannelBuffers>,
4192    session: UserSession,
4193) -> Result<()> {
4194    let db = session.db().await;
4195    let buffers = db
4196        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4197        .await?;
4198
4199    for rejoined_buffer in &buffers {
4200        let collaborators_to_notify = rejoined_buffer
4201            .buffer
4202            .collaborators
4203            .iter()
4204            .filter_map(|c| Some(c.peer_id?.into()));
4205        channel_buffer_updated(
4206            session.connection_id,
4207            collaborators_to_notify,
4208            &proto::UpdateChannelBufferCollaborators {
4209                channel_id: rejoined_buffer.buffer.channel_id,
4210                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4211            },
4212            &session.peer,
4213        );
4214    }
4215
4216    response.send(proto::RejoinChannelBuffersResponse {
4217        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4218    })?;
4219
4220    Ok(())
4221}
4222
4223/// Stop editing the channel notes
4224async fn leave_channel_buffer(
4225    request: proto::LeaveChannelBuffer,
4226    response: Response<proto::LeaveChannelBuffer>,
4227    session: UserSession,
4228) -> Result<()> {
4229    let db = session.db().await;
4230    let channel_id = ChannelId::from_proto(request.channel_id);
4231
4232    let left_buffer = db
4233        .leave_channel_buffer(channel_id, session.connection_id)
4234        .await?;
4235
4236    response.send(Ack {})?;
4237
4238    channel_buffer_updated(
4239        session.connection_id,
4240        left_buffer.connections,
4241        &proto::UpdateChannelBufferCollaborators {
4242            channel_id: channel_id.to_proto(),
4243            collaborators: left_buffer.collaborators,
4244        },
4245        &session.peer,
4246    );
4247
4248    Ok(())
4249}
4250
4251fn channel_buffer_updated<T: EnvelopedMessage>(
4252    sender_id: ConnectionId,
4253    collaborators: impl IntoIterator<Item = ConnectionId>,
4254    message: &T,
4255    peer: &Peer,
4256) {
4257    broadcast(Some(sender_id), collaborators, |peer_id| {
4258        peer.send(peer_id, message.clone())
4259    });
4260}
4261
4262fn send_notifications(
4263    connection_pool: &ConnectionPool,
4264    peer: &Peer,
4265    notifications: db::NotificationBatch,
4266) {
4267    for (user_id, notification) in notifications {
4268        for connection_id in connection_pool.user_connection_ids(user_id) {
4269            if let Err(error) = peer.send(
4270                connection_id,
4271                proto::AddNotification {
4272                    notification: Some(notification.clone()),
4273                },
4274            ) {
4275                tracing::error!(
4276                    "failed to send notification to {:?} {}",
4277                    connection_id,
4278                    error
4279                );
4280            }
4281        }
4282    }
4283}
4284
4285/// Send a message to the channel
4286async fn send_channel_message(
4287    request: proto::SendChannelMessage,
4288    response: Response<proto::SendChannelMessage>,
4289    session: UserSession,
4290) -> Result<()> {
4291    // Validate the message body.
4292    let body = request.body.trim().to_string();
4293    if body.len() > MAX_MESSAGE_LEN {
4294        return Err(anyhow!("message is too long"))?;
4295    }
4296    if body.is_empty() {
4297        return Err(anyhow!("message can't be blank"))?;
4298    }
4299
4300    // TODO: adjust mentions if body is trimmed
4301
4302    let timestamp = OffsetDateTime::now_utc();
4303    let nonce = request
4304        .nonce
4305        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4306
4307    let channel_id = ChannelId::from_proto(request.channel_id);
4308    let CreatedChannelMessage {
4309        message_id,
4310        participant_connection_ids,
4311        notifications,
4312    } = session
4313        .db()
4314        .await
4315        .create_channel_message(
4316            channel_id,
4317            session.user_id(),
4318            &body,
4319            &request.mentions,
4320            timestamp,
4321            nonce.clone().into(),
4322            match request.reply_to_message_id {
4323                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4324                None => None,
4325            },
4326        )
4327        .await?;
4328
4329    let message = proto::ChannelMessage {
4330        sender_id: session.user_id().to_proto(),
4331        id: message_id.to_proto(),
4332        body,
4333        mentions: request.mentions,
4334        timestamp: timestamp.unix_timestamp() as u64,
4335        nonce: Some(nonce),
4336        reply_to_message_id: request.reply_to_message_id,
4337        edited_at: None,
4338    };
4339    broadcast(
4340        Some(session.connection_id),
4341        participant_connection_ids.clone(),
4342        |connection| {
4343            session.peer.send(
4344                connection,
4345                proto::ChannelMessageSent {
4346                    channel_id: channel_id.to_proto(),
4347                    message: Some(message.clone()),
4348                },
4349            )
4350        },
4351    );
4352    response.send(proto::SendChannelMessageResponse {
4353        message: Some(message),
4354    })?;
4355
4356    let pool = &*session.connection_pool().await;
4357    let non_participants =
4358        pool.channel_connection_ids(channel_id)
4359            .filter_map(|(connection_id, _)| {
4360                if participant_connection_ids.contains(&connection_id) {
4361                    None
4362                } else {
4363                    Some(connection_id)
4364                }
4365            });
4366    broadcast(None, non_participants, |peer_id| {
4367        session.peer.send(
4368            peer_id,
4369            proto::UpdateChannels {
4370                latest_channel_message_ids: vec![proto::ChannelMessageId {
4371                    channel_id: channel_id.to_proto(),
4372                    message_id: message_id.to_proto(),
4373                }],
4374                ..Default::default()
4375            },
4376        )
4377    });
4378    send_notifications(pool, &session.peer, notifications);
4379
4380    Ok(())
4381}
4382
4383/// Delete a channel message
4384async fn remove_channel_message(
4385    request: proto::RemoveChannelMessage,
4386    response: Response<proto::RemoveChannelMessage>,
4387    session: UserSession,
4388) -> Result<()> {
4389    let channel_id = ChannelId::from_proto(request.channel_id);
4390    let message_id = MessageId::from_proto(request.message_id);
4391    let (connection_ids, existing_notification_ids) = session
4392        .db()
4393        .await
4394        .remove_channel_message(channel_id, message_id, session.user_id())
4395        .await?;
4396
4397    broadcast(
4398        Some(session.connection_id),
4399        connection_ids,
4400        move |connection| {
4401            session.peer.send(connection, request.clone())?;
4402
4403            for notification_id in &existing_notification_ids {
4404                session.peer.send(
4405                    connection,
4406                    proto::DeleteNotification {
4407                        notification_id: (*notification_id).to_proto(),
4408                    },
4409                )?;
4410            }
4411
4412            Ok(())
4413        },
4414    );
4415    response.send(proto::Ack {})?;
4416    Ok(())
4417}
4418
4419async fn update_channel_message(
4420    request: proto::UpdateChannelMessage,
4421    response: Response<proto::UpdateChannelMessage>,
4422    session: UserSession,
4423) -> Result<()> {
4424    let channel_id = ChannelId::from_proto(request.channel_id);
4425    let message_id = MessageId::from_proto(request.message_id);
4426    let updated_at = OffsetDateTime::now_utc();
4427    let UpdatedChannelMessage {
4428        message_id,
4429        participant_connection_ids,
4430        notifications,
4431        reply_to_message_id,
4432        timestamp,
4433        deleted_mention_notification_ids,
4434        updated_mention_notifications,
4435    } = session
4436        .db()
4437        .await
4438        .update_channel_message(
4439            channel_id,
4440            message_id,
4441            session.user_id(),
4442            request.body.as_str(),
4443            &request.mentions,
4444            updated_at,
4445        )
4446        .await?;
4447
4448    let nonce = request
4449        .nonce
4450        .clone()
4451        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4452
4453    let message = proto::ChannelMessage {
4454        sender_id: session.user_id().to_proto(),
4455        id: message_id.to_proto(),
4456        body: request.body.clone(),
4457        mentions: request.mentions.clone(),
4458        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4459        nonce: Some(nonce),
4460        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4461        edited_at: Some(updated_at.unix_timestamp() as u64),
4462    };
4463
4464    response.send(proto::Ack {})?;
4465
4466    let pool = &*session.connection_pool().await;
4467    broadcast(
4468        Some(session.connection_id),
4469        participant_connection_ids,
4470        |connection| {
4471            session.peer.send(
4472                connection,
4473                proto::ChannelMessageUpdate {
4474                    channel_id: channel_id.to_proto(),
4475                    message: Some(message.clone()),
4476                },
4477            )?;
4478
4479            for notification_id in &deleted_mention_notification_ids {
4480                session.peer.send(
4481                    connection,
4482                    proto::DeleteNotification {
4483                        notification_id: (*notification_id).to_proto(),
4484                    },
4485                )?;
4486            }
4487
4488            for notification in &updated_mention_notifications {
4489                session.peer.send(
4490                    connection,
4491                    proto::UpdateNotification {
4492                        notification: Some(notification.clone()),
4493                    },
4494                )?;
4495            }
4496
4497            Ok(())
4498        },
4499    );
4500
4501    send_notifications(pool, &session.peer, notifications);
4502
4503    Ok(())
4504}
4505
4506/// Mark a channel message as read
4507async fn acknowledge_channel_message(
4508    request: proto::AckChannelMessage,
4509    session: UserSession,
4510) -> Result<()> {
4511    let channel_id = ChannelId::from_proto(request.channel_id);
4512    let message_id = MessageId::from_proto(request.message_id);
4513    let notifications = session
4514        .db()
4515        .await
4516        .observe_channel_message(channel_id, session.user_id(), message_id)
4517        .await?;
4518    send_notifications(
4519        &*session.connection_pool().await,
4520        &session.peer,
4521        notifications,
4522    );
4523    Ok(())
4524}
4525
4526/// Mark a buffer version as synced
4527async fn acknowledge_buffer_version(
4528    request: proto::AckBufferOperation,
4529    session: UserSession,
4530) -> Result<()> {
4531    let buffer_id = BufferId::from_proto(request.buffer_id);
4532    session
4533        .db()
4534        .await
4535        .observe_buffer_version(
4536            buffer_id,
4537            session.user_id(),
4538            request.epoch as i32,
4539            &request.version,
4540        )
4541        .await?;
4542    Ok(())
4543}
4544
4545struct ZedProCompleteWithLanguageModelRateLimit;
4546
4547impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
4548    fn capacity(&self) -> usize {
4549        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4550            .ok()
4551            .and_then(|v| v.parse().ok())
4552            .unwrap_or(120) // Picked arbitrarily
4553    }
4554
4555    fn refill_duration(&self) -> chrono::Duration {
4556        chrono::Duration::hours(1)
4557    }
4558
4559    fn db_name(&self) -> &'static str {
4560        "zed-pro:complete-with-language-model"
4561    }
4562}
4563
4564struct FreeCompleteWithLanguageModelRateLimit;
4565
4566impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
4567    fn capacity(&self) -> usize {
4568        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
4569            .ok()
4570            .and_then(|v| v.parse().ok())
4571            .unwrap_or(120 / 10) // Picked arbitrarily
4572    }
4573
4574    fn refill_duration(&self) -> chrono::Duration {
4575        chrono::Duration::hours(1)
4576    }
4577
4578    fn db_name(&self) -> &'static str {
4579        "free:complete-with-language-model"
4580    }
4581}
4582
4583async fn complete_with_language_model(
4584    request: proto::CompleteWithLanguageModel,
4585    response: Response<proto::CompleteWithLanguageModel>,
4586    session: Session,
4587    config: &Config,
4588) -> Result<()> {
4589    let Some(session) = session.for_user() else {
4590        return Err(anyhow!("user not found"))?;
4591    };
4592    authorize_access_to_language_models(&session).await?;
4593
4594    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4595        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4596        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4597    };
4598
4599    session
4600        .rate_limiter
4601        .check(&*rate_limit, session.user_id())
4602        .await?;
4603
4604    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4605        Some(proto::LanguageModelProvider::Anthropic) => {
4606            let api_key = config
4607                .anthropic_api_key
4608                .as_ref()
4609                .context("no Anthropic AI API key configured on the server")?;
4610            anthropic::complete(
4611                session.http_client.as_ref(),
4612                anthropic::ANTHROPIC_API_URL,
4613                api_key,
4614                serde_json::from_str(&request.request)?,
4615            )
4616            .await?
4617        }
4618        _ => return Err(anyhow!("unsupported provider"))?,
4619    };
4620
4621    response.send(proto::CompleteWithLanguageModelResponse {
4622        completion: serde_json::to_string(&result)?,
4623    })?;
4624
4625    Ok(())
4626}
4627
4628async fn stream_complete_with_language_model(
4629    request: proto::StreamCompleteWithLanguageModel,
4630    response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4631    session: Session,
4632    config: &Config,
4633) -> Result<()> {
4634    let Some(session) = session.for_user() else {
4635        return Err(anyhow!("user not found"))?;
4636    };
4637    authorize_access_to_language_models(&session).await?;
4638
4639    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4640        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4641        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4642    };
4643
4644    session
4645        .rate_limiter
4646        .check(&*rate_limit, session.user_id())
4647        .await?;
4648
4649    match proto::LanguageModelProvider::from_i32(request.provider) {
4650        Some(proto::LanguageModelProvider::Anthropic) => {
4651            let api_key = config
4652                .anthropic_api_key
4653                .as_ref()
4654                .context("no Anthropic AI API key configured on the server")?;
4655            let mut chunks = anthropic::stream_completion(
4656                session.http_client.as_ref(),
4657                anthropic::ANTHROPIC_API_URL,
4658                api_key,
4659                serde_json::from_str(&request.request)?,
4660                None,
4661            )
4662            .await?;
4663            while let Some(event) = chunks.next().await {
4664                let chunk = event?;
4665                response.send(proto::StreamCompleteWithLanguageModelResponse {
4666                    event: serde_json::to_string(&chunk)?,
4667                })?;
4668            }
4669        }
4670        Some(proto::LanguageModelProvider::OpenAi) => {
4671            let api_key = config
4672                .openai_api_key
4673                .as_ref()
4674                .context("no OpenAI API key configured on the server")?;
4675            let mut events = open_ai::stream_completion(
4676                session.http_client.as_ref(),
4677                open_ai::OPEN_AI_API_URL,
4678                api_key,
4679                serde_json::from_str(&request.request)?,
4680                None,
4681            )
4682            .await?;
4683            while let Some(event) = events.next().await {
4684                let event = event?;
4685                response.send(proto::StreamCompleteWithLanguageModelResponse {
4686                    event: serde_json::to_string(&event)?,
4687                })?;
4688            }
4689        }
4690        Some(proto::LanguageModelProvider::Google) => {
4691            let api_key = config
4692                .google_ai_api_key
4693                .as_ref()
4694                .context("no Google AI API key configured on the server")?;
4695            let mut events = google_ai::stream_generate_content(
4696                session.http_client.as_ref(),
4697                google_ai::API_URL,
4698                api_key,
4699                serde_json::from_str(&request.request)?,
4700            )
4701            .await?;
4702            while let Some(event) = events.next().await {
4703                let event = event?;
4704                response.send(proto::StreamCompleteWithLanguageModelResponse {
4705                    event: serde_json::to_string(&event)?,
4706                })?;
4707            }
4708        }
4709        Some(proto::LanguageModelProvider::Zed) => {
4710            let api_key = config
4711                .qwen2_7b_api_key
4712                .as_ref()
4713                .context("no Qwen2-7B API key configured on the server")?;
4714            let api_url = config
4715                .qwen2_7b_api_url
4716                .as_ref()
4717                .context("no Qwen2-7B URL configured on the server")?;
4718            let mut events = open_ai::stream_completion(
4719                session.http_client.as_ref(),
4720                &api_url,
4721                api_key,
4722                serde_json::from_str(&request.request)?,
4723                None,
4724            )
4725            .await?;
4726            while let Some(event) = events.next().await {
4727                let event = event?;
4728                response.send(proto::StreamCompleteWithLanguageModelResponse {
4729                    event: serde_json::to_string(&event)?,
4730                })?;
4731            }
4732        }
4733        None => return Err(anyhow!("unknown provider"))?,
4734    }
4735
4736    Ok(())
4737}
4738
4739async fn count_language_model_tokens(
4740    request: proto::CountLanguageModelTokens,
4741    response: Response<proto::CountLanguageModelTokens>,
4742    session: Session,
4743    config: &Config,
4744) -> Result<()> {
4745    let Some(session) = session.for_user() else {
4746        return Err(anyhow!("user not found"))?;
4747    };
4748    authorize_access_to_language_models(&session).await?;
4749
4750    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4751        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4752        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4753    };
4754
4755    session
4756        .rate_limiter
4757        .check(&*rate_limit, session.user_id())
4758        .await?;
4759
4760    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4761        Some(proto::LanguageModelProvider::Google) => {
4762            let api_key = config
4763                .google_ai_api_key
4764                .as_ref()
4765                .context("no Google AI API key configured on the server")?;
4766            google_ai::count_tokens(
4767                session.http_client.as_ref(),
4768                google_ai::API_URL,
4769                api_key,
4770                serde_json::from_str(&request.request)?,
4771            )
4772            .await?
4773        }
4774        _ => return Err(anyhow!("unsupported provider"))?,
4775    };
4776
4777    response.send(proto::CountLanguageModelTokensResponse {
4778        token_count: result.total_tokens as u32,
4779    })?;
4780
4781    Ok(())
4782}
4783
4784struct ZedProCountLanguageModelTokensRateLimit;
4785
4786impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4787    fn capacity(&self) -> usize {
4788        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4789            .ok()
4790            .and_then(|v| v.parse().ok())
4791            .unwrap_or(600) // Picked arbitrarily
4792    }
4793
4794    fn refill_duration(&self) -> chrono::Duration {
4795        chrono::Duration::hours(1)
4796    }
4797
4798    fn db_name(&self) -> &'static str {
4799        "zed-pro:count-language-model-tokens"
4800    }
4801}
4802
4803struct FreeCountLanguageModelTokensRateLimit;
4804
4805impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4806    fn capacity(&self) -> usize {
4807        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4808            .ok()
4809            .and_then(|v| v.parse().ok())
4810            .unwrap_or(600 / 10) // Picked arbitrarily
4811    }
4812
4813    fn refill_duration(&self) -> chrono::Duration {
4814        chrono::Duration::hours(1)
4815    }
4816
4817    fn db_name(&self) -> &'static str {
4818        "free:count-language-model-tokens"
4819    }
4820}
4821
4822struct ZedProComputeEmbeddingsRateLimit;
4823
4824impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4825    fn capacity(&self) -> usize {
4826        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4827            .ok()
4828            .and_then(|v| v.parse().ok())
4829            .unwrap_or(5000) // Picked arbitrarily
4830    }
4831
4832    fn refill_duration(&self) -> chrono::Duration {
4833        chrono::Duration::hours(1)
4834    }
4835
4836    fn db_name(&self) -> &'static str {
4837        "zed-pro:compute-embeddings"
4838    }
4839}
4840
4841struct FreeComputeEmbeddingsRateLimit;
4842
4843impl RateLimit for FreeComputeEmbeddingsRateLimit {
4844    fn capacity(&self) -> usize {
4845        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4846            .ok()
4847            .and_then(|v| v.parse().ok())
4848            .unwrap_or(5000 / 10) // Picked arbitrarily
4849    }
4850
4851    fn refill_duration(&self) -> chrono::Duration {
4852        chrono::Duration::hours(1)
4853    }
4854
4855    fn db_name(&self) -> &'static str {
4856        "free:compute-embeddings"
4857    }
4858}
4859
4860async fn compute_embeddings(
4861    request: proto::ComputeEmbeddings,
4862    response: Response<proto::ComputeEmbeddings>,
4863    session: UserSession,
4864    api_key: Option<Arc<str>>,
4865) -> Result<()> {
4866    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4867    authorize_access_to_language_models(&session).await?;
4868
4869    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4870        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4871        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4872    };
4873
4874    session
4875        .rate_limiter
4876        .check(&*rate_limit, session.user_id())
4877        .await?;
4878
4879    let embeddings = match request.model.as_str() {
4880        "openai/text-embedding-3-small" => {
4881            open_ai::embed(
4882                session.http_client.as_ref(),
4883                OPEN_AI_API_URL,
4884                &api_key,
4885                OpenAiEmbeddingModel::TextEmbedding3Small,
4886                request.texts.iter().map(|text| text.as_str()),
4887            )
4888            .await?
4889        }
4890        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4891    };
4892
4893    let embeddings = request
4894        .texts
4895        .iter()
4896        .map(|text| {
4897            let mut hasher = sha2::Sha256::new();
4898            hasher.update(text.as_bytes());
4899            let result = hasher.finalize();
4900            result.to_vec()
4901        })
4902        .zip(
4903            embeddings
4904                .data
4905                .into_iter()
4906                .map(|embedding| embedding.embedding),
4907        )
4908        .collect::<HashMap<_, _>>();
4909
4910    let db = session.db().await;
4911    db.save_embeddings(&request.model, &embeddings)
4912        .await
4913        .context("failed to save embeddings")
4914        .trace_err();
4915
4916    response.send(proto::ComputeEmbeddingsResponse {
4917        embeddings: embeddings
4918            .into_iter()
4919            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4920            .collect(),
4921    })?;
4922    Ok(())
4923}
4924
4925async fn get_cached_embeddings(
4926    request: proto::GetCachedEmbeddings,
4927    response: Response<proto::GetCachedEmbeddings>,
4928    session: UserSession,
4929) -> Result<()> {
4930    authorize_access_to_language_models(&session).await?;
4931
4932    let db = session.db().await;
4933    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4934
4935    response.send(proto::GetCachedEmbeddingsResponse {
4936        embeddings: embeddings
4937            .into_iter()
4938            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4939            .collect(),
4940    })?;
4941    Ok(())
4942}
4943
4944async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4945    let db = session.db().await;
4946    let flags = db.get_user_flags(session.user_id()).await?;
4947    if flags.iter().any(|flag| flag == "language-models") {
4948        return Ok(());
4949    }
4950
4951    Err(anyhow!("permission denied"))?
4952}
4953
4954/// Get a Supermaven API key for the user
4955async fn get_supermaven_api_key(
4956    _request: proto::GetSupermavenApiKey,
4957    response: Response<proto::GetSupermavenApiKey>,
4958    session: UserSession,
4959) -> Result<()> {
4960    let user_id: String = session.user_id().to_string();
4961    if !session.is_staff() {
4962        return Err(anyhow!("supermaven not enabled for this account"))?;
4963    }
4964
4965    let email = session
4966        .email()
4967        .ok_or_else(|| anyhow!("user must have an email"))?;
4968
4969    let supermaven_admin_api = session
4970        .supermaven_client
4971        .as_ref()
4972        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4973
4974    let result = supermaven_admin_api
4975        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4976        .await?;
4977
4978    response.send(proto::GetSupermavenApiKeyResponse {
4979        api_key: result.api_key,
4980    })?;
4981
4982    Ok(())
4983}
4984
4985/// Start receiving chat updates for a channel
4986async fn join_channel_chat(
4987    request: proto::JoinChannelChat,
4988    response: Response<proto::JoinChannelChat>,
4989    session: UserSession,
4990) -> Result<()> {
4991    let channel_id = ChannelId::from_proto(request.channel_id);
4992
4993    let db = session.db().await;
4994    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4995        .await?;
4996    let messages = db
4997        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4998        .await?;
4999    response.send(proto::JoinChannelChatResponse {
5000        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5001        messages,
5002    })?;
5003    Ok(())
5004}
5005
5006/// Stop receiving chat updates for a channel
5007async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5008    let channel_id = ChannelId::from_proto(request.channel_id);
5009    session
5010        .db()
5011        .await
5012        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5013        .await?;
5014    Ok(())
5015}
5016
5017/// Retrieve the chat history for a channel
5018async fn get_channel_messages(
5019    request: proto::GetChannelMessages,
5020    response: Response<proto::GetChannelMessages>,
5021    session: UserSession,
5022) -> Result<()> {
5023    let channel_id = ChannelId::from_proto(request.channel_id);
5024    let messages = session
5025        .db()
5026        .await
5027        .get_channel_messages(
5028            channel_id,
5029            session.user_id(),
5030            MESSAGE_COUNT_PER_PAGE,
5031            Some(MessageId::from_proto(request.before_message_id)),
5032        )
5033        .await?;
5034    response.send(proto::GetChannelMessagesResponse {
5035        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5036        messages,
5037    })?;
5038    Ok(())
5039}
5040
5041/// Retrieve specific chat messages
5042async fn get_channel_messages_by_id(
5043    request: proto::GetChannelMessagesById,
5044    response: Response<proto::GetChannelMessagesById>,
5045    session: UserSession,
5046) -> Result<()> {
5047    let message_ids = request
5048        .message_ids
5049        .iter()
5050        .map(|id| MessageId::from_proto(*id))
5051        .collect::<Vec<_>>();
5052    let messages = session
5053        .db()
5054        .await
5055        .get_channel_messages_by_id(session.user_id(), &message_ids)
5056        .await?;
5057    response.send(proto::GetChannelMessagesResponse {
5058        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5059        messages,
5060    })?;
5061    Ok(())
5062}
5063
5064/// Retrieve the current users notifications
5065async fn get_notifications(
5066    request: proto::GetNotifications,
5067    response: Response<proto::GetNotifications>,
5068    session: UserSession,
5069) -> Result<()> {
5070    let notifications = session
5071        .db()
5072        .await
5073        .get_notifications(
5074            session.user_id(),
5075            NOTIFICATION_COUNT_PER_PAGE,
5076            request
5077                .before_id
5078                .map(|id| db::NotificationId::from_proto(id)),
5079        )
5080        .await?;
5081    response.send(proto::GetNotificationsResponse {
5082        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5083        notifications,
5084    })?;
5085    Ok(())
5086}
5087
5088/// Mark notifications as read
5089async fn mark_notification_as_read(
5090    request: proto::MarkNotificationRead,
5091    response: Response<proto::MarkNotificationRead>,
5092    session: UserSession,
5093) -> Result<()> {
5094    let database = &session.db().await;
5095    let notifications = database
5096        .mark_notification_as_read_by_id(
5097            session.user_id(),
5098            NotificationId::from_proto(request.notification_id),
5099        )
5100        .await?;
5101    send_notifications(
5102        &*session.connection_pool().await,
5103        &session.peer,
5104        notifications,
5105    );
5106    response.send(proto::Ack {})?;
5107    Ok(())
5108}
5109
5110/// Get the current users information
5111async fn get_private_user_info(
5112    _request: proto::GetPrivateUserInfo,
5113    response: Response<proto::GetPrivateUserInfo>,
5114    session: UserSession,
5115) -> Result<()> {
5116    let db = session.db().await;
5117
5118    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5119    let user = db
5120        .get_user_by_id(session.user_id())
5121        .await?
5122        .ok_or_else(|| anyhow!("user not found"))?;
5123    let flags = db.get_user_flags(session.user_id()).await?;
5124
5125    response.send(proto::GetPrivateUserInfoResponse {
5126        metrics_id,
5127        staff: user.admin,
5128        flags,
5129    })?;
5130    Ok(())
5131}
5132
5133fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5134    let message = match message {
5135        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5136        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5137        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5138        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5139        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5140            code: frame.code.into(),
5141            reason: frame.reason,
5142        })),
5143        // We should never receive a frame while reading the message, according
5144        // to the `tungstenite` maintainers:
5145        //
5146        // > It cannot occur when you read messages from the WebSocket, but it
5147        // > can be used when you want to send the raw frames (e.g. you want to
5148        // > send the frames to the WebSocket without composing the full message first).
5149        // >
5150        // > — https://github.com/snapview/tungstenite-rs/issues/268
5151        TungsteniteMessage::Frame(_) => {
5152            bail!("received an unexpected frame while reading the message")
5153        }
5154    };
5155
5156    Ok(message)
5157}
5158
5159fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5160    match message {
5161        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5162        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5163        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5164        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5165        AxumMessage::Close(frame) => {
5166            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5167                code: frame.code.into(),
5168                reason: frame.reason,
5169            }))
5170        }
5171    }
5172}
5173
5174fn notify_membership_updated(
5175    connection_pool: &mut ConnectionPool,
5176    result: MembershipUpdated,
5177    user_id: UserId,
5178    peer: &Peer,
5179) {
5180    for membership in &result.new_channels.channel_memberships {
5181        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5182    }
5183    for channel_id in &result.removed_channels {
5184        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5185    }
5186
5187    let user_channels_update = proto::UpdateUserChannels {
5188        channel_memberships: result
5189            .new_channels
5190            .channel_memberships
5191            .iter()
5192            .map(|cm| proto::ChannelMembership {
5193                channel_id: cm.channel_id.to_proto(),
5194                role: cm.role.into(),
5195            })
5196            .collect(),
5197        ..Default::default()
5198    };
5199
5200    let mut update = build_channels_update(result.new_channels);
5201    update.delete_channels = result
5202        .removed_channels
5203        .into_iter()
5204        .map(|id| id.to_proto())
5205        .collect();
5206    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5207
5208    for connection_id in connection_pool.user_connection_ids(user_id) {
5209        peer.send(connection_id, user_channels_update.clone())
5210            .trace_err();
5211        peer.send(connection_id, update.clone()).trace_err();
5212    }
5213}
5214
5215fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5216    proto::UpdateUserChannels {
5217        channel_memberships: channels
5218            .channel_memberships
5219            .iter()
5220            .map(|m| proto::ChannelMembership {
5221                channel_id: m.channel_id.to_proto(),
5222                role: m.role.into(),
5223            })
5224            .collect(),
5225        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5226        observed_channel_message_id: channels.observed_channel_messages.clone(),
5227    }
5228}
5229
5230fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5231    let mut update = proto::UpdateChannels::default();
5232
5233    for channel in channels.channels {
5234        update.channels.push(channel.to_proto());
5235    }
5236
5237    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5238    update.latest_channel_message_ids = channels.latest_channel_messages;
5239
5240    for (channel_id, participants) in channels.channel_participants {
5241        update
5242            .channel_participants
5243            .push(proto::ChannelParticipants {
5244                channel_id: channel_id.to_proto(),
5245                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5246            });
5247    }
5248
5249    for channel in channels.invited_channels {
5250        update.channel_invitations.push(channel.to_proto());
5251    }
5252
5253    update.hosted_projects = channels.hosted_projects;
5254    update
5255}
5256
5257fn build_initial_contacts_update(
5258    contacts: Vec<db::Contact>,
5259    pool: &ConnectionPool,
5260) -> proto::UpdateContacts {
5261    let mut update = proto::UpdateContacts::default();
5262
5263    for contact in contacts {
5264        match contact {
5265            db::Contact::Accepted { user_id, busy } => {
5266                update.contacts.push(contact_for_user(user_id, busy, &pool));
5267            }
5268            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5269            db::Contact::Incoming { user_id } => {
5270                update
5271                    .incoming_requests
5272                    .push(proto::IncomingContactRequest {
5273                        requester_id: user_id.to_proto(),
5274                    })
5275            }
5276        }
5277    }
5278
5279    update
5280}
5281
5282fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5283    proto::Contact {
5284        user_id: user_id.to_proto(),
5285        online: pool.is_user_online(user_id),
5286        busy,
5287    }
5288}
5289
5290fn room_updated(room: &proto::Room, peer: &Peer) {
5291    broadcast(
5292        None,
5293        room.participants
5294            .iter()
5295            .filter_map(|participant| Some(participant.peer_id?.into())),
5296        |peer_id| {
5297            peer.send(
5298                peer_id,
5299                proto::RoomUpdated {
5300                    room: Some(room.clone()),
5301                },
5302            )
5303        },
5304    );
5305}
5306
5307fn channel_updated(
5308    channel: &db::channel::Model,
5309    room: &proto::Room,
5310    peer: &Peer,
5311    pool: &ConnectionPool,
5312) {
5313    let participants = room
5314        .participants
5315        .iter()
5316        .map(|p| p.user_id)
5317        .collect::<Vec<_>>();
5318
5319    broadcast(
5320        None,
5321        pool.channel_connection_ids(channel.root_id())
5322            .filter_map(|(channel_id, role)| {
5323                role.can_see_channel(channel.visibility).then(|| channel_id)
5324            }),
5325        |peer_id| {
5326            peer.send(
5327                peer_id,
5328                proto::UpdateChannels {
5329                    channel_participants: vec![proto::ChannelParticipants {
5330                        channel_id: channel.id.to_proto(),
5331                        participant_user_ids: participants.clone(),
5332                    }],
5333                    ..Default::default()
5334                },
5335            )
5336        },
5337    );
5338}
5339
5340async fn send_dev_server_projects_update(
5341    user_id: UserId,
5342    mut status: proto::DevServerProjectsUpdate,
5343    session: &Session,
5344) {
5345    let pool = session.connection_pool().await;
5346    for dev_server in &mut status.dev_servers {
5347        dev_server.status =
5348            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5349    }
5350    let connections = pool.user_connection_ids(user_id);
5351    for connection_id in connections {
5352        session.peer.send(connection_id, status.clone()).trace_err();
5353    }
5354}
5355
5356async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5357    let db = session.db().await;
5358
5359    let contacts = db.get_contacts(user_id).await?;
5360    let busy = db.is_user_busy(user_id).await?;
5361
5362    let pool = session.connection_pool().await;
5363    let updated_contact = contact_for_user(user_id, busy, &pool);
5364    for contact in contacts {
5365        if let db::Contact::Accepted {
5366            user_id: contact_user_id,
5367            ..
5368        } = contact
5369        {
5370            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5371                session
5372                    .peer
5373                    .send(
5374                        contact_conn_id,
5375                        proto::UpdateContacts {
5376                            contacts: vec![updated_contact.clone()],
5377                            remove_contacts: Default::default(),
5378                            incoming_requests: Default::default(),
5379                            remove_incoming_requests: Default::default(),
5380                            outgoing_requests: Default::default(),
5381                            remove_outgoing_requests: Default::default(),
5382                        },
5383                    )
5384                    .trace_err();
5385            }
5386        }
5387    }
5388    Ok(())
5389}
5390
5391async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5392    log::info!("lost dev server connection, unsharing projects");
5393    let project_ids = session
5394        .db()
5395        .await
5396        .get_stale_dev_server_projects(session.connection_id)
5397        .await?;
5398
5399    for project_id in project_ids {
5400        // not unshare re-checks the connection ids match, so we get away with no transaction
5401        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5402    }
5403
5404    let user_id = session.dev_server().user_id;
5405    let update = session
5406        .db()
5407        .await
5408        .dev_server_projects_update(user_id)
5409        .await?;
5410
5411    send_dev_server_projects_update(user_id, update, session).await;
5412
5413    Ok(())
5414}
5415
5416async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5417    let mut contacts_to_update = HashSet::default();
5418
5419    let room_id;
5420    let canceled_calls_to_user_ids;
5421    let live_kit_room;
5422    let delete_live_kit_room;
5423    let room;
5424    let channel;
5425
5426    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5427        contacts_to_update.insert(session.user_id());
5428
5429        for project in left_room.left_projects.values() {
5430            project_left(project, session);
5431        }
5432
5433        room_id = RoomId::from_proto(left_room.room.id);
5434        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5435        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5436        delete_live_kit_room = left_room.deleted;
5437        room = mem::take(&mut left_room.room);
5438        channel = mem::take(&mut left_room.channel);
5439
5440        room_updated(&room, &session.peer);
5441    } else {
5442        return Ok(());
5443    }
5444
5445    if let Some(channel) = channel {
5446        channel_updated(
5447            &channel,
5448            &room,
5449            &session.peer,
5450            &*session.connection_pool().await,
5451        );
5452    }
5453
5454    {
5455        let pool = session.connection_pool().await;
5456        for canceled_user_id in canceled_calls_to_user_ids {
5457            for connection_id in pool.user_connection_ids(canceled_user_id) {
5458                session
5459                    .peer
5460                    .send(
5461                        connection_id,
5462                        proto::CallCanceled {
5463                            room_id: room_id.to_proto(),
5464                        },
5465                    )
5466                    .trace_err();
5467            }
5468            contacts_to_update.insert(canceled_user_id);
5469        }
5470    }
5471
5472    for contact_user_id in contacts_to_update {
5473        update_user_contacts(contact_user_id, &session).await?;
5474    }
5475
5476    if let Some(live_kit) = session.live_kit_client.as_ref() {
5477        live_kit
5478            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5479            .await
5480            .trace_err();
5481
5482        if delete_live_kit_room {
5483            live_kit.delete_room(live_kit_room).await.trace_err();
5484        }
5485    }
5486
5487    Ok(())
5488}
5489
5490async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5491    let left_channel_buffers = session
5492        .db()
5493        .await
5494        .leave_channel_buffers(session.connection_id)
5495        .await?;
5496
5497    for left_buffer in left_channel_buffers {
5498        channel_buffer_updated(
5499            session.connection_id,
5500            left_buffer.connections,
5501            &proto::UpdateChannelBufferCollaborators {
5502                channel_id: left_buffer.channel_id.to_proto(),
5503                collaborators: left_buffer.collaborators,
5504            },
5505            &session.peer,
5506        );
5507    }
5508
5509    Ok(())
5510}
5511
5512fn project_left(project: &db::LeftProject, session: &UserSession) {
5513    for connection_id in &project.connection_ids {
5514        if project.should_unshare {
5515            session
5516                .peer
5517                .send(
5518                    *connection_id,
5519                    proto::UnshareProject {
5520                        project_id: project.id.to_proto(),
5521                    },
5522                )
5523                .trace_err();
5524        } else {
5525            session
5526                .peer
5527                .send(
5528                    *connection_id,
5529                    proto::RemoveProjectCollaborator {
5530                        project_id: project.id.to_proto(),
5531                        peer_id: Some(session.connection_id.into()),
5532                    },
5533                )
5534                .trace_err();
5535        }
5536    }
5537}
5538
5539pub trait ResultExt {
5540    type Ok;
5541
5542    fn trace_err(self) -> Option<Self::Ok>;
5543}
5544
5545impl<T, E> ResultExt for Result<T, E>
5546where
5547    E: std::fmt::Debug,
5548{
5549    type Ok = T;
5550
5551    #[track_caller]
5552    fn trace_err(self) -> Option<T> {
5553        match self {
5554            Ok(value) => Some(value),
5555            Err(error) => {
5556                tracing::error!("{:?}", error);
5557                None
5558            }
5559        }
5560    }
5561}