rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Config, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, bail, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http_client::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(update_dev_server_project))
 435            .add_request_handler(user_handler(delete_dev_server_project))
 436            .add_request_handler(user_handler(create_dev_server))
 437            .add_request_handler(user_handler(regenerate_dev_server_token))
 438            .add_request_handler(user_handler(rename_dev_server))
 439            .add_request_handler(user_handler(delete_dev_server))
 440            .add_request_handler(user_handler(list_remote_directory))
 441            .add_request_handler(dev_server_handler(share_dev_server_project))
 442            .add_request_handler(dev_server_handler(shutdown_dev_server))
 443            .add_request_handler(dev_server_handler(reconnect_dev_server))
 444            .add_message_handler(user_message_handler(leave_project))
 445            .add_request_handler(update_project)
 446            .add_request_handler(update_worktree)
 447            .add_message_handler(start_language_server)
 448            .add_message_handler(update_language_server)
 449            .add_message_handler(update_diagnostic_summary)
 450            .add_message_handler(update_worktree_settings)
 451            .add_request_handler(user_handler(
 452                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 453            ))
 454            .add_request_handler(user_handler(
 455                forward_project_request_for_owner::<proto::TaskTemplates>,
 456            ))
 457            .add_request_handler(user_handler(
 458                forward_read_only_project_request::<proto::GetHover>,
 459            ))
 460            .add_request_handler(user_handler(
 461                forward_read_only_project_request::<proto::GetDefinition>,
 462            ))
 463            .add_request_handler(user_handler(
 464                forward_read_only_project_request::<proto::GetTypeDefinition>,
 465            ))
 466            .add_request_handler(user_handler(
 467                forward_read_only_project_request::<proto::GetReferences>,
 468            ))
 469            .add_request_handler(user_handler(
 470                forward_read_only_project_request::<proto::SearchProject>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetProjectSymbols>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::OpenBufferById>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::InlayHints>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::OpenBufferByPath>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_mutating_project_request::<proto::GetCompletions>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::GetCodeActions>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::ApplyCodeAction>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::PrepareRename>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::PerformRename>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ReloadBuffers>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::FormatBuffers>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::CreateProjectEntry>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::RenameProjectEntry>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::CopyProjectEntry>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::OnTypeFormatting>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::BlameBuffer>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::MultiLspQuery>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::RestartLanguageServers>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::LinkedEditingRange>,
 555            ))
 556            .add_message_handler(create_buffer_for_peer)
 557            .add_request_handler(update_buffer)
 558            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 559            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 560            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 561            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 562            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 563            .add_request_handler(get_users)
 564            .add_request_handler(user_handler(fuzzy_search_users))
 565            .add_request_handler(user_handler(request_contact))
 566            .add_request_handler(user_handler(remove_contact))
 567            .add_request_handler(user_handler(respond_to_contact_request))
 568            .add_message_handler(subscribe_to_channels)
 569            .add_request_handler(user_handler(create_channel))
 570            .add_request_handler(user_handler(delete_channel))
 571            .add_request_handler(user_handler(invite_channel_member))
 572            .add_request_handler(user_handler(remove_channel_member))
 573            .add_request_handler(user_handler(set_channel_member_role))
 574            .add_request_handler(user_handler(set_channel_visibility))
 575            .add_request_handler(user_handler(rename_channel))
 576            .add_request_handler(user_handler(join_channel_buffer))
 577            .add_request_handler(user_handler(leave_channel_buffer))
 578            .add_message_handler(user_message_handler(update_channel_buffer))
 579            .add_request_handler(user_handler(rejoin_channel_buffers))
 580            .add_request_handler(user_handler(get_channel_members))
 581            .add_request_handler(user_handler(respond_to_channel_invite))
 582            .add_request_handler(user_handler(join_channel))
 583            .add_request_handler(user_handler(join_channel_chat))
 584            .add_message_handler(user_message_handler(leave_channel_chat))
 585            .add_request_handler(user_handler(send_channel_message))
 586            .add_request_handler(user_handler(remove_channel_message))
 587            .add_request_handler(user_handler(update_channel_message))
 588            .add_request_handler(user_handler(get_channel_messages))
 589            .add_request_handler(user_handler(get_channel_messages_by_id))
 590            .add_request_handler(user_handler(get_notifications))
 591            .add_request_handler(user_handler(mark_notification_as_read))
 592            .add_request_handler(user_handler(move_channel))
 593            .add_request_handler(user_handler(follow))
 594            .add_message_handler(user_message_handler(unfollow))
 595            .add_message_handler(user_message_handler(update_followers))
 596            .add_request_handler(user_handler(get_private_user_info))
 597            .add_message_handler(user_message_handler(acknowledge_channel_message))
 598            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 599            .add_request_handler(user_handler(get_supermaven_api_key))
 600            .add_request_handler(user_handler(
 601                forward_mutating_project_request::<proto::OpenContext>,
 602            ))
 603            .add_request_handler(user_handler(
 604                forward_mutating_project_request::<proto::CreateContext>,
 605            ))
 606            .add_request_handler(user_handler(
 607                forward_mutating_project_request::<proto::SynchronizeContexts>,
 608            ))
 609            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 610            .add_message_handler(update_context)
 611            .add_request_handler({
 612                let app_state = app_state.clone();
 613                move |request, response, session| {
 614                    let app_state = app_state.clone();
 615                    async move {
 616                        complete_with_language_model(request, response, session, &app_state.config)
 617                            .await
 618                    }
 619                }
 620            })
 621            .add_streaming_request_handler({
 622                let app_state = app_state.clone();
 623                move |request, response, session| {
 624                    let app_state = app_state.clone();
 625                    async move {
 626                        stream_complete_with_language_model(
 627                            request,
 628                            response,
 629                            session,
 630                            &app_state.config,
 631                        )
 632                        .await
 633                    }
 634                }
 635            })
 636            .add_request_handler({
 637                let app_state = app_state.clone();
 638                move |request, response, session| {
 639                    let app_state = app_state.clone();
 640                    async move {
 641                        count_language_model_tokens(request, response, session, &app_state.config)
 642                            .await
 643                    }
 644                }
 645            })
 646            .add_request_handler({
 647                user_handler(move |request, response, session| {
 648                    get_cached_embeddings(request, response, session)
 649                })
 650            })
 651            .add_request_handler({
 652                let app_state = app_state.clone();
 653                user_handler(move |request, response, session| {
 654                    compute_embeddings(
 655                        request,
 656                        response,
 657                        session,
 658                        app_state.config.openai_api_key.clone(),
 659                    )
 660                })
 661            });
 662
 663        Arc::new(server)
 664    }
 665
 666    pub async fn start(&self) -> Result<()> {
 667        let server_id = *self.id.lock();
 668        let app_state = self.app_state.clone();
 669        let peer = self.peer.clone();
 670        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 671        let pool = self.connection_pool.clone();
 672        let live_kit_client = self.app_state.live_kit_client.clone();
 673
 674        let span = info_span!("start server");
 675        self.app_state.executor.spawn_detached(
 676            async move {
 677                tracing::info!("waiting for cleanup timeout");
 678                timeout.await;
 679                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 680                if let Some((room_ids, channel_ids)) = app_state
 681                    .db
 682                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 683                    .await
 684                    .trace_err()
 685                {
 686                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 687                    tracing::info!(
 688                        stale_channel_buffer_count = channel_ids.len(),
 689                        "retrieved stale channel buffers"
 690                    );
 691
 692                    for channel_id in channel_ids {
 693                        if let Some(refreshed_channel_buffer) = app_state
 694                            .db
 695                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 696                            .await
 697                            .trace_err()
 698                        {
 699                            for connection_id in refreshed_channel_buffer.connection_ids {
 700                                peer.send(
 701                                    connection_id,
 702                                    proto::UpdateChannelBufferCollaborators {
 703                                        channel_id: channel_id.to_proto(),
 704                                        collaborators: refreshed_channel_buffer
 705                                            .collaborators
 706                                            .clone(),
 707                                    },
 708                                )
 709                                .trace_err();
 710                            }
 711                        }
 712                    }
 713
 714                    for room_id in room_ids {
 715                        let mut contacts_to_update = HashSet::default();
 716                        let mut canceled_calls_to_user_ids = Vec::new();
 717                        let mut live_kit_room = String::new();
 718                        let mut delete_live_kit_room = false;
 719
 720                        if let Some(mut refreshed_room) = app_state
 721                            .db
 722                            .clear_stale_room_participants(room_id, server_id)
 723                            .await
 724                            .trace_err()
 725                        {
 726                            tracing::info!(
 727                                room_id = room_id.0,
 728                                new_participant_count = refreshed_room.room.participants.len(),
 729                                "refreshed room"
 730                            );
 731                            room_updated(&refreshed_room.room, &peer);
 732                            if let Some(channel) = refreshed_room.channel.as_ref() {
 733                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 734                            }
 735                            contacts_to_update
 736                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 737                            contacts_to_update
 738                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 739                            canceled_calls_to_user_ids =
 740                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 741                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 742                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 743                        }
 744
 745                        {
 746                            let pool = pool.lock();
 747                            for canceled_user_id in canceled_calls_to_user_ids {
 748                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 749                                    peer.send(
 750                                        connection_id,
 751                                        proto::CallCanceled {
 752                                            room_id: room_id.to_proto(),
 753                                        },
 754                                    )
 755                                    .trace_err();
 756                                }
 757                            }
 758                        }
 759
 760                        for user_id in contacts_to_update {
 761                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 762                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 763                            if let Some((busy, contacts)) = busy.zip(contacts) {
 764                                let pool = pool.lock();
 765                                let updated_contact = contact_for_user(user_id, busy, &pool);
 766                                for contact in contacts {
 767                                    if let db::Contact::Accepted {
 768                                        user_id: contact_user_id,
 769                                        ..
 770                                    } = contact
 771                                    {
 772                                        for contact_conn_id in
 773                                            pool.user_connection_ids(contact_user_id)
 774                                        {
 775                                            peer.send(
 776                                                contact_conn_id,
 777                                                proto::UpdateContacts {
 778                                                    contacts: vec![updated_contact.clone()],
 779                                                    remove_contacts: Default::default(),
 780                                                    incoming_requests: Default::default(),
 781                                                    remove_incoming_requests: Default::default(),
 782                                                    outgoing_requests: Default::default(),
 783                                                    remove_outgoing_requests: Default::default(),
 784                                                },
 785                                            )
 786                                            .trace_err();
 787                                        }
 788                                    }
 789                                }
 790                            }
 791                        }
 792
 793                        if let Some(live_kit) = live_kit_client.as_ref() {
 794                            if delete_live_kit_room {
 795                                live_kit.delete_room(live_kit_room).await.trace_err();
 796                            }
 797                        }
 798                    }
 799                }
 800
 801                app_state
 802                    .db
 803                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 804                    .await
 805                    .trace_err();
 806            }
 807            .instrument(span),
 808        );
 809        Ok(())
 810    }
 811
 812    pub fn teardown(&self) {
 813        self.peer.teardown();
 814        self.connection_pool.lock().reset();
 815        let _ = self.teardown.send(true);
 816    }
 817
 818    #[cfg(test)]
 819    pub fn reset(&self, id: ServerId) {
 820        self.teardown();
 821        *self.id.lock() = id;
 822        self.peer.reset(id.0 as u32);
 823        let _ = self.teardown.send(false);
 824    }
 825
 826    #[cfg(test)]
 827    pub fn id(&self) -> ServerId {
 828        *self.id.lock()
 829    }
 830
 831    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 832    where
 833        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 834        Fut: 'static + Send + Future<Output = Result<()>>,
 835        M: EnvelopedMessage,
 836    {
 837        let prev_handler = self.handlers.insert(
 838            TypeId::of::<M>(),
 839            Box::new(move |envelope, session| {
 840                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 841                let received_at = envelope.received_at;
 842                tracing::info!("message received");
 843                let start_time = Instant::now();
 844                let future = (handler)(*envelope, session);
 845                async move {
 846                    let result = future.await;
 847                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 848                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 849                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 850                    let payload_type = M::NAME;
 851
 852                    match result {
 853                        Err(error) => {
 854                            tracing::error!(
 855                                ?error,
 856                                total_duration_ms,
 857                                processing_duration_ms,
 858                                queue_duration_ms,
 859                                payload_type,
 860                                "error handling message"
 861                            )
 862                        }
 863                        Ok(()) => tracing::info!(
 864                            total_duration_ms,
 865                            processing_duration_ms,
 866                            queue_duration_ms,
 867                            "finished handling message"
 868                        ),
 869                    }
 870                }
 871                .boxed()
 872            }),
 873        );
 874        if prev_handler.is_some() {
 875            panic!("registered a handler for the same message twice");
 876        }
 877        self
 878    }
 879
 880    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 881    where
 882        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 883        Fut: 'static + Send + Future<Output = Result<()>>,
 884        M: EnvelopedMessage,
 885    {
 886        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 887        self
 888    }
 889
 890    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 891    where
 892        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 893        Fut: Send + Future<Output = Result<()>>,
 894        M: RequestMessage,
 895    {
 896        let handler = Arc::new(handler);
 897        self.add_handler(move |envelope, session| {
 898            let receipt = envelope.receipt();
 899            let handler = handler.clone();
 900            async move {
 901                let peer = session.peer.clone();
 902                let responded = Arc::new(AtomicBool::default());
 903                let response = Response {
 904                    peer: peer.clone(),
 905                    responded: responded.clone(),
 906                    receipt,
 907                };
 908                match (handler)(envelope.payload, response, session).await {
 909                    Ok(()) => {
 910                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 911                            Ok(())
 912                        } else {
 913                            Err(anyhow!("handler did not send a response"))?
 914                        }
 915                    }
 916                    Err(error) => {
 917                        let proto_err = match &error {
 918                            Error::Internal(err) => err.to_proto(),
 919                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 920                        };
 921                        peer.respond_with_error(receipt, proto_err)?;
 922                        Err(error)
 923                    }
 924                }
 925            }
 926        })
 927    }
 928
 929    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 930    where
 931        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 932        Fut: Send + Future<Output = Result<()>>,
 933        M: RequestMessage,
 934    {
 935        let handler = Arc::new(handler);
 936        self.add_handler(move |envelope, session| {
 937            let receipt = envelope.receipt();
 938            let handler = handler.clone();
 939            async move {
 940                let peer = session.peer.clone();
 941                let response = StreamingResponse {
 942                    peer: peer.clone(),
 943                    receipt,
 944                };
 945                match (handler)(envelope.payload, response, session).await {
 946                    Ok(()) => {
 947                        peer.end_stream(receipt)?;
 948                        Ok(())
 949                    }
 950                    Err(error) => {
 951                        let proto_err = match &error {
 952                            Error::Internal(err) => err.to_proto(),
 953                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 954                        };
 955                        peer.respond_with_error(receipt, proto_err)?;
 956                        Err(error)
 957                    }
 958                }
 959            }
 960        })
 961    }
 962
 963    #[allow(clippy::too_many_arguments)]
 964    pub fn handle_connection(
 965        self: &Arc<Self>,
 966        connection: Connection,
 967        address: String,
 968        principal: Principal,
 969        zed_version: ZedVersion,
 970        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 971        executor: Executor,
 972    ) -> impl Future<Output = ()> {
 973        let this = self.clone();
 974        let span = info_span!("handle connection", %address,
 975            connection_id=field::Empty,
 976            user_id=field::Empty,
 977            login=field::Empty,
 978            impersonator=field::Empty,
 979            dev_server_id=field::Empty
 980        );
 981        principal.update_span(&span);
 982
 983        let mut teardown = self.teardown.subscribe();
 984        async move {
 985            if *teardown.borrow() {
 986                tracing::error!("server is tearing down");
 987                return
 988            }
 989            let (connection_id, handle_io, mut incoming_rx) = this
 990                .peer
 991                .add_connection(connection, {
 992                    let executor = executor.clone();
 993                    move |duration| executor.sleep(duration)
 994                });
 995            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 996            tracing::info!("connection opened");
 997
 998            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 999            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1000                Ok(http_client) => Arc::new(http_client),
1001                Err(error) => {
1002                    tracing::error!(?error, "failed to create HTTP client");
1003                    return;
1004                }
1005            };
1006
1007            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1008                Some(Arc::new(SupermavenAdminApi::new(
1009                    supermaven_admin_api_key.to_string(),
1010                    http_client.clone(),
1011                )))
1012            } else {
1013                None
1014            };
1015
1016            let session = Session {
1017                principal: principal.clone(),
1018                connection_id,
1019                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1020                peer: this.peer.clone(),
1021                connection_pool: this.connection_pool.clone(),
1022                live_kit_client: this.app_state.live_kit_client.clone(),
1023                http_client,
1024                rate_limiter: this.app_state.rate_limiter.clone(),
1025                _executor: executor.clone(),
1026                supermaven_client,
1027            };
1028
1029            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1030                tracing::error!(?error, "failed to send initial client update");
1031                return;
1032            }
1033
1034            let handle_io = handle_io.fuse();
1035            futures::pin_mut!(handle_io);
1036
1037            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1038            // This prevents deadlocks when e.g., client A performs a request to client B and
1039            // client B performs a request to client A. If both clients stop processing further
1040            // messages until their respective request completes, they won't have a chance to
1041            // respond to the other client's request and cause a deadlock.
1042            //
1043            // This arrangement ensures we will attempt to process earlier messages first, but fall
1044            // back to processing messages arrived later in the spirit of making progress.
1045            let mut foreground_message_handlers = FuturesUnordered::new();
1046            let concurrent_handlers = Arc::new(Semaphore::new(256));
1047            loop {
1048                let next_message = async {
1049                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1050                    let message = incoming_rx.next().await;
1051                    (permit, message)
1052                }.fuse();
1053                futures::pin_mut!(next_message);
1054                futures::select_biased! {
1055                    _ = teardown.changed().fuse() => return,
1056                    result = handle_io => {
1057                        if let Err(error) = result {
1058                            tracing::error!(?error, "error handling I/O");
1059                        }
1060                        break;
1061                    }
1062                    _ = foreground_message_handlers.next() => {}
1063                    next_message = next_message => {
1064                        let (permit, message) = next_message;
1065                        if let Some(message) = message {
1066                            let type_name = message.payload_type_name();
1067                            // note: we copy all the fields from the parent span so we can query them in the logs.
1068                            // (https://github.com/tokio-rs/tracing/issues/2670).
1069                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1070                                user_id=field::Empty,
1071                                login=field::Empty,
1072                                impersonator=field::Empty,
1073                                dev_server_id=field::Empty
1074                            );
1075                            principal.update_span(&span);
1076                            let span_enter = span.enter();
1077                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1078                                let is_background = message.is_background();
1079                                let handle_message = (handler)(message, session.clone());
1080                                drop(span_enter);
1081
1082                                let handle_message = async move {
1083                                    handle_message.await;
1084                                    drop(permit);
1085                                }.instrument(span);
1086                                if is_background {
1087                                    executor.spawn_detached(handle_message);
1088                                } else {
1089                                    foreground_message_handlers.push(handle_message);
1090                                }
1091                            } else {
1092                                tracing::error!("no message handler");
1093                            }
1094                        } else {
1095                            tracing::info!("connection closed");
1096                            break;
1097                        }
1098                    }
1099                }
1100            }
1101
1102            drop(foreground_message_handlers);
1103            tracing::info!("signing out");
1104            if let Err(error) = connection_lost(session, teardown, executor).await {
1105                tracing::error!(?error, "error signing out");
1106            }
1107
1108        }.instrument(span)
1109    }
1110
1111    async fn send_initial_client_update(
1112        &self,
1113        connection_id: ConnectionId,
1114        principal: &Principal,
1115        zed_version: ZedVersion,
1116        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1117        session: &Session,
1118    ) -> Result<()> {
1119        self.peer.send(
1120            connection_id,
1121            proto::Hello {
1122                peer_id: Some(connection_id.into()),
1123            },
1124        )?;
1125        tracing::info!("sent hello message");
1126        if let Some(send_connection_id) = send_connection_id.take() {
1127            let _ = send_connection_id.send(connection_id);
1128        }
1129
1130        match principal {
1131            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1132                if !user.connected_once {
1133                    self.peer.send(connection_id, proto::ShowContacts {})?;
1134                    self.app_state
1135                        .db
1136                        .set_user_connected_once(user.id, true)
1137                        .await?;
1138                }
1139
1140                update_user_plan(user.id, session).await?;
1141
1142                let (contacts, dev_server_projects) = future::try_join(
1143                    self.app_state.db.get_contacts(user.id),
1144                    self.app_state.db.dev_server_projects_update(user.id),
1145                )
1146                .await?;
1147
1148                {
1149                    let mut pool = self.connection_pool.lock();
1150                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1151                    self.peer.send(
1152                        connection_id,
1153                        build_initial_contacts_update(contacts, &pool),
1154                    )?;
1155                }
1156
1157                if should_auto_subscribe_to_channels(zed_version) {
1158                    subscribe_user_to_channels(user.id, session).await?;
1159                }
1160
1161                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1162
1163                if let Some(incoming_call) =
1164                    self.app_state.db.incoming_call_for_user(user.id).await?
1165                {
1166                    self.peer.send(connection_id, incoming_call)?;
1167                }
1168
1169                update_user_contacts(user.id, &session).await?;
1170            }
1171            Principal::DevServer(dev_server) => {
1172                {
1173                    let mut pool = self.connection_pool.lock();
1174                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1175                    {
1176                        self.peer.send(
1177                            stale_connection_id,
1178                            proto::ShutdownDevServer {
1179                                reason: Some(
1180                                    "another dev server connected with the same token".to_string(),
1181                                ),
1182                            },
1183                        )?;
1184                        pool.remove_connection(stale_connection_id)?;
1185                    };
1186                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1187                }
1188
1189                let projects = self
1190                    .app_state
1191                    .db
1192                    .get_projects_for_dev_server(dev_server.id)
1193                    .await?;
1194                self.peer
1195                    .send(connection_id, proto::DevServerInstructions { projects })?;
1196
1197                let status = self
1198                    .app_state
1199                    .db
1200                    .dev_server_projects_update(dev_server.user_id)
1201                    .await?;
1202                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1203            }
1204        }
1205
1206        Ok(())
1207    }
1208
1209    pub async fn invite_code_redeemed(
1210        self: &Arc<Self>,
1211        inviter_id: UserId,
1212        invitee_id: UserId,
1213    ) -> Result<()> {
1214        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1215            if let Some(code) = &user.invite_code {
1216                let pool = self.connection_pool.lock();
1217                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1218                for connection_id in pool.user_connection_ids(inviter_id) {
1219                    self.peer.send(
1220                        connection_id,
1221                        proto::UpdateContacts {
1222                            contacts: vec![invitee_contact.clone()],
1223                            ..Default::default()
1224                        },
1225                    )?;
1226                    self.peer.send(
1227                        connection_id,
1228                        proto::UpdateInviteInfo {
1229                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1230                            count: user.invite_count as u32,
1231                        },
1232                    )?;
1233                }
1234            }
1235        }
1236        Ok(())
1237    }
1238
1239    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1240        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1241            if let Some(invite_code) = &user.invite_code {
1242                let pool = self.connection_pool.lock();
1243                for connection_id in pool.user_connection_ids(user_id) {
1244                    self.peer.send(
1245                        connection_id,
1246                        proto::UpdateInviteInfo {
1247                            url: format!(
1248                                "{}{}",
1249                                self.app_state.config.invite_link_prefix, invite_code
1250                            ),
1251                            count: user.invite_count as u32,
1252                        },
1253                    )?;
1254                }
1255            }
1256        }
1257        Ok(())
1258    }
1259
1260    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1261        ServerSnapshot {
1262            connection_pool: ConnectionPoolGuard {
1263                guard: self.connection_pool.lock(),
1264                _not_send: PhantomData,
1265            },
1266            peer: &self.peer,
1267        }
1268    }
1269}
1270
1271impl<'a> Deref for ConnectionPoolGuard<'a> {
1272    type Target = ConnectionPool;
1273
1274    fn deref(&self) -> &Self::Target {
1275        &self.guard
1276    }
1277}
1278
1279impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1280    fn deref_mut(&mut self) -> &mut Self::Target {
1281        &mut self.guard
1282    }
1283}
1284
1285impl<'a> Drop for ConnectionPoolGuard<'a> {
1286    fn drop(&mut self) {
1287        #[cfg(test)]
1288        self.check_invariants();
1289    }
1290}
1291
1292fn broadcast<F>(
1293    sender_id: Option<ConnectionId>,
1294    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1295    mut f: F,
1296) where
1297    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1298{
1299    for receiver_id in receiver_ids {
1300        if Some(receiver_id) != sender_id {
1301            if let Err(error) = f(receiver_id) {
1302                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1303            }
1304        }
1305    }
1306}
1307
1308pub struct ProtocolVersion(u32);
1309
1310impl Header for ProtocolVersion {
1311    fn name() -> &'static HeaderName {
1312        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1313        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1314    }
1315
1316    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1317    where
1318        Self: Sized,
1319        I: Iterator<Item = &'i axum::http::HeaderValue>,
1320    {
1321        let version = values
1322            .next()
1323            .ok_or_else(axum::headers::Error::invalid)?
1324            .to_str()
1325            .map_err(|_| axum::headers::Error::invalid())?
1326            .parse()
1327            .map_err(|_| axum::headers::Error::invalid())?;
1328        Ok(Self(version))
1329    }
1330
1331    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1332        values.extend([self.0.to_string().parse().unwrap()]);
1333    }
1334}
1335
1336pub struct AppVersionHeader(SemanticVersion);
1337impl Header for AppVersionHeader {
1338    fn name() -> &'static HeaderName {
1339        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1340        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1341    }
1342
1343    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1344    where
1345        Self: Sized,
1346        I: Iterator<Item = &'i axum::http::HeaderValue>,
1347    {
1348        let version = values
1349            .next()
1350            .ok_or_else(axum::headers::Error::invalid)?
1351            .to_str()
1352            .map_err(|_| axum::headers::Error::invalid())?
1353            .parse()
1354            .map_err(|_| axum::headers::Error::invalid())?;
1355        Ok(Self(version))
1356    }
1357
1358    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1359        values.extend([self.0.to_string().parse().unwrap()]);
1360    }
1361}
1362
1363pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1364    Router::new()
1365        .route("/rpc", get(handle_websocket_request))
1366        .layer(
1367            ServiceBuilder::new()
1368                .layer(Extension(server.app_state.clone()))
1369                .layer(middleware::from_fn(auth::validate_header)),
1370        )
1371        .route("/metrics", get(handle_metrics))
1372        .layer(Extension(server))
1373}
1374
1375pub async fn handle_websocket_request(
1376    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1377    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1378    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1379    Extension(server): Extension<Arc<Server>>,
1380    Extension(principal): Extension<Principal>,
1381    ws: WebSocketUpgrade,
1382) -> axum::response::Response {
1383    if protocol_version != rpc::PROTOCOL_VERSION {
1384        return (
1385            StatusCode::UPGRADE_REQUIRED,
1386            "client must be upgraded".to_string(),
1387        )
1388            .into_response();
1389    }
1390
1391    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1392        return (
1393            StatusCode::UPGRADE_REQUIRED,
1394            "no version header found".to_string(),
1395        )
1396            .into_response();
1397    };
1398
1399    if !version.can_collaborate() {
1400        return (
1401            StatusCode::UPGRADE_REQUIRED,
1402            "client must be upgraded".to_string(),
1403        )
1404            .into_response();
1405    }
1406
1407    let socket_address = socket_address.to_string();
1408    ws.on_upgrade(move |socket| {
1409        let socket = socket
1410            .map_ok(to_tungstenite_message)
1411            .err_into()
1412            .with(|message| async move { to_axum_message(message) });
1413        let connection = Connection::new(Box::pin(socket));
1414        async move {
1415            server
1416                .handle_connection(
1417                    connection,
1418                    socket_address,
1419                    principal,
1420                    version,
1421                    None,
1422                    Executor::Production,
1423                )
1424                .await;
1425        }
1426    })
1427}
1428
1429pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1430    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1431    let connections_metric = CONNECTIONS_METRIC
1432        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1433
1434    let connections = server
1435        .connection_pool
1436        .lock()
1437        .connections()
1438        .filter(|connection| !connection.admin)
1439        .count();
1440    connections_metric.set(connections as _);
1441
1442    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1443    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1444        register_int_gauge!(
1445            "shared_projects",
1446            "number of open projects with one or more guests"
1447        )
1448        .unwrap()
1449    });
1450
1451    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1452    shared_projects_metric.set(shared_projects as _);
1453
1454    let encoder = prometheus::TextEncoder::new();
1455    let metric_families = prometheus::gather();
1456    let encoded_metrics = encoder
1457        .encode_to_string(&metric_families)
1458        .map_err(|err| anyhow!("{}", err))?;
1459    Ok(encoded_metrics)
1460}
1461
1462#[instrument(err, skip(executor))]
1463async fn connection_lost(
1464    session: Session,
1465    mut teardown: watch::Receiver<bool>,
1466    executor: Executor,
1467) -> Result<()> {
1468    session.peer.disconnect(session.connection_id);
1469    session
1470        .connection_pool()
1471        .await
1472        .remove_connection(session.connection_id)?;
1473
1474    session
1475        .db()
1476        .await
1477        .connection_lost(session.connection_id)
1478        .await
1479        .trace_err();
1480
1481    futures::select_biased! {
1482        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1483            match &session.principal {
1484                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1485                    let session = session.for_user().unwrap();
1486
1487                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1488                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1489                    leave_channel_buffers_for_session(&session)
1490                        .await
1491                        .trace_err();
1492
1493                    if !session
1494                        .connection_pool()
1495                        .await
1496                        .is_user_online(session.user_id())
1497                    {
1498                        let db = session.db().await;
1499                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1500                            room_updated(&room, &session.peer);
1501                        }
1502                    }
1503
1504                    update_user_contacts(session.user_id(), &session).await?;
1505                },
1506            Principal::DevServer(_) => {
1507                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1508            },
1509        }
1510        },
1511        _ = teardown.changed().fuse() => {}
1512    }
1513
1514    Ok(())
1515}
1516
1517/// Acknowledges a ping from a client, used to keep the connection alive.
1518async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1519    response.send(proto::Ack {})?;
1520    Ok(())
1521}
1522
1523/// Creates a new room for calling (outside of channels)
1524async fn create_room(
1525    _request: proto::CreateRoom,
1526    response: Response<proto::CreateRoom>,
1527    session: UserSession,
1528) -> Result<()> {
1529    let live_kit_room = nanoid::nanoid!(30);
1530
1531    let live_kit_connection_info = util::maybe!(async {
1532        let live_kit = session.live_kit_client.as_ref();
1533        let live_kit = live_kit?;
1534        let user_id = session.user_id().to_string();
1535
1536        let token = live_kit
1537            .room_token(&live_kit_room, &user_id.to_string())
1538            .trace_err()?;
1539
1540        Some(proto::LiveKitConnectionInfo {
1541            server_url: live_kit.url().into(),
1542            token,
1543            can_publish: true,
1544        })
1545    })
1546    .await;
1547
1548    let room = session
1549        .db()
1550        .await
1551        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1552        .await?;
1553
1554    response.send(proto::CreateRoomResponse {
1555        room: Some(room.clone()),
1556        live_kit_connection_info,
1557    })?;
1558
1559    update_user_contacts(session.user_id(), &session).await?;
1560    Ok(())
1561}
1562
1563/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1564async fn join_room(
1565    request: proto::JoinRoom,
1566    response: Response<proto::JoinRoom>,
1567    session: UserSession,
1568) -> Result<()> {
1569    let room_id = RoomId::from_proto(request.id);
1570
1571    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1572
1573    if let Some(channel_id) = channel_id {
1574        return join_channel_internal(channel_id, Box::new(response), session).await;
1575    }
1576
1577    let joined_room = {
1578        let room = session
1579            .db()
1580            .await
1581            .join_room(room_id, session.user_id(), session.connection_id)
1582            .await?;
1583        room_updated(&room.room, &session.peer);
1584        room.into_inner()
1585    };
1586
1587    for connection_id in session
1588        .connection_pool()
1589        .await
1590        .user_connection_ids(session.user_id())
1591    {
1592        session
1593            .peer
1594            .send(
1595                connection_id,
1596                proto::CallCanceled {
1597                    room_id: room_id.to_proto(),
1598                },
1599            )
1600            .trace_err();
1601    }
1602
1603    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1604        if let Some(token) = live_kit
1605            .room_token(
1606                &joined_room.room.live_kit_room,
1607                &session.user_id().to_string(),
1608            )
1609            .trace_err()
1610        {
1611            Some(proto::LiveKitConnectionInfo {
1612                server_url: live_kit.url().into(),
1613                token,
1614                can_publish: true,
1615            })
1616        } else {
1617            None
1618        }
1619    } else {
1620        None
1621    };
1622
1623    response.send(proto::JoinRoomResponse {
1624        room: Some(joined_room.room),
1625        channel_id: None,
1626        live_kit_connection_info,
1627    })?;
1628
1629    update_user_contacts(session.user_id(), &session).await?;
1630    Ok(())
1631}
1632
1633/// Rejoin room is used to reconnect to a room after connection errors.
1634async fn rejoin_room(
1635    request: proto::RejoinRoom,
1636    response: Response<proto::RejoinRoom>,
1637    session: UserSession,
1638) -> Result<()> {
1639    let room;
1640    let channel;
1641    {
1642        let mut rejoined_room = session
1643            .db()
1644            .await
1645            .rejoin_room(request, session.user_id(), session.connection_id)
1646            .await?;
1647
1648        response.send(proto::RejoinRoomResponse {
1649            room: Some(rejoined_room.room.clone()),
1650            reshared_projects: rejoined_room
1651                .reshared_projects
1652                .iter()
1653                .map(|project| proto::ResharedProject {
1654                    id: project.id.to_proto(),
1655                    collaborators: project
1656                        .collaborators
1657                        .iter()
1658                        .map(|collaborator| collaborator.to_proto())
1659                        .collect(),
1660                })
1661                .collect(),
1662            rejoined_projects: rejoined_room
1663                .rejoined_projects
1664                .iter()
1665                .map(|rejoined_project| rejoined_project.to_proto())
1666                .collect(),
1667        })?;
1668        room_updated(&rejoined_room.room, &session.peer);
1669
1670        for project in &rejoined_room.reshared_projects {
1671            for collaborator in &project.collaborators {
1672                session
1673                    .peer
1674                    .send(
1675                        collaborator.connection_id,
1676                        proto::UpdateProjectCollaborator {
1677                            project_id: project.id.to_proto(),
1678                            old_peer_id: Some(project.old_connection_id.into()),
1679                            new_peer_id: Some(session.connection_id.into()),
1680                        },
1681                    )
1682                    .trace_err();
1683            }
1684
1685            broadcast(
1686                Some(session.connection_id),
1687                project
1688                    .collaborators
1689                    .iter()
1690                    .map(|collaborator| collaborator.connection_id),
1691                |connection_id| {
1692                    session.peer.forward_send(
1693                        session.connection_id,
1694                        connection_id,
1695                        proto::UpdateProject {
1696                            project_id: project.id.to_proto(),
1697                            worktrees: project.worktrees.clone(),
1698                        },
1699                    )
1700                },
1701            );
1702        }
1703
1704        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1705
1706        let rejoined_room = rejoined_room.into_inner();
1707
1708        room = rejoined_room.room;
1709        channel = rejoined_room.channel;
1710    }
1711
1712    if let Some(channel) = channel {
1713        channel_updated(
1714            &channel,
1715            &room,
1716            &session.peer,
1717            &*session.connection_pool().await,
1718        );
1719    }
1720
1721    update_user_contacts(session.user_id(), &session).await?;
1722    Ok(())
1723}
1724
1725fn notify_rejoined_projects(
1726    rejoined_projects: &mut Vec<RejoinedProject>,
1727    session: &UserSession,
1728) -> Result<()> {
1729    for project in rejoined_projects.iter() {
1730        for collaborator in &project.collaborators {
1731            session
1732                .peer
1733                .send(
1734                    collaborator.connection_id,
1735                    proto::UpdateProjectCollaborator {
1736                        project_id: project.id.to_proto(),
1737                        old_peer_id: Some(project.old_connection_id.into()),
1738                        new_peer_id: Some(session.connection_id.into()),
1739                    },
1740                )
1741                .trace_err();
1742        }
1743    }
1744
1745    for project in rejoined_projects {
1746        for worktree in mem::take(&mut project.worktrees) {
1747            #[cfg(any(test, feature = "test-support"))]
1748            const MAX_CHUNK_SIZE: usize = 2;
1749            #[cfg(not(any(test, feature = "test-support")))]
1750            const MAX_CHUNK_SIZE: usize = 256;
1751
1752            // Stream this worktree's entries.
1753            let message = proto::UpdateWorktree {
1754                project_id: project.id.to_proto(),
1755                worktree_id: worktree.id,
1756                abs_path: worktree.abs_path.clone(),
1757                root_name: worktree.root_name,
1758                updated_entries: worktree.updated_entries,
1759                removed_entries: worktree.removed_entries,
1760                scan_id: worktree.scan_id,
1761                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1762                updated_repositories: worktree.updated_repositories,
1763                removed_repositories: worktree.removed_repositories,
1764            };
1765            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1766                session.peer.send(session.connection_id, update.clone())?;
1767            }
1768
1769            // Stream this worktree's diagnostics.
1770            for summary in worktree.diagnostic_summaries {
1771                session.peer.send(
1772                    session.connection_id,
1773                    proto::UpdateDiagnosticSummary {
1774                        project_id: project.id.to_proto(),
1775                        worktree_id: worktree.id,
1776                        summary: Some(summary),
1777                    },
1778                )?;
1779            }
1780
1781            for settings_file in worktree.settings_files {
1782                session.peer.send(
1783                    session.connection_id,
1784                    proto::UpdateWorktreeSettings {
1785                        project_id: project.id.to_proto(),
1786                        worktree_id: worktree.id,
1787                        path: settings_file.path,
1788                        content: Some(settings_file.content),
1789                    },
1790                )?;
1791            }
1792        }
1793
1794        for language_server in &project.language_servers {
1795            session.peer.send(
1796                session.connection_id,
1797                proto::UpdateLanguageServer {
1798                    project_id: project.id.to_proto(),
1799                    language_server_id: language_server.id,
1800                    variant: Some(
1801                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1802                            proto::LspDiskBasedDiagnosticsUpdated {},
1803                        ),
1804                    ),
1805                },
1806            )?;
1807        }
1808    }
1809    Ok(())
1810}
1811
1812/// leave room disconnects from the room.
1813async fn leave_room(
1814    _: proto::LeaveRoom,
1815    response: Response<proto::LeaveRoom>,
1816    session: UserSession,
1817) -> Result<()> {
1818    leave_room_for_session(&session, session.connection_id).await?;
1819    response.send(proto::Ack {})?;
1820    Ok(())
1821}
1822
1823/// Updates the permissions of someone else in the room.
1824async fn set_room_participant_role(
1825    request: proto::SetRoomParticipantRole,
1826    response: Response<proto::SetRoomParticipantRole>,
1827    session: UserSession,
1828) -> Result<()> {
1829    let user_id = UserId::from_proto(request.user_id);
1830    let role = ChannelRole::from(request.role());
1831
1832    let (live_kit_room, can_publish) = {
1833        let room = session
1834            .db()
1835            .await
1836            .set_room_participant_role(
1837                session.user_id(),
1838                RoomId::from_proto(request.room_id),
1839                user_id,
1840                role,
1841            )
1842            .await?;
1843
1844        let live_kit_room = room.live_kit_room.clone();
1845        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1846        room_updated(&room, &session.peer);
1847        (live_kit_room, can_publish)
1848    };
1849
1850    if let Some(live_kit) = session.live_kit_client.as_ref() {
1851        live_kit
1852            .update_participant(
1853                live_kit_room.clone(),
1854                request.user_id.to_string(),
1855                live_kit_server::proto::ParticipantPermission {
1856                    can_subscribe: true,
1857                    can_publish,
1858                    can_publish_data: can_publish,
1859                    hidden: false,
1860                    recorder: false,
1861                },
1862            )
1863            .await
1864            .trace_err();
1865    }
1866
1867    response.send(proto::Ack {})?;
1868    Ok(())
1869}
1870
1871/// Call someone else into the current room
1872async fn call(
1873    request: proto::Call,
1874    response: Response<proto::Call>,
1875    session: UserSession,
1876) -> Result<()> {
1877    let room_id = RoomId::from_proto(request.room_id);
1878    let calling_user_id = session.user_id();
1879    let calling_connection_id = session.connection_id;
1880    let called_user_id = UserId::from_proto(request.called_user_id);
1881    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1882    if !session
1883        .db()
1884        .await
1885        .has_contact(calling_user_id, called_user_id)
1886        .await?
1887    {
1888        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1889    }
1890
1891    let incoming_call = {
1892        let (room, incoming_call) = &mut *session
1893            .db()
1894            .await
1895            .call(
1896                room_id,
1897                calling_user_id,
1898                calling_connection_id,
1899                called_user_id,
1900                initial_project_id,
1901            )
1902            .await?;
1903        room_updated(&room, &session.peer);
1904        mem::take(incoming_call)
1905    };
1906    update_user_contacts(called_user_id, &session).await?;
1907
1908    let mut calls = session
1909        .connection_pool()
1910        .await
1911        .user_connection_ids(called_user_id)
1912        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1913        .collect::<FuturesUnordered<_>>();
1914
1915    while let Some(call_response) = calls.next().await {
1916        match call_response.as_ref() {
1917            Ok(_) => {
1918                response.send(proto::Ack {})?;
1919                return Ok(());
1920            }
1921            Err(_) => {
1922                call_response.trace_err();
1923            }
1924        }
1925    }
1926
1927    {
1928        let room = session
1929            .db()
1930            .await
1931            .call_failed(room_id, called_user_id)
1932            .await?;
1933        room_updated(&room, &session.peer);
1934    }
1935    update_user_contacts(called_user_id, &session).await?;
1936
1937    Err(anyhow!("failed to ring user"))?
1938}
1939
1940/// Cancel an outgoing call.
1941async fn cancel_call(
1942    request: proto::CancelCall,
1943    response: Response<proto::CancelCall>,
1944    session: UserSession,
1945) -> Result<()> {
1946    let called_user_id = UserId::from_proto(request.called_user_id);
1947    let room_id = RoomId::from_proto(request.room_id);
1948    {
1949        let room = session
1950            .db()
1951            .await
1952            .cancel_call(room_id, session.connection_id, called_user_id)
1953            .await?;
1954        room_updated(&room, &session.peer);
1955    }
1956
1957    for connection_id in session
1958        .connection_pool()
1959        .await
1960        .user_connection_ids(called_user_id)
1961    {
1962        session
1963            .peer
1964            .send(
1965                connection_id,
1966                proto::CallCanceled {
1967                    room_id: room_id.to_proto(),
1968                },
1969            )
1970            .trace_err();
1971    }
1972    response.send(proto::Ack {})?;
1973
1974    update_user_contacts(called_user_id, &session).await?;
1975    Ok(())
1976}
1977
1978/// Decline an incoming call.
1979async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1980    let room_id = RoomId::from_proto(message.room_id);
1981    {
1982        let room = session
1983            .db()
1984            .await
1985            .decline_call(Some(room_id), session.user_id())
1986            .await?
1987            .ok_or_else(|| anyhow!("failed to decline call"))?;
1988        room_updated(&room, &session.peer);
1989    }
1990
1991    for connection_id in session
1992        .connection_pool()
1993        .await
1994        .user_connection_ids(session.user_id())
1995    {
1996        session
1997            .peer
1998            .send(
1999                connection_id,
2000                proto::CallCanceled {
2001                    room_id: room_id.to_proto(),
2002                },
2003            )
2004            .trace_err();
2005    }
2006    update_user_contacts(session.user_id(), &session).await?;
2007    Ok(())
2008}
2009
2010/// Updates other participants in the room with your current location.
2011async fn update_participant_location(
2012    request: proto::UpdateParticipantLocation,
2013    response: Response<proto::UpdateParticipantLocation>,
2014    session: UserSession,
2015) -> Result<()> {
2016    let room_id = RoomId::from_proto(request.room_id);
2017    let location = request
2018        .location
2019        .ok_or_else(|| anyhow!("invalid location"))?;
2020
2021    let db = session.db().await;
2022    let room = db
2023        .update_room_participant_location(room_id, session.connection_id, location)
2024        .await?;
2025
2026    room_updated(&room, &session.peer);
2027    response.send(proto::Ack {})?;
2028    Ok(())
2029}
2030
2031/// Share a project into the room.
2032async fn share_project(
2033    request: proto::ShareProject,
2034    response: Response<proto::ShareProject>,
2035    session: UserSession,
2036) -> Result<()> {
2037    let (project_id, room) = &*session
2038        .db()
2039        .await
2040        .share_project(
2041            RoomId::from_proto(request.room_id),
2042            session.connection_id,
2043            &request.worktrees,
2044            request
2045                .dev_server_project_id
2046                .map(|id| DevServerProjectId::from_proto(id)),
2047        )
2048        .await?;
2049    response.send(proto::ShareProjectResponse {
2050        project_id: project_id.to_proto(),
2051    })?;
2052    room_updated(&room, &session.peer);
2053
2054    Ok(())
2055}
2056
2057/// Unshare a project from the room.
2058async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2059    let project_id = ProjectId::from_proto(message.project_id);
2060    unshare_project_internal(
2061        project_id,
2062        session.connection_id,
2063        session.user_id(),
2064        &session,
2065    )
2066    .await
2067}
2068
2069async fn unshare_project_internal(
2070    project_id: ProjectId,
2071    connection_id: ConnectionId,
2072    user_id: Option<UserId>,
2073    session: &Session,
2074) -> Result<()> {
2075    let delete = {
2076        let room_guard = session
2077            .db()
2078            .await
2079            .unshare_project(project_id, connection_id, user_id)
2080            .await?;
2081
2082        let (delete, room, guest_connection_ids) = &*room_guard;
2083
2084        let message = proto::UnshareProject {
2085            project_id: project_id.to_proto(),
2086        };
2087
2088        broadcast(
2089            Some(connection_id),
2090            guest_connection_ids.iter().copied(),
2091            |conn_id| session.peer.send(conn_id, message.clone()),
2092        );
2093        if let Some(room) = room {
2094            room_updated(room, &session.peer);
2095        }
2096
2097        *delete
2098    };
2099
2100    if delete {
2101        let db = session.db().await;
2102        db.delete_project(project_id).await?;
2103    }
2104
2105    Ok(())
2106}
2107
2108/// DevServer makes a project available online
2109async fn share_dev_server_project(
2110    request: proto::ShareDevServerProject,
2111    response: Response<proto::ShareDevServerProject>,
2112    session: DevServerSession,
2113) -> Result<()> {
2114    let (dev_server_project, user_id, status) = session
2115        .db()
2116        .await
2117        .share_dev_server_project(
2118            DevServerProjectId::from_proto(request.dev_server_project_id),
2119            session.dev_server_id(),
2120            session.connection_id,
2121            &request.worktrees,
2122        )
2123        .await?;
2124    let Some(project_id) = dev_server_project.project_id else {
2125        return Err(anyhow!("failed to share remote project"))?;
2126    };
2127
2128    send_dev_server_projects_update(user_id, status, &session).await;
2129
2130    response.send(proto::ShareProjectResponse { project_id })?;
2131
2132    Ok(())
2133}
2134
2135/// Join someone elses shared project.
2136async fn join_project(
2137    request: proto::JoinProject,
2138    response: Response<proto::JoinProject>,
2139    session: UserSession,
2140) -> Result<()> {
2141    let project_id = ProjectId::from_proto(request.project_id);
2142
2143    tracing::info!(%project_id, "join project");
2144
2145    let db = session.db().await;
2146    let (project, replica_id) = &mut *db
2147        .join_project(project_id, session.connection_id, session.user_id())
2148        .await?;
2149    drop(db);
2150    tracing::info!(%project_id, "join remote project");
2151    join_project_internal(response, session, project, replica_id)
2152}
2153
2154trait JoinProjectInternalResponse {
2155    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2156}
2157impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2158    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2159        Response::<proto::JoinProject>::send(self, result)
2160    }
2161}
2162impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2163    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2164        Response::<proto::JoinHostedProject>::send(self, result)
2165    }
2166}
2167
2168fn join_project_internal(
2169    response: impl JoinProjectInternalResponse,
2170    session: UserSession,
2171    project: &mut Project,
2172    replica_id: &ReplicaId,
2173) -> Result<()> {
2174    let collaborators = project
2175        .collaborators
2176        .iter()
2177        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2178        .map(|collaborator| collaborator.to_proto())
2179        .collect::<Vec<_>>();
2180    let project_id = project.id;
2181    let guest_user_id = session.user_id();
2182
2183    let worktrees = project
2184        .worktrees
2185        .iter()
2186        .map(|(id, worktree)| proto::WorktreeMetadata {
2187            id: *id,
2188            root_name: worktree.root_name.clone(),
2189            visible: worktree.visible,
2190            abs_path: worktree.abs_path.clone(),
2191        })
2192        .collect::<Vec<_>>();
2193
2194    let add_project_collaborator = proto::AddProjectCollaborator {
2195        project_id: project_id.to_proto(),
2196        collaborator: Some(proto::Collaborator {
2197            peer_id: Some(session.connection_id.into()),
2198            replica_id: replica_id.0 as u32,
2199            user_id: guest_user_id.to_proto(),
2200        }),
2201    };
2202
2203    for collaborator in &collaborators {
2204        session
2205            .peer
2206            .send(
2207                collaborator.peer_id.unwrap().into(),
2208                add_project_collaborator.clone(),
2209            )
2210            .trace_err();
2211    }
2212
2213    // First, we send the metadata associated with each worktree.
2214    response.send(proto::JoinProjectResponse {
2215        project_id: project.id.0 as u64,
2216        worktrees: worktrees.clone(),
2217        replica_id: replica_id.0 as u32,
2218        collaborators: collaborators.clone(),
2219        language_servers: project.language_servers.clone(),
2220        role: project.role.into(),
2221        dev_server_project_id: project
2222            .dev_server_project_id
2223            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2224    })?;
2225
2226    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2227        #[cfg(any(test, feature = "test-support"))]
2228        const MAX_CHUNK_SIZE: usize = 2;
2229        #[cfg(not(any(test, feature = "test-support")))]
2230        const MAX_CHUNK_SIZE: usize = 256;
2231
2232        // Stream this worktree's entries.
2233        let message = proto::UpdateWorktree {
2234            project_id: project_id.to_proto(),
2235            worktree_id,
2236            abs_path: worktree.abs_path.clone(),
2237            root_name: worktree.root_name,
2238            updated_entries: worktree.entries,
2239            removed_entries: Default::default(),
2240            scan_id: worktree.scan_id,
2241            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2242            updated_repositories: worktree.repository_entries.into_values().collect(),
2243            removed_repositories: Default::default(),
2244        };
2245        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2246            session.peer.send(session.connection_id, update.clone())?;
2247        }
2248
2249        // Stream this worktree's diagnostics.
2250        for summary in worktree.diagnostic_summaries {
2251            session.peer.send(
2252                session.connection_id,
2253                proto::UpdateDiagnosticSummary {
2254                    project_id: project_id.to_proto(),
2255                    worktree_id: worktree.id,
2256                    summary: Some(summary),
2257                },
2258            )?;
2259        }
2260
2261        for settings_file in worktree.settings_files {
2262            session.peer.send(
2263                session.connection_id,
2264                proto::UpdateWorktreeSettings {
2265                    project_id: project_id.to_proto(),
2266                    worktree_id: worktree.id,
2267                    path: settings_file.path,
2268                    content: Some(settings_file.content),
2269                },
2270            )?;
2271        }
2272    }
2273
2274    for language_server in &project.language_servers {
2275        session.peer.send(
2276            session.connection_id,
2277            proto::UpdateLanguageServer {
2278                project_id: project_id.to_proto(),
2279                language_server_id: language_server.id,
2280                variant: Some(
2281                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2282                        proto::LspDiskBasedDiagnosticsUpdated {},
2283                    ),
2284                ),
2285            },
2286        )?;
2287    }
2288
2289    Ok(())
2290}
2291
2292/// Leave someone elses shared project.
2293async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2294    let sender_id = session.connection_id;
2295    let project_id = ProjectId::from_proto(request.project_id);
2296    let db = session.db().await;
2297    if db.is_hosted_project(project_id).await? {
2298        let project = db.leave_hosted_project(project_id, sender_id).await?;
2299        project_left(&project, &session);
2300        return Ok(());
2301    }
2302
2303    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2304    tracing::info!(
2305        %project_id,
2306        "leave project"
2307    );
2308
2309    project_left(&project, &session);
2310    if let Some(room) = room {
2311        room_updated(&room, &session.peer);
2312    }
2313
2314    Ok(())
2315}
2316
2317async fn join_hosted_project(
2318    request: proto::JoinHostedProject,
2319    response: Response<proto::JoinHostedProject>,
2320    session: UserSession,
2321) -> Result<()> {
2322    let (mut project, replica_id) = session
2323        .db()
2324        .await
2325        .join_hosted_project(
2326            ProjectId(request.project_id as i32),
2327            session.user_id(),
2328            session.connection_id,
2329        )
2330        .await?;
2331
2332    join_project_internal(response, session, &mut project, &replica_id)
2333}
2334
2335async fn list_remote_directory(
2336    request: proto::ListRemoteDirectory,
2337    response: Response<proto::ListRemoteDirectory>,
2338    session: UserSession,
2339) -> Result<()> {
2340    let dev_server_id = DevServerId(request.dev_server_id as i32);
2341    let dev_server_connection_id = session
2342        .connection_pool()
2343        .await
2344        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2345
2346    session
2347        .db()
2348        .await
2349        .get_dev_server_for_user(dev_server_id, session.user_id())
2350        .await?;
2351
2352    response.send(
2353        session
2354            .peer
2355            .forward_request(session.connection_id, dev_server_connection_id, request)
2356            .await?,
2357    )?;
2358    Ok(())
2359}
2360
2361async fn update_dev_server_project(
2362    request: proto::UpdateDevServerProject,
2363    response: Response<proto::UpdateDevServerProject>,
2364    session: UserSession,
2365) -> Result<()> {
2366    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2367
2368    let (dev_server_project, update) = session
2369        .db()
2370        .await
2371        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2372        .await?;
2373
2374    let projects = session
2375        .db()
2376        .await
2377        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2378        .await?;
2379
2380    let dev_server_connection_id = session
2381        .connection_pool()
2382        .await
2383        .dev_server_connection_id_supporting(
2384            dev_server_project.dev_server_id,
2385            ZedVersion::with_list_directory(),
2386        )?;
2387
2388    session.peer.send(
2389        dev_server_connection_id,
2390        proto::DevServerInstructions { projects },
2391    )?;
2392
2393    send_dev_server_projects_update(session.user_id(), update, &session).await;
2394
2395    response.send(proto::Ack {})
2396}
2397
2398async fn create_dev_server_project(
2399    request: proto::CreateDevServerProject,
2400    response: Response<proto::CreateDevServerProject>,
2401    session: UserSession,
2402) -> Result<()> {
2403    let dev_server_id = DevServerId(request.dev_server_id as i32);
2404    let dev_server_connection_id = session
2405        .connection_pool()
2406        .await
2407        .dev_server_connection_id(dev_server_id);
2408    let Some(dev_server_connection_id) = dev_server_connection_id else {
2409        Err(ErrorCode::DevServerOffline
2410            .message("Cannot create a remote project when the dev server is offline".to_string())
2411            .anyhow())?
2412    };
2413
2414    let path = request.path.clone();
2415    //Check that the path exists on the dev server
2416    session
2417        .peer
2418        .forward_request(
2419            session.connection_id,
2420            dev_server_connection_id,
2421            proto::ValidateDevServerProjectRequest { path: path.clone() },
2422        )
2423        .await?;
2424
2425    let (dev_server_project, update) = session
2426        .db()
2427        .await
2428        .create_dev_server_project(
2429            DevServerId(request.dev_server_id as i32),
2430            &request.path,
2431            session.user_id(),
2432        )
2433        .await?;
2434
2435    let projects = session
2436        .db()
2437        .await
2438        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2439        .await?;
2440
2441    session.peer.send(
2442        dev_server_connection_id,
2443        proto::DevServerInstructions { projects },
2444    )?;
2445
2446    send_dev_server_projects_update(session.user_id(), update, &session).await;
2447
2448    response.send(proto::CreateDevServerProjectResponse {
2449        dev_server_project: Some(dev_server_project.to_proto(None)),
2450    })?;
2451    Ok(())
2452}
2453
2454async fn create_dev_server(
2455    request: proto::CreateDevServer,
2456    response: Response<proto::CreateDevServer>,
2457    session: UserSession,
2458) -> Result<()> {
2459    let access_token = auth::random_token();
2460    let hashed_access_token = auth::hash_access_token(&access_token);
2461
2462    if request.name.is_empty() {
2463        return Err(proto::ErrorCode::Forbidden
2464            .message("Dev server name cannot be empty".to_string())
2465            .anyhow())?;
2466    }
2467
2468    let (dev_server, status) = session
2469        .db()
2470        .await
2471        .create_dev_server(
2472            &request.name,
2473            request.ssh_connection_string.as_deref(),
2474            &hashed_access_token,
2475            session.user_id(),
2476        )
2477        .await?;
2478
2479    send_dev_server_projects_update(session.user_id(), status, &session).await;
2480
2481    response.send(proto::CreateDevServerResponse {
2482        dev_server_id: dev_server.id.0 as u64,
2483        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2484        name: request.name,
2485    })?;
2486    Ok(())
2487}
2488
2489async fn regenerate_dev_server_token(
2490    request: proto::RegenerateDevServerToken,
2491    response: Response<proto::RegenerateDevServerToken>,
2492    session: UserSession,
2493) -> Result<()> {
2494    let dev_server_id = DevServerId(request.dev_server_id as i32);
2495    let access_token = auth::random_token();
2496    let hashed_access_token = auth::hash_access_token(&access_token);
2497
2498    let connection_id = session
2499        .connection_pool()
2500        .await
2501        .dev_server_connection_id(dev_server_id);
2502    if let Some(connection_id) = connection_id {
2503        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2504        session.peer.send(
2505            connection_id,
2506            proto::ShutdownDevServer {
2507                reason: Some("dev server token was regenerated".to_string()),
2508            },
2509        )?;
2510        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2511    }
2512
2513    let status = session
2514        .db()
2515        .await
2516        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2517        .await?;
2518
2519    send_dev_server_projects_update(session.user_id(), status, &session).await;
2520
2521    response.send(proto::RegenerateDevServerTokenResponse {
2522        dev_server_id: dev_server_id.to_proto(),
2523        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2524    })?;
2525    Ok(())
2526}
2527
2528async fn rename_dev_server(
2529    request: proto::RenameDevServer,
2530    response: Response<proto::RenameDevServer>,
2531    session: UserSession,
2532) -> Result<()> {
2533    if request.name.trim().is_empty() {
2534        return Err(proto::ErrorCode::Forbidden
2535            .message("Dev server name cannot be empty".to_string())
2536            .anyhow())?;
2537    }
2538
2539    let dev_server_id = DevServerId(request.dev_server_id as i32);
2540    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2541    if dev_server.user_id != session.user_id() {
2542        return Err(anyhow!(ErrorCode::Forbidden))?;
2543    }
2544
2545    let status = session
2546        .db()
2547        .await
2548        .rename_dev_server(
2549            dev_server_id,
2550            &request.name,
2551            request.ssh_connection_string.as_deref(),
2552            session.user_id(),
2553        )
2554        .await?;
2555
2556    send_dev_server_projects_update(session.user_id(), status, &session).await;
2557
2558    response.send(proto::Ack {})?;
2559    Ok(())
2560}
2561
2562async fn delete_dev_server(
2563    request: proto::DeleteDevServer,
2564    response: Response<proto::DeleteDevServer>,
2565    session: UserSession,
2566) -> Result<()> {
2567    let dev_server_id = DevServerId(request.dev_server_id as i32);
2568    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2569    if dev_server.user_id != session.user_id() {
2570        return Err(anyhow!(ErrorCode::Forbidden))?;
2571    }
2572
2573    let connection_id = session
2574        .connection_pool()
2575        .await
2576        .dev_server_connection_id(dev_server_id);
2577    if let Some(connection_id) = connection_id {
2578        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2579        session.peer.send(
2580            connection_id,
2581            proto::ShutdownDevServer {
2582                reason: Some("dev server was deleted".to_string()),
2583            },
2584        )?;
2585        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2586    }
2587
2588    let status = session
2589        .db()
2590        .await
2591        .delete_dev_server(dev_server_id, session.user_id())
2592        .await?;
2593
2594    send_dev_server_projects_update(session.user_id(), status, &session).await;
2595
2596    response.send(proto::Ack {})?;
2597    Ok(())
2598}
2599
2600async fn delete_dev_server_project(
2601    request: proto::DeleteDevServerProject,
2602    response: Response<proto::DeleteDevServerProject>,
2603    session: UserSession,
2604) -> Result<()> {
2605    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2606    let dev_server_project = session
2607        .db()
2608        .await
2609        .get_dev_server_project(dev_server_project_id)
2610        .await?;
2611
2612    let dev_server = session
2613        .db()
2614        .await
2615        .get_dev_server(dev_server_project.dev_server_id)
2616        .await?;
2617    if dev_server.user_id != session.user_id() {
2618        return Err(anyhow!(ErrorCode::Forbidden))?;
2619    }
2620
2621    let dev_server_connection_id = session
2622        .connection_pool()
2623        .await
2624        .dev_server_connection_id(dev_server.id);
2625
2626    if let Some(dev_server_connection_id) = dev_server_connection_id {
2627        let project = session
2628            .db()
2629            .await
2630            .find_dev_server_project(dev_server_project_id)
2631            .await;
2632        if let Ok(project) = project {
2633            unshare_project_internal(
2634                project.id,
2635                dev_server_connection_id,
2636                Some(session.user_id()),
2637                &session,
2638            )
2639            .await?;
2640        }
2641    }
2642
2643    let (projects, status) = session
2644        .db()
2645        .await
2646        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2647        .await?;
2648
2649    if let Some(dev_server_connection_id) = dev_server_connection_id {
2650        session.peer.send(
2651            dev_server_connection_id,
2652            proto::DevServerInstructions { projects },
2653        )?;
2654    }
2655
2656    send_dev_server_projects_update(session.user_id(), status, &session).await;
2657
2658    response.send(proto::Ack {})?;
2659    Ok(())
2660}
2661
2662async fn rejoin_dev_server_projects(
2663    request: proto::RejoinRemoteProjects,
2664    response: Response<proto::RejoinRemoteProjects>,
2665    session: UserSession,
2666) -> Result<()> {
2667    let mut rejoined_projects = {
2668        let db = session.db().await;
2669        db.rejoin_dev_server_projects(
2670            &request.rejoined_projects,
2671            session.user_id(),
2672            session.0.connection_id,
2673        )
2674        .await?
2675    };
2676    response.send(proto::RejoinRemoteProjectsResponse {
2677        rejoined_projects: rejoined_projects
2678            .iter()
2679            .map(|project| project.to_proto())
2680            .collect(),
2681    })?;
2682    notify_rejoined_projects(&mut rejoined_projects, &session)
2683}
2684
2685async fn reconnect_dev_server(
2686    request: proto::ReconnectDevServer,
2687    response: Response<proto::ReconnectDevServer>,
2688    session: DevServerSession,
2689) -> Result<()> {
2690    let reshared_projects = {
2691        let db = session.db().await;
2692        db.reshare_dev_server_projects(
2693            &request.reshared_projects,
2694            session.dev_server_id(),
2695            session.0.connection_id,
2696        )
2697        .await?
2698    };
2699
2700    for project in &reshared_projects {
2701        for collaborator in &project.collaborators {
2702            session
2703                .peer
2704                .send(
2705                    collaborator.connection_id,
2706                    proto::UpdateProjectCollaborator {
2707                        project_id: project.id.to_proto(),
2708                        old_peer_id: Some(project.old_connection_id.into()),
2709                        new_peer_id: Some(session.connection_id.into()),
2710                    },
2711                )
2712                .trace_err();
2713        }
2714
2715        broadcast(
2716            Some(session.connection_id),
2717            project
2718                .collaborators
2719                .iter()
2720                .map(|collaborator| collaborator.connection_id),
2721            |connection_id| {
2722                session.peer.forward_send(
2723                    session.connection_id,
2724                    connection_id,
2725                    proto::UpdateProject {
2726                        project_id: project.id.to_proto(),
2727                        worktrees: project.worktrees.clone(),
2728                    },
2729                )
2730            },
2731        );
2732    }
2733
2734    response.send(proto::ReconnectDevServerResponse {
2735        reshared_projects: reshared_projects
2736            .iter()
2737            .map(|project| proto::ResharedProject {
2738                id: project.id.to_proto(),
2739                collaborators: project
2740                    .collaborators
2741                    .iter()
2742                    .map(|collaborator| collaborator.to_proto())
2743                    .collect(),
2744            })
2745            .collect(),
2746    })?;
2747
2748    Ok(())
2749}
2750
2751async fn shutdown_dev_server(
2752    _: proto::ShutdownDevServer,
2753    response: Response<proto::ShutdownDevServer>,
2754    session: DevServerSession,
2755) -> Result<()> {
2756    response.send(proto::Ack {})?;
2757    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2758    remove_dev_server_connection(session.dev_server_id(), &session).await
2759}
2760
2761async fn shutdown_dev_server_internal(
2762    dev_server_id: DevServerId,
2763    connection_id: ConnectionId,
2764    session: &Session,
2765) -> Result<()> {
2766    let (dev_server_projects, dev_server) = {
2767        let db = session.db().await;
2768        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2769        let dev_server = db.get_dev_server(dev_server_id).await?;
2770        (dev_server_projects, dev_server)
2771    };
2772
2773    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2774        unshare_project_internal(
2775            ProjectId::from_proto(project_id),
2776            connection_id,
2777            None,
2778            session,
2779        )
2780        .await?;
2781    }
2782
2783    session
2784        .connection_pool()
2785        .await
2786        .set_dev_server_offline(dev_server_id);
2787
2788    let status = session
2789        .db()
2790        .await
2791        .dev_server_projects_update(dev_server.user_id)
2792        .await?;
2793    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2794
2795    Ok(())
2796}
2797
2798async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2799    let dev_server_connection = session
2800        .connection_pool()
2801        .await
2802        .dev_server_connection_id(dev_server_id);
2803
2804    if let Some(dev_server_connection) = dev_server_connection {
2805        session
2806            .connection_pool()
2807            .await
2808            .remove_connection(dev_server_connection)?;
2809    }
2810    Ok(())
2811}
2812
2813/// Updates other participants with changes to the project
2814async fn update_project(
2815    request: proto::UpdateProject,
2816    response: Response<proto::UpdateProject>,
2817    session: Session,
2818) -> Result<()> {
2819    let project_id = ProjectId::from_proto(request.project_id);
2820    let (room, guest_connection_ids) = &*session
2821        .db()
2822        .await
2823        .update_project(project_id, session.connection_id, &request.worktrees)
2824        .await?;
2825    broadcast(
2826        Some(session.connection_id),
2827        guest_connection_ids.iter().copied(),
2828        |connection_id| {
2829            session
2830                .peer
2831                .forward_send(session.connection_id, connection_id, request.clone())
2832        },
2833    );
2834    if let Some(room) = room {
2835        room_updated(&room, &session.peer);
2836    }
2837    response.send(proto::Ack {})?;
2838
2839    Ok(())
2840}
2841
2842/// Updates other participants with changes to the worktree
2843async fn update_worktree(
2844    request: proto::UpdateWorktree,
2845    response: Response<proto::UpdateWorktree>,
2846    session: Session,
2847) -> Result<()> {
2848    let guest_connection_ids = session
2849        .db()
2850        .await
2851        .update_worktree(&request, session.connection_id)
2852        .await?;
2853
2854    broadcast(
2855        Some(session.connection_id),
2856        guest_connection_ids.iter().copied(),
2857        |connection_id| {
2858            session
2859                .peer
2860                .forward_send(session.connection_id, connection_id, request.clone())
2861        },
2862    );
2863    response.send(proto::Ack {})?;
2864    Ok(())
2865}
2866
2867/// Updates other participants with changes to the diagnostics
2868async fn update_diagnostic_summary(
2869    message: proto::UpdateDiagnosticSummary,
2870    session: Session,
2871) -> Result<()> {
2872    let guest_connection_ids = session
2873        .db()
2874        .await
2875        .update_diagnostic_summary(&message, session.connection_id)
2876        .await?;
2877
2878    broadcast(
2879        Some(session.connection_id),
2880        guest_connection_ids.iter().copied(),
2881        |connection_id| {
2882            session
2883                .peer
2884                .forward_send(session.connection_id, connection_id, message.clone())
2885        },
2886    );
2887
2888    Ok(())
2889}
2890
2891/// Updates other participants with changes to the worktree settings
2892async fn update_worktree_settings(
2893    message: proto::UpdateWorktreeSettings,
2894    session: Session,
2895) -> Result<()> {
2896    let guest_connection_ids = session
2897        .db()
2898        .await
2899        .update_worktree_settings(&message, session.connection_id)
2900        .await?;
2901
2902    broadcast(
2903        Some(session.connection_id),
2904        guest_connection_ids.iter().copied(),
2905        |connection_id| {
2906            session
2907                .peer
2908                .forward_send(session.connection_id, connection_id, message.clone())
2909        },
2910    );
2911
2912    Ok(())
2913}
2914
2915/// Notify other participants that a  language server has started.
2916async fn start_language_server(
2917    request: proto::StartLanguageServer,
2918    session: Session,
2919) -> Result<()> {
2920    let guest_connection_ids = session
2921        .db()
2922        .await
2923        .start_language_server(&request, session.connection_id)
2924        .await?;
2925
2926    broadcast(
2927        Some(session.connection_id),
2928        guest_connection_ids.iter().copied(),
2929        |connection_id| {
2930            session
2931                .peer
2932                .forward_send(session.connection_id, connection_id, request.clone())
2933        },
2934    );
2935    Ok(())
2936}
2937
2938/// Notify other participants that a language server has changed.
2939async fn update_language_server(
2940    request: proto::UpdateLanguageServer,
2941    session: Session,
2942) -> Result<()> {
2943    let project_id = ProjectId::from_proto(request.project_id);
2944    let project_connection_ids = session
2945        .db()
2946        .await
2947        .project_connection_ids(project_id, session.connection_id, true)
2948        .await?;
2949    broadcast(
2950        Some(session.connection_id),
2951        project_connection_ids.iter().copied(),
2952        |connection_id| {
2953            session
2954                .peer
2955                .forward_send(session.connection_id, connection_id, request.clone())
2956        },
2957    );
2958    Ok(())
2959}
2960
2961/// forward a project request to the host. These requests should be read only
2962/// as guests are allowed to send them.
2963async fn forward_read_only_project_request<T>(
2964    request: T,
2965    response: Response<T>,
2966    session: UserSession,
2967) -> Result<()>
2968where
2969    T: EntityMessage + RequestMessage,
2970{
2971    let project_id = ProjectId::from_proto(request.remote_entity_id());
2972    let host_connection_id = session
2973        .db()
2974        .await
2975        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2976        .await?;
2977    let payload = session
2978        .peer
2979        .forward_request(session.connection_id, host_connection_id, request)
2980        .await?;
2981    response.send(payload)?;
2982    Ok(())
2983}
2984
2985/// forward a project request to the dev server. Only allowed
2986/// if it's your dev server.
2987async fn forward_project_request_for_owner<T>(
2988    request: T,
2989    response: Response<T>,
2990    session: UserSession,
2991) -> Result<()>
2992where
2993    T: EntityMessage + RequestMessage,
2994{
2995    let project_id = ProjectId::from_proto(request.remote_entity_id());
2996
2997    let host_connection_id = session
2998        .db()
2999        .await
3000        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3001        .await?;
3002    let payload = session
3003        .peer
3004        .forward_request(session.connection_id, host_connection_id, request)
3005        .await?;
3006    response.send(payload)?;
3007    Ok(())
3008}
3009
3010/// forward a project request to the host. These requests are disallowed
3011/// for guests.
3012async fn forward_mutating_project_request<T>(
3013    request: T,
3014    response: Response<T>,
3015    session: UserSession,
3016) -> Result<()>
3017where
3018    T: EntityMessage + RequestMessage,
3019{
3020    let project_id = ProjectId::from_proto(request.remote_entity_id());
3021
3022    let host_connection_id = session
3023        .db()
3024        .await
3025        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3026        .await?;
3027    let payload = session
3028        .peer
3029        .forward_request(session.connection_id, host_connection_id, request)
3030        .await?;
3031    response.send(payload)?;
3032    Ok(())
3033}
3034
3035/// forward a project request to the host. These requests are disallowed
3036/// for guests.
3037async fn forward_versioned_mutating_project_request<T>(
3038    request: T,
3039    response: Response<T>,
3040    session: UserSession,
3041) -> Result<()>
3042where
3043    T: EntityMessage + RequestMessage + VersionedMessage,
3044{
3045    let project_id = ProjectId::from_proto(request.remote_entity_id());
3046
3047    let host_connection_id = session
3048        .db()
3049        .await
3050        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3051        .await?;
3052    if let Some(host_version) = session
3053        .connection_pool()
3054        .await
3055        .connection(host_connection_id)
3056        .map(|c| c.zed_version)
3057    {
3058        if let Some(min_required_version) = request.required_host_version() {
3059            if min_required_version > host_version {
3060                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3061                    .with_tag("required", &min_required_version.to_string())))?;
3062            }
3063        }
3064    }
3065
3066    let payload = session
3067        .peer
3068        .forward_request(session.connection_id, host_connection_id, request)
3069        .await?;
3070    response.send(payload)?;
3071    Ok(())
3072}
3073
3074/// Notify other participants that a new buffer has been created
3075async fn create_buffer_for_peer(
3076    request: proto::CreateBufferForPeer,
3077    session: Session,
3078) -> Result<()> {
3079    session
3080        .db()
3081        .await
3082        .check_user_is_project_host(
3083            ProjectId::from_proto(request.project_id),
3084            session.connection_id,
3085        )
3086        .await?;
3087    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3088    session
3089        .peer
3090        .forward_send(session.connection_id, peer_id.into(), request)?;
3091    Ok(())
3092}
3093
3094/// Notify other participants that a buffer has been updated. This is
3095/// allowed for guests as long as the update is limited to selections.
3096async fn update_buffer(
3097    request: proto::UpdateBuffer,
3098    response: Response<proto::UpdateBuffer>,
3099    session: Session,
3100) -> Result<()> {
3101    let project_id = ProjectId::from_proto(request.project_id);
3102    let mut capability = Capability::ReadOnly;
3103
3104    for op in request.operations.iter() {
3105        match op.variant {
3106            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3107            Some(_) => capability = Capability::ReadWrite,
3108        }
3109    }
3110
3111    let host = {
3112        let guard = session
3113            .db()
3114            .await
3115            .connections_for_buffer_update(
3116                project_id,
3117                session.principal_id(),
3118                session.connection_id,
3119                capability,
3120            )
3121            .await?;
3122
3123        let (host, guests) = &*guard;
3124
3125        broadcast(
3126            Some(session.connection_id),
3127            guests.clone(),
3128            |connection_id| {
3129                session
3130                    .peer
3131                    .forward_send(session.connection_id, connection_id, request.clone())
3132            },
3133        );
3134
3135        *host
3136    };
3137
3138    if host != session.connection_id {
3139        session
3140            .peer
3141            .forward_request(session.connection_id, host, request.clone())
3142            .await?;
3143    }
3144
3145    response.send(proto::Ack {})?;
3146    Ok(())
3147}
3148
3149async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3150    let project_id = ProjectId::from_proto(message.project_id);
3151
3152    let operation = message.operation.as_ref().context("invalid operation")?;
3153    let capability = match operation.variant.as_ref() {
3154        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3155            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3156                match buffer_op.variant {
3157                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3158                        Capability::ReadOnly
3159                    }
3160                    _ => Capability::ReadWrite,
3161                }
3162            } else {
3163                Capability::ReadWrite
3164            }
3165        }
3166        Some(_) => Capability::ReadWrite,
3167        None => Capability::ReadOnly,
3168    };
3169
3170    let guard = session
3171        .db()
3172        .await
3173        .connections_for_buffer_update(
3174            project_id,
3175            session.principal_id(),
3176            session.connection_id,
3177            capability,
3178        )
3179        .await?;
3180
3181    let (host, guests) = &*guard;
3182
3183    broadcast(
3184        Some(session.connection_id),
3185        guests.iter().chain([host]).copied(),
3186        |connection_id| {
3187            session
3188                .peer
3189                .forward_send(session.connection_id, connection_id, message.clone())
3190        },
3191    );
3192
3193    Ok(())
3194}
3195
3196/// Notify other participants that a project has been updated.
3197async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3198    request: T,
3199    session: Session,
3200) -> Result<()> {
3201    let project_id = ProjectId::from_proto(request.remote_entity_id());
3202    let project_connection_ids = session
3203        .db()
3204        .await
3205        .project_connection_ids(project_id, session.connection_id, false)
3206        .await?;
3207
3208    broadcast(
3209        Some(session.connection_id),
3210        project_connection_ids.iter().copied(),
3211        |connection_id| {
3212            session
3213                .peer
3214                .forward_send(session.connection_id, connection_id, request.clone())
3215        },
3216    );
3217    Ok(())
3218}
3219
3220/// Start following another user in a call.
3221async fn follow(
3222    request: proto::Follow,
3223    response: Response<proto::Follow>,
3224    session: UserSession,
3225) -> Result<()> {
3226    let room_id = RoomId::from_proto(request.room_id);
3227    let project_id = request.project_id.map(ProjectId::from_proto);
3228    let leader_id = request
3229        .leader_id
3230        .ok_or_else(|| anyhow!("invalid leader id"))?
3231        .into();
3232    let follower_id = session.connection_id;
3233
3234    session
3235        .db()
3236        .await
3237        .check_room_participants(room_id, leader_id, session.connection_id)
3238        .await?;
3239
3240    let response_payload = session
3241        .peer
3242        .forward_request(session.connection_id, leader_id, request)
3243        .await?;
3244    response.send(response_payload)?;
3245
3246    if let Some(project_id) = project_id {
3247        let room = session
3248            .db()
3249            .await
3250            .follow(room_id, project_id, leader_id, follower_id)
3251            .await?;
3252        room_updated(&room, &session.peer);
3253    }
3254
3255    Ok(())
3256}
3257
3258/// Stop following another user in a call.
3259async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3260    let room_id = RoomId::from_proto(request.room_id);
3261    let project_id = request.project_id.map(ProjectId::from_proto);
3262    let leader_id = request
3263        .leader_id
3264        .ok_or_else(|| anyhow!("invalid leader id"))?
3265        .into();
3266    let follower_id = session.connection_id;
3267
3268    session
3269        .db()
3270        .await
3271        .check_room_participants(room_id, leader_id, session.connection_id)
3272        .await?;
3273
3274    session
3275        .peer
3276        .forward_send(session.connection_id, leader_id, request)?;
3277
3278    if let Some(project_id) = project_id {
3279        let room = session
3280            .db()
3281            .await
3282            .unfollow(room_id, project_id, leader_id, follower_id)
3283            .await?;
3284        room_updated(&room, &session.peer);
3285    }
3286
3287    Ok(())
3288}
3289
3290/// Notify everyone following you of your current location.
3291async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3292    let room_id = RoomId::from_proto(request.room_id);
3293    let database = session.db.lock().await;
3294
3295    let connection_ids = if let Some(project_id) = request.project_id {
3296        let project_id = ProjectId::from_proto(project_id);
3297        database
3298            .project_connection_ids(project_id, session.connection_id, true)
3299            .await?
3300    } else {
3301        database
3302            .room_connection_ids(room_id, session.connection_id)
3303            .await?
3304    };
3305
3306    // For now, don't send view update messages back to that view's current leader.
3307    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3308        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3309        _ => None,
3310    });
3311
3312    for connection_id in connection_ids.iter().cloned() {
3313        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3314            session
3315                .peer
3316                .forward_send(session.connection_id, connection_id, request.clone())?;
3317        }
3318    }
3319    Ok(())
3320}
3321
3322/// Get public data about users.
3323async fn get_users(
3324    request: proto::GetUsers,
3325    response: Response<proto::GetUsers>,
3326    session: Session,
3327) -> Result<()> {
3328    let user_ids = request
3329        .user_ids
3330        .into_iter()
3331        .map(UserId::from_proto)
3332        .collect();
3333    let users = session
3334        .db()
3335        .await
3336        .get_users_by_ids(user_ids)
3337        .await?
3338        .into_iter()
3339        .map(|user| proto::User {
3340            id: user.id.to_proto(),
3341            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3342            github_login: user.github_login,
3343        })
3344        .collect();
3345    response.send(proto::UsersResponse { users })?;
3346    Ok(())
3347}
3348
3349/// Search for users (to invite) buy Github login
3350async fn fuzzy_search_users(
3351    request: proto::FuzzySearchUsers,
3352    response: Response<proto::FuzzySearchUsers>,
3353    session: UserSession,
3354) -> Result<()> {
3355    let query = request.query;
3356    let users = match query.len() {
3357        0 => vec![],
3358        1 | 2 => session
3359            .db()
3360            .await
3361            .get_user_by_github_login(&query)
3362            .await?
3363            .into_iter()
3364            .collect(),
3365        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3366    };
3367    let users = users
3368        .into_iter()
3369        .filter(|user| user.id != session.user_id())
3370        .map(|user| proto::User {
3371            id: user.id.to_proto(),
3372            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3373            github_login: user.github_login,
3374        })
3375        .collect();
3376    response.send(proto::UsersResponse { users })?;
3377    Ok(())
3378}
3379
3380/// Send a contact request to another user.
3381async fn request_contact(
3382    request: proto::RequestContact,
3383    response: Response<proto::RequestContact>,
3384    session: UserSession,
3385) -> Result<()> {
3386    let requester_id = session.user_id();
3387    let responder_id = UserId::from_proto(request.responder_id);
3388    if requester_id == responder_id {
3389        return Err(anyhow!("cannot add yourself as a contact"))?;
3390    }
3391
3392    let notifications = session
3393        .db()
3394        .await
3395        .send_contact_request(requester_id, responder_id)
3396        .await?;
3397
3398    // Update outgoing contact requests of requester
3399    let mut update = proto::UpdateContacts::default();
3400    update.outgoing_requests.push(responder_id.to_proto());
3401    for connection_id in session
3402        .connection_pool()
3403        .await
3404        .user_connection_ids(requester_id)
3405    {
3406        session.peer.send(connection_id, update.clone())?;
3407    }
3408
3409    // Update incoming contact requests of responder
3410    let mut update = proto::UpdateContacts::default();
3411    update
3412        .incoming_requests
3413        .push(proto::IncomingContactRequest {
3414            requester_id: requester_id.to_proto(),
3415        });
3416    let connection_pool = session.connection_pool().await;
3417    for connection_id in connection_pool.user_connection_ids(responder_id) {
3418        session.peer.send(connection_id, update.clone())?;
3419    }
3420
3421    send_notifications(&connection_pool, &session.peer, notifications);
3422
3423    response.send(proto::Ack {})?;
3424    Ok(())
3425}
3426
3427/// Accept or decline a contact request
3428async fn respond_to_contact_request(
3429    request: proto::RespondToContactRequest,
3430    response: Response<proto::RespondToContactRequest>,
3431    session: UserSession,
3432) -> Result<()> {
3433    let responder_id = session.user_id();
3434    let requester_id = UserId::from_proto(request.requester_id);
3435    let db = session.db().await;
3436    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3437        db.dismiss_contact_notification(responder_id, requester_id)
3438            .await?;
3439    } else {
3440        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3441
3442        let notifications = db
3443            .respond_to_contact_request(responder_id, requester_id, accept)
3444            .await?;
3445        let requester_busy = db.is_user_busy(requester_id).await?;
3446        let responder_busy = db.is_user_busy(responder_id).await?;
3447
3448        let pool = session.connection_pool().await;
3449        // Update responder with new contact
3450        let mut update = proto::UpdateContacts::default();
3451        if accept {
3452            update
3453                .contacts
3454                .push(contact_for_user(requester_id, requester_busy, &pool));
3455        }
3456        update
3457            .remove_incoming_requests
3458            .push(requester_id.to_proto());
3459        for connection_id in pool.user_connection_ids(responder_id) {
3460            session.peer.send(connection_id, update.clone())?;
3461        }
3462
3463        // Update requester with new contact
3464        let mut update = proto::UpdateContacts::default();
3465        if accept {
3466            update
3467                .contacts
3468                .push(contact_for_user(responder_id, responder_busy, &pool));
3469        }
3470        update
3471            .remove_outgoing_requests
3472            .push(responder_id.to_proto());
3473
3474        for connection_id in pool.user_connection_ids(requester_id) {
3475            session.peer.send(connection_id, update.clone())?;
3476        }
3477
3478        send_notifications(&pool, &session.peer, notifications);
3479    }
3480
3481    response.send(proto::Ack {})?;
3482    Ok(())
3483}
3484
3485/// Remove a contact.
3486async fn remove_contact(
3487    request: proto::RemoveContact,
3488    response: Response<proto::RemoveContact>,
3489    session: UserSession,
3490) -> Result<()> {
3491    let requester_id = session.user_id();
3492    let responder_id = UserId::from_proto(request.user_id);
3493    let db = session.db().await;
3494    let (contact_accepted, deleted_notification_id) =
3495        db.remove_contact(requester_id, responder_id).await?;
3496
3497    let pool = session.connection_pool().await;
3498    // Update outgoing contact requests of requester
3499    let mut update = proto::UpdateContacts::default();
3500    if contact_accepted {
3501        update.remove_contacts.push(responder_id.to_proto());
3502    } else {
3503        update
3504            .remove_outgoing_requests
3505            .push(responder_id.to_proto());
3506    }
3507    for connection_id in pool.user_connection_ids(requester_id) {
3508        session.peer.send(connection_id, update.clone())?;
3509    }
3510
3511    // Update incoming contact requests of responder
3512    let mut update = proto::UpdateContacts::default();
3513    if contact_accepted {
3514        update.remove_contacts.push(requester_id.to_proto());
3515    } else {
3516        update
3517            .remove_incoming_requests
3518            .push(requester_id.to_proto());
3519    }
3520    for connection_id in pool.user_connection_ids(responder_id) {
3521        session.peer.send(connection_id, update.clone())?;
3522        if let Some(notification_id) = deleted_notification_id {
3523            session.peer.send(
3524                connection_id,
3525                proto::DeleteNotification {
3526                    notification_id: notification_id.to_proto(),
3527                },
3528            )?;
3529        }
3530    }
3531
3532    response.send(proto::Ack {})?;
3533    Ok(())
3534}
3535
3536fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3537    version.0.minor() < 139
3538}
3539
3540async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
3541    let db = session.db().await;
3542    let active_subscriptions = db.get_active_billing_subscriptions(user_id).await?;
3543
3544    let plan = if session.is_staff() || !active_subscriptions.is_empty() {
3545        proto::Plan::ZedPro
3546    } else {
3547        proto::Plan::Free
3548    };
3549
3550    session
3551        .peer
3552        .send(
3553            session.connection_id,
3554            proto::UpdateUserPlan { plan: plan.into() },
3555        )
3556        .trace_err();
3557
3558    Ok(())
3559}
3560
3561async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3562    subscribe_user_to_channels(
3563        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3564        &session,
3565    )
3566    .await?;
3567    Ok(())
3568}
3569
3570async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3571    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3572    let mut pool = session.connection_pool().await;
3573    for membership in &channels_for_user.channel_memberships {
3574        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3575    }
3576    session.peer.send(
3577        session.connection_id,
3578        build_update_user_channels(&channels_for_user),
3579    )?;
3580    session.peer.send(
3581        session.connection_id,
3582        build_channels_update(channels_for_user),
3583    )?;
3584    Ok(())
3585}
3586
3587/// Creates a new channel.
3588async fn create_channel(
3589    request: proto::CreateChannel,
3590    response: Response<proto::CreateChannel>,
3591    session: UserSession,
3592) -> Result<()> {
3593    let db = session.db().await;
3594
3595    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3596    let (channel, membership) = db
3597        .create_channel(&request.name, parent_id, session.user_id())
3598        .await?;
3599
3600    let root_id = channel.root_id();
3601    let channel = Channel::from_model(channel);
3602
3603    response.send(proto::CreateChannelResponse {
3604        channel: Some(channel.to_proto()),
3605        parent_id: request.parent_id,
3606    })?;
3607
3608    let mut connection_pool = session.connection_pool().await;
3609    if let Some(membership) = membership {
3610        connection_pool.subscribe_to_channel(
3611            membership.user_id,
3612            membership.channel_id,
3613            membership.role,
3614        );
3615        let update = proto::UpdateUserChannels {
3616            channel_memberships: vec![proto::ChannelMembership {
3617                channel_id: membership.channel_id.to_proto(),
3618                role: membership.role.into(),
3619            }],
3620            ..Default::default()
3621        };
3622        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3623            session.peer.send(connection_id, update.clone())?;
3624        }
3625    }
3626
3627    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3628        if !role.can_see_channel(channel.visibility) {
3629            continue;
3630        }
3631
3632        let update = proto::UpdateChannels {
3633            channels: vec![channel.to_proto()],
3634            ..Default::default()
3635        };
3636        session.peer.send(connection_id, update.clone())?;
3637    }
3638
3639    Ok(())
3640}
3641
3642/// Delete a channel
3643async fn delete_channel(
3644    request: proto::DeleteChannel,
3645    response: Response<proto::DeleteChannel>,
3646    session: UserSession,
3647) -> Result<()> {
3648    let db = session.db().await;
3649
3650    let channel_id = request.channel_id;
3651    let (root_channel, removed_channels) = db
3652        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3653        .await?;
3654    response.send(proto::Ack {})?;
3655
3656    // Notify members of removed channels
3657    let mut update = proto::UpdateChannels::default();
3658    update
3659        .delete_channels
3660        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3661
3662    let connection_pool = session.connection_pool().await;
3663    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3664        session.peer.send(connection_id, update.clone())?;
3665    }
3666
3667    Ok(())
3668}
3669
3670/// Invite someone to join a channel.
3671async fn invite_channel_member(
3672    request: proto::InviteChannelMember,
3673    response: Response<proto::InviteChannelMember>,
3674    session: UserSession,
3675) -> Result<()> {
3676    let db = session.db().await;
3677    let channel_id = ChannelId::from_proto(request.channel_id);
3678    let invitee_id = UserId::from_proto(request.user_id);
3679    let InviteMemberResult {
3680        channel,
3681        notifications,
3682    } = db
3683        .invite_channel_member(
3684            channel_id,
3685            invitee_id,
3686            session.user_id(),
3687            request.role().into(),
3688        )
3689        .await?;
3690
3691    let update = proto::UpdateChannels {
3692        channel_invitations: vec![channel.to_proto()],
3693        ..Default::default()
3694    };
3695
3696    let connection_pool = session.connection_pool().await;
3697    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3698        session.peer.send(connection_id, update.clone())?;
3699    }
3700
3701    send_notifications(&connection_pool, &session.peer, notifications);
3702
3703    response.send(proto::Ack {})?;
3704    Ok(())
3705}
3706
3707/// remove someone from a channel
3708async fn remove_channel_member(
3709    request: proto::RemoveChannelMember,
3710    response: Response<proto::RemoveChannelMember>,
3711    session: UserSession,
3712) -> Result<()> {
3713    let db = session.db().await;
3714    let channel_id = ChannelId::from_proto(request.channel_id);
3715    let member_id = UserId::from_proto(request.user_id);
3716
3717    let RemoveChannelMemberResult {
3718        membership_update,
3719        notification_id,
3720    } = db
3721        .remove_channel_member(channel_id, member_id, session.user_id())
3722        .await?;
3723
3724    let mut connection_pool = session.connection_pool().await;
3725    notify_membership_updated(
3726        &mut connection_pool,
3727        membership_update,
3728        member_id,
3729        &session.peer,
3730    );
3731    for connection_id in connection_pool.user_connection_ids(member_id) {
3732        if let Some(notification_id) = notification_id {
3733            session
3734                .peer
3735                .send(
3736                    connection_id,
3737                    proto::DeleteNotification {
3738                        notification_id: notification_id.to_proto(),
3739                    },
3740                )
3741                .trace_err();
3742        }
3743    }
3744
3745    response.send(proto::Ack {})?;
3746    Ok(())
3747}
3748
3749/// Toggle the channel between public and private.
3750/// Care is taken to maintain the invariant that public channels only descend from public channels,
3751/// (though members-only channels can appear at any point in the hierarchy).
3752async fn set_channel_visibility(
3753    request: proto::SetChannelVisibility,
3754    response: Response<proto::SetChannelVisibility>,
3755    session: UserSession,
3756) -> Result<()> {
3757    let db = session.db().await;
3758    let channel_id = ChannelId::from_proto(request.channel_id);
3759    let visibility = request.visibility().into();
3760
3761    let channel_model = db
3762        .set_channel_visibility(channel_id, visibility, session.user_id())
3763        .await?;
3764    let root_id = channel_model.root_id();
3765    let channel = Channel::from_model(channel_model);
3766
3767    let mut connection_pool = session.connection_pool().await;
3768    for (user_id, role) in connection_pool
3769        .channel_user_ids(root_id)
3770        .collect::<Vec<_>>()
3771        .into_iter()
3772    {
3773        let update = if role.can_see_channel(channel.visibility) {
3774            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3775            proto::UpdateChannels {
3776                channels: vec![channel.to_proto()],
3777                ..Default::default()
3778            }
3779        } else {
3780            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3781            proto::UpdateChannels {
3782                delete_channels: vec![channel.id.to_proto()],
3783                ..Default::default()
3784            }
3785        };
3786
3787        for connection_id in connection_pool.user_connection_ids(user_id) {
3788            session.peer.send(connection_id, update.clone())?;
3789        }
3790    }
3791
3792    response.send(proto::Ack {})?;
3793    Ok(())
3794}
3795
3796/// Alter the role for a user in the channel.
3797async fn set_channel_member_role(
3798    request: proto::SetChannelMemberRole,
3799    response: Response<proto::SetChannelMemberRole>,
3800    session: UserSession,
3801) -> Result<()> {
3802    let db = session.db().await;
3803    let channel_id = ChannelId::from_proto(request.channel_id);
3804    let member_id = UserId::from_proto(request.user_id);
3805    let result = db
3806        .set_channel_member_role(
3807            channel_id,
3808            session.user_id(),
3809            member_id,
3810            request.role().into(),
3811        )
3812        .await?;
3813
3814    match result {
3815        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3816            let mut connection_pool = session.connection_pool().await;
3817            notify_membership_updated(
3818                &mut connection_pool,
3819                membership_update,
3820                member_id,
3821                &session.peer,
3822            )
3823        }
3824        db::SetMemberRoleResult::InviteUpdated(channel) => {
3825            let update = proto::UpdateChannels {
3826                channel_invitations: vec![channel.to_proto()],
3827                ..Default::default()
3828            };
3829
3830            for connection_id in session
3831                .connection_pool()
3832                .await
3833                .user_connection_ids(member_id)
3834            {
3835                session.peer.send(connection_id, update.clone())?;
3836            }
3837        }
3838    }
3839
3840    response.send(proto::Ack {})?;
3841    Ok(())
3842}
3843
3844/// Change the name of a channel
3845async fn rename_channel(
3846    request: proto::RenameChannel,
3847    response: Response<proto::RenameChannel>,
3848    session: UserSession,
3849) -> Result<()> {
3850    let db = session.db().await;
3851    let channel_id = ChannelId::from_proto(request.channel_id);
3852    let channel_model = db
3853        .rename_channel(channel_id, session.user_id(), &request.name)
3854        .await?;
3855    let root_id = channel_model.root_id();
3856    let channel = Channel::from_model(channel_model);
3857
3858    response.send(proto::RenameChannelResponse {
3859        channel: Some(channel.to_proto()),
3860    })?;
3861
3862    let connection_pool = session.connection_pool().await;
3863    let update = proto::UpdateChannels {
3864        channels: vec![channel.to_proto()],
3865        ..Default::default()
3866    };
3867    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3868        if role.can_see_channel(channel.visibility) {
3869            session.peer.send(connection_id, update.clone())?;
3870        }
3871    }
3872
3873    Ok(())
3874}
3875
3876/// Move a channel to a new parent.
3877async fn move_channel(
3878    request: proto::MoveChannel,
3879    response: Response<proto::MoveChannel>,
3880    session: UserSession,
3881) -> Result<()> {
3882    let channel_id = ChannelId::from_proto(request.channel_id);
3883    let to = ChannelId::from_proto(request.to);
3884
3885    let (root_id, channels) = session
3886        .db()
3887        .await
3888        .move_channel(channel_id, to, session.user_id())
3889        .await?;
3890
3891    let connection_pool = session.connection_pool().await;
3892    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3893        let channels = channels
3894            .iter()
3895            .filter_map(|channel| {
3896                if role.can_see_channel(channel.visibility) {
3897                    Some(channel.to_proto())
3898                } else {
3899                    None
3900                }
3901            })
3902            .collect::<Vec<_>>();
3903        if channels.is_empty() {
3904            continue;
3905        }
3906
3907        let update = proto::UpdateChannels {
3908            channels,
3909            ..Default::default()
3910        };
3911
3912        session.peer.send(connection_id, update.clone())?;
3913    }
3914
3915    response.send(Ack {})?;
3916    Ok(())
3917}
3918
3919/// Get the list of channel members
3920async fn get_channel_members(
3921    request: proto::GetChannelMembers,
3922    response: Response<proto::GetChannelMembers>,
3923    session: UserSession,
3924) -> Result<()> {
3925    let db = session.db().await;
3926    let channel_id = ChannelId::from_proto(request.channel_id);
3927    let limit = if request.limit == 0 {
3928        u16::MAX as u64
3929    } else {
3930        request.limit
3931    };
3932    let (members, users) = db
3933        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3934        .await?;
3935    response.send(proto::GetChannelMembersResponse { members, users })?;
3936    Ok(())
3937}
3938
3939/// Accept or decline a channel invitation.
3940async fn respond_to_channel_invite(
3941    request: proto::RespondToChannelInvite,
3942    response: Response<proto::RespondToChannelInvite>,
3943    session: UserSession,
3944) -> Result<()> {
3945    let db = session.db().await;
3946    let channel_id = ChannelId::from_proto(request.channel_id);
3947    let RespondToChannelInvite {
3948        membership_update,
3949        notifications,
3950    } = db
3951        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3952        .await?;
3953
3954    let mut connection_pool = session.connection_pool().await;
3955    if let Some(membership_update) = membership_update {
3956        notify_membership_updated(
3957            &mut connection_pool,
3958            membership_update,
3959            session.user_id(),
3960            &session.peer,
3961        );
3962    } else {
3963        let update = proto::UpdateChannels {
3964            remove_channel_invitations: vec![channel_id.to_proto()],
3965            ..Default::default()
3966        };
3967
3968        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3969            session.peer.send(connection_id, update.clone())?;
3970        }
3971    };
3972
3973    send_notifications(&connection_pool, &session.peer, notifications);
3974
3975    response.send(proto::Ack {})?;
3976
3977    Ok(())
3978}
3979
3980/// Join the channels' room
3981async fn join_channel(
3982    request: proto::JoinChannel,
3983    response: Response<proto::JoinChannel>,
3984    session: UserSession,
3985) -> Result<()> {
3986    let channel_id = ChannelId::from_proto(request.channel_id);
3987    join_channel_internal(channel_id, Box::new(response), session).await
3988}
3989
3990trait JoinChannelInternalResponse {
3991    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3992}
3993impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3994    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3995        Response::<proto::JoinChannel>::send(self, result)
3996    }
3997}
3998impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3999    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4000        Response::<proto::JoinRoom>::send(self, result)
4001    }
4002}
4003
4004async fn join_channel_internal(
4005    channel_id: ChannelId,
4006    response: Box<impl JoinChannelInternalResponse>,
4007    session: UserSession,
4008) -> Result<()> {
4009    let joined_room = {
4010        let mut db = session.db().await;
4011        // If zed quits without leaving the room, and the user re-opens zed before the
4012        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4013        // room they were in.
4014        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4015            tracing::info!(
4016                stale_connection_id = %connection,
4017                "cleaning up stale connection",
4018            );
4019            drop(db);
4020            leave_room_for_session(&session, connection).await?;
4021            db = session.db().await;
4022        }
4023
4024        let (joined_room, membership_updated, role) = db
4025            .join_channel(channel_id, session.user_id(), session.connection_id)
4026            .await?;
4027
4028        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4029            let (can_publish, token) = if role == ChannelRole::Guest {
4030                (
4031                    false,
4032                    live_kit
4033                        .guest_token(
4034                            &joined_room.room.live_kit_room,
4035                            &session.user_id().to_string(),
4036                        )
4037                        .trace_err()?,
4038                )
4039            } else {
4040                (
4041                    true,
4042                    live_kit
4043                        .room_token(
4044                            &joined_room.room.live_kit_room,
4045                            &session.user_id().to_string(),
4046                        )
4047                        .trace_err()?,
4048                )
4049            };
4050
4051            Some(LiveKitConnectionInfo {
4052                server_url: live_kit.url().into(),
4053                token,
4054                can_publish,
4055            })
4056        });
4057
4058        response.send(proto::JoinRoomResponse {
4059            room: Some(joined_room.room.clone()),
4060            channel_id: joined_room
4061                .channel
4062                .as_ref()
4063                .map(|channel| channel.id.to_proto()),
4064            live_kit_connection_info,
4065        })?;
4066
4067        let mut connection_pool = session.connection_pool().await;
4068        if let Some(membership_updated) = membership_updated {
4069            notify_membership_updated(
4070                &mut connection_pool,
4071                membership_updated,
4072                session.user_id(),
4073                &session.peer,
4074            );
4075        }
4076
4077        room_updated(&joined_room.room, &session.peer);
4078
4079        joined_room
4080    };
4081
4082    channel_updated(
4083        &joined_room
4084            .channel
4085            .ok_or_else(|| anyhow!("channel not returned"))?,
4086        &joined_room.room,
4087        &session.peer,
4088        &*session.connection_pool().await,
4089    );
4090
4091    update_user_contacts(session.user_id(), &session).await?;
4092    Ok(())
4093}
4094
4095/// Start editing the channel notes
4096async fn join_channel_buffer(
4097    request: proto::JoinChannelBuffer,
4098    response: Response<proto::JoinChannelBuffer>,
4099    session: UserSession,
4100) -> Result<()> {
4101    let db = session.db().await;
4102    let channel_id = ChannelId::from_proto(request.channel_id);
4103
4104    let open_response = db
4105        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4106        .await?;
4107
4108    let collaborators = open_response.collaborators.clone();
4109    response.send(open_response)?;
4110
4111    let update = UpdateChannelBufferCollaborators {
4112        channel_id: channel_id.to_proto(),
4113        collaborators: collaborators.clone(),
4114    };
4115    channel_buffer_updated(
4116        session.connection_id,
4117        collaborators
4118            .iter()
4119            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4120        &update,
4121        &session.peer,
4122    );
4123
4124    Ok(())
4125}
4126
4127/// Edit the channel notes
4128async fn update_channel_buffer(
4129    request: proto::UpdateChannelBuffer,
4130    session: UserSession,
4131) -> Result<()> {
4132    let db = session.db().await;
4133    let channel_id = ChannelId::from_proto(request.channel_id);
4134
4135    let (collaborators, epoch, version) = db
4136        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4137        .await?;
4138
4139    channel_buffer_updated(
4140        session.connection_id,
4141        collaborators.clone(),
4142        &proto::UpdateChannelBuffer {
4143            channel_id: channel_id.to_proto(),
4144            operations: request.operations,
4145        },
4146        &session.peer,
4147    );
4148
4149    let pool = &*session.connection_pool().await;
4150
4151    let non_collaborators =
4152        pool.channel_connection_ids(channel_id)
4153            .filter_map(|(connection_id, _)| {
4154                if collaborators.contains(&connection_id) {
4155                    None
4156                } else {
4157                    Some(connection_id)
4158                }
4159            });
4160
4161    broadcast(None, non_collaborators, |peer_id| {
4162        session.peer.send(
4163            peer_id,
4164            proto::UpdateChannels {
4165                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4166                    channel_id: channel_id.to_proto(),
4167                    epoch: epoch as u64,
4168                    version: version.clone(),
4169                }],
4170                ..Default::default()
4171            },
4172        )
4173    });
4174
4175    Ok(())
4176}
4177
4178/// Rejoin the channel notes after a connection blip
4179async fn rejoin_channel_buffers(
4180    request: proto::RejoinChannelBuffers,
4181    response: Response<proto::RejoinChannelBuffers>,
4182    session: UserSession,
4183) -> Result<()> {
4184    let db = session.db().await;
4185    let buffers = db
4186        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4187        .await?;
4188
4189    for rejoined_buffer in &buffers {
4190        let collaborators_to_notify = rejoined_buffer
4191            .buffer
4192            .collaborators
4193            .iter()
4194            .filter_map(|c| Some(c.peer_id?.into()));
4195        channel_buffer_updated(
4196            session.connection_id,
4197            collaborators_to_notify,
4198            &proto::UpdateChannelBufferCollaborators {
4199                channel_id: rejoined_buffer.buffer.channel_id,
4200                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4201            },
4202            &session.peer,
4203        );
4204    }
4205
4206    response.send(proto::RejoinChannelBuffersResponse {
4207        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4208    })?;
4209
4210    Ok(())
4211}
4212
4213/// Stop editing the channel notes
4214async fn leave_channel_buffer(
4215    request: proto::LeaveChannelBuffer,
4216    response: Response<proto::LeaveChannelBuffer>,
4217    session: UserSession,
4218) -> Result<()> {
4219    let db = session.db().await;
4220    let channel_id = ChannelId::from_proto(request.channel_id);
4221
4222    let left_buffer = db
4223        .leave_channel_buffer(channel_id, session.connection_id)
4224        .await?;
4225
4226    response.send(Ack {})?;
4227
4228    channel_buffer_updated(
4229        session.connection_id,
4230        left_buffer.connections,
4231        &proto::UpdateChannelBufferCollaborators {
4232            channel_id: channel_id.to_proto(),
4233            collaborators: left_buffer.collaborators,
4234        },
4235        &session.peer,
4236    );
4237
4238    Ok(())
4239}
4240
4241fn channel_buffer_updated<T: EnvelopedMessage>(
4242    sender_id: ConnectionId,
4243    collaborators: impl IntoIterator<Item = ConnectionId>,
4244    message: &T,
4245    peer: &Peer,
4246) {
4247    broadcast(Some(sender_id), collaborators, |peer_id| {
4248        peer.send(peer_id, message.clone())
4249    });
4250}
4251
4252fn send_notifications(
4253    connection_pool: &ConnectionPool,
4254    peer: &Peer,
4255    notifications: db::NotificationBatch,
4256) {
4257    for (user_id, notification) in notifications {
4258        for connection_id in connection_pool.user_connection_ids(user_id) {
4259            if let Err(error) = peer.send(
4260                connection_id,
4261                proto::AddNotification {
4262                    notification: Some(notification.clone()),
4263                },
4264            ) {
4265                tracing::error!(
4266                    "failed to send notification to {:?} {}",
4267                    connection_id,
4268                    error
4269                );
4270            }
4271        }
4272    }
4273}
4274
4275/// Send a message to the channel
4276async fn send_channel_message(
4277    request: proto::SendChannelMessage,
4278    response: Response<proto::SendChannelMessage>,
4279    session: UserSession,
4280) -> Result<()> {
4281    // Validate the message body.
4282    let body = request.body.trim().to_string();
4283    if body.len() > MAX_MESSAGE_LEN {
4284        return Err(anyhow!("message is too long"))?;
4285    }
4286    if body.is_empty() {
4287        return Err(anyhow!("message can't be blank"))?;
4288    }
4289
4290    // TODO: adjust mentions if body is trimmed
4291
4292    let timestamp = OffsetDateTime::now_utc();
4293    let nonce = request
4294        .nonce
4295        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4296
4297    let channel_id = ChannelId::from_proto(request.channel_id);
4298    let CreatedChannelMessage {
4299        message_id,
4300        participant_connection_ids,
4301        notifications,
4302    } = session
4303        .db()
4304        .await
4305        .create_channel_message(
4306            channel_id,
4307            session.user_id(),
4308            &body,
4309            &request.mentions,
4310            timestamp,
4311            nonce.clone().into(),
4312            match request.reply_to_message_id {
4313                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4314                None => None,
4315            },
4316        )
4317        .await?;
4318
4319    let message = proto::ChannelMessage {
4320        sender_id: session.user_id().to_proto(),
4321        id: message_id.to_proto(),
4322        body,
4323        mentions: request.mentions,
4324        timestamp: timestamp.unix_timestamp() as u64,
4325        nonce: Some(nonce),
4326        reply_to_message_id: request.reply_to_message_id,
4327        edited_at: None,
4328    };
4329    broadcast(
4330        Some(session.connection_id),
4331        participant_connection_ids.clone(),
4332        |connection| {
4333            session.peer.send(
4334                connection,
4335                proto::ChannelMessageSent {
4336                    channel_id: channel_id.to_proto(),
4337                    message: Some(message.clone()),
4338                },
4339            )
4340        },
4341    );
4342    response.send(proto::SendChannelMessageResponse {
4343        message: Some(message),
4344    })?;
4345
4346    let pool = &*session.connection_pool().await;
4347    let non_participants =
4348        pool.channel_connection_ids(channel_id)
4349            .filter_map(|(connection_id, _)| {
4350                if participant_connection_ids.contains(&connection_id) {
4351                    None
4352                } else {
4353                    Some(connection_id)
4354                }
4355            });
4356    broadcast(None, non_participants, |peer_id| {
4357        session.peer.send(
4358            peer_id,
4359            proto::UpdateChannels {
4360                latest_channel_message_ids: vec![proto::ChannelMessageId {
4361                    channel_id: channel_id.to_proto(),
4362                    message_id: message_id.to_proto(),
4363                }],
4364                ..Default::default()
4365            },
4366        )
4367    });
4368    send_notifications(pool, &session.peer, notifications);
4369
4370    Ok(())
4371}
4372
4373/// Delete a channel message
4374async fn remove_channel_message(
4375    request: proto::RemoveChannelMessage,
4376    response: Response<proto::RemoveChannelMessage>,
4377    session: UserSession,
4378) -> Result<()> {
4379    let channel_id = ChannelId::from_proto(request.channel_id);
4380    let message_id = MessageId::from_proto(request.message_id);
4381    let (connection_ids, existing_notification_ids) = session
4382        .db()
4383        .await
4384        .remove_channel_message(channel_id, message_id, session.user_id())
4385        .await?;
4386
4387    broadcast(
4388        Some(session.connection_id),
4389        connection_ids,
4390        move |connection| {
4391            session.peer.send(connection, request.clone())?;
4392
4393            for notification_id in &existing_notification_ids {
4394                session.peer.send(
4395                    connection,
4396                    proto::DeleteNotification {
4397                        notification_id: (*notification_id).to_proto(),
4398                    },
4399                )?;
4400            }
4401
4402            Ok(())
4403        },
4404    );
4405    response.send(proto::Ack {})?;
4406    Ok(())
4407}
4408
4409async fn update_channel_message(
4410    request: proto::UpdateChannelMessage,
4411    response: Response<proto::UpdateChannelMessage>,
4412    session: UserSession,
4413) -> Result<()> {
4414    let channel_id = ChannelId::from_proto(request.channel_id);
4415    let message_id = MessageId::from_proto(request.message_id);
4416    let updated_at = OffsetDateTime::now_utc();
4417    let UpdatedChannelMessage {
4418        message_id,
4419        participant_connection_ids,
4420        notifications,
4421        reply_to_message_id,
4422        timestamp,
4423        deleted_mention_notification_ids,
4424        updated_mention_notifications,
4425    } = session
4426        .db()
4427        .await
4428        .update_channel_message(
4429            channel_id,
4430            message_id,
4431            session.user_id(),
4432            request.body.as_str(),
4433            &request.mentions,
4434            updated_at,
4435        )
4436        .await?;
4437
4438    let nonce = request
4439        .nonce
4440        .clone()
4441        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4442
4443    let message = proto::ChannelMessage {
4444        sender_id: session.user_id().to_proto(),
4445        id: message_id.to_proto(),
4446        body: request.body.clone(),
4447        mentions: request.mentions.clone(),
4448        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4449        nonce: Some(nonce),
4450        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4451        edited_at: Some(updated_at.unix_timestamp() as u64),
4452    };
4453
4454    response.send(proto::Ack {})?;
4455
4456    let pool = &*session.connection_pool().await;
4457    broadcast(
4458        Some(session.connection_id),
4459        participant_connection_ids,
4460        |connection| {
4461            session.peer.send(
4462                connection,
4463                proto::ChannelMessageUpdate {
4464                    channel_id: channel_id.to_proto(),
4465                    message: Some(message.clone()),
4466                },
4467            )?;
4468
4469            for notification_id in &deleted_mention_notification_ids {
4470                session.peer.send(
4471                    connection,
4472                    proto::DeleteNotification {
4473                        notification_id: (*notification_id).to_proto(),
4474                    },
4475                )?;
4476            }
4477
4478            for notification in &updated_mention_notifications {
4479                session.peer.send(
4480                    connection,
4481                    proto::UpdateNotification {
4482                        notification: Some(notification.clone()),
4483                    },
4484                )?;
4485            }
4486
4487            Ok(())
4488        },
4489    );
4490
4491    send_notifications(pool, &session.peer, notifications);
4492
4493    Ok(())
4494}
4495
4496/// Mark a channel message as read
4497async fn acknowledge_channel_message(
4498    request: proto::AckChannelMessage,
4499    session: UserSession,
4500) -> Result<()> {
4501    let channel_id = ChannelId::from_proto(request.channel_id);
4502    let message_id = MessageId::from_proto(request.message_id);
4503    let notifications = session
4504        .db()
4505        .await
4506        .observe_channel_message(channel_id, session.user_id(), message_id)
4507        .await?;
4508    send_notifications(
4509        &*session.connection_pool().await,
4510        &session.peer,
4511        notifications,
4512    );
4513    Ok(())
4514}
4515
4516/// Mark a buffer version as synced
4517async fn acknowledge_buffer_version(
4518    request: proto::AckBufferOperation,
4519    session: UserSession,
4520) -> Result<()> {
4521    let buffer_id = BufferId::from_proto(request.buffer_id);
4522    session
4523        .db()
4524        .await
4525        .observe_buffer_version(
4526            buffer_id,
4527            session.user_id(),
4528            request.epoch as i32,
4529            &request.version,
4530        )
4531        .await?;
4532    Ok(())
4533}
4534
4535struct CompleteWithLanguageModelRateLimit;
4536
4537impl RateLimit for CompleteWithLanguageModelRateLimit {
4538    fn capacity() -> usize {
4539        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4540            .ok()
4541            .and_then(|v| v.parse().ok())
4542            .unwrap_or(120) // Picked arbitrarily
4543    }
4544
4545    fn refill_duration() -> chrono::Duration {
4546        chrono::Duration::hours(1)
4547    }
4548
4549    fn db_name() -> &'static str {
4550        "complete-with-language-model"
4551    }
4552}
4553
4554async fn complete_with_language_model(
4555    request: proto::CompleteWithLanguageModel,
4556    response: Response<proto::CompleteWithLanguageModel>,
4557    session: Session,
4558    config: &Config,
4559) -> Result<()> {
4560    let Some(session) = session.for_user() else {
4561        return Err(anyhow!("user not found"))?;
4562    };
4563    authorize_access_to_language_models(&session).await?;
4564
4565    session
4566        .rate_limiter
4567        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4568        .await?;
4569
4570    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4571        Some(proto::LanguageModelProvider::Anthropic) => {
4572            let api_key = config
4573                .anthropic_api_key
4574                .as_ref()
4575                .context("no Anthropic AI API key configured on the server")?;
4576            anthropic::complete(
4577                session.http_client.as_ref(),
4578                anthropic::ANTHROPIC_API_URL,
4579                api_key,
4580                serde_json::from_str(&request.request)?,
4581            )
4582            .await?
4583        }
4584        _ => return Err(anyhow!("unsupported provider"))?,
4585    };
4586
4587    response.send(proto::CompleteWithLanguageModelResponse {
4588        completion: serde_json::to_string(&result)?,
4589    })?;
4590
4591    Ok(())
4592}
4593
4594async fn stream_complete_with_language_model(
4595    request: proto::StreamCompleteWithLanguageModel,
4596    response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4597    session: Session,
4598    config: &Config,
4599) -> Result<()> {
4600    let Some(session) = session.for_user() else {
4601        return Err(anyhow!("user not found"))?;
4602    };
4603    authorize_access_to_language_models(&session).await?;
4604
4605    session
4606        .rate_limiter
4607        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4608        .await?;
4609
4610    match proto::LanguageModelProvider::from_i32(request.provider) {
4611        Some(proto::LanguageModelProvider::Anthropic) => {
4612            let api_key = config
4613                .anthropic_api_key
4614                .as_ref()
4615                .context("no Anthropic AI API key configured on the server")?;
4616            let mut chunks = anthropic::stream_completion(
4617                session.http_client.as_ref(),
4618                anthropic::ANTHROPIC_API_URL,
4619                api_key,
4620                serde_json::from_str(&request.request)?,
4621                None,
4622            )
4623            .await?;
4624            while let Some(event) = chunks.next().await {
4625                let chunk = event?;
4626                response.send(proto::StreamCompleteWithLanguageModelResponse {
4627                    event: serde_json::to_string(&chunk)?,
4628                })?;
4629            }
4630        }
4631        Some(proto::LanguageModelProvider::OpenAi) => {
4632            let api_key = config
4633                .openai_api_key
4634                .as_ref()
4635                .context("no OpenAI API key configured on the server")?;
4636            let mut events = open_ai::stream_completion(
4637                session.http_client.as_ref(),
4638                open_ai::OPEN_AI_API_URL,
4639                api_key,
4640                serde_json::from_str(&request.request)?,
4641                None,
4642            )
4643            .await?;
4644            while let Some(event) = events.next().await {
4645                let event = event?;
4646                response.send(proto::StreamCompleteWithLanguageModelResponse {
4647                    event: serde_json::to_string(&event)?,
4648                })?;
4649            }
4650        }
4651        Some(proto::LanguageModelProvider::Google) => {
4652            let api_key = config
4653                .google_ai_api_key
4654                .as_ref()
4655                .context("no Google AI API key configured on the server")?;
4656            let mut events = google_ai::stream_generate_content(
4657                session.http_client.as_ref(),
4658                google_ai::API_URL,
4659                api_key,
4660                serde_json::from_str(&request.request)?,
4661            )
4662            .await?;
4663            while let Some(event) = events.next().await {
4664                let event = event?;
4665                response.send(proto::StreamCompleteWithLanguageModelResponse {
4666                    event: serde_json::to_string(&event)?,
4667                })?;
4668            }
4669        }
4670        None => return Err(anyhow!("unknown provider"))?,
4671    }
4672
4673    Ok(())
4674}
4675
4676async fn count_language_model_tokens(
4677    request: proto::CountLanguageModelTokens,
4678    response: Response<proto::CountLanguageModelTokens>,
4679    session: Session,
4680    config: &Config,
4681) -> Result<()> {
4682    let Some(session) = session.for_user() else {
4683        return Err(anyhow!("user not found"))?;
4684    };
4685    authorize_access_to_language_models(&session).await?;
4686
4687    session
4688        .rate_limiter
4689        .check::<CountLanguageModelTokensRateLimit>(session.user_id())
4690        .await?;
4691
4692    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4693        Some(proto::LanguageModelProvider::Google) => {
4694            let api_key = config
4695                .google_ai_api_key
4696                .as_ref()
4697                .context("no Google AI API key configured on the server")?;
4698            google_ai::count_tokens(
4699                session.http_client.as_ref(),
4700                google_ai::API_URL,
4701                api_key,
4702                serde_json::from_str(&request.request)?,
4703            )
4704            .await?
4705        }
4706        _ => return Err(anyhow!("unsupported provider"))?,
4707    };
4708
4709    response.send(proto::CountLanguageModelTokensResponse {
4710        token_count: result.total_tokens as u32,
4711    })?;
4712
4713    Ok(())
4714}
4715
4716struct CountLanguageModelTokensRateLimit;
4717
4718impl RateLimit for CountLanguageModelTokensRateLimit {
4719    fn capacity() -> usize {
4720        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4721            .ok()
4722            .and_then(|v| v.parse().ok())
4723            .unwrap_or(600) // Picked arbitrarily
4724    }
4725
4726    fn refill_duration() -> chrono::Duration {
4727        chrono::Duration::hours(1)
4728    }
4729
4730    fn db_name() -> &'static str {
4731        "count-language-model-tokens"
4732    }
4733}
4734
4735struct ComputeEmbeddingsRateLimit;
4736
4737impl RateLimit for ComputeEmbeddingsRateLimit {
4738    fn capacity() -> usize {
4739        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4740            .ok()
4741            .and_then(|v| v.parse().ok())
4742            .unwrap_or(5000) // Picked arbitrarily
4743    }
4744
4745    fn refill_duration() -> chrono::Duration {
4746        chrono::Duration::hours(1)
4747    }
4748
4749    fn db_name() -> &'static str {
4750        "compute-embeddings"
4751    }
4752}
4753
4754async fn compute_embeddings(
4755    request: proto::ComputeEmbeddings,
4756    response: Response<proto::ComputeEmbeddings>,
4757    session: UserSession,
4758    api_key: Option<Arc<str>>,
4759) -> Result<()> {
4760    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4761    authorize_access_to_language_models(&session).await?;
4762
4763    session
4764        .rate_limiter
4765        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4766        .await?;
4767
4768    let embeddings = match request.model.as_str() {
4769        "openai/text-embedding-3-small" => {
4770            open_ai::embed(
4771                session.http_client.as_ref(),
4772                OPEN_AI_API_URL,
4773                &api_key,
4774                OpenAiEmbeddingModel::TextEmbedding3Small,
4775                request.texts.iter().map(|text| text.as_str()),
4776            )
4777            .await?
4778        }
4779        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4780    };
4781
4782    let embeddings = request
4783        .texts
4784        .iter()
4785        .map(|text| {
4786            let mut hasher = sha2::Sha256::new();
4787            hasher.update(text.as_bytes());
4788            let result = hasher.finalize();
4789            result.to_vec()
4790        })
4791        .zip(
4792            embeddings
4793                .data
4794                .into_iter()
4795                .map(|embedding| embedding.embedding),
4796        )
4797        .collect::<HashMap<_, _>>();
4798
4799    let db = session.db().await;
4800    db.save_embeddings(&request.model, &embeddings)
4801        .await
4802        .context("failed to save embeddings")
4803        .trace_err();
4804
4805    response.send(proto::ComputeEmbeddingsResponse {
4806        embeddings: embeddings
4807            .into_iter()
4808            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4809            .collect(),
4810    })?;
4811    Ok(())
4812}
4813
4814async fn get_cached_embeddings(
4815    request: proto::GetCachedEmbeddings,
4816    response: Response<proto::GetCachedEmbeddings>,
4817    session: UserSession,
4818) -> Result<()> {
4819    authorize_access_to_language_models(&session).await?;
4820
4821    let db = session.db().await;
4822    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4823
4824    response.send(proto::GetCachedEmbeddingsResponse {
4825        embeddings: embeddings
4826            .into_iter()
4827            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4828            .collect(),
4829    })?;
4830    Ok(())
4831}
4832
4833async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4834    let db = session.db().await;
4835    let flags = db.get_user_flags(session.user_id()).await?;
4836    if flags.iter().any(|flag| flag == "language-models") {
4837        Ok(())
4838    } else {
4839        Err(anyhow!("permission denied"))?
4840    }
4841}
4842
4843/// Get a Supermaven API key for the user
4844async fn get_supermaven_api_key(
4845    _request: proto::GetSupermavenApiKey,
4846    response: Response<proto::GetSupermavenApiKey>,
4847    session: UserSession,
4848) -> Result<()> {
4849    let user_id: String = session.user_id().to_string();
4850    if !session.is_staff() {
4851        return Err(anyhow!("supermaven not enabled for this account"))?;
4852    }
4853
4854    let email = session
4855        .email()
4856        .ok_or_else(|| anyhow!("user must have an email"))?;
4857
4858    let supermaven_admin_api = session
4859        .supermaven_client
4860        .as_ref()
4861        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4862
4863    let result = supermaven_admin_api
4864        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4865        .await?;
4866
4867    response.send(proto::GetSupermavenApiKeyResponse {
4868        api_key: result.api_key,
4869    })?;
4870
4871    Ok(())
4872}
4873
4874/// Start receiving chat updates for a channel
4875async fn join_channel_chat(
4876    request: proto::JoinChannelChat,
4877    response: Response<proto::JoinChannelChat>,
4878    session: UserSession,
4879) -> Result<()> {
4880    let channel_id = ChannelId::from_proto(request.channel_id);
4881
4882    let db = session.db().await;
4883    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4884        .await?;
4885    let messages = db
4886        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4887        .await?;
4888    response.send(proto::JoinChannelChatResponse {
4889        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4890        messages,
4891    })?;
4892    Ok(())
4893}
4894
4895/// Stop receiving chat updates for a channel
4896async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4897    let channel_id = ChannelId::from_proto(request.channel_id);
4898    session
4899        .db()
4900        .await
4901        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4902        .await?;
4903    Ok(())
4904}
4905
4906/// Retrieve the chat history for a channel
4907async fn get_channel_messages(
4908    request: proto::GetChannelMessages,
4909    response: Response<proto::GetChannelMessages>,
4910    session: UserSession,
4911) -> Result<()> {
4912    let channel_id = ChannelId::from_proto(request.channel_id);
4913    let messages = session
4914        .db()
4915        .await
4916        .get_channel_messages(
4917            channel_id,
4918            session.user_id(),
4919            MESSAGE_COUNT_PER_PAGE,
4920            Some(MessageId::from_proto(request.before_message_id)),
4921        )
4922        .await?;
4923    response.send(proto::GetChannelMessagesResponse {
4924        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4925        messages,
4926    })?;
4927    Ok(())
4928}
4929
4930/// Retrieve specific chat messages
4931async fn get_channel_messages_by_id(
4932    request: proto::GetChannelMessagesById,
4933    response: Response<proto::GetChannelMessagesById>,
4934    session: UserSession,
4935) -> Result<()> {
4936    let message_ids = request
4937        .message_ids
4938        .iter()
4939        .map(|id| MessageId::from_proto(*id))
4940        .collect::<Vec<_>>();
4941    let messages = session
4942        .db()
4943        .await
4944        .get_channel_messages_by_id(session.user_id(), &message_ids)
4945        .await?;
4946    response.send(proto::GetChannelMessagesResponse {
4947        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4948        messages,
4949    })?;
4950    Ok(())
4951}
4952
4953/// Retrieve the current users notifications
4954async fn get_notifications(
4955    request: proto::GetNotifications,
4956    response: Response<proto::GetNotifications>,
4957    session: UserSession,
4958) -> Result<()> {
4959    let notifications = session
4960        .db()
4961        .await
4962        .get_notifications(
4963            session.user_id(),
4964            NOTIFICATION_COUNT_PER_PAGE,
4965            request
4966                .before_id
4967                .map(|id| db::NotificationId::from_proto(id)),
4968        )
4969        .await?;
4970    response.send(proto::GetNotificationsResponse {
4971        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4972        notifications,
4973    })?;
4974    Ok(())
4975}
4976
4977/// Mark notifications as read
4978async fn mark_notification_as_read(
4979    request: proto::MarkNotificationRead,
4980    response: Response<proto::MarkNotificationRead>,
4981    session: UserSession,
4982) -> Result<()> {
4983    let database = &session.db().await;
4984    let notifications = database
4985        .mark_notification_as_read_by_id(
4986            session.user_id(),
4987            NotificationId::from_proto(request.notification_id),
4988        )
4989        .await?;
4990    send_notifications(
4991        &*session.connection_pool().await,
4992        &session.peer,
4993        notifications,
4994    );
4995    response.send(proto::Ack {})?;
4996    Ok(())
4997}
4998
4999/// Get the current users information
5000async fn get_private_user_info(
5001    _request: proto::GetPrivateUserInfo,
5002    response: Response<proto::GetPrivateUserInfo>,
5003    session: UserSession,
5004) -> Result<()> {
5005    let db = session.db().await;
5006
5007    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5008    let user = db
5009        .get_user_by_id(session.user_id())
5010        .await?
5011        .ok_or_else(|| anyhow!("user not found"))?;
5012    let flags = db.get_user_flags(session.user_id()).await?;
5013
5014    response.send(proto::GetPrivateUserInfoResponse {
5015        metrics_id,
5016        staff: user.admin,
5017        flags,
5018    })?;
5019    Ok(())
5020}
5021
5022fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5023    let message = match message {
5024        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5025        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5026        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5027        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5028        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5029            code: frame.code.into(),
5030            reason: frame.reason,
5031        })),
5032        // We should never receive a frame while reading the message, according
5033        // to the `tungstenite` maintainers:
5034        //
5035        // > It cannot occur when you read messages from the WebSocket, but it
5036        // > can be used when you want to send the raw frames (e.g. you want to
5037        // > send the frames to the WebSocket without composing the full message first).
5038        // >
5039        // > — https://github.com/snapview/tungstenite-rs/issues/268
5040        TungsteniteMessage::Frame(_) => {
5041            bail!("received an unexpected frame while reading the message")
5042        }
5043    };
5044
5045    Ok(message)
5046}
5047
5048fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5049    match message {
5050        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5051        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5052        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5053        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5054        AxumMessage::Close(frame) => {
5055            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5056                code: frame.code.into(),
5057                reason: frame.reason,
5058            }))
5059        }
5060    }
5061}
5062
5063fn notify_membership_updated(
5064    connection_pool: &mut ConnectionPool,
5065    result: MembershipUpdated,
5066    user_id: UserId,
5067    peer: &Peer,
5068) {
5069    for membership in &result.new_channels.channel_memberships {
5070        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5071    }
5072    for channel_id in &result.removed_channels {
5073        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5074    }
5075
5076    let user_channels_update = proto::UpdateUserChannels {
5077        channel_memberships: result
5078            .new_channels
5079            .channel_memberships
5080            .iter()
5081            .map(|cm| proto::ChannelMembership {
5082                channel_id: cm.channel_id.to_proto(),
5083                role: cm.role.into(),
5084            })
5085            .collect(),
5086        ..Default::default()
5087    };
5088
5089    let mut update = build_channels_update(result.new_channels);
5090    update.delete_channels = result
5091        .removed_channels
5092        .into_iter()
5093        .map(|id| id.to_proto())
5094        .collect();
5095    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5096
5097    for connection_id in connection_pool.user_connection_ids(user_id) {
5098        peer.send(connection_id, user_channels_update.clone())
5099            .trace_err();
5100        peer.send(connection_id, update.clone()).trace_err();
5101    }
5102}
5103
5104fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5105    proto::UpdateUserChannels {
5106        channel_memberships: channels
5107            .channel_memberships
5108            .iter()
5109            .map(|m| proto::ChannelMembership {
5110                channel_id: m.channel_id.to_proto(),
5111                role: m.role.into(),
5112            })
5113            .collect(),
5114        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5115        observed_channel_message_id: channels.observed_channel_messages.clone(),
5116    }
5117}
5118
5119fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5120    let mut update = proto::UpdateChannels::default();
5121
5122    for channel in channels.channels {
5123        update.channels.push(channel.to_proto());
5124    }
5125
5126    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5127    update.latest_channel_message_ids = channels.latest_channel_messages;
5128
5129    for (channel_id, participants) in channels.channel_participants {
5130        update
5131            .channel_participants
5132            .push(proto::ChannelParticipants {
5133                channel_id: channel_id.to_proto(),
5134                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5135            });
5136    }
5137
5138    for channel in channels.invited_channels {
5139        update.channel_invitations.push(channel.to_proto());
5140    }
5141
5142    update.hosted_projects = channels.hosted_projects;
5143    update
5144}
5145
5146fn build_initial_contacts_update(
5147    contacts: Vec<db::Contact>,
5148    pool: &ConnectionPool,
5149) -> proto::UpdateContacts {
5150    let mut update = proto::UpdateContacts::default();
5151
5152    for contact in contacts {
5153        match contact {
5154            db::Contact::Accepted { user_id, busy } => {
5155                update.contacts.push(contact_for_user(user_id, busy, &pool));
5156            }
5157            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5158            db::Contact::Incoming { user_id } => {
5159                update
5160                    .incoming_requests
5161                    .push(proto::IncomingContactRequest {
5162                        requester_id: user_id.to_proto(),
5163                    })
5164            }
5165        }
5166    }
5167
5168    update
5169}
5170
5171fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5172    proto::Contact {
5173        user_id: user_id.to_proto(),
5174        online: pool.is_user_online(user_id),
5175        busy,
5176    }
5177}
5178
5179fn room_updated(room: &proto::Room, peer: &Peer) {
5180    broadcast(
5181        None,
5182        room.participants
5183            .iter()
5184            .filter_map(|participant| Some(participant.peer_id?.into())),
5185        |peer_id| {
5186            peer.send(
5187                peer_id,
5188                proto::RoomUpdated {
5189                    room: Some(room.clone()),
5190                },
5191            )
5192        },
5193    );
5194}
5195
5196fn channel_updated(
5197    channel: &db::channel::Model,
5198    room: &proto::Room,
5199    peer: &Peer,
5200    pool: &ConnectionPool,
5201) {
5202    let participants = room
5203        .participants
5204        .iter()
5205        .map(|p| p.user_id)
5206        .collect::<Vec<_>>();
5207
5208    broadcast(
5209        None,
5210        pool.channel_connection_ids(channel.root_id())
5211            .filter_map(|(channel_id, role)| {
5212                role.can_see_channel(channel.visibility).then(|| channel_id)
5213            }),
5214        |peer_id| {
5215            peer.send(
5216                peer_id,
5217                proto::UpdateChannels {
5218                    channel_participants: vec![proto::ChannelParticipants {
5219                        channel_id: channel.id.to_proto(),
5220                        participant_user_ids: participants.clone(),
5221                    }],
5222                    ..Default::default()
5223                },
5224            )
5225        },
5226    );
5227}
5228
5229async fn send_dev_server_projects_update(
5230    user_id: UserId,
5231    mut status: proto::DevServerProjectsUpdate,
5232    session: &Session,
5233) {
5234    let pool = session.connection_pool().await;
5235    for dev_server in &mut status.dev_servers {
5236        dev_server.status =
5237            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5238    }
5239    let connections = pool.user_connection_ids(user_id);
5240    for connection_id in connections {
5241        session.peer.send(connection_id, status.clone()).trace_err();
5242    }
5243}
5244
5245async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5246    let db = session.db().await;
5247
5248    let contacts = db.get_contacts(user_id).await?;
5249    let busy = db.is_user_busy(user_id).await?;
5250
5251    let pool = session.connection_pool().await;
5252    let updated_contact = contact_for_user(user_id, busy, &pool);
5253    for contact in contacts {
5254        if let db::Contact::Accepted {
5255            user_id: contact_user_id,
5256            ..
5257        } = contact
5258        {
5259            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5260                session
5261                    .peer
5262                    .send(
5263                        contact_conn_id,
5264                        proto::UpdateContacts {
5265                            contacts: vec![updated_contact.clone()],
5266                            remove_contacts: Default::default(),
5267                            incoming_requests: Default::default(),
5268                            remove_incoming_requests: Default::default(),
5269                            outgoing_requests: Default::default(),
5270                            remove_outgoing_requests: Default::default(),
5271                        },
5272                    )
5273                    .trace_err();
5274            }
5275        }
5276    }
5277    Ok(())
5278}
5279
5280async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5281    log::info!("lost dev server connection, unsharing projects");
5282    let project_ids = session
5283        .db()
5284        .await
5285        .get_stale_dev_server_projects(session.connection_id)
5286        .await?;
5287
5288    for project_id in project_ids {
5289        // not unshare re-checks the connection ids match, so we get away with no transaction
5290        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5291    }
5292
5293    let user_id = session.dev_server().user_id;
5294    let update = session
5295        .db()
5296        .await
5297        .dev_server_projects_update(user_id)
5298        .await?;
5299
5300    send_dev_server_projects_update(user_id, update, session).await;
5301
5302    Ok(())
5303}
5304
5305async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5306    let mut contacts_to_update = HashSet::default();
5307
5308    let room_id;
5309    let canceled_calls_to_user_ids;
5310    let live_kit_room;
5311    let delete_live_kit_room;
5312    let room;
5313    let channel;
5314
5315    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5316        contacts_to_update.insert(session.user_id());
5317
5318        for project in left_room.left_projects.values() {
5319            project_left(project, session);
5320        }
5321
5322        room_id = RoomId::from_proto(left_room.room.id);
5323        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5324        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5325        delete_live_kit_room = left_room.deleted;
5326        room = mem::take(&mut left_room.room);
5327        channel = mem::take(&mut left_room.channel);
5328
5329        room_updated(&room, &session.peer);
5330    } else {
5331        return Ok(());
5332    }
5333
5334    if let Some(channel) = channel {
5335        channel_updated(
5336            &channel,
5337            &room,
5338            &session.peer,
5339            &*session.connection_pool().await,
5340        );
5341    }
5342
5343    {
5344        let pool = session.connection_pool().await;
5345        for canceled_user_id in canceled_calls_to_user_ids {
5346            for connection_id in pool.user_connection_ids(canceled_user_id) {
5347                session
5348                    .peer
5349                    .send(
5350                        connection_id,
5351                        proto::CallCanceled {
5352                            room_id: room_id.to_proto(),
5353                        },
5354                    )
5355                    .trace_err();
5356            }
5357            contacts_to_update.insert(canceled_user_id);
5358        }
5359    }
5360
5361    for contact_user_id in contacts_to_update {
5362        update_user_contacts(contact_user_id, &session).await?;
5363    }
5364
5365    if let Some(live_kit) = session.live_kit_client.as_ref() {
5366        live_kit
5367            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5368            .await
5369            .trace_err();
5370
5371        if delete_live_kit_room {
5372            live_kit.delete_room(live_kit_room).await.trace_err();
5373        }
5374    }
5375
5376    Ok(())
5377}
5378
5379async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5380    let left_channel_buffers = session
5381        .db()
5382        .await
5383        .leave_channel_buffers(session.connection_id)
5384        .await?;
5385
5386    for left_buffer in left_channel_buffers {
5387        channel_buffer_updated(
5388            session.connection_id,
5389            left_buffer.connections,
5390            &proto::UpdateChannelBufferCollaborators {
5391                channel_id: left_buffer.channel_id.to_proto(),
5392                collaborators: left_buffer.collaborators,
5393            },
5394            &session.peer,
5395        );
5396    }
5397
5398    Ok(())
5399}
5400
5401fn project_left(project: &db::LeftProject, session: &UserSession) {
5402    for connection_id in &project.connection_ids {
5403        if project.should_unshare {
5404            session
5405                .peer
5406                .send(
5407                    *connection_id,
5408                    proto::UnshareProject {
5409                        project_id: project.id.to_proto(),
5410                    },
5411                )
5412                .trace_err();
5413        } else {
5414            session
5415                .peer
5416                .send(
5417                    *connection_id,
5418                    proto::RemoveProjectCollaborator {
5419                        project_id: project.id.to_proto(),
5420                        peer_id: Some(session.connection_id.into()),
5421                    },
5422                )
5423                .trace_err();
5424        }
5425    }
5426}
5427
5428pub trait ResultExt {
5429    type Ok;
5430
5431    fn trace_err(self) -> Option<Self::Ok>;
5432}
5433
5434impl<T, E> ResultExt for Result<T, E>
5435where
5436    E: std::fmt::Debug,
5437{
5438    type Ok = T;
5439
5440    #[track_caller]
5441    fn trace_err(self) -> Option<T> {
5442        match self {
5443            Ok(value) => Some(value),
5444            Err(error) => {
5445                tracing::error!("{:?}", error);
5446                None
5447            }
5448        }
5449    }
5450}