rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Config, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, bail, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http_client::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(update_dev_server_project))
 435            .add_request_handler(user_handler(delete_dev_server_project))
 436            .add_request_handler(user_handler(create_dev_server))
 437            .add_request_handler(user_handler(regenerate_dev_server_token))
 438            .add_request_handler(user_handler(rename_dev_server))
 439            .add_request_handler(user_handler(delete_dev_server))
 440            .add_request_handler(user_handler(list_remote_directory))
 441            .add_request_handler(dev_server_handler(share_dev_server_project))
 442            .add_request_handler(dev_server_handler(shutdown_dev_server))
 443            .add_request_handler(dev_server_handler(reconnect_dev_server))
 444            .add_message_handler(user_message_handler(leave_project))
 445            .add_request_handler(update_project)
 446            .add_request_handler(update_worktree)
 447            .add_message_handler(start_language_server)
 448            .add_message_handler(update_language_server)
 449            .add_message_handler(update_diagnostic_summary)
 450            .add_message_handler(update_worktree_settings)
 451            .add_request_handler(user_handler(
 452                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 453            ))
 454            .add_request_handler(user_handler(
 455                forward_project_request_for_owner::<proto::TaskTemplates>,
 456            ))
 457            .add_request_handler(user_handler(
 458                forward_read_only_project_request::<proto::GetHover>,
 459            ))
 460            .add_request_handler(user_handler(
 461                forward_read_only_project_request::<proto::GetDefinition>,
 462            ))
 463            .add_request_handler(user_handler(
 464                forward_read_only_project_request::<proto::GetTypeDefinition>,
 465            ))
 466            .add_request_handler(user_handler(
 467                forward_read_only_project_request::<proto::GetReferences>,
 468            ))
 469            .add_request_handler(user_handler(
 470                forward_read_only_project_request::<proto::SearchProject>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetProjectSymbols>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::OpenBufferById>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::InlayHints>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::OpenBufferByPath>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_mutating_project_request::<proto::GetCompletions>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::GetCodeActions>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::ApplyCodeAction>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::PrepareRename>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::PerformRename>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ReloadBuffers>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::FormatBuffers>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::CreateProjectEntry>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::RenameProjectEntry>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::CopyProjectEntry>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::OnTypeFormatting>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::BlameBuffer>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::MultiLspQuery>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::RestartLanguageServers>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::LinkedEditingRange>,
 555            ))
 556            .add_message_handler(create_buffer_for_peer)
 557            .add_request_handler(update_buffer)
 558            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 559            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 560            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 561            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 562            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 563            .add_request_handler(get_users)
 564            .add_request_handler(user_handler(fuzzy_search_users))
 565            .add_request_handler(user_handler(request_contact))
 566            .add_request_handler(user_handler(remove_contact))
 567            .add_request_handler(user_handler(respond_to_contact_request))
 568            .add_message_handler(subscribe_to_channels)
 569            .add_request_handler(user_handler(create_channel))
 570            .add_request_handler(user_handler(delete_channel))
 571            .add_request_handler(user_handler(invite_channel_member))
 572            .add_request_handler(user_handler(remove_channel_member))
 573            .add_request_handler(user_handler(set_channel_member_role))
 574            .add_request_handler(user_handler(set_channel_visibility))
 575            .add_request_handler(user_handler(rename_channel))
 576            .add_request_handler(user_handler(join_channel_buffer))
 577            .add_request_handler(user_handler(leave_channel_buffer))
 578            .add_message_handler(user_message_handler(update_channel_buffer))
 579            .add_request_handler(user_handler(rejoin_channel_buffers))
 580            .add_request_handler(user_handler(get_channel_members))
 581            .add_request_handler(user_handler(respond_to_channel_invite))
 582            .add_request_handler(user_handler(join_channel))
 583            .add_request_handler(user_handler(join_channel_chat))
 584            .add_message_handler(user_message_handler(leave_channel_chat))
 585            .add_request_handler(user_handler(send_channel_message))
 586            .add_request_handler(user_handler(remove_channel_message))
 587            .add_request_handler(user_handler(update_channel_message))
 588            .add_request_handler(user_handler(get_channel_messages))
 589            .add_request_handler(user_handler(get_channel_messages_by_id))
 590            .add_request_handler(user_handler(get_notifications))
 591            .add_request_handler(user_handler(mark_notification_as_read))
 592            .add_request_handler(user_handler(move_channel))
 593            .add_request_handler(user_handler(follow))
 594            .add_message_handler(user_message_handler(unfollow))
 595            .add_message_handler(user_message_handler(update_followers))
 596            .add_request_handler(user_handler(get_private_user_info))
 597            .add_message_handler(user_message_handler(acknowledge_channel_message))
 598            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 599            .add_request_handler(user_handler(get_supermaven_api_key))
 600            .add_request_handler(user_handler(
 601                forward_mutating_project_request::<proto::OpenContext>,
 602            ))
 603            .add_request_handler(user_handler(
 604                forward_mutating_project_request::<proto::CreateContext>,
 605            ))
 606            .add_request_handler(user_handler(
 607                forward_mutating_project_request::<proto::SynchronizeContexts>,
 608            ))
 609            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 610            .add_message_handler(update_context)
 611            .add_request_handler({
 612                let app_state = app_state.clone();
 613                move |request, response, session| {
 614                    let app_state = app_state.clone();
 615                    async move {
 616                        complete_with_language_model(request, response, session, &app_state.config)
 617                            .await
 618                    }
 619                }
 620            })
 621            .add_streaming_request_handler({
 622                let app_state = app_state.clone();
 623                move |request, response, session| {
 624                    let app_state = app_state.clone();
 625                    async move {
 626                        stream_complete_with_language_model(
 627                            request,
 628                            response,
 629                            session,
 630                            &app_state.config,
 631                        )
 632                        .await
 633                    }
 634                }
 635            })
 636            .add_request_handler({
 637                let app_state = app_state.clone();
 638                move |request, response, session| {
 639                    let app_state = app_state.clone();
 640                    async move {
 641                        count_language_model_tokens(request, response, session, &app_state.config)
 642                            .await
 643                    }
 644                }
 645            })
 646            .add_request_handler({
 647                user_handler(move |request, response, session| {
 648                    get_cached_embeddings(request, response, session)
 649                })
 650            })
 651            .add_request_handler({
 652                let app_state = app_state.clone();
 653                user_handler(move |request, response, session| {
 654                    compute_embeddings(
 655                        request,
 656                        response,
 657                        session,
 658                        app_state.config.openai_api_key.clone(),
 659                    )
 660                })
 661            });
 662
 663        Arc::new(server)
 664    }
 665
 666    pub async fn start(&self) -> Result<()> {
 667        let server_id = *self.id.lock();
 668        let app_state = self.app_state.clone();
 669        let peer = self.peer.clone();
 670        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 671        let pool = self.connection_pool.clone();
 672        let live_kit_client = self.app_state.live_kit_client.clone();
 673
 674        let span = info_span!("start server");
 675        self.app_state.executor.spawn_detached(
 676            async move {
 677                tracing::info!("waiting for cleanup timeout");
 678                timeout.await;
 679                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 680                if let Some((room_ids, channel_ids)) = app_state
 681                    .db
 682                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 683                    .await
 684                    .trace_err()
 685                {
 686                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 687                    tracing::info!(
 688                        stale_channel_buffer_count = channel_ids.len(),
 689                        "retrieved stale channel buffers"
 690                    );
 691
 692                    for channel_id in channel_ids {
 693                        if let Some(refreshed_channel_buffer) = app_state
 694                            .db
 695                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 696                            .await
 697                            .trace_err()
 698                        {
 699                            for connection_id in refreshed_channel_buffer.connection_ids {
 700                                peer.send(
 701                                    connection_id,
 702                                    proto::UpdateChannelBufferCollaborators {
 703                                        channel_id: channel_id.to_proto(),
 704                                        collaborators: refreshed_channel_buffer
 705                                            .collaborators
 706                                            .clone(),
 707                                    },
 708                                )
 709                                .trace_err();
 710                            }
 711                        }
 712                    }
 713
 714                    for room_id in room_ids {
 715                        let mut contacts_to_update = HashSet::default();
 716                        let mut canceled_calls_to_user_ids = Vec::new();
 717                        let mut live_kit_room = String::new();
 718                        let mut delete_live_kit_room = false;
 719
 720                        if let Some(mut refreshed_room) = app_state
 721                            .db
 722                            .clear_stale_room_participants(room_id, server_id)
 723                            .await
 724                            .trace_err()
 725                        {
 726                            tracing::info!(
 727                                room_id = room_id.0,
 728                                new_participant_count = refreshed_room.room.participants.len(),
 729                                "refreshed room"
 730                            );
 731                            room_updated(&refreshed_room.room, &peer);
 732                            if let Some(channel) = refreshed_room.channel.as_ref() {
 733                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 734                            }
 735                            contacts_to_update
 736                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 737                            contacts_to_update
 738                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 739                            canceled_calls_to_user_ids =
 740                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 741                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 742                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 743                        }
 744
 745                        {
 746                            let pool = pool.lock();
 747                            for canceled_user_id in canceled_calls_to_user_ids {
 748                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 749                                    peer.send(
 750                                        connection_id,
 751                                        proto::CallCanceled {
 752                                            room_id: room_id.to_proto(),
 753                                        },
 754                                    )
 755                                    .trace_err();
 756                                }
 757                            }
 758                        }
 759
 760                        for user_id in contacts_to_update {
 761                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 762                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 763                            if let Some((busy, contacts)) = busy.zip(contacts) {
 764                                let pool = pool.lock();
 765                                let updated_contact = contact_for_user(user_id, busy, &pool);
 766                                for contact in contacts {
 767                                    if let db::Contact::Accepted {
 768                                        user_id: contact_user_id,
 769                                        ..
 770                                    } = contact
 771                                    {
 772                                        for contact_conn_id in
 773                                            pool.user_connection_ids(contact_user_id)
 774                                        {
 775                                            peer.send(
 776                                                contact_conn_id,
 777                                                proto::UpdateContacts {
 778                                                    contacts: vec![updated_contact.clone()],
 779                                                    remove_contacts: Default::default(),
 780                                                    incoming_requests: Default::default(),
 781                                                    remove_incoming_requests: Default::default(),
 782                                                    outgoing_requests: Default::default(),
 783                                                    remove_outgoing_requests: Default::default(),
 784                                                },
 785                                            )
 786                                            .trace_err();
 787                                        }
 788                                    }
 789                                }
 790                            }
 791                        }
 792
 793                        if let Some(live_kit) = live_kit_client.as_ref() {
 794                            if delete_live_kit_room {
 795                                live_kit.delete_room(live_kit_room).await.trace_err();
 796                            }
 797                        }
 798                    }
 799                }
 800
 801                app_state
 802                    .db
 803                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 804                    .await
 805                    .trace_err();
 806            }
 807            .instrument(span),
 808        );
 809        Ok(())
 810    }
 811
 812    pub fn teardown(&self) {
 813        self.peer.teardown();
 814        self.connection_pool.lock().reset();
 815        let _ = self.teardown.send(true);
 816    }
 817
 818    #[cfg(test)]
 819    pub fn reset(&self, id: ServerId) {
 820        self.teardown();
 821        *self.id.lock() = id;
 822        self.peer.reset(id.0 as u32);
 823        let _ = self.teardown.send(false);
 824    }
 825
 826    #[cfg(test)]
 827    pub fn id(&self) -> ServerId {
 828        *self.id.lock()
 829    }
 830
 831    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 832    where
 833        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 834        Fut: 'static + Send + Future<Output = Result<()>>,
 835        M: EnvelopedMessage,
 836    {
 837        let prev_handler = self.handlers.insert(
 838            TypeId::of::<M>(),
 839            Box::new(move |envelope, session| {
 840                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 841                let received_at = envelope.received_at;
 842                tracing::info!("message received");
 843                let start_time = Instant::now();
 844                let future = (handler)(*envelope, session);
 845                async move {
 846                    let result = future.await;
 847                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 848                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 849                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 850                    let payload_type = M::NAME;
 851
 852                    match result {
 853                        Err(error) => {
 854                            tracing::error!(
 855                                ?error,
 856                                total_duration_ms,
 857                                processing_duration_ms,
 858                                queue_duration_ms,
 859                                payload_type,
 860                                "error handling message"
 861                            )
 862                        }
 863                        Ok(()) => tracing::info!(
 864                            total_duration_ms,
 865                            processing_duration_ms,
 866                            queue_duration_ms,
 867                            "finished handling message"
 868                        ),
 869                    }
 870                }
 871                .boxed()
 872            }),
 873        );
 874        if prev_handler.is_some() {
 875            panic!("registered a handler for the same message twice");
 876        }
 877        self
 878    }
 879
 880    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 881    where
 882        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 883        Fut: 'static + Send + Future<Output = Result<()>>,
 884        M: EnvelopedMessage,
 885    {
 886        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 887        self
 888    }
 889
 890    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 891    where
 892        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 893        Fut: Send + Future<Output = Result<()>>,
 894        M: RequestMessage,
 895    {
 896        let handler = Arc::new(handler);
 897        self.add_handler(move |envelope, session| {
 898            let receipt = envelope.receipt();
 899            let handler = handler.clone();
 900            async move {
 901                let peer = session.peer.clone();
 902                let responded = Arc::new(AtomicBool::default());
 903                let response = Response {
 904                    peer: peer.clone(),
 905                    responded: responded.clone(),
 906                    receipt,
 907                };
 908                match (handler)(envelope.payload, response, session).await {
 909                    Ok(()) => {
 910                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 911                            Ok(())
 912                        } else {
 913                            Err(anyhow!("handler did not send a response"))?
 914                        }
 915                    }
 916                    Err(error) => {
 917                        let proto_err = match &error {
 918                            Error::Internal(err) => err.to_proto(),
 919                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 920                        };
 921                        peer.respond_with_error(receipt, proto_err)?;
 922                        Err(error)
 923                    }
 924                }
 925            }
 926        })
 927    }
 928
 929    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 930    where
 931        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 932        Fut: Send + Future<Output = Result<()>>,
 933        M: RequestMessage,
 934    {
 935        let handler = Arc::new(handler);
 936        self.add_handler(move |envelope, session| {
 937            let receipt = envelope.receipt();
 938            let handler = handler.clone();
 939            async move {
 940                let peer = session.peer.clone();
 941                let response = StreamingResponse {
 942                    peer: peer.clone(),
 943                    receipt,
 944                };
 945                match (handler)(envelope.payload, response, session).await {
 946                    Ok(()) => {
 947                        peer.end_stream(receipt)?;
 948                        Ok(())
 949                    }
 950                    Err(error) => {
 951                        let proto_err = match &error {
 952                            Error::Internal(err) => err.to_proto(),
 953                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 954                        };
 955                        peer.respond_with_error(receipt, proto_err)?;
 956                        Err(error)
 957                    }
 958                }
 959            }
 960        })
 961    }
 962
 963    #[allow(clippy::too_many_arguments)]
 964    pub fn handle_connection(
 965        self: &Arc<Self>,
 966        connection: Connection,
 967        address: String,
 968        principal: Principal,
 969        zed_version: ZedVersion,
 970        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 971        executor: Executor,
 972    ) -> impl Future<Output = ()> {
 973        let this = self.clone();
 974        let span = info_span!("handle connection", %address,
 975            connection_id=field::Empty,
 976            user_id=field::Empty,
 977            login=field::Empty,
 978            impersonator=field::Empty,
 979            dev_server_id=field::Empty
 980        );
 981        principal.update_span(&span);
 982
 983        let mut teardown = self.teardown.subscribe();
 984        async move {
 985            if *teardown.borrow() {
 986                tracing::error!("server is tearing down");
 987                return
 988            }
 989            let (connection_id, handle_io, mut incoming_rx) = this
 990                .peer
 991                .add_connection(connection, {
 992                    let executor = executor.clone();
 993                    move |duration| executor.sleep(duration)
 994                });
 995            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 996            tracing::info!("connection opened");
 997
 998            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 999            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1000                Ok(http_client) => Arc::new(http_client),
1001                Err(error) => {
1002                    tracing::error!(?error, "failed to create HTTP client");
1003                    return;
1004                }
1005            };
1006
1007            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1008                Some(Arc::new(SupermavenAdminApi::new(
1009                    supermaven_admin_api_key.to_string(),
1010                    http_client.clone(),
1011                )))
1012            } else {
1013                None
1014            };
1015
1016            let session = Session {
1017                principal: principal.clone(),
1018                connection_id,
1019                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1020                peer: this.peer.clone(),
1021                connection_pool: this.connection_pool.clone(),
1022                live_kit_client: this.app_state.live_kit_client.clone(),
1023                http_client,
1024                rate_limiter: this.app_state.rate_limiter.clone(),
1025                _executor: executor.clone(),
1026                supermaven_client,
1027            };
1028
1029            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1030                tracing::error!(?error, "failed to send initial client update");
1031                return;
1032            }
1033
1034            let handle_io = handle_io.fuse();
1035            futures::pin_mut!(handle_io);
1036
1037            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1038            // This prevents deadlocks when e.g., client A performs a request to client B and
1039            // client B performs a request to client A. If both clients stop processing further
1040            // messages until their respective request completes, they won't have a chance to
1041            // respond to the other client's request and cause a deadlock.
1042            //
1043            // This arrangement ensures we will attempt to process earlier messages first, but fall
1044            // back to processing messages arrived later in the spirit of making progress.
1045            let mut foreground_message_handlers = FuturesUnordered::new();
1046            let concurrent_handlers = Arc::new(Semaphore::new(256));
1047            loop {
1048                let next_message = async {
1049                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1050                    let message = incoming_rx.next().await;
1051                    (permit, message)
1052                }.fuse();
1053                futures::pin_mut!(next_message);
1054                futures::select_biased! {
1055                    _ = teardown.changed().fuse() => return,
1056                    result = handle_io => {
1057                        if let Err(error) = result {
1058                            tracing::error!(?error, "error handling I/O");
1059                        }
1060                        break;
1061                    }
1062                    _ = foreground_message_handlers.next() => {}
1063                    next_message = next_message => {
1064                        let (permit, message) = next_message;
1065                        if let Some(message) = message {
1066                            let type_name = message.payload_type_name();
1067                            // note: we copy all the fields from the parent span so we can query them in the logs.
1068                            // (https://github.com/tokio-rs/tracing/issues/2670).
1069                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1070                                user_id=field::Empty,
1071                                login=field::Empty,
1072                                impersonator=field::Empty,
1073                                dev_server_id=field::Empty
1074                            );
1075                            principal.update_span(&span);
1076                            let span_enter = span.enter();
1077                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1078                                let is_background = message.is_background();
1079                                let handle_message = (handler)(message, session.clone());
1080                                drop(span_enter);
1081
1082                                let handle_message = async move {
1083                                    handle_message.await;
1084                                    drop(permit);
1085                                }.instrument(span);
1086                                if is_background {
1087                                    executor.spawn_detached(handle_message);
1088                                } else {
1089                                    foreground_message_handlers.push(handle_message);
1090                                }
1091                            } else {
1092                                tracing::error!("no message handler");
1093                            }
1094                        } else {
1095                            tracing::info!("connection closed");
1096                            break;
1097                        }
1098                    }
1099                }
1100            }
1101
1102            drop(foreground_message_handlers);
1103            tracing::info!("signing out");
1104            if let Err(error) = connection_lost(session, teardown, executor).await {
1105                tracing::error!(?error, "error signing out");
1106            }
1107
1108        }.instrument(span)
1109    }
1110
1111    async fn send_initial_client_update(
1112        &self,
1113        connection_id: ConnectionId,
1114        principal: &Principal,
1115        zed_version: ZedVersion,
1116        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1117        session: &Session,
1118    ) -> Result<()> {
1119        self.peer.send(
1120            connection_id,
1121            proto::Hello {
1122                peer_id: Some(connection_id.into()),
1123            },
1124        )?;
1125        tracing::info!("sent hello message");
1126        if let Some(send_connection_id) = send_connection_id.take() {
1127            let _ = send_connection_id.send(connection_id);
1128        }
1129
1130        match principal {
1131            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1132                if !user.connected_once {
1133                    self.peer.send(connection_id, proto::ShowContacts {})?;
1134                    self.app_state
1135                        .db
1136                        .set_user_connected_once(user.id, true)
1137                        .await?;
1138                }
1139
1140                let (contacts, dev_server_projects) = future::try_join(
1141                    self.app_state.db.get_contacts(user.id),
1142                    self.app_state.db.dev_server_projects_update(user.id),
1143                )
1144                .await?;
1145
1146                {
1147                    let mut pool = self.connection_pool.lock();
1148                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1149                    self.peer.send(
1150                        connection_id,
1151                        build_initial_contacts_update(contacts, &pool),
1152                    )?;
1153                }
1154
1155                if should_auto_subscribe_to_channels(zed_version) {
1156                    subscribe_user_to_channels(user.id, session).await?;
1157                }
1158
1159                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1160
1161                if let Some(incoming_call) =
1162                    self.app_state.db.incoming_call_for_user(user.id).await?
1163                {
1164                    self.peer.send(connection_id, incoming_call)?;
1165                }
1166
1167                update_user_contacts(user.id, &session).await?;
1168            }
1169            Principal::DevServer(dev_server) => {
1170                {
1171                    let mut pool = self.connection_pool.lock();
1172                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1173                    {
1174                        self.peer.send(
1175                            stale_connection_id,
1176                            proto::ShutdownDevServer {
1177                                reason: Some(
1178                                    "another dev server connected with the same token".to_string(),
1179                                ),
1180                            },
1181                        )?;
1182                        pool.remove_connection(stale_connection_id)?;
1183                    };
1184                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1185                }
1186
1187                let projects = self
1188                    .app_state
1189                    .db
1190                    .get_projects_for_dev_server(dev_server.id)
1191                    .await?;
1192                self.peer
1193                    .send(connection_id, proto::DevServerInstructions { projects })?;
1194
1195                let status = self
1196                    .app_state
1197                    .db
1198                    .dev_server_projects_update(dev_server.user_id)
1199                    .await?;
1200                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1201            }
1202        }
1203
1204        Ok(())
1205    }
1206
1207    pub async fn invite_code_redeemed(
1208        self: &Arc<Self>,
1209        inviter_id: UserId,
1210        invitee_id: UserId,
1211    ) -> Result<()> {
1212        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1213            if let Some(code) = &user.invite_code {
1214                let pool = self.connection_pool.lock();
1215                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1216                for connection_id in pool.user_connection_ids(inviter_id) {
1217                    self.peer.send(
1218                        connection_id,
1219                        proto::UpdateContacts {
1220                            contacts: vec![invitee_contact.clone()],
1221                            ..Default::default()
1222                        },
1223                    )?;
1224                    self.peer.send(
1225                        connection_id,
1226                        proto::UpdateInviteInfo {
1227                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1228                            count: user.invite_count as u32,
1229                        },
1230                    )?;
1231                }
1232            }
1233        }
1234        Ok(())
1235    }
1236
1237    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1238        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1239            if let Some(invite_code) = &user.invite_code {
1240                let pool = self.connection_pool.lock();
1241                for connection_id in pool.user_connection_ids(user_id) {
1242                    self.peer.send(
1243                        connection_id,
1244                        proto::UpdateInviteInfo {
1245                            url: format!(
1246                                "{}{}",
1247                                self.app_state.config.invite_link_prefix, invite_code
1248                            ),
1249                            count: user.invite_count as u32,
1250                        },
1251                    )?;
1252                }
1253            }
1254        }
1255        Ok(())
1256    }
1257
1258    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1259        ServerSnapshot {
1260            connection_pool: ConnectionPoolGuard {
1261                guard: self.connection_pool.lock(),
1262                _not_send: PhantomData,
1263            },
1264            peer: &self.peer,
1265        }
1266    }
1267}
1268
1269impl<'a> Deref for ConnectionPoolGuard<'a> {
1270    type Target = ConnectionPool;
1271
1272    fn deref(&self) -> &Self::Target {
1273        &self.guard
1274    }
1275}
1276
1277impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1278    fn deref_mut(&mut self) -> &mut Self::Target {
1279        &mut self.guard
1280    }
1281}
1282
1283impl<'a> Drop for ConnectionPoolGuard<'a> {
1284    fn drop(&mut self) {
1285        #[cfg(test)]
1286        self.check_invariants();
1287    }
1288}
1289
1290fn broadcast<F>(
1291    sender_id: Option<ConnectionId>,
1292    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1293    mut f: F,
1294) where
1295    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1296{
1297    for receiver_id in receiver_ids {
1298        if Some(receiver_id) != sender_id {
1299            if let Err(error) = f(receiver_id) {
1300                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1301            }
1302        }
1303    }
1304}
1305
1306pub struct ProtocolVersion(u32);
1307
1308impl Header for ProtocolVersion {
1309    fn name() -> &'static HeaderName {
1310        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1311        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1312    }
1313
1314    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1315    where
1316        Self: Sized,
1317        I: Iterator<Item = &'i axum::http::HeaderValue>,
1318    {
1319        let version = values
1320            .next()
1321            .ok_or_else(axum::headers::Error::invalid)?
1322            .to_str()
1323            .map_err(|_| axum::headers::Error::invalid())?
1324            .parse()
1325            .map_err(|_| axum::headers::Error::invalid())?;
1326        Ok(Self(version))
1327    }
1328
1329    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1330        values.extend([self.0.to_string().parse().unwrap()]);
1331    }
1332}
1333
1334pub struct AppVersionHeader(SemanticVersion);
1335impl Header for AppVersionHeader {
1336    fn name() -> &'static HeaderName {
1337        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1338        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1339    }
1340
1341    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1342    where
1343        Self: Sized,
1344        I: Iterator<Item = &'i axum::http::HeaderValue>,
1345    {
1346        let version = values
1347            .next()
1348            .ok_or_else(axum::headers::Error::invalid)?
1349            .to_str()
1350            .map_err(|_| axum::headers::Error::invalid())?
1351            .parse()
1352            .map_err(|_| axum::headers::Error::invalid())?;
1353        Ok(Self(version))
1354    }
1355
1356    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1357        values.extend([self.0.to_string().parse().unwrap()]);
1358    }
1359}
1360
1361pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1362    Router::new()
1363        .route("/rpc", get(handle_websocket_request))
1364        .layer(
1365            ServiceBuilder::new()
1366                .layer(Extension(server.app_state.clone()))
1367                .layer(middleware::from_fn(auth::validate_header)),
1368        )
1369        .route("/metrics", get(handle_metrics))
1370        .layer(Extension(server))
1371}
1372
1373pub async fn handle_websocket_request(
1374    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1375    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1376    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1377    Extension(server): Extension<Arc<Server>>,
1378    Extension(principal): Extension<Principal>,
1379    ws: WebSocketUpgrade,
1380) -> axum::response::Response {
1381    if protocol_version != rpc::PROTOCOL_VERSION {
1382        return (
1383            StatusCode::UPGRADE_REQUIRED,
1384            "client must be upgraded".to_string(),
1385        )
1386            .into_response();
1387    }
1388
1389    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1390        return (
1391            StatusCode::UPGRADE_REQUIRED,
1392            "no version header found".to_string(),
1393        )
1394            .into_response();
1395    };
1396
1397    if !version.can_collaborate() {
1398        return (
1399            StatusCode::UPGRADE_REQUIRED,
1400            "client must be upgraded".to_string(),
1401        )
1402            .into_response();
1403    }
1404
1405    let socket_address = socket_address.to_string();
1406    ws.on_upgrade(move |socket| {
1407        let socket = socket
1408            .map_ok(to_tungstenite_message)
1409            .err_into()
1410            .with(|message| async move { to_axum_message(message) });
1411        let connection = Connection::new(Box::pin(socket));
1412        async move {
1413            server
1414                .handle_connection(
1415                    connection,
1416                    socket_address,
1417                    principal,
1418                    version,
1419                    None,
1420                    Executor::Production,
1421                )
1422                .await;
1423        }
1424    })
1425}
1426
1427pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1428    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1429    let connections_metric = CONNECTIONS_METRIC
1430        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1431
1432    let connections = server
1433        .connection_pool
1434        .lock()
1435        .connections()
1436        .filter(|connection| !connection.admin)
1437        .count();
1438    connections_metric.set(connections as _);
1439
1440    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1441    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1442        register_int_gauge!(
1443            "shared_projects",
1444            "number of open projects with one or more guests"
1445        )
1446        .unwrap()
1447    });
1448
1449    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1450    shared_projects_metric.set(shared_projects as _);
1451
1452    let encoder = prometheus::TextEncoder::new();
1453    let metric_families = prometheus::gather();
1454    let encoded_metrics = encoder
1455        .encode_to_string(&metric_families)
1456        .map_err(|err| anyhow!("{}", err))?;
1457    Ok(encoded_metrics)
1458}
1459
1460#[instrument(err, skip(executor))]
1461async fn connection_lost(
1462    session: Session,
1463    mut teardown: watch::Receiver<bool>,
1464    executor: Executor,
1465) -> Result<()> {
1466    session.peer.disconnect(session.connection_id);
1467    session
1468        .connection_pool()
1469        .await
1470        .remove_connection(session.connection_id)?;
1471
1472    session
1473        .db()
1474        .await
1475        .connection_lost(session.connection_id)
1476        .await
1477        .trace_err();
1478
1479    futures::select_biased! {
1480        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1481            match &session.principal {
1482                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1483                    let session = session.for_user().unwrap();
1484
1485                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1486                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1487                    leave_channel_buffers_for_session(&session)
1488                        .await
1489                        .trace_err();
1490
1491                    if !session
1492                        .connection_pool()
1493                        .await
1494                        .is_user_online(session.user_id())
1495                    {
1496                        let db = session.db().await;
1497                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1498                            room_updated(&room, &session.peer);
1499                        }
1500                    }
1501
1502                    update_user_contacts(session.user_id(), &session).await?;
1503                },
1504            Principal::DevServer(_) => {
1505                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1506            },
1507        }
1508        },
1509        _ = teardown.changed().fuse() => {}
1510    }
1511
1512    Ok(())
1513}
1514
1515/// Acknowledges a ping from a client, used to keep the connection alive.
1516async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1517    response.send(proto::Ack {})?;
1518    Ok(())
1519}
1520
1521/// Creates a new room for calling (outside of channels)
1522async fn create_room(
1523    _request: proto::CreateRoom,
1524    response: Response<proto::CreateRoom>,
1525    session: UserSession,
1526) -> Result<()> {
1527    let live_kit_room = nanoid::nanoid!(30);
1528
1529    let live_kit_connection_info = util::maybe!(async {
1530        let live_kit = session.live_kit_client.as_ref();
1531        let live_kit = live_kit?;
1532        let user_id = session.user_id().to_string();
1533
1534        let token = live_kit
1535            .room_token(&live_kit_room, &user_id.to_string())
1536            .trace_err()?;
1537
1538        Some(proto::LiveKitConnectionInfo {
1539            server_url: live_kit.url().into(),
1540            token,
1541            can_publish: true,
1542        })
1543    })
1544    .await;
1545
1546    let room = session
1547        .db()
1548        .await
1549        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1550        .await?;
1551
1552    response.send(proto::CreateRoomResponse {
1553        room: Some(room.clone()),
1554        live_kit_connection_info,
1555    })?;
1556
1557    update_user_contacts(session.user_id(), &session).await?;
1558    Ok(())
1559}
1560
1561/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1562async fn join_room(
1563    request: proto::JoinRoom,
1564    response: Response<proto::JoinRoom>,
1565    session: UserSession,
1566) -> Result<()> {
1567    let room_id = RoomId::from_proto(request.id);
1568
1569    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1570
1571    if let Some(channel_id) = channel_id {
1572        return join_channel_internal(channel_id, Box::new(response), session).await;
1573    }
1574
1575    let joined_room = {
1576        let room = session
1577            .db()
1578            .await
1579            .join_room(room_id, session.user_id(), session.connection_id)
1580            .await?;
1581        room_updated(&room.room, &session.peer);
1582        room.into_inner()
1583    };
1584
1585    for connection_id in session
1586        .connection_pool()
1587        .await
1588        .user_connection_ids(session.user_id())
1589    {
1590        session
1591            .peer
1592            .send(
1593                connection_id,
1594                proto::CallCanceled {
1595                    room_id: room_id.to_proto(),
1596                },
1597            )
1598            .trace_err();
1599    }
1600
1601    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1602        if let Some(token) = live_kit
1603            .room_token(
1604                &joined_room.room.live_kit_room,
1605                &session.user_id().to_string(),
1606            )
1607            .trace_err()
1608        {
1609            Some(proto::LiveKitConnectionInfo {
1610                server_url: live_kit.url().into(),
1611                token,
1612                can_publish: true,
1613            })
1614        } else {
1615            None
1616        }
1617    } else {
1618        None
1619    };
1620
1621    response.send(proto::JoinRoomResponse {
1622        room: Some(joined_room.room),
1623        channel_id: None,
1624        live_kit_connection_info,
1625    })?;
1626
1627    update_user_contacts(session.user_id(), &session).await?;
1628    Ok(())
1629}
1630
1631/// Rejoin room is used to reconnect to a room after connection errors.
1632async fn rejoin_room(
1633    request: proto::RejoinRoom,
1634    response: Response<proto::RejoinRoom>,
1635    session: UserSession,
1636) -> Result<()> {
1637    let room;
1638    let channel;
1639    {
1640        let mut rejoined_room = session
1641            .db()
1642            .await
1643            .rejoin_room(request, session.user_id(), session.connection_id)
1644            .await?;
1645
1646        response.send(proto::RejoinRoomResponse {
1647            room: Some(rejoined_room.room.clone()),
1648            reshared_projects: rejoined_room
1649                .reshared_projects
1650                .iter()
1651                .map(|project| proto::ResharedProject {
1652                    id: project.id.to_proto(),
1653                    collaborators: project
1654                        .collaborators
1655                        .iter()
1656                        .map(|collaborator| collaborator.to_proto())
1657                        .collect(),
1658                })
1659                .collect(),
1660            rejoined_projects: rejoined_room
1661                .rejoined_projects
1662                .iter()
1663                .map(|rejoined_project| rejoined_project.to_proto())
1664                .collect(),
1665        })?;
1666        room_updated(&rejoined_room.room, &session.peer);
1667
1668        for project in &rejoined_room.reshared_projects {
1669            for collaborator in &project.collaborators {
1670                session
1671                    .peer
1672                    .send(
1673                        collaborator.connection_id,
1674                        proto::UpdateProjectCollaborator {
1675                            project_id: project.id.to_proto(),
1676                            old_peer_id: Some(project.old_connection_id.into()),
1677                            new_peer_id: Some(session.connection_id.into()),
1678                        },
1679                    )
1680                    .trace_err();
1681            }
1682
1683            broadcast(
1684                Some(session.connection_id),
1685                project
1686                    .collaborators
1687                    .iter()
1688                    .map(|collaborator| collaborator.connection_id),
1689                |connection_id| {
1690                    session.peer.forward_send(
1691                        session.connection_id,
1692                        connection_id,
1693                        proto::UpdateProject {
1694                            project_id: project.id.to_proto(),
1695                            worktrees: project.worktrees.clone(),
1696                        },
1697                    )
1698                },
1699            );
1700        }
1701
1702        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1703
1704        let rejoined_room = rejoined_room.into_inner();
1705
1706        room = rejoined_room.room;
1707        channel = rejoined_room.channel;
1708    }
1709
1710    if let Some(channel) = channel {
1711        channel_updated(
1712            &channel,
1713            &room,
1714            &session.peer,
1715            &*session.connection_pool().await,
1716        );
1717    }
1718
1719    update_user_contacts(session.user_id(), &session).await?;
1720    Ok(())
1721}
1722
1723fn notify_rejoined_projects(
1724    rejoined_projects: &mut Vec<RejoinedProject>,
1725    session: &UserSession,
1726) -> Result<()> {
1727    for project in rejoined_projects.iter() {
1728        for collaborator in &project.collaborators {
1729            session
1730                .peer
1731                .send(
1732                    collaborator.connection_id,
1733                    proto::UpdateProjectCollaborator {
1734                        project_id: project.id.to_proto(),
1735                        old_peer_id: Some(project.old_connection_id.into()),
1736                        new_peer_id: Some(session.connection_id.into()),
1737                    },
1738                )
1739                .trace_err();
1740        }
1741    }
1742
1743    for project in rejoined_projects {
1744        for worktree in mem::take(&mut project.worktrees) {
1745            #[cfg(any(test, feature = "test-support"))]
1746            const MAX_CHUNK_SIZE: usize = 2;
1747            #[cfg(not(any(test, feature = "test-support")))]
1748            const MAX_CHUNK_SIZE: usize = 256;
1749
1750            // Stream this worktree's entries.
1751            let message = proto::UpdateWorktree {
1752                project_id: project.id.to_proto(),
1753                worktree_id: worktree.id,
1754                abs_path: worktree.abs_path.clone(),
1755                root_name: worktree.root_name,
1756                updated_entries: worktree.updated_entries,
1757                removed_entries: worktree.removed_entries,
1758                scan_id: worktree.scan_id,
1759                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1760                updated_repositories: worktree.updated_repositories,
1761                removed_repositories: worktree.removed_repositories,
1762            };
1763            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1764                session.peer.send(session.connection_id, update.clone())?;
1765            }
1766
1767            // Stream this worktree's diagnostics.
1768            for summary in worktree.diagnostic_summaries {
1769                session.peer.send(
1770                    session.connection_id,
1771                    proto::UpdateDiagnosticSummary {
1772                        project_id: project.id.to_proto(),
1773                        worktree_id: worktree.id,
1774                        summary: Some(summary),
1775                    },
1776                )?;
1777            }
1778
1779            for settings_file in worktree.settings_files {
1780                session.peer.send(
1781                    session.connection_id,
1782                    proto::UpdateWorktreeSettings {
1783                        project_id: project.id.to_proto(),
1784                        worktree_id: worktree.id,
1785                        path: settings_file.path,
1786                        content: Some(settings_file.content),
1787                    },
1788                )?;
1789            }
1790        }
1791
1792        for language_server in &project.language_servers {
1793            session.peer.send(
1794                session.connection_id,
1795                proto::UpdateLanguageServer {
1796                    project_id: project.id.to_proto(),
1797                    language_server_id: language_server.id,
1798                    variant: Some(
1799                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1800                            proto::LspDiskBasedDiagnosticsUpdated {},
1801                        ),
1802                    ),
1803                },
1804            )?;
1805        }
1806    }
1807    Ok(())
1808}
1809
1810/// leave room disconnects from the room.
1811async fn leave_room(
1812    _: proto::LeaveRoom,
1813    response: Response<proto::LeaveRoom>,
1814    session: UserSession,
1815) -> Result<()> {
1816    leave_room_for_session(&session, session.connection_id).await?;
1817    response.send(proto::Ack {})?;
1818    Ok(())
1819}
1820
1821/// Updates the permissions of someone else in the room.
1822async fn set_room_participant_role(
1823    request: proto::SetRoomParticipantRole,
1824    response: Response<proto::SetRoomParticipantRole>,
1825    session: UserSession,
1826) -> Result<()> {
1827    let user_id = UserId::from_proto(request.user_id);
1828    let role = ChannelRole::from(request.role());
1829
1830    let (live_kit_room, can_publish) = {
1831        let room = session
1832            .db()
1833            .await
1834            .set_room_participant_role(
1835                session.user_id(),
1836                RoomId::from_proto(request.room_id),
1837                user_id,
1838                role,
1839            )
1840            .await?;
1841
1842        let live_kit_room = room.live_kit_room.clone();
1843        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1844        room_updated(&room, &session.peer);
1845        (live_kit_room, can_publish)
1846    };
1847
1848    if let Some(live_kit) = session.live_kit_client.as_ref() {
1849        live_kit
1850            .update_participant(
1851                live_kit_room.clone(),
1852                request.user_id.to_string(),
1853                live_kit_server::proto::ParticipantPermission {
1854                    can_subscribe: true,
1855                    can_publish,
1856                    can_publish_data: can_publish,
1857                    hidden: false,
1858                    recorder: false,
1859                },
1860            )
1861            .await
1862            .trace_err();
1863    }
1864
1865    response.send(proto::Ack {})?;
1866    Ok(())
1867}
1868
1869/// Call someone else into the current room
1870async fn call(
1871    request: proto::Call,
1872    response: Response<proto::Call>,
1873    session: UserSession,
1874) -> Result<()> {
1875    let room_id = RoomId::from_proto(request.room_id);
1876    let calling_user_id = session.user_id();
1877    let calling_connection_id = session.connection_id;
1878    let called_user_id = UserId::from_proto(request.called_user_id);
1879    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1880    if !session
1881        .db()
1882        .await
1883        .has_contact(calling_user_id, called_user_id)
1884        .await?
1885    {
1886        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1887    }
1888
1889    let incoming_call = {
1890        let (room, incoming_call) = &mut *session
1891            .db()
1892            .await
1893            .call(
1894                room_id,
1895                calling_user_id,
1896                calling_connection_id,
1897                called_user_id,
1898                initial_project_id,
1899            )
1900            .await?;
1901        room_updated(&room, &session.peer);
1902        mem::take(incoming_call)
1903    };
1904    update_user_contacts(called_user_id, &session).await?;
1905
1906    let mut calls = session
1907        .connection_pool()
1908        .await
1909        .user_connection_ids(called_user_id)
1910        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1911        .collect::<FuturesUnordered<_>>();
1912
1913    while let Some(call_response) = calls.next().await {
1914        match call_response.as_ref() {
1915            Ok(_) => {
1916                response.send(proto::Ack {})?;
1917                return Ok(());
1918            }
1919            Err(_) => {
1920                call_response.trace_err();
1921            }
1922        }
1923    }
1924
1925    {
1926        let room = session
1927            .db()
1928            .await
1929            .call_failed(room_id, called_user_id)
1930            .await?;
1931        room_updated(&room, &session.peer);
1932    }
1933    update_user_contacts(called_user_id, &session).await?;
1934
1935    Err(anyhow!("failed to ring user"))?
1936}
1937
1938/// Cancel an outgoing call.
1939async fn cancel_call(
1940    request: proto::CancelCall,
1941    response: Response<proto::CancelCall>,
1942    session: UserSession,
1943) -> Result<()> {
1944    let called_user_id = UserId::from_proto(request.called_user_id);
1945    let room_id = RoomId::from_proto(request.room_id);
1946    {
1947        let room = session
1948            .db()
1949            .await
1950            .cancel_call(room_id, session.connection_id, called_user_id)
1951            .await?;
1952        room_updated(&room, &session.peer);
1953    }
1954
1955    for connection_id in session
1956        .connection_pool()
1957        .await
1958        .user_connection_ids(called_user_id)
1959    {
1960        session
1961            .peer
1962            .send(
1963                connection_id,
1964                proto::CallCanceled {
1965                    room_id: room_id.to_proto(),
1966                },
1967            )
1968            .trace_err();
1969    }
1970    response.send(proto::Ack {})?;
1971
1972    update_user_contacts(called_user_id, &session).await?;
1973    Ok(())
1974}
1975
1976/// Decline an incoming call.
1977async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1978    let room_id = RoomId::from_proto(message.room_id);
1979    {
1980        let room = session
1981            .db()
1982            .await
1983            .decline_call(Some(room_id), session.user_id())
1984            .await?
1985            .ok_or_else(|| anyhow!("failed to decline call"))?;
1986        room_updated(&room, &session.peer);
1987    }
1988
1989    for connection_id in session
1990        .connection_pool()
1991        .await
1992        .user_connection_ids(session.user_id())
1993    {
1994        session
1995            .peer
1996            .send(
1997                connection_id,
1998                proto::CallCanceled {
1999                    room_id: room_id.to_proto(),
2000                },
2001            )
2002            .trace_err();
2003    }
2004    update_user_contacts(session.user_id(), &session).await?;
2005    Ok(())
2006}
2007
2008/// Updates other participants in the room with your current location.
2009async fn update_participant_location(
2010    request: proto::UpdateParticipantLocation,
2011    response: Response<proto::UpdateParticipantLocation>,
2012    session: UserSession,
2013) -> Result<()> {
2014    let room_id = RoomId::from_proto(request.room_id);
2015    let location = request
2016        .location
2017        .ok_or_else(|| anyhow!("invalid location"))?;
2018
2019    let db = session.db().await;
2020    let room = db
2021        .update_room_participant_location(room_id, session.connection_id, location)
2022        .await?;
2023
2024    room_updated(&room, &session.peer);
2025    response.send(proto::Ack {})?;
2026    Ok(())
2027}
2028
2029/// Share a project into the room.
2030async fn share_project(
2031    request: proto::ShareProject,
2032    response: Response<proto::ShareProject>,
2033    session: UserSession,
2034) -> Result<()> {
2035    let (project_id, room) = &*session
2036        .db()
2037        .await
2038        .share_project(
2039            RoomId::from_proto(request.room_id),
2040            session.connection_id,
2041            &request.worktrees,
2042            request
2043                .dev_server_project_id
2044                .map(|id| DevServerProjectId::from_proto(id)),
2045        )
2046        .await?;
2047    response.send(proto::ShareProjectResponse {
2048        project_id: project_id.to_proto(),
2049    })?;
2050    room_updated(&room, &session.peer);
2051
2052    Ok(())
2053}
2054
2055/// Unshare a project from the room.
2056async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2057    let project_id = ProjectId::from_proto(message.project_id);
2058    unshare_project_internal(
2059        project_id,
2060        session.connection_id,
2061        session.user_id(),
2062        &session,
2063    )
2064    .await
2065}
2066
2067async fn unshare_project_internal(
2068    project_id: ProjectId,
2069    connection_id: ConnectionId,
2070    user_id: Option<UserId>,
2071    session: &Session,
2072) -> Result<()> {
2073    let delete = {
2074        let room_guard = session
2075            .db()
2076            .await
2077            .unshare_project(project_id, connection_id, user_id)
2078            .await?;
2079
2080        let (delete, room, guest_connection_ids) = &*room_guard;
2081
2082        let message = proto::UnshareProject {
2083            project_id: project_id.to_proto(),
2084        };
2085
2086        broadcast(
2087            Some(connection_id),
2088            guest_connection_ids.iter().copied(),
2089            |conn_id| session.peer.send(conn_id, message.clone()),
2090        );
2091        if let Some(room) = room {
2092            room_updated(room, &session.peer);
2093        }
2094
2095        *delete
2096    };
2097
2098    if delete {
2099        let db = session.db().await;
2100        db.delete_project(project_id).await?;
2101    }
2102
2103    Ok(())
2104}
2105
2106/// DevServer makes a project available online
2107async fn share_dev_server_project(
2108    request: proto::ShareDevServerProject,
2109    response: Response<proto::ShareDevServerProject>,
2110    session: DevServerSession,
2111) -> Result<()> {
2112    let (dev_server_project, user_id, status) = session
2113        .db()
2114        .await
2115        .share_dev_server_project(
2116            DevServerProjectId::from_proto(request.dev_server_project_id),
2117            session.dev_server_id(),
2118            session.connection_id,
2119            &request.worktrees,
2120        )
2121        .await?;
2122    let Some(project_id) = dev_server_project.project_id else {
2123        return Err(anyhow!("failed to share remote project"))?;
2124    };
2125
2126    send_dev_server_projects_update(user_id, status, &session).await;
2127
2128    response.send(proto::ShareProjectResponse { project_id })?;
2129
2130    Ok(())
2131}
2132
2133/// Join someone elses shared project.
2134async fn join_project(
2135    request: proto::JoinProject,
2136    response: Response<proto::JoinProject>,
2137    session: UserSession,
2138) -> Result<()> {
2139    let project_id = ProjectId::from_proto(request.project_id);
2140
2141    tracing::info!(%project_id, "join project");
2142
2143    let db = session.db().await;
2144    let (project, replica_id) = &mut *db
2145        .join_project(project_id, session.connection_id, session.user_id())
2146        .await?;
2147    drop(db);
2148    tracing::info!(%project_id, "join remote project");
2149    join_project_internal(response, session, project, replica_id)
2150}
2151
2152trait JoinProjectInternalResponse {
2153    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2154}
2155impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2156    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2157        Response::<proto::JoinProject>::send(self, result)
2158    }
2159}
2160impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2161    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2162        Response::<proto::JoinHostedProject>::send(self, result)
2163    }
2164}
2165
2166fn join_project_internal(
2167    response: impl JoinProjectInternalResponse,
2168    session: UserSession,
2169    project: &mut Project,
2170    replica_id: &ReplicaId,
2171) -> Result<()> {
2172    let collaborators = project
2173        .collaborators
2174        .iter()
2175        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2176        .map(|collaborator| collaborator.to_proto())
2177        .collect::<Vec<_>>();
2178    let project_id = project.id;
2179    let guest_user_id = session.user_id();
2180
2181    let worktrees = project
2182        .worktrees
2183        .iter()
2184        .map(|(id, worktree)| proto::WorktreeMetadata {
2185            id: *id,
2186            root_name: worktree.root_name.clone(),
2187            visible: worktree.visible,
2188            abs_path: worktree.abs_path.clone(),
2189        })
2190        .collect::<Vec<_>>();
2191
2192    let add_project_collaborator = proto::AddProjectCollaborator {
2193        project_id: project_id.to_proto(),
2194        collaborator: Some(proto::Collaborator {
2195            peer_id: Some(session.connection_id.into()),
2196            replica_id: replica_id.0 as u32,
2197            user_id: guest_user_id.to_proto(),
2198        }),
2199    };
2200
2201    for collaborator in &collaborators {
2202        session
2203            .peer
2204            .send(
2205                collaborator.peer_id.unwrap().into(),
2206                add_project_collaborator.clone(),
2207            )
2208            .trace_err();
2209    }
2210
2211    // First, we send the metadata associated with each worktree.
2212    response.send(proto::JoinProjectResponse {
2213        project_id: project.id.0 as u64,
2214        worktrees: worktrees.clone(),
2215        replica_id: replica_id.0 as u32,
2216        collaborators: collaborators.clone(),
2217        language_servers: project.language_servers.clone(),
2218        role: project.role.into(),
2219        dev_server_project_id: project
2220            .dev_server_project_id
2221            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2222    })?;
2223
2224    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2225        #[cfg(any(test, feature = "test-support"))]
2226        const MAX_CHUNK_SIZE: usize = 2;
2227        #[cfg(not(any(test, feature = "test-support")))]
2228        const MAX_CHUNK_SIZE: usize = 256;
2229
2230        // Stream this worktree's entries.
2231        let message = proto::UpdateWorktree {
2232            project_id: project_id.to_proto(),
2233            worktree_id,
2234            abs_path: worktree.abs_path.clone(),
2235            root_name: worktree.root_name,
2236            updated_entries: worktree.entries,
2237            removed_entries: Default::default(),
2238            scan_id: worktree.scan_id,
2239            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2240            updated_repositories: worktree.repository_entries.into_values().collect(),
2241            removed_repositories: Default::default(),
2242        };
2243        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2244            session.peer.send(session.connection_id, update.clone())?;
2245        }
2246
2247        // Stream this worktree's diagnostics.
2248        for summary in worktree.diagnostic_summaries {
2249            session.peer.send(
2250                session.connection_id,
2251                proto::UpdateDiagnosticSummary {
2252                    project_id: project_id.to_proto(),
2253                    worktree_id: worktree.id,
2254                    summary: Some(summary),
2255                },
2256            )?;
2257        }
2258
2259        for settings_file in worktree.settings_files {
2260            session.peer.send(
2261                session.connection_id,
2262                proto::UpdateWorktreeSettings {
2263                    project_id: project_id.to_proto(),
2264                    worktree_id: worktree.id,
2265                    path: settings_file.path,
2266                    content: Some(settings_file.content),
2267                },
2268            )?;
2269        }
2270    }
2271
2272    for language_server in &project.language_servers {
2273        session.peer.send(
2274            session.connection_id,
2275            proto::UpdateLanguageServer {
2276                project_id: project_id.to_proto(),
2277                language_server_id: language_server.id,
2278                variant: Some(
2279                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2280                        proto::LspDiskBasedDiagnosticsUpdated {},
2281                    ),
2282                ),
2283            },
2284        )?;
2285    }
2286
2287    Ok(())
2288}
2289
2290/// Leave someone elses shared project.
2291async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2292    let sender_id = session.connection_id;
2293    let project_id = ProjectId::from_proto(request.project_id);
2294    let db = session.db().await;
2295    if db.is_hosted_project(project_id).await? {
2296        let project = db.leave_hosted_project(project_id, sender_id).await?;
2297        project_left(&project, &session);
2298        return Ok(());
2299    }
2300
2301    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2302    tracing::info!(
2303        %project_id,
2304        "leave project"
2305    );
2306
2307    project_left(&project, &session);
2308    if let Some(room) = room {
2309        room_updated(&room, &session.peer);
2310    }
2311
2312    Ok(())
2313}
2314
2315async fn join_hosted_project(
2316    request: proto::JoinHostedProject,
2317    response: Response<proto::JoinHostedProject>,
2318    session: UserSession,
2319) -> Result<()> {
2320    let (mut project, replica_id) = session
2321        .db()
2322        .await
2323        .join_hosted_project(
2324            ProjectId(request.project_id as i32),
2325            session.user_id(),
2326            session.connection_id,
2327        )
2328        .await?;
2329
2330    join_project_internal(response, session, &mut project, &replica_id)
2331}
2332
2333async fn list_remote_directory(
2334    request: proto::ListRemoteDirectory,
2335    response: Response<proto::ListRemoteDirectory>,
2336    session: UserSession,
2337) -> Result<()> {
2338    let dev_server_id = DevServerId(request.dev_server_id as i32);
2339    let dev_server_connection_id = session
2340        .connection_pool()
2341        .await
2342        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2343
2344    session
2345        .db()
2346        .await
2347        .get_dev_server_for_user(dev_server_id, session.user_id())
2348        .await?;
2349
2350    response.send(
2351        session
2352            .peer
2353            .forward_request(session.connection_id, dev_server_connection_id, request)
2354            .await?,
2355    )?;
2356    Ok(())
2357}
2358
2359async fn update_dev_server_project(
2360    request: proto::UpdateDevServerProject,
2361    response: Response<proto::UpdateDevServerProject>,
2362    session: UserSession,
2363) -> Result<()> {
2364    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2365
2366    let (dev_server_project, update) = session
2367        .db()
2368        .await
2369        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2370        .await?;
2371
2372    let projects = session
2373        .db()
2374        .await
2375        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2376        .await?;
2377
2378    let dev_server_connection_id = session
2379        .connection_pool()
2380        .await
2381        .dev_server_connection_id_supporting(
2382            dev_server_project.dev_server_id,
2383            ZedVersion::with_list_directory(),
2384        )?;
2385
2386    session.peer.send(
2387        dev_server_connection_id,
2388        proto::DevServerInstructions { projects },
2389    )?;
2390
2391    send_dev_server_projects_update(session.user_id(), update, &session).await;
2392
2393    response.send(proto::Ack {})
2394}
2395
2396async fn create_dev_server_project(
2397    request: proto::CreateDevServerProject,
2398    response: Response<proto::CreateDevServerProject>,
2399    session: UserSession,
2400) -> Result<()> {
2401    let dev_server_id = DevServerId(request.dev_server_id as i32);
2402    let dev_server_connection_id = session
2403        .connection_pool()
2404        .await
2405        .dev_server_connection_id(dev_server_id);
2406    let Some(dev_server_connection_id) = dev_server_connection_id else {
2407        Err(ErrorCode::DevServerOffline
2408            .message("Cannot create a remote project when the dev server is offline".to_string())
2409            .anyhow())?
2410    };
2411
2412    let path = request.path.clone();
2413    //Check that the path exists on the dev server
2414    session
2415        .peer
2416        .forward_request(
2417            session.connection_id,
2418            dev_server_connection_id,
2419            proto::ValidateDevServerProjectRequest { path: path.clone() },
2420        )
2421        .await?;
2422
2423    let (dev_server_project, update) = session
2424        .db()
2425        .await
2426        .create_dev_server_project(
2427            DevServerId(request.dev_server_id as i32),
2428            &request.path,
2429            session.user_id(),
2430        )
2431        .await?;
2432
2433    let projects = session
2434        .db()
2435        .await
2436        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2437        .await?;
2438
2439    session.peer.send(
2440        dev_server_connection_id,
2441        proto::DevServerInstructions { projects },
2442    )?;
2443
2444    send_dev_server_projects_update(session.user_id(), update, &session).await;
2445
2446    response.send(proto::CreateDevServerProjectResponse {
2447        dev_server_project: Some(dev_server_project.to_proto(None)),
2448    })?;
2449    Ok(())
2450}
2451
2452async fn create_dev_server(
2453    request: proto::CreateDevServer,
2454    response: Response<proto::CreateDevServer>,
2455    session: UserSession,
2456) -> Result<()> {
2457    let access_token = auth::random_token();
2458    let hashed_access_token = auth::hash_access_token(&access_token);
2459
2460    if request.name.is_empty() {
2461        return Err(proto::ErrorCode::Forbidden
2462            .message("Dev server name cannot be empty".to_string())
2463            .anyhow())?;
2464    }
2465
2466    let (dev_server, status) = session
2467        .db()
2468        .await
2469        .create_dev_server(
2470            &request.name,
2471            request.ssh_connection_string.as_deref(),
2472            &hashed_access_token,
2473            session.user_id(),
2474        )
2475        .await?;
2476
2477    send_dev_server_projects_update(session.user_id(), status, &session).await;
2478
2479    response.send(proto::CreateDevServerResponse {
2480        dev_server_id: dev_server.id.0 as u64,
2481        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2482        name: request.name,
2483    })?;
2484    Ok(())
2485}
2486
2487async fn regenerate_dev_server_token(
2488    request: proto::RegenerateDevServerToken,
2489    response: Response<proto::RegenerateDevServerToken>,
2490    session: UserSession,
2491) -> Result<()> {
2492    let dev_server_id = DevServerId(request.dev_server_id as i32);
2493    let access_token = auth::random_token();
2494    let hashed_access_token = auth::hash_access_token(&access_token);
2495
2496    let connection_id = session
2497        .connection_pool()
2498        .await
2499        .dev_server_connection_id(dev_server_id);
2500    if let Some(connection_id) = connection_id {
2501        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2502        session.peer.send(
2503            connection_id,
2504            proto::ShutdownDevServer {
2505                reason: Some("dev server token was regenerated".to_string()),
2506            },
2507        )?;
2508        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2509    }
2510
2511    let status = session
2512        .db()
2513        .await
2514        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2515        .await?;
2516
2517    send_dev_server_projects_update(session.user_id(), status, &session).await;
2518
2519    response.send(proto::RegenerateDevServerTokenResponse {
2520        dev_server_id: dev_server_id.to_proto(),
2521        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2522    })?;
2523    Ok(())
2524}
2525
2526async fn rename_dev_server(
2527    request: proto::RenameDevServer,
2528    response: Response<proto::RenameDevServer>,
2529    session: UserSession,
2530) -> Result<()> {
2531    if request.name.trim().is_empty() {
2532        return Err(proto::ErrorCode::Forbidden
2533            .message("Dev server name cannot be empty".to_string())
2534            .anyhow())?;
2535    }
2536
2537    let dev_server_id = DevServerId(request.dev_server_id as i32);
2538    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2539    if dev_server.user_id != session.user_id() {
2540        return Err(anyhow!(ErrorCode::Forbidden))?;
2541    }
2542
2543    let status = session
2544        .db()
2545        .await
2546        .rename_dev_server(
2547            dev_server_id,
2548            &request.name,
2549            request.ssh_connection_string.as_deref(),
2550            session.user_id(),
2551        )
2552        .await?;
2553
2554    send_dev_server_projects_update(session.user_id(), status, &session).await;
2555
2556    response.send(proto::Ack {})?;
2557    Ok(())
2558}
2559
2560async fn delete_dev_server(
2561    request: proto::DeleteDevServer,
2562    response: Response<proto::DeleteDevServer>,
2563    session: UserSession,
2564) -> Result<()> {
2565    let dev_server_id = DevServerId(request.dev_server_id as i32);
2566    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2567    if dev_server.user_id != session.user_id() {
2568        return Err(anyhow!(ErrorCode::Forbidden))?;
2569    }
2570
2571    let connection_id = session
2572        .connection_pool()
2573        .await
2574        .dev_server_connection_id(dev_server_id);
2575    if let Some(connection_id) = connection_id {
2576        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2577        session.peer.send(
2578            connection_id,
2579            proto::ShutdownDevServer {
2580                reason: Some("dev server was deleted".to_string()),
2581            },
2582        )?;
2583        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2584    }
2585
2586    let status = session
2587        .db()
2588        .await
2589        .delete_dev_server(dev_server_id, session.user_id())
2590        .await?;
2591
2592    send_dev_server_projects_update(session.user_id(), status, &session).await;
2593
2594    response.send(proto::Ack {})?;
2595    Ok(())
2596}
2597
2598async fn delete_dev_server_project(
2599    request: proto::DeleteDevServerProject,
2600    response: Response<proto::DeleteDevServerProject>,
2601    session: UserSession,
2602) -> Result<()> {
2603    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2604    let dev_server_project = session
2605        .db()
2606        .await
2607        .get_dev_server_project(dev_server_project_id)
2608        .await?;
2609
2610    let dev_server = session
2611        .db()
2612        .await
2613        .get_dev_server(dev_server_project.dev_server_id)
2614        .await?;
2615    if dev_server.user_id != session.user_id() {
2616        return Err(anyhow!(ErrorCode::Forbidden))?;
2617    }
2618
2619    let dev_server_connection_id = session
2620        .connection_pool()
2621        .await
2622        .dev_server_connection_id(dev_server.id);
2623
2624    if let Some(dev_server_connection_id) = dev_server_connection_id {
2625        let project = session
2626            .db()
2627            .await
2628            .find_dev_server_project(dev_server_project_id)
2629            .await;
2630        if let Ok(project) = project {
2631            unshare_project_internal(
2632                project.id,
2633                dev_server_connection_id,
2634                Some(session.user_id()),
2635                &session,
2636            )
2637            .await?;
2638        }
2639    }
2640
2641    let (projects, status) = session
2642        .db()
2643        .await
2644        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2645        .await?;
2646
2647    if let Some(dev_server_connection_id) = dev_server_connection_id {
2648        session.peer.send(
2649            dev_server_connection_id,
2650            proto::DevServerInstructions { projects },
2651        )?;
2652    }
2653
2654    send_dev_server_projects_update(session.user_id(), status, &session).await;
2655
2656    response.send(proto::Ack {})?;
2657    Ok(())
2658}
2659
2660async fn rejoin_dev_server_projects(
2661    request: proto::RejoinRemoteProjects,
2662    response: Response<proto::RejoinRemoteProjects>,
2663    session: UserSession,
2664) -> Result<()> {
2665    let mut rejoined_projects = {
2666        let db = session.db().await;
2667        db.rejoin_dev_server_projects(
2668            &request.rejoined_projects,
2669            session.user_id(),
2670            session.0.connection_id,
2671        )
2672        .await?
2673    };
2674    response.send(proto::RejoinRemoteProjectsResponse {
2675        rejoined_projects: rejoined_projects
2676            .iter()
2677            .map(|project| project.to_proto())
2678            .collect(),
2679    })?;
2680    notify_rejoined_projects(&mut rejoined_projects, &session)
2681}
2682
2683async fn reconnect_dev_server(
2684    request: proto::ReconnectDevServer,
2685    response: Response<proto::ReconnectDevServer>,
2686    session: DevServerSession,
2687) -> Result<()> {
2688    let reshared_projects = {
2689        let db = session.db().await;
2690        db.reshare_dev_server_projects(
2691            &request.reshared_projects,
2692            session.dev_server_id(),
2693            session.0.connection_id,
2694        )
2695        .await?
2696    };
2697
2698    for project in &reshared_projects {
2699        for collaborator in &project.collaborators {
2700            session
2701                .peer
2702                .send(
2703                    collaborator.connection_id,
2704                    proto::UpdateProjectCollaborator {
2705                        project_id: project.id.to_proto(),
2706                        old_peer_id: Some(project.old_connection_id.into()),
2707                        new_peer_id: Some(session.connection_id.into()),
2708                    },
2709                )
2710                .trace_err();
2711        }
2712
2713        broadcast(
2714            Some(session.connection_id),
2715            project
2716                .collaborators
2717                .iter()
2718                .map(|collaborator| collaborator.connection_id),
2719            |connection_id| {
2720                session.peer.forward_send(
2721                    session.connection_id,
2722                    connection_id,
2723                    proto::UpdateProject {
2724                        project_id: project.id.to_proto(),
2725                        worktrees: project.worktrees.clone(),
2726                    },
2727                )
2728            },
2729        );
2730    }
2731
2732    response.send(proto::ReconnectDevServerResponse {
2733        reshared_projects: reshared_projects
2734            .iter()
2735            .map(|project| proto::ResharedProject {
2736                id: project.id.to_proto(),
2737                collaborators: project
2738                    .collaborators
2739                    .iter()
2740                    .map(|collaborator| collaborator.to_proto())
2741                    .collect(),
2742            })
2743            .collect(),
2744    })?;
2745
2746    Ok(())
2747}
2748
2749async fn shutdown_dev_server(
2750    _: proto::ShutdownDevServer,
2751    response: Response<proto::ShutdownDevServer>,
2752    session: DevServerSession,
2753) -> Result<()> {
2754    response.send(proto::Ack {})?;
2755    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2756    remove_dev_server_connection(session.dev_server_id(), &session).await
2757}
2758
2759async fn shutdown_dev_server_internal(
2760    dev_server_id: DevServerId,
2761    connection_id: ConnectionId,
2762    session: &Session,
2763) -> Result<()> {
2764    let (dev_server_projects, dev_server) = {
2765        let db = session.db().await;
2766        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2767        let dev_server = db.get_dev_server(dev_server_id).await?;
2768        (dev_server_projects, dev_server)
2769    };
2770
2771    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2772        unshare_project_internal(
2773            ProjectId::from_proto(project_id),
2774            connection_id,
2775            None,
2776            session,
2777        )
2778        .await?;
2779    }
2780
2781    session
2782        .connection_pool()
2783        .await
2784        .set_dev_server_offline(dev_server_id);
2785
2786    let status = session
2787        .db()
2788        .await
2789        .dev_server_projects_update(dev_server.user_id)
2790        .await?;
2791    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2792
2793    Ok(())
2794}
2795
2796async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2797    let dev_server_connection = session
2798        .connection_pool()
2799        .await
2800        .dev_server_connection_id(dev_server_id);
2801
2802    if let Some(dev_server_connection) = dev_server_connection {
2803        session
2804            .connection_pool()
2805            .await
2806            .remove_connection(dev_server_connection)?;
2807    }
2808    Ok(())
2809}
2810
2811/// Updates other participants with changes to the project
2812async fn update_project(
2813    request: proto::UpdateProject,
2814    response: Response<proto::UpdateProject>,
2815    session: Session,
2816) -> Result<()> {
2817    let project_id = ProjectId::from_proto(request.project_id);
2818    let (room, guest_connection_ids) = &*session
2819        .db()
2820        .await
2821        .update_project(project_id, session.connection_id, &request.worktrees)
2822        .await?;
2823    broadcast(
2824        Some(session.connection_id),
2825        guest_connection_ids.iter().copied(),
2826        |connection_id| {
2827            session
2828                .peer
2829                .forward_send(session.connection_id, connection_id, request.clone())
2830        },
2831    );
2832    if let Some(room) = room {
2833        room_updated(&room, &session.peer);
2834    }
2835    response.send(proto::Ack {})?;
2836
2837    Ok(())
2838}
2839
2840/// Updates other participants with changes to the worktree
2841async fn update_worktree(
2842    request: proto::UpdateWorktree,
2843    response: Response<proto::UpdateWorktree>,
2844    session: Session,
2845) -> Result<()> {
2846    let guest_connection_ids = session
2847        .db()
2848        .await
2849        .update_worktree(&request, session.connection_id)
2850        .await?;
2851
2852    broadcast(
2853        Some(session.connection_id),
2854        guest_connection_ids.iter().copied(),
2855        |connection_id| {
2856            session
2857                .peer
2858                .forward_send(session.connection_id, connection_id, request.clone())
2859        },
2860    );
2861    response.send(proto::Ack {})?;
2862    Ok(())
2863}
2864
2865/// Updates other participants with changes to the diagnostics
2866async fn update_diagnostic_summary(
2867    message: proto::UpdateDiagnosticSummary,
2868    session: Session,
2869) -> Result<()> {
2870    let guest_connection_ids = session
2871        .db()
2872        .await
2873        .update_diagnostic_summary(&message, session.connection_id)
2874        .await?;
2875
2876    broadcast(
2877        Some(session.connection_id),
2878        guest_connection_ids.iter().copied(),
2879        |connection_id| {
2880            session
2881                .peer
2882                .forward_send(session.connection_id, connection_id, message.clone())
2883        },
2884    );
2885
2886    Ok(())
2887}
2888
2889/// Updates other participants with changes to the worktree settings
2890async fn update_worktree_settings(
2891    message: proto::UpdateWorktreeSettings,
2892    session: Session,
2893) -> Result<()> {
2894    let guest_connection_ids = session
2895        .db()
2896        .await
2897        .update_worktree_settings(&message, session.connection_id)
2898        .await?;
2899
2900    broadcast(
2901        Some(session.connection_id),
2902        guest_connection_ids.iter().copied(),
2903        |connection_id| {
2904            session
2905                .peer
2906                .forward_send(session.connection_id, connection_id, message.clone())
2907        },
2908    );
2909
2910    Ok(())
2911}
2912
2913/// Notify other participants that a  language server has started.
2914async fn start_language_server(
2915    request: proto::StartLanguageServer,
2916    session: Session,
2917) -> Result<()> {
2918    let guest_connection_ids = session
2919        .db()
2920        .await
2921        .start_language_server(&request, session.connection_id)
2922        .await?;
2923
2924    broadcast(
2925        Some(session.connection_id),
2926        guest_connection_ids.iter().copied(),
2927        |connection_id| {
2928            session
2929                .peer
2930                .forward_send(session.connection_id, connection_id, request.clone())
2931        },
2932    );
2933    Ok(())
2934}
2935
2936/// Notify other participants that a language server has changed.
2937async fn update_language_server(
2938    request: proto::UpdateLanguageServer,
2939    session: Session,
2940) -> Result<()> {
2941    let project_id = ProjectId::from_proto(request.project_id);
2942    let project_connection_ids = session
2943        .db()
2944        .await
2945        .project_connection_ids(project_id, session.connection_id, true)
2946        .await?;
2947    broadcast(
2948        Some(session.connection_id),
2949        project_connection_ids.iter().copied(),
2950        |connection_id| {
2951            session
2952                .peer
2953                .forward_send(session.connection_id, connection_id, request.clone())
2954        },
2955    );
2956    Ok(())
2957}
2958
2959/// forward a project request to the host. These requests should be read only
2960/// as guests are allowed to send them.
2961async fn forward_read_only_project_request<T>(
2962    request: T,
2963    response: Response<T>,
2964    session: UserSession,
2965) -> Result<()>
2966where
2967    T: EntityMessage + RequestMessage,
2968{
2969    let project_id = ProjectId::from_proto(request.remote_entity_id());
2970    let host_connection_id = session
2971        .db()
2972        .await
2973        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2974        .await?;
2975    let payload = session
2976        .peer
2977        .forward_request(session.connection_id, host_connection_id, request)
2978        .await?;
2979    response.send(payload)?;
2980    Ok(())
2981}
2982
2983/// forward a project request to the dev server. Only allowed
2984/// if it's your dev server.
2985async fn forward_project_request_for_owner<T>(
2986    request: T,
2987    response: Response<T>,
2988    session: UserSession,
2989) -> Result<()>
2990where
2991    T: EntityMessage + RequestMessage,
2992{
2993    let project_id = ProjectId::from_proto(request.remote_entity_id());
2994
2995    let host_connection_id = session
2996        .db()
2997        .await
2998        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2999        .await?;
3000    let payload = session
3001        .peer
3002        .forward_request(session.connection_id, host_connection_id, request)
3003        .await?;
3004    response.send(payload)?;
3005    Ok(())
3006}
3007
3008/// forward a project request to the host. These requests are disallowed
3009/// for guests.
3010async fn forward_mutating_project_request<T>(
3011    request: T,
3012    response: Response<T>,
3013    session: UserSession,
3014) -> Result<()>
3015where
3016    T: EntityMessage + RequestMessage,
3017{
3018    let project_id = ProjectId::from_proto(request.remote_entity_id());
3019
3020    let host_connection_id = session
3021        .db()
3022        .await
3023        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3024        .await?;
3025    let payload = session
3026        .peer
3027        .forward_request(session.connection_id, host_connection_id, request)
3028        .await?;
3029    response.send(payload)?;
3030    Ok(())
3031}
3032
3033/// forward a project request to the host. These requests are disallowed
3034/// for guests.
3035async fn forward_versioned_mutating_project_request<T>(
3036    request: T,
3037    response: Response<T>,
3038    session: UserSession,
3039) -> Result<()>
3040where
3041    T: EntityMessage + RequestMessage + VersionedMessage,
3042{
3043    let project_id = ProjectId::from_proto(request.remote_entity_id());
3044
3045    let host_connection_id = session
3046        .db()
3047        .await
3048        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3049        .await?;
3050    if let Some(host_version) = session
3051        .connection_pool()
3052        .await
3053        .connection(host_connection_id)
3054        .map(|c| c.zed_version)
3055    {
3056        if let Some(min_required_version) = request.required_host_version() {
3057            if min_required_version > host_version {
3058                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3059                    .with_tag("required", &min_required_version.to_string())))?;
3060            }
3061        }
3062    }
3063
3064    let payload = session
3065        .peer
3066        .forward_request(session.connection_id, host_connection_id, request)
3067        .await?;
3068    response.send(payload)?;
3069    Ok(())
3070}
3071
3072/// Notify other participants that a new buffer has been created
3073async fn create_buffer_for_peer(
3074    request: proto::CreateBufferForPeer,
3075    session: Session,
3076) -> Result<()> {
3077    session
3078        .db()
3079        .await
3080        .check_user_is_project_host(
3081            ProjectId::from_proto(request.project_id),
3082            session.connection_id,
3083        )
3084        .await?;
3085    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3086    session
3087        .peer
3088        .forward_send(session.connection_id, peer_id.into(), request)?;
3089    Ok(())
3090}
3091
3092/// Notify other participants that a buffer has been updated. This is
3093/// allowed for guests as long as the update is limited to selections.
3094async fn update_buffer(
3095    request: proto::UpdateBuffer,
3096    response: Response<proto::UpdateBuffer>,
3097    session: Session,
3098) -> Result<()> {
3099    let project_id = ProjectId::from_proto(request.project_id);
3100    let mut capability = Capability::ReadOnly;
3101
3102    for op in request.operations.iter() {
3103        match op.variant {
3104            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3105            Some(_) => capability = Capability::ReadWrite,
3106        }
3107    }
3108
3109    let host = {
3110        let guard = session
3111            .db()
3112            .await
3113            .connections_for_buffer_update(
3114                project_id,
3115                session.principal_id(),
3116                session.connection_id,
3117                capability,
3118            )
3119            .await?;
3120
3121        let (host, guests) = &*guard;
3122
3123        broadcast(
3124            Some(session.connection_id),
3125            guests.clone(),
3126            |connection_id| {
3127                session
3128                    .peer
3129                    .forward_send(session.connection_id, connection_id, request.clone())
3130            },
3131        );
3132
3133        *host
3134    };
3135
3136    if host != session.connection_id {
3137        session
3138            .peer
3139            .forward_request(session.connection_id, host, request.clone())
3140            .await?;
3141    }
3142
3143    response.send(proto::Ack {})?;
3144    Ok(())
3145}
3146
3147async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3148    let project_id = ProjectId::from_proto(message.project_id);
3149
3150    let operation = message.operation.as_ref().context("invalid operation")?;
3151    let capability = match operation.variant.as_ref() {
3152        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3153            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3154                match buffer_op.variant {
3155                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3156                        Capability::ReadOnly
3157                    }
3158                    _ => Capability::ReadWrite,
3159                }
3160            } else {
3161                Capability::ReadWrite
3162            }
3163        }
3164        Some(_) => Capability::ReadWrite,
3165        None => Capability::ReadOnly,
3166    };
3167
3168    let guard = session
3169        .db()
3170        .await
3171        .connections_for_buffer_update(
3172            project_id,
3173            session.principal_id(),
3174            session.connection_id,
3175            capability,
3176        )
3177        .await?;
3178
3179    let (host, guests) = &*guard;
3180
3181    broadcast(
3182        Some(session.connection_id),
3183        guests.iter().chain([host]).copied(),
3184        |connection_id| {
3185            session
3186                .peer
3187                .forward_send(session.connection_id, connection_id, message.clone())
3188        },
3189    );
3190
3191    Ok(())
3192}
3193
3194/// Notify other participants that a project has been updated.
3195async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3196    request: T,
3197    session: Session,
3198) -> Result<()> {
3199    let project_id = ProjectId::from_proto(request.remote_entity_id());
3200    let project_connection_ids = session
3201        .db()
3202        .await
3203        .project_connection_ids(project_id, session.connection_id, false)
3204        .await?;
3205
3206    broadcast(
3207        Some(session.connection_id),
3208        project_connection_ids.iter().copied(),
3209        |connection_id| {
3210            session
3211                .peer
3212                .forward_send(session.connection_id, connection_id, request.clone())
3213        },
3214    );
3215    Ok(())
3216}
3217
3218/// Start following another user in a call.
3219async fn follow(
3220    request: proto::Follow,
3221    response: Response<proto::Follow>,
3222    session: UserSession,
3223) -> Result<()> {
3224    let room_id = RoomId::from_proto(request.room_id);
3225    let project_id = request.project_id.map(ProjectId::from_proto);
3226    let leader_id = request
3227        .leader_id
3228        .ok_or_else(|| anyhow!("invalid leader id"))?
3229        .into();
3230    let follower_id = session.connection_id;
3231
3232    session
3233        .db()
3234        .await
3235        .check_room_participants(room_id, leader_id, session.connection_id)
3236        .await?;
3237
3238    let response_payload = session
3239        .peer
3240        .forward_request(session.connection_id, leader_id, request)
3241        .await?;
3242    response.send(response_payload)?;
3243
3244    if let Some(project_id) = project_id {
3245        let room = session
3246            .db()
3247            .await
3248            .follow(room_id, project_id, leader_id, follower_id)
3249            .await?;
3250        room_updated(&room, &session.peer);
3251    }
3252
3253    Ok(())
3254}
3255
3256/// Stop following another user in a call.
3257async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3258    let room_id = RoomId::from_proto(request.room_id);
3259    let project_id = request.project_id.map(ProjectId::from_proto);
3260    let leader_id = request
3261        .leader_id
3262        .ok_or_else(|| anyhow!("invalid leader id"))?
3263        .into();
3264    let follower_id = session.connection_id;
3265
3266    session
3267        .db()
3268        .await
3269        .check_room_participants(room_id, leader_id, session.connection_id)
3270        .await?;
3271
3272    session
3273        .peer
3274        .forward_send(session.connection_id, leader_id, request)?;
3275
3276    if let Some(project_id) = project_id {
3277        let room = session
3278            .db()
3279            .await
3280            .unfollow(room_id, project_id, leader_id, follower_id)
3281            .await?;
3282        room_updated(&room, &session.peer);
3283    }
3284
3285    Ok(())
3286}
3287
3288/// Notify everyone following you of your current location.
3289async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3290    let room_id = RoomId::from_proto(request.room_id);
3291    let database = session.db.lock().await;
3292
3293    let connection_ids = if let Some(project_id) = request.project_id {
3294        let project_id = ProjectId::from_proto(project_id);
3295        database
3296            .project_connection_ids(project_id, session.connection_id, true)
3297            .await?
3298    } else {
3299        database
3300            .room_connection_ids(room_id, session.connection_id)
3301            .await?
3302    };
3303
3304    // For now, don't send view update messages back to that view's current leader.
3305    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3306        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3307        _ => None,
3308    });
3309
3310    for connection_id in connection_ids.iter().cloned() {
3311        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3312            session
3313                .peer
3314                .forward_send(session.connection_id, connection_id, request.clone())?;
3315        }
3316    }
3317    Ok(())
3318}
3319
3320/// Get public data about users.
3321async fn get_users(
3322    request: proto::GetUsers,
3323    response: Response<proto::GetUsers>,
3324    session: Session,
3325) -> Result<()> {
3326    let user_ids = request
3327        .user_ids
3328        .into_iter()
3329        .map(UserId::from_proto)
3330        .collect();
3331    let users = session
3332        .db()
3333        .await
3334        .get_users_by_ids(user_ids)
3335        .await?
3336        .into_iter()
3337        .map(|user| proto::User {
3338            id: user.id.to_proto(),
3339            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3340            github_login: user.github_login,
3341        })
3342        .collect();
3343    response.send(proto::UsersResponse { users })?;
3344    Ok(())
3345}
3346
3347/// Search for users (to invite) buy Github login
3348async fn fuzzy_search_users(
3349    request: proto::FuzzySearchUsers,
3350    response: Response<proto::FuzzySearchUsers>,
3351    session: UserSession,
3352) -> Result<()> {
3353    let query = request.query;
3354    let users = match query.len() {
3355        0 => vec![],
3356        1 | 2 => session
3357            .db()
3358            .await
3359            .get_user_by_github_login(&query)
3360            .await?
3361            .into_iter()
3362            .collect(),
3363        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3364    };
3365    let users = users
3366        .into_iter()
3367        .filter(|user| user.id != session.user_id())
3368        .map(|user| proto::User {
3369            id: user.id.to_proto(),
3370            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3371            github_login: user.github_login,
3372        })
3373        .collect();
3374    response.send(proto::UsersResponse { users })?;
3375    Ok(())
3376}
3377
3378/// Send a contact request to another user.
3379async fn request_contact(
3380    request: proto::RequestContact,
3381    response: Response<proto::RequestContact>,
3382    session: UserSession,
3383) -> Result<()> {
3384    let requester_id = session.user_id();
3385    let responder_id = UserId::from_proto(request.responder_id);
3386    if requester_id == responder_id {
3387        return Err(anyhow!("cannot add yourself as a contact"))?;
3388    }
3389
3390    let notifications = session
3391        .db()
3392        .await
3393        .send_contact_request(requester_id, responder_id)
3394        .await?;
3395
3396    // Update outgoing contact requests of requester
3397    let mut update = proto::UpdateContacts::default();
3398    update.outgoing_requests.push(responder_id.to_proto());
3399    for connection_id in session
3400        .connection_pool()
3401        .await
3402        .user_connection_ids(requester_id)
3403    {
3404        session.peer.send(connection_id, update.clone())?;
3405    }
3406
3407    // Update incoming contact requests of responder
3408    let mut update = proto::UpdateContacts::default();
3409    update
3410        .incoming_requests
3411        .push(proto::IncomingContactRequest {
3412            requester_id: requester_id.to_proto(),
3413        });
3414    let connection_pool = session.connection_pool().await;
3415    for connection_id in connection_pool.user_connection_ids(responder_id) {
3416        session.peer.send(connection_id, update.clone())?;
3417    }
3418
3419    send_notifications(&connection_pool, &session.peer, notifications);
3420
3421    response.send(proto::Ack {})?;
3422    Ok(())
3423}
3424
3425/// Accept or decline a contact request
3426async fn respond_to_contact_request(
3427    request: proto::RespondToContactRequest,
3428    response: Response<proto::RespondToContactRequest>,
3429    session: UserSession,
3430) -> Result<()> {
3431    let responder_id = session.user_id();
3432    let requester_id = UserId::from_proto(request.requester_id);
3433    let db = session.db().await;
3434    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3435        db.dismiss_contact_notification(responder_id, requester_id)
3436            .await?;
3437    } else {
3438        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3439
3440        let notifications = db
3441            .respond_to_contact_request(responder_id, requester_id, accept)
3442            .await?;
3443        let requester_busy = db.is_user_busy(requester_id).await?;
3444        let responder_busy = db.is_user_busy(responder_id).await?;
3445
3446        let pool = session.connection_pool().await;
3447        // Update responder with new contact
3448        let mut update = proto::UpdateContacts::default();
3449        if accept {
3450            update
3451                .contacts
3452                .push(contact_for_user(requester_id, requester_busy, &pool));
3453        }
3454        update
3455            .remove_incoming_requests
3456            .push(requester_id.to_proto());
3457        for connection_id in pool.user_connection_ids(responder_id) {
3458            session.peer.send(connection_id, update.clone())?;
3459        }
3460
3461        // Update requester with new contact
3462        let mut update = proto::UpdateContacts::default();
3463        if accept {
3464            update
3465                .contacts
3466                .push(contact_for_user(responder_id, responder_busy, &pool));
3467        }
3468        update
3469            .remove_outgoing_requests
3470            .push(responder_id.to_proto());
3471
3472        for connection_id in pool.user_connection_ids(requester_id) {
3473            session.peer.send(connection_id, update.clone())?;
3474        }
3475
3476        send_notifications(&pool, &session.peer, notifications);
3477    }
3478
3479    response.send(proto::Ack {})?;
3480    Ok(())
3481}
3482
3483/// Remove a contact.
3484async fn remove_contact(
3485    request: proto::RemoveContact,
3486    response: Response<proto::RemoveContact>,
3487    session: UserSession,
3488) -> Result<()> {
3489    let requester_id = session.user_id();
3490    let responder_id = UserId::from_proto(request.user_id);
3491    let db = session.db().await;
3492    let (contact_accepted, deleted_notification_id) =
3493        db.remove_contact(requester_id, responder_id).await?;
3494
3495    let pool = session.connection_pool().await;
3496    // Update outgoing contact requests of requester
3497    let mut update = proto::UpdateContacts::default();
3498    if contact_accepted {
3499        update.remove_contacts.push(responder_id.to_proto());
3500    } else {
3501        update
3502            .remove_outgoing_requests
3503            .push(responder_id.to_proto());
3504    }
3505    for connection_id in pool.user_connection_ids(requester_id) {
3506        session.peer.send(connection_id, update.clone())?;
3507    }
3508
3509    // Update incoming contact requests of responder
3510    let mut update = proto::UpdateContacts::default();
3511    if contact_accepted {
3512        update.remove_contacts.push(requester_id.to_proto());
3513    } else {
3514        update
3515            .remove_incoming_requests
3516            .push(requester_id.to_proto());
3517    }
3518    for connection_id in pool.user_connection_ids(responder_id) {
3519        session.peer.send(connection_id, update.clone())?;
3520        if let Some(notification_id) = deleted_notification_id {
3521            session.peer.send(
3522                connection_id,
3523                proto::DeleteNotification {
3524                    notification_id: notification_id.to_proto(),
3525                },
3526            )?;
3527        }
3528    }
3529
3530    response.send(proto::Ack {})?;
3531    Ok(())
3532}
3533
3534fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3535    version.0.minor() < 139
3536}
3537
3538async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3539    subscribe_user_to_channels(
3540        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3541        &session,
3542    )
3543    .await?;
3544    Ok(())
3545}
3546
3547async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3548    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3549    let mut pool = session.connection_pool().await;
3550    for membership in &channels_for_user.channel_memberships {
3551        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3552    }
3553    session.peer.send(
3554        session.connection_id,
3555        build_update_user_channels(&channels_for_user),
3556    )?;
3557    session.peer.send(
3558        session.connection_id,
3559        build_channels_update(channels_for_user),
3560    )?;
3561    Ok(())
3562}
3563
3564/// Creates a new channel.
3565async fn create_channel(
3566    request: proto::CreateChannel,
3567    response: Response<proto::CreateChannel>,
3568    session: UserSession,
3569) -> Result<()> {
3570    let db = session.db().await;
3571
3572    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3573    let (channel, membership) = db
3574        .create_channel(&request.name, parent_id, session.user_id())
3575        .await?;
3576
3577    let root_id = channel.root_id();
3578    let channel = Channel::from_model(channel);
3579
3580    response.send(proto::CreateChannelResponse {
3581        channel: Some(channel.to_proto()),
3582        parent_id: request.parent_id,
3583    })?;
3584
3585    let mut connection_pool = session.connection_pool().await;
3586    if let Some(membership) = membership {
3587        connection_pool.subscribe_to_channel(
3588            membership.user_id,
3589            membership.channel_id,
3590            membership.role,
3591        );
3592        let update = proto::UpdateUserChannels {
3593            channel_memberships: vec![proto::ChannelMembership {
3594                channel_id: membership.channel_id.to_proto(),
3595                role: membership.role.into(),
3596            }],
3597            ..Default::default()
3598        };
3599        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3600            session.peer.send(connection_id, update.clone())?;
3601        }
3602    }
3603
3604    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3605        if !role.can_see_channel(channel.visibility) {
3606            continue;
3607        }
3608
3609        let update = proto::UpdateChannels {
3610            channels: vec![channel.to_proto()],
3611            ..Default::default()
3612        };
3613        session.peer.send(connection_id, update.clone())?;
3614    }
3615
3616    Ok(())
3617}
3618
3619/// Delete a channel
3620async fn delete_channel(
3621    request: proto::DeleteChannel,
3622    response: Response<proto::DeleteChannel>,
3623    session: UserSession,
3624) -> Result<()> {
3625    let db = session.db().await;
3626
3627    let channel_id = request.channel_id;
3628    let (root_channel, removed_channels) = db
3629        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3630        .await?;
3631    response.send(proto::Ack {})?;
3632
3633    // Notify members of removed channels
3634    let mut update = proto::UpdateChannels::default();
3635    update
3636        .delete_channels
3637        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3638
3639    let connection_pool = session.connection_pool().await;
3640    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3641        session.peer.send(connection_id, update.clone())?;
3642    }
3643
3644    Ok(())
3645}
3646
3647/// Invite someone to join a channel.
3648async fn invite_channel_member(
3649    request: proto::InviteChannelMember,
3650    response: Response<proto::InviteChannelMember>,
3651    session: UserSession,
3652) -> Result<()> {
3653    let db = session.db().await;
3654    let channel_id = ChannelId::from_proto(request.channel_id);
3655    let invitee_id = UserId::from_proto(request.user_id);
3656    let InviteMemberResult {
3657        channel,
3658        notifications,
3659    } = db
3660        .invite_channel_member(
3661            channel_id,
3662            invitee_id,
3663            session.user_id(),
3664            request.role().into(),
3665        )
3666        .await?;
3667
3668    let update = proto::UpdateChannels {
3669        channel_invitations: vec![channel.to_proto()],
3670        ..Default::default()
3671    };
3672
3673    let connection_pool = session.connection_pool().await;
3674    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3675        session.peer.send(connection_id, update.clone())?;
3676    }
3677
3678    send_notifications(&connection_pool, &session.peer, notifications);
3679
3680    response.send(proto::Ack {})?;
3681    Ok(())
3682}
3683
3684/// remove someone from a channel
3685async fn remove_channel_member(
3686    request: proto::RemoveChannelMember,
3687    response: Response<proto::RemoveChannelMember>,
3688    session: UserSession,
3689) -> Result<()> {
3690    let db = session.db().await;
3691    let channel_id = ChannelId::from_proto(request.channel_id);
3692    let member_id = UserId::from_proto(request.user_id);
3693
3694    let RemoveChannelMemberResult {
3695        membership_update,
3696        notification_id,
3697    } = db
3698        .remove_channel_member(channel_id, member_id, session.user_id())
3699        .await?;
3700
3701    let mut connection_pool = session.connection_pool().await;
3702    notify_membership_updated(
3703        &mut connection_pool,
3704        membership_update,
3705        member_id,
3706        &session.peer,
3707    );
3708    for connection_id in connection_pool.user_connection_ids(member_id) {
3709        if let Some(notification_id) = notification_id {
3710            session
3711                .peer
3712                .send(
3713                    connection_id,
3714                    proto::DeleteNotification {
3715                        notification_id: notification_id.to_proto(),
3716                    },
3717                )
3718                .trace_err();
3719        }
3720    }
3721
3722    response.send(proto::Ack {})?;
3723    Ok(())
3724}
3725
3726/// Toggle the channel between public and private.
3727/// Care is taken to maintain the invariant that public channels only descend from public channels,
3728/// (though members-only channels can appear at any point in the hierarchy).
3729async fn set_channel_visibility(
3730    request: proto::SetChannelVisibility,
3731    response: Response<proto::SetChannelVisibility>,
3732    session: UserSession,
3733) -> Result<()> {
3734    let db = session.db().await;
3735    let channel_id = ChannelId::from_proto(request.channel_id);
3736    let visibility = request.visibility().into();
3737
3738    let channel_model = db
3739        .set_channel_visibility(channel_id, visibility, session.user_id())
3740        .await?;
3741    let root_id = channel_model.root_id();
3742    let channel = Channel::from_model(channel_model);
3743
3744    let mut connection_pool = session.connection_pool().await;
3745    for (user_id, role) in connection_pool
3746        .channel_user_ids(root_id)
3747        .collect::<Vec<_>>()
3748        .into_iter()
3749    {
3750        let update = if role.can_see_channel(channel.visibility) {
3751            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3752            proto::UpdateChannels {
3753                channels: vec![channel.to_proto()],
3754                ..Default::default()
3755            }
3756        } else {
3757            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3758            proto::UpdateChannels {
3759                delete_channels: vec![channel.id.to_proto()],
3760                ..Default::default()
3761            }
3762        };
3763
3764        for connection_id in connection_pool.user_connection_ids(user_id) {
3765            session.peer.send(connection_id, update.clone())?;
3766        }
3767    }
3768
3769    response.send(proto::Ack {})?;
3770    Ok(())
3771}
3772
3773/// Alter the role for a user in the channel.
3774async fn set_channel_member_role(
3775    request: proto::SetChannelMemberRole,
3776    response: Response<proto::SetChannelMemberRole>,
3777    session: UserSession,
3778) -> Result<()> {
3779    let db = session.db().await;
3780    let channel_id = ChannelId::from_proto(request.channel_id);
3781    let member_id = UserId::from_proto(request.user_id);
3782    let result = db
3783        .set_channel_member_role(
3784            channel_id,
3785            session.user_id(),
3786            member_id,
3787            request.role().into(),
3788        )
3789        .await?;
3790
3791    match result {
3792        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3793            let mut connection_pool = session.connection_pool().await;
3794            notify_membership_updated(
3795                &mut connection_pool,
3796                membership_update,
3797                member_id,
3798                &session.peer,
3799            )
3800        }
3801        db::SetMemberRoleResult::InviteUpdated(channel) => {
3802            let update = proto::UpdateChannels {
3803                channel_invitations: vec![channel.to_proto()],
3804                ..Default::default()
3805            };
3806
3807            for connection_id in session
3808                .connection_pool()
3809                .await
3810                .user_connection_ids(member_id)
3811            {
3812                session.peer.send(connection_id, update.clone())?;
3813            }
3814        }
3815    }
3816
3817    response.send(proto::Ack {})?;
3818    Ok(())
3819}
3820
3821/// Change the name of a channel
3822async fn rename_channel(
3823    request: proto::RenameChannel,
3824    response: Response<proto::RenameChannel>,
3825    session: UserSession,
3826) -> Result<()> {
3827    let db = session.db().await;
3828    let channel_id = ChannelId::from_proto(request.channel_id);
3829    let channel_model = db
3830        .rename_channel(channel_id, session.user_id(), &request.name)
3831        .await?;
3832    let root_id = channel_model.root_id();
3833    let channel = Channel::from_model(channel_model);
3834
3835    response.send(proto::RenameChannelResponse {
3836        channel: Some(channel.to_proto()),
3837    })?;
3838
3839    let connection_pool = session.connection_pool().await;
3840    let update = proto::UpdateChannels {
3841        channels: vec![channel.to_proto()],
3842        ..Default::default()
3843    };
3844    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3845        if role.can_see_channel(channel.visibility) {
3846            session.peer.send(connection_id, update.clone())?;
3847        }
3848    }
3849
3850    Ok(())
3851}
3852
3853/// Move a channel to a new parent.
3854async fn move_channel(
3855    request: proto::MoveChannel,
3856    response: Response<proto::MoveChannel>,
3857    session: UserSession,
3858) -> Result<()> {
3859    let channel_id = ChannelId::from_proto(request.channel_id);
3860    let to = ChannelId::from_proto(request.to);
3861
3862    let (root_id, channels) = session
3863        .db()
3864        .await
3865        .move_channel(channel_id, to, session.user_id())
3866        .await?;
3867
3868    let connection_pool = session.connection_pool().await;
3869    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3870        let channels = channels
3871            .iter()
3872            .filter_map(|channel| {
3873                if role.can_see_channel(channel.visibility) {
3874                    Some(channel.to_proto())
3875                } else {
3876                    None
3877                }
3878            })
3879            .collect::<Vec<_>>();
3880        if channels.is_empty() {
3881            continue;
3882        }
3883
3884        let update = proto::UpdateChannels {
3885            channels,
3886            ..Default::default()
3887        };
3888
3889        session.peer.send(connection_id, update.clone())?;
3890    }
3891
3892    response.send(Ack {})?;
3893    Ok(())
3894}
3895
3896/// Get the list of channel members
3897async fn get_channel_members(
3898    request: proto::GetChannelMembers,
3899    response: Response<proto::GetChannelMembers>,
3900    session: UserSession,
3901) -> Result<()> {
3902    let db = session.db().await;
3903    let channel_id = ChannelId::from_proto(request.channel_id);
3904    let limit = if request.limit == 0 {
3905        u16::MAX as u64
3906    } else {
3907        request.limit
3908    };
3909    let (members, users) = db
3910        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3911        .await?;
3912    response.send(proto::GetChannelMembersResponse { members, users })?;
3913    Ok(())
3914}
3915
3916/// Accept or decline a channel invitation.
3917async fn respond_to_channel_invite(
3918    request: proto::RespondToChannelInvite,
3919    response: Response<proto::RespondToChannelInvite>,
3920    session: UserSession,
3921) -> Result<()> {
3922    let db = session.db().await;
3923    let channel_id = ChannelId::from_proto(request.channel_id);
3924    let RespondToChannelInvite {
3925        membership_update,
3926        notifications,
3927    } = db
3928        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3929        .await?;
3930
3931    let mut connection_pool = session.connection_pool().await;
3932    if let Some(membership_update) = membership_update {
3933        notify_membership_updated(
3934            &mut connection_pool,
3935            membership_update,
3936            session.user_id(),
3937            &session.peer,
3938        );
3939    } else {
3940        let update = proto::UpdateChannels {
3941            remove_channel_invitations: vec![channel_id.to_proto()],
3942            ..Default::default()
3943        };
3944
3945        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3946            session.peer.send(connection_id, update.clone())?;
3947        }
3948    };
3949
3950    send_notifications(&connection_pool, &session.peer, notifications);
3951
3952    response.send(proto::Ack {})?;
3953
3954    Ok(())
3955}
3956
3957/// Join the channels' room
3958async fn join_channel(
3959    request: proto::JoinChannel,
3960    response: Response<proto::JoinChannel>,
3961    session: UserSession,
3962) -> Result<()> {
3963    let channel_id = ChannelId::from_proto(request.channel_id);
3964    join_channel_internal(channel_id, Box::new(response), session).await
3965}
3966
3967trait JoinChannelInternalResponse {
3968    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3969}
3970impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3971    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3972        Response::<proto::JoinChannel>::send(self, result)
3973    }
3974}
3975impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3976    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3977        Response::<proto::JoinRoom>::send(self, result)
3978    }
3979}
3980
3981async fn join_channel_internal(
3982    channel_id: ChannelId,
3983    response: Box<impl JoinChannelInternalResponse>,
3984    session: UserSession,
3985) -> Result<()> {
3986    let joined_room = {
3987        let mut db = session.db().await;
3988        // If zed quits without leaving the room, and the user re-opens zed before the
3989        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3990        // room they were in.
3991        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3992            tracing::info!(
3993                stale_connection_id = %connection,
3994                "cleaning up stale connection",
3995            );
3996            drop(db);
3997            leave_room_for_session(&session, connection).await?;
3998            db = session.db().await;
3999        }
4000
4001        let (joined_room, membership_updated, role) = db
4002            .join_channel(channel_id, session.user_id(), session.connection_id)
4003            .await?;
4004
4005        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4006            let (can_publish, token) = if role == ChannelRole::Guest {
4007                (
4008                    false,
4009                    live_kit
4010                        .guest_token(
4011                            &joined_room.room.live_kit_room,
4012                            &session.user_id().to_string(),
4013                        )
4014                        .trace_err()?,
4015                )
4016            } else {
4017                (
4018                    true,
4019                    live_kit
4020                        .room_token(
4021                            &joined_room.room.live_kit_room,
4022                            &session.user_id().to_string(),
4023                        )
4024                        .trace_err()?,
4025                )
4026            };
4027
4028            Some(LiveKitConnectionInfo {
4029                server_url: live_kit.url().into(),
4030                token,
4031                can_publish,
4032            })
4033        });
4034
4035        response.send(proto::JoinRoomResponse {
4036            room: Some(joined_room.room.clone()),
4037            channel_id: joined_room
4038                .channel
4039                .as_ref()
4040                .map(|channel| channel.id.to_proto()),
4041            live_kit_connection_info,
4042        })?;
4043
4044        let mut connection_pool = session.connection_pool().await;
4045        if let Some(membership_updated) = membership_updated {
4046            notify_membership_updated(
4047                &mut connection_pool,
4048                membership_updated,
4049                session.user_id(),
4050                &session.peer,
4051            );
4052        }
4053
4054        room_updated(&joined_room.room, &session.peer);
4055
4056        joined_room
4057    };
4058
4059    channel_updated(
4060        &joined_room
4061            .channel
4062            .ok_or_else(|| anyhow!("channel not returned"))?,
4063        &joined_room.room,
4064        &session.peer,
4065        &*session.connection_pool().await,
4066    );
4067
4068    update_user_contacts(session.user_id(), &session).await?;
4069    Ok(())
4070}
4071
4072/// Start editing the channel notes
4073async fn join_channel_buffer(
4074    request: proto::JoinChannelBuffer,
4075    response: Response<proto::JoinChannelBuffer>,
4076    session: UserSession,
4077) -> Result<()> {
4078    let db = session.db().await;
4079    let channel_id = ChannelId::from_proto(request.channel_id);
4080
4081    let open_response = db
4082        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4083        .await?;
4084
4085    let collaborators = open_response.collaborators.clone();
4086    response.send(open_response)?;
4087
4088    let update = UpdateChannelBufferCollaborators {
4089        channel_id: channel_id.to_proto(),
4090        collaborators: collaborators.clone(),
4091    };
4092    channel_buffer_updated(
4093        session.connection_id,
4094        collaborators
4095            .iter()
4096            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4097        &update,
4098        &session.peer,
4099    );
4100
4101    Ok(())
4102}
4103
4104/// Edit the channel notes
4105async fn update_channel_buffer(
4106    request: proto::UpdateChannelBuffer,
4107    session: UserSession,
4108) -> Result<()> {
4109    let db = session.db().await;
4110    let channel_id = ChannelId::from_proto(request.channel_id);
4111
4112    let (collaborators, epoch, version) = db
4113        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4114        .await?;
4115
4116    channel_buffer_updated(
4117        session.connection_id,
4118        collaborators.clone(),
4119        &proto::UpdateChannelBuffer {
4120            channel_id: channel_id.to_proto(),
4121            operations: request.operations,
4122        },
4123        &session.peer,
4124    );
4125
4126    let pool = &*session.connection_pool().await;
4127
4128    let non_collaborators =
4129        pool.channel_connection_ids(channel_id)
4130            .filter_map(|(connection_id, _)| {
4131                if collaborators.contains(&connection_id) {
4132                    None
4133                } else {
4134                    Some(connection_id)
4135                }
4136            });
4137
4138    broadcast(None, non_collaborators, |peer_id| {
4139        session.peer.send(
4140            peer_id,
4141            proto::UpdateChannels {
4142                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4143                    channel_id: channel_id.to_proto(),
4144                    epoch: epoch as u64,
4145                    version: version.clone(),
4146                }],
4147                ..Default::default()
4148            },
4149        )
4150    });
4151
4152    Ok(())
4153}
4154
4155/// Rejoin the channel notes after a connection blip
4156async fn rejoin_channel_buffers(
4157    request: proto::RejoinChannelBuffers,
4158    response: Response<proto::RejoinChannelBuffers>,
4159    session: UserSession,
4160) -> Result<()> {
4161    let db = session.db().await;
4162    let buffers = db
4163        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4164        .await?;
4165
4166    for rejoined_buffer in &buffers {
4167        let collaborators_to_notify = rejoined_buffer
4168            .buffer
4169            .collaborators
4170            .iter()
4171            .filter_map(|c| Some(c.peer_id?.into()));
4172        channel_buffer_updated(
4173            session.connection_id,
4174            collaborators_to_notify,
4175            &proto::UpdateChannelBufferCollaborators {
4176                channel_id: rejoined_buffer.buffer.channel_id,
4177                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4178            },
4179            &session.peer,
4180        );
4181    }
4182
4183    response.send(proto::RejoinChannelBuffersResponse {
4184        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4185    })?;
4186
4187    Ok(())
4188}
4189
4190/// Stop editing the channel notes
4191async fn leave_channel_buffer(
4192    request: proto::LeaveChannelBuffer,
4193    response: Response<proto::LeaveChannelBuffer>,
4194    session: UserSession,
4195) -> Result<()> {
4196    let db = session.db().await;
4197    let channel_id = ChannelId::from_proto(request.channel_id);
4198
4199    let left_buffer = db
4200        .leave_channel_buffer(channel_id, session.connection_id)
4201        .await?;
4202
4203    response.send(Ack {})?;
4204
4205    channel_buffer_updated(
4206        session.connection_id,
4207        left_buffer.connections,
4208        &proto::UpdateChannelBufferCollaborators {
4209            channel_id: channel_id.to_proto(),
4210            collaborators: left_buffer.collaborators,
4211        },
4212        &session.peer,
4213    );
4214
4215    Ok(())
4216}
4217
4218fn channel_buffer_updated<T: EnvelopedMessage>(
4219    sender_id: ConnectionId,
4220    collaborators: impl IntoIterator<Item = ConnectionId>,
4221    message: &T,
4222    peer: &Peer,
4223) {
4224    broadcast(Some(sender_id), collaborators, |peer_id| {
4225        peer.send(peer_id, message.clone())
4226    });
4227}
4228
4229fn send_notifications(
4230    connection_pool: &ConnectionPool,
4231    peer: &Peer,
4232    notifications: db::NotificationBatch,
4233) {
4234    for (user_id, notification) in notifications {
4235        for connection_id in connection_pool.user_connection_ids(user_id) {
4236            if let Err(error) = peer.send(
4237                connection_id,
4238                proto::AddNotification {
4239                    notification: Some(notification.clone()),
4240                },
4241            ) {
4242                tracing::error!(
4243                    "failed to send notification to {:?} {}",
4244                    connection_id,
4245                    error
4246                );
4247            }
4248        }
4249    }
4250}
4251
4252/// Send a message to the channel
4253async fn send_channel_message(
4254    request: proto::SendChannelMessage,
4255    response: Response<proto::SendChannelMessage>,
4256    session: UserSession,
4257) -> Result<()> {
4258    // Validate the message body.
4259    let body = request.body.trim().to_string();
4260    if body.len() > MAX_MESSAGE_LEN {
4261        return Err(anyhow!("message is too long"))?;
4262    }
4263    if body.is_empty() {
4264        return Err(anyhow!("message can't be blank"))?;
4265    }
4266
4267    // TODO: adjust mentions if body is trimmed
4268
4269    let timestamp = OffsetDateTime::now_utc();
4270    let nonce = request
4271        .nonce
4272        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4273
4274    let channel_id = ChannelId::from_proto(request.channel_id);
4275    let CreatedChannelMessage {
4276        message_id,
4277        participant_connection_ids,
4278        notifications,
4279    } = session
4280        .db()
4281        .await
4282        .create_channel_message(
4283            channel_id,
4284            session.user_id(),
4285            &body,
4286            &request.mentions,
4287            timestamp,
4288            nonce.clone().into(),
4289            match request.reply_to_message_id {
4290                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4291                None => None,
4292            },
4293        )
4294        .await?;
4295
4296    let message = proto::ChannelMessage {
4297        sender_id: session.user_id().to_proto(),
4298        id: message_id.to_proto(),
4299        body,
4300        mentions: request.mentions,
4301        timestamp: timestamp.unix_timestamp() as u64,
4302        nonce: Some(nonce),
4303        reply_to_message_id: request.reply_to_message_id,
4304        edited_at: None,
4305    };
4306    broadcast(
4307        Some(session.connection_id),
4308        participant_connection_ids.clone(),
4309        |connection| {
4310            session.peer.send(
4311                connection,
4312                proto::ChannelMessageSent {
4313                    channel_id: channel_id.to_proto(),
4314                    message: Some(message.clone()),
4315                },
4316            )
4317        },
4318    );
4319    response.send(proto::SendChannelMessageResponse {
4320        message: Some(message),
4321    })?;
4322
4323    let pool = &*session.connection_pool().await;
4324    let non_participants =
4325        pool.channel_connection_ids(channel_id)
4326            .filter_map(|(connection_id, _)| {
4327                if participant_connection_ids.contains(&connection_id) {
4328                    None
4329                } else {
4330                    Some(connection_id)
4331                }
4332            });
4333    broadcast(None, non_participants, |peer_id| {
4334        session.peer.send(
4335            peer_id,
4336            proto::UpdateChannels {
4337                latest_channel_message_ids: vec![proto::ChannelMessageId {
4338                    channel_id: channel_id.to_proto(),
4339                    message_id: message_id.to_proto(),
4340                }],
4341                ..Default::default()
4342            },
4343        )
4344    });
4345    send_notifications(pool, &session.peer, notifications);
4346
4347    Ok(())
4348}
4349
4350/// Delete a channel message
4351async fn remove_channel_message(
4352    request: proto::RemoveChannelMessage,
4353    response: Response<proto::RemoveChannelMessage>,
4354    session: UserSession,
4355) -> Result<()> {
4356    let channel_id = ChannelId::from_proto(request.channel_id);
4357    let message_id = MessageId::from_proto(request.message_id);
4358    let (connection_ids, existing_notification_ids) = session
4359        .db()
4360        .await
4361        .remove_channel_message(channel_id, message_id, session.user_id())
4362        .await?;
4363
4364    broadcast(
4365        Some(session.connection_id),
4366        connection_ids,
4367        move |connection| {
4368            session.peer.send(connection, request.clone())?;
4369
4370            for notification_id in &existing_notification_ids {
4371                session.peer.send(
4372                    connection,
4373                    proto::DeleteNotification {
4374                        notification_id: (*notification_id).to_proto(),
4375                    },
4376                )?;
4377            }
4378
4379            Ok(())
4380        },
4381    );
4382    response.send(proto::Ack {})?;
4383    Ok(())
4384}
4385
4386async fn update_channel_message(
4387    request: proto::UpdateChannelMessage,
4388    response: Response<proto::UpdateChannelMessage>,
4389    session: UserSession,
4390) -> Result<()> {
4391    let channel_id = ChannelId::from_proto(request.channel_id);
4392    let message_id = MessageId::from_proto(request.message_id);
4393    let updated_at = OffsetDateTime::now_utc();
4394    let UpdatedChannelMessage {
4395        message_id,
4396        participant_connection_ids,
4397        notifications,
4398        reply_to_message_id,
4399        timestamp,
4400        deleted_mention_notification_ids,
4401        updated_mention_notifications,
4402    } = session
4403        .db()
4404        .await
4405        .update_channel_message(
4406            channel_id,
4407            message_id,
4408            session.user_id(),
4409            request.body.as_str(),
4410            &request.mentions,
4411            updated_at,
4412        )
4413        .await?;
4414
4415    let nonce = request
4416        .nonce
4417        .clone()
4418        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4419
4420    let message = proto::ChannelMessage {
4421        sender_id: session.user_id().to_proto(),
4422        id: message_id.to_proto(),
4423        body: request.body.clone(),
4424        mentions: request.mentions.clone(),
4425        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4426        nonce: Some(nonce),
4427        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4428        edited_at: Some(updated_at.unix_timestamp() as u64),
4429    };
4430
4431    response.send(proto::Ack {})?;
4432
4433    let pool = &*session.connection_pool().await;
4434    broadcast(
4435        Some(session.connection_id),
4436        participant_connection_ids,
4437        |connection| {
4438            session.peer.send(
4439                connection,
4440                proto::ChannelMessageUpdate {
4441                    channel_id: channel_id.to_proto(),
4442                    message: Some(message.clone()),
4443                },
4444            )?;
4445
4446            for notification_id in &deleted_mention_notification_ids {
4447                session.peer.send(
4448                    connection,
4449                    proto::DeleteNotification {
4450                        notification_id: (*notification_id).to_proto(),
4451                    },
4452                )?;
4453            }
4454
4455            for notification in &updated_mention_notifications {
4456                session.peer.send(
4457                    connection,
4458                    proto::UpdateNotification {
4459                        notification: Some(notification.clone()),
4460                    },
4461                )?;
4462            }
4463
4464            Ok(())
4465        },
4466    );
4467
4468    send_notifications(pool, &session.peer, notifications);
4469
4470    Ok(())
4471}
4472
4473/// Mark a channel message as read
4474async fn acknowledge_channel_message(
4475    request: proto::AckChannelMessage,
4476    session: UserSession,
4477) -> Result<()> {
4478    let channel_id = ChannelId::from_proto(request.channel_id);
4479    let message_id = MessageId::from_proto(request.message_id);
4480    let notifications = session
4481        .db()
4482        .await
4483        .observe_channel_message(channel_id, session.user_id(), message_id)
4484        .await?;
4485    send_notifications(
4486        &*session.connection_pool().await,
4487        &session.peer,
4488        notifications,
4489    );
4490    Ok(())
4491}
4492
4493/// Mark a buffer version as synced
4494async fn acknowledge_buffer_version(
4495    request: proto::AckBufferOperation,
4496    session: UserSession,
4497) -> Result<()> {
4498    let buffer_id = BufferId::from_proto(request.buffer_id);
4499    session
4500        .db()
4501        .await
4502        .observe_buffer_version(
4503            buffer_id,
4504            session.user_id(),
4505            request.epoch as i32,
4506            &request.version,
4507        )
4508        .await?;
4509    Ok(())
4510}
4511
4512struct CompleteWithLanguageModelRateLimit;
4513
4514impl RateLimit for CompleteWithLanguageModelRateLimit {
4515    fn capacity() -> usize {
4516        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4517            .ok()
4518            .and_then(|v| v.parse().ok())
4519            .unwrap_or(120) // Picked arbitrarily
4520    }
4521
4522    fn refill_duration() -> chrono::Duration {
4523        chrono::Duration::hours(1)
4524    }
4525
4526    fn db_name() -> &'static str {
4527        "complete-with-language-model"
4528    }
4529}
4530
4531async fn complete_with_language_model(
4532    request: proto::CompleteWithLanguageModel,
4533    response: Response<proto::CompleteWithLanguageModel>,
4534    session: Session,
4535    config: &Config,
4536) -> Result<()> {
4537    let Some(session) = session.for_user() else {
4538        return Err(anyhow!("user not found"))?;
4539    };
4540    authorize_access_to_language_models(&session).await?;
4541
4542    session
4543        .rate_limiter
4544        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4545        .await?;
4546
4547    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4548        Some(proto::LanguageModelProvider::Anthropic) => {
4549            let api_key = config
4550                .anthropic_api_key
4551                .as_ref()
4552                .context("no Anthropic AI API key configured on the server")?;
4553            anthropic::complete(
4554                session.http_client.as_ref(),
4555                anthropic::ANTHROPIC_API_URL,
4556                api_key,
4557                serde_json::from_str(&request.request)?,
4558            )
4559            .await?
4560        }
4561        _ => return Err(anyhow!("unsupported provider"))?,
4562    };
4563
4564    response.send(proto::CompleteWithLanguageModelResponse {
4565        completion: serde_json::to_string(&result)?,
4566    })?;
4567
4568    Ok(())
4569}
4570
4571async fn stream_complete_with_language_model(
4572    request: proto::StreamCompleteWithLanguageModel,
4573    response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4574    session: Session,
4575    config: &Config,
4576) -> Result<()> {
4577    let Some(session) = session.for_user() else {
4578        return Err(anyhow!("user not found"))?;
4579    };
4580    authorize_access_to_language_models(&session).await?;
4581
4582    session
4583        .rate_limiter
4584        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4585        .await?;
4586
4587    match proto::LanguageModelProvider::from_i32(request.provider) {
4588        Some(proto::LanguageModelProvider::Anthropic) => {
4589            let api_key = config
4590                .anthropic_api_key
4591                .as_ref()
4592                .context("no Anthropic AI API key configured on the server")?;
4593            let mut chunks = anthropic::stream_completion(
4594                session.http_client.as_ref(),
4595                anthropic::ANTHROPIC_API_URL,
4596                api_key,
4597                serde_json::from_str(&request.request)?,
4598                None,
4599            )
4600            .await?;
4601            while let Some(event) = chunks.next().await {
4602                let chunk = event?;
4603                response.send(proto::StreamCompleteWithLanguageModelResponse {
4604                    event: serde_json::to_string(&chunk)?,
4605                })?;
4606            }
4607        }
4608        Some(proto::LanguageModelProvider::OpenAi) => {
4609            let api_key = config
4610                .openai_api_key
4611                .as_ref()
4612                .context("no OpenAI API key configured on the server")?;
4613            let mut events = open_ai::stream_completion(
4614                session.http_client.as_ref(),
4615                open_ai::OPEN_AI_API_URL,
4616                api_key,
4617                serde_json::from_str(&request.request)?,
4618                None,
4619            )
4620            .await?;
4621            while let Some(event) = events.next().await {
4622                let event = event?;
4623                response.send(proto::StreamCompleteWithLanguageModelResponse {
4624                    event: serde_json::to_string(&event)?,
4625                })?;
4626            }
4627        }
4628        Some(proto::LanguageModelProvider::Google) => {
4629            let api_key = config
4630                .google_ai_api_key
4631                .as_ref()
4632                .context("no Google AI API key configured on the server")?;
4633            let mut events = google_ai::stream_generate_content(
4634                session.http_client.as_ref(),
4635                google_ai::API_URL,
4636                api_key,
4637                serde_json::from_str(&request.request)?,
4638            )
4639            .await?;
4640            while let Some(event) = events.next().await {
4641                let event = event?;
4642                response.send(proto::StreamCompleteWithLanguageModelResponse {
4643                    event: serde_json::to_string(&event)?,
4644                })?;
4645            }
4646        }
4647        None => return Err(anyhow!("unknown provider"))?,
4648    }
4649
4650    Ok(())
4651}
4652
4653async fn count_language_model_tokens(
4654    request: proto::CountLanguageModelTokens,
4655    response: Response<proto::CountLanguageModelTokens>,
4656    session: Session,
4657    config: &Config,
4658) -> Result<()> {
4659    let Some(session) = session.for_user() else {
4660        return Err(anyhow!("user not found"))?;
4661    };
4662    authorize_access_to_language_models(&session).await?;
4663
4664    session
4665        .rate_limiter
4666        .check::<CountLanguageModelTokensRateLimit>(session.user_id())
4667        .await?;
4668
4669    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4670        Some(proto::LanguageModelProvider::Google) => {
4671            let api_key = config
4672                .google_ai_api_key
4673                .as_ref()
4674                .context("no Google AI API key configured on the server")?;
4675            google_ai::count_tokens(
4676                session.http_client.as_ref(),
4677                google_ai::API_URL,
4678                api_key,
4679                serde_json::from_str(&request.request)?,
4680            )
4681            .await?
4682        }
4683        _ => return Err(anyhow!("unsupported provider"))?,
4684    };
4685
4686    response.send(proto::CountLanguageModelTokensResponse {
4687        token_count: result.total_tokens as u32,
4688    })?;
4689
4690    Ok(())
4691}
4692
4693struct CountLanguageModelTokensRateLimit;
4694
4695impl RateLimit for CountLanguageModelTokensRateLimit {
4696    fn capacity() -> usize {
4697        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4698            .ok()
4699            .and_then(|v| v.parse().ok())
4700            .unwrap_or(600) // Picked arbitrarily
4701    }
4702
4703    fn refill_duration() -> chrono::Duration {
4704        chrono::Duration::hours(1)
4705    }
4706
4707    fn db_name() -> &'static str {
4708        "count-language-model-tokens"
4709    }
4710}
4711
4712struct ComputeEmbeddingsRateLimit;
4713
4714impl RateLimit for ComputeEmbeddingsRateLimit {
4715    fn capacity() -> usize {
4716        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4717            .ok()
4718            .and_then(|v| v.parse().ok())
4719            .unwrap_or(5000) // Picked arbitrarily
4720    }
4721
4722    fn refill_duration() -> chrono::Duration {
4723        chrono::Duration::hours(1)
4724    }
4725
4726    fn db_name() -> &'static str {
4727        "compute-embeddings"
4728    }
4729}
4730
4731async fn compute_embeddings(
4732    request: proto::ComputeEmbeddings,
4733    response: Response<proto::ComputeEmbeddings>,
4734    session: UserSession,
4735    api_key: Option<Arc<str>>,
4736) -> Result<()> {
4737    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4738    authorize_access_to_language_models(&session).await?;
4739
4740    session
4741        .rate_limiter
4742        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4743        .await?;
4744
4745    let embeddings = match request.model.as_str() {
4746        "openai/text-embedding-3-small" => {
4747            open_ai::embed(
4748                session.http_client.as_ref(),
4749                OPEN_AI_API_URL,
4750                &api_key,
4751                OpenAiEmbeddingModel::TextEmbedding3Small,
4752                request.texts.iter().map(|text| text.as_str()),
4753            )
4754            .await?
4755        }
4756        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4757    };
4758
4759    let embeddings = request
4760        .texts
4761        .iter()
4762        .map(|text| {
4763            let mut hasher = sha2::Sha256::new();
4764            hasher.update(text.as_bytes());
4765            let result = hasher.finalize();
4766            result.to_vec()
4767        })
4768        .zip(
4769            embeddings
4770                .data
4771                .into_iter()
4772                .map(|embedding| embedding.embedding),
4773        )
4774        .collect::<HashMap<_, _>>();
4775
4776    let db = session.db().await;
4777    db.save_embeddings(&request.model, &embeddings)
4778        .await
4779        .context("failed to save embeddings")
4780        .trace_err();
4781
4782    response.send(proto::ComputeEmbeddingsResponse {
4783        embeddings: embeddings
4784            .into_iter()
4785            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4786            .collect(),
4787    })?;
4788    Ok(())
4789}
4790
4791async fn get_cached_embeddings(
4792    request: proto::GetCachedEmbeddings,
4793    response: Response<proto::GetCachedEmbeddings>,
4794    session: UserSession,
4795) -> Result<()> {
4796    authorize_access_to_language_models(&session).await?;
4797
4798    let db = session.db().await;
4799    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4800
4801    response.send(proto::GetCachedEmbeddingsResponse {
4802        embeddings: embeddings
4803            .into_iter()
4804            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4805            .collect(),
4806    })?;
4807    Ok(())
4808}
4809
4810async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4811    let db = session.db().await;
4812    let flags = db.get_user_flags(session.user_id()).await?;
4813    if flags.iter().any(|flag| flag == "language-models") {
4814        Ok(())
4815    } else {
4816        Err(anyhow!("permission denied"))?
4817    }
4818}
4819
4820/// Get a Supermaven API key for the user
4821async fn get_supermaven_api_key(
4822    _request: proto::GetSupermavenApiKey,
4823    response: Response<proto::GetSupermavenApiKey>,
4824    session: UserSession,
4825) -> Result<()> {
4826    let user_id: String = session.user_id().to_string();
4827    if !session.is_staff() {
4828        return Err(anyhow!("supermaven not enabled for this account"))?;
4829    }
4830
4831    let email = session
4832        .email()
4833        .ok_or_else(|| anyhow!("user must have an email"))?;
4834
4835    let supermaven_admin_api = session
4836        .supermaven_client
4837        .as_ref()
4838        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4839
4840    let result = supermaven_admin_api
4841        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4842        .await?;
4843
4844    response.send(proto::GetSupermavenApiKeyResponse {
4845        api_key: result.api_key,
4846    })?;
4847
4848    Ok(())
4849}
4850
4851/// Start receiving chat updates for a channel
4852async fn join_channel_chat(
4853    request: proto::JoinChannelChat,
4854    response: Response<proto::JoinChannelChat>,
4855    session: UserSession,
4856) -> Result<()> {
4857    let channel_id = ChannelId::from_proto(request.channel_id);
4858
4859    let db = session.db().await;
4860    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4861        .await?;
4862    let messages = db
4863        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4864        .await?;
4865    response.send(proto::JoinChannelChatResponse {
4866        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4867        messages,
4868    })?;
4869    Ok(())
4870}
4871
4872/// Stop receiving chat updates for a channel
4873async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4874    let channel_id = ChannelId::from_proto(request.channel_id);
4875    session
4876        .db()
4877        .await
4878        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4879        .await?;
4880    Ok(())
4881}
4882
4883/// Retrieve the chat history for a channel
4884async fn get_channel_messages(
4885    request: proto::GetChannelMessages,
4886    response: Response<proto::GetChannelMessages>,
4887    session: UserSession,
4888) -> Result<()> {
4889    let channel_id = ChannelId::from_proto(request.channel_id);
4890    let messages = session
4891        .db()
4892        .await
4893        .get_channel_messages(
4894            channel_id,
4895            session.user_id(),
4896            MESSAGE_COUNT_PER_PAGE,
4897            Some(MessageId::from_proto(request.before_message_id)),
4898        )
4899        .await?;
4900    response.send(proto::GetChannelMessagesResponse {
4901        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4902        messages,
4903    })?;
4904    Ok(())
4905}
4906
4907/// Retrieve specific chat messages
4908async fn get_channel_messages_by_id(
4909    request: proto::GetChannelMessagesById,
4910    response: Response<proto::GetChannelMessagesById>,
4911    session: UserSession,
4912) -> Result<()> {
4913    let message_ids = request
4914        .message_ids
4915        .iter()
4916        .map(|id| MessageId::from_proto(*id))
4917        .collect::<Vec<_>>();
4918    let messages = session
4919        .db()
4920        .await
4921        .get_channel_messages_by_id(session.user_id(), &message_ids)
4922        .await?;
4923    response.send(proto::GetChannelMessagesResponse {
4924        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4925        messages,
4926    })?;
4927    Ok(())
4928}
4929
4930/// Retrieve the current users notifications
4931async fn get_notifications(
4932    request: proto::GetNotifications,
4933    response: Response<proto::GetNotifications>,
4934    session: UserSession,
4935) -> Result<()> {
4936    let notifications = session
4937        .db()
4938        .await
4939        .get_notifications(
4940            session.user_id(),
4941            NOTIFICATION_COUNT_PER_PAGE,
4942            request
4943                .before_id
4944                .map(|id| db::NotificationId::from_proto(id)),
4945        )
4946        .await?;
4947    response.send(proto::GetNotificationsResponse {
4948        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4949        notifications,
4950    })?;
4951    Ok(())
4952}
4953
4954/// Mark notifications as read
4955async fn mark_notification_as_read(
4956    request: proto::MarkNotificationRead,
4957    response: Response<proto::MarkNotificationRead>,
4958    session: UserSession,
4959) -> Result<()> {
4960    let database = &session.db().await;
4961    let notifications = database
4962        .mark_notification_as_read_by_id(
4963            session.user_id(),
4964            NotificationId::from_proto(request.notification_id),
4965        )
4966        .await?;
4967    send_notifications(
4968        &*session.connection_pool().await,
4969        &session.peer,
4970        notifications,
4971    );
4972    response.send(proto::Ack {})?;
4973    Ok(())
4974}
4975
4976/// Get the current users information
4977async fn get_private_user_info(
4978    _request: proto::GetPrivateUserInfo,
4979    response: Response<proto::GetPrivateUserInfo>,
4980    session: UserSession,
4981) -> Result<()> {
4982    let db = session.db().await;
4983
4984    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4985    let user = db
4986        .get_user_by_id(session.user_id())
4987        .await?
4988        .ok_or_else(|| anyhow!("user not found"))?;
4989    let flags = db.get_user_flags(session.user_id()).await?;
4990
4991    response.send(proto::GetPrivateUserInfoResponse {
4992        metrics_id,
4993        staff: user.admin,
4994        flags,
4995    })?;
4996    Ok(())
4997}
4998
4999fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5000    let message = match message {
5001        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5002        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5003        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5004        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5005        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5006            code: frame.code.into(),
5007            reason: frame.reason,
5008        })),
5009        // We should never receive a frame while reading the message, according
5010        // to the `tungstenite` maintainers:
5011        //
5012        // > It cannot occur when you read messages from the WebSocket, but it
5013        // > can be used when you want to send the raw frames (e.g. you want to
5014        // > send the frames to the WebSocket without composing the full message first).
5015        // >
5016        // > — https://github.com/snapview/tungstenite-rs/issues/268
5017        TungsteniteMessage::Frame(_) => {
5018            bail!("received an unexpected frame while reading the message")
5019        }
5020    };
5021
5022    Ok(message)
5023}
5024
5025fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5026    match message {
5027        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5028        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5029        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5030        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5031        AxumMessage::Close(frame) => {
5032            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5033                code: frame.code.into(),
5034                reason: frame.reason,
5035            }))
5036        }
5037    }
5038}
5039
5040fn notify_membership_updated(
5041    connection_pool: &mut ConnectionPool,
5042    result: MembershipUpdated,
5043    user_id: UserId,
5044    peer: &Peer,
5045) {
5046    for membership in &result.new_channels.channel_memberships {
5047        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5048    }
5049    for channel_id in &result.removed_channels {
5050        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5051    }
5052
5053    let user_channels_update = proto::UpdateUserChannels {
5054        channel_memberships: result
5055            .new_channels
5056            .channel_memberships
5057            .iter()
5058            .map(|cm| proto::ChannelMembership {
5059                channel_id: cm.channel_id.to_proto(),
5060                role: cm.role.into(),
5061            })
5062            .collect(),
5063        ..Default::default()
5064    };
5065
5066    let mut update = build_channels_update(result.new_channels);
5067    update.delete_channels = result
5068        .removed_channels
5069        .into_iter()
5070        .map(|id| id.to_proto())
5071        .collect();
5072    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5073
5074    for connection_id in connection_pool.user_connection_ids(user_id) {
5075        peer.send(connection_id, user_channels_update.clone())
5076            .trace_err();
5077        peer.send(connection_id, update.clone()).trace_err();
5078    }
5079}
5080
5081fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5082    proto::UpdateUserChannels {
5083        channel_memberships: channels
5084            .channel_memberships
5085            .iter()
5086            .map(|m| proto::ChannelMembership {
5087                channel_id: m.channel_id.to_proto(),
5088                role: m.role.into(),
5089            })
5090            .collect(),
5091        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5092        observed_channel_message_id: channels.observed_channel_messages.clone(),
5093    }
5094}
5095
5096fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5097    let mut update = proto::UpdateChannels::default();
5098
5099    for channel in channels.channels {
5100        update.channels.push(channel.to_proto());
5101    }
5102
5103    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5104    update.latest_channel_message_ids = channels.latest_channel_messages;
5105
5106    for (channel_id, participants) in channels.channel_participants {
5107        update
5108            .channel_participants
5109            .push(proto::ChannelParticipants {
5110                channel_id: channel_id.to_proto(),
5111                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5112            });
5113    }
5114
5115    for channel in channels.invited_channels {
5116        update.channel_invitations.push(channel.to_proto());
5117    }
5118
5119    update.hosted_projects = channels.hosted_projects;
5120    update
5121}
5122
5123fn build_initial_contacts_update(
5124    contacts: Vec<db::Contact>,
5125    pool: &ConnectionPool,
5126) -> proto::UpdateContacts {
5127    let mut update = proto::UpdateContacts::default();
5128
5129    for contact in contacts {
5130        match contact {
5131            db::Contact::Accepted { user_id, busy } => {
5132                update.contacts.push(contact_for_user(user_id, busy, &pool));
5133            }
5134            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5135            db::Contact::Incoming { user_id } => {
5136                update
5137                    .incoming_requests
5138                    .push(proto::IncomingContactRequest {
5139                        requester_id: user_id.to_proto(),
5140                    })
5141            }
5142        }
5143    }
5144
5145    update
5146}
5147
5148fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5149    proto::Contact {
5150        user_id: user_id.to_proto(),
5151        online: pool.is_user_online(user_id),
5152        busy,
5153    }
5154}
5155
5156fn room_updated(room: &proto::Room, peer: &Peer) {
5157    broadcast(
5158        None,
5159        room.participants
5160            .iter()
5161            .filter_map(|participant| Some(participant.peer_id?.into())),
5162        |peer_id| {
5163            peer.send(
5164                peer_id,
5165                proto::RoomUpdated {
5166                    room: Some(room.clone()),
5167                },
5168            )
5169        },
5170    );
5171}
5172
5173fn channel_updated(
5174    channel: &db::channel::Model,
5175    room: &proto::Room,
5176    peer: &Peer,
5177    pool: &ConnectionPool,
5178) {
5179    let participants = room
5180        .participants
5181        .iter()
5182        .map(|p| p.user_id)
5183        .collect::<Vec<_>>();
5184
5185    broadcast(
5186        None,
5187        pool.channel_connection_ids(channel.root_id())
5188            .filter_map(|(channel_id, role)| {
5189                role.can_see_channel(channel.visibility).then(|| channel_id)
5190            }),
5191        |peer_id| {
5192            peer.send(
5193                peer_id,
5194                proto::UpdateChannels {
5195                    channel_participants: vec![proto::ChannelParticipants {
5196                        channel_id: channel.id.to_proto(),
5197                        participant_user_ids: participants.clone(),
5198                    }],
5199                    ..Default::default()
5200                },
5201            )
5202        },
5203    );
5204}
5205
5206async fn send_dev_server_projects_update(
5207    user_id: UserId,
5208    mut status: proto::DevServerProjectsUpdate,
5209    session: &Session,
5210) {
5211    let pool = session.connection_pool().await;
5212    for dev_server in &mut status.dev_servers {
5213        dev_server.status =
5214            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5215    }
5216    let connections = pool.user_connection_ids(user_id);
5217    for connection_id in connections {
5218        session.peer.send(connection_id, status.clone()).trace_err();
5219    }
5220}
5221
5222async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5223    let db = session.db().await;
5224
5225    let contacts = db.get_contacts(user_id).await?;
5226    let busy = db.is_user_busy(user_id).await?;
5227
5228    let pool = session.connection_pool().await;
5229    let updated_contact = contact_for_user(user_id, busy, &pool);
5230    for contact in contacts {
5231        if let db::Contact::Accepted {
5232            user_id: contact_user_id,
5233            ..
5234        } = contact
5235        {
5236            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5237                session
5238                    .peer
5239                    .send(
5240                        contact_conn_id,
5241                        proto::UpdateContacts {
5242                            contacts: vec![updated_contact.clone()],
5243                            remove_contacts: Default::default(),
5244                            incoming_requests: Default::default(),
5245                            remove_incoming_requests: Default::default(),
5246                            outgoing_requests: Default::default(),
5247                            remove_outgoing_requests: Default::default(),
5248                        },
5249                    )
5250                    .trace_err();
5251            }
5252        }
5253    }
5254    Ok(())
5255}
5256
5257async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5258    log::info!("lost dev server connection, unsharing projects");
5259    let project_ids = session
5260        .db()
5261        .await
5262        .get_stale_dev_server_projects(session.connection_id)
5263        .await?;
5264
5265    for project_id in project_ids {
5266        // not unshare re-checks the connection ids match, so we get away with no transaction
5267        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5268    }
5269
5270    let user_id = session.dev_server().user_id;
5271    let update = session
5272        .db()
5273        .await
5274        .dev_server_projects_update(user_id)
5275        .await?;
5276
5277    send_dev_server_projects_update(user_id, update, session).await;
5278
5279    Ok(())
5280}
5281
5282async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5283    let mut contacts_to_update = HashSet::default();
5284
5285    let room_id;
5286    let canceled_calls_to_user_ids;
5287    let live_kit_room;
5288    let delete_live_kit_room;
5289    let room;
5290    let channel;
5291
5292    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5293        contacts_to_update.insert(session.user_id());
5294
5295        for project in left_room.left_projects.values() {
5296            project_left(project, session);
5297        }
5298
5299        room_id = RoomId::from_proto(left_room.room.id);
5300        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5301        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5302        delete_live_kit_room = left_room.deleted;
5303        room = mem::take(&mut left_room.room);
5304        channel = mem::take(&mut left_room.channel);
5305
5306        room_updated(&room, &session.peer);
5307    } else {
5308        return Ok(());
5309    }
5310
5311    if let Some(channel) = channel {
5312        channel_updated(
5313            &channel,
5314            &room,
5315            &session.peer,
5316            &*session.connection_pool().await,
5317        );
5318    }
5319
5320    {
5321        let pool = session.connection_pool().await;
5322        for canceled_user_id in canceled_calls_to_user_ids {
5323            for connection_id in pool.user_connection_ids(canceled_user_id) {
5324                session
5325                    .peer
5326                    .send(
5327                        connection_id,
5328                        proto::CallCanceled {
5329                            room_id: room_id.to_proto(),
5330                        },
5331                    )
5332                    .trace_err();
5333            }
5334            contacts_to_update.insert(canceled_user_id);
5335        }
5336    }
5337
5338    for contact_user_id in contacts_to_update {
5339        update_user_contacts(contact_user_id, &session).await?;
5340    }
5341
5342    if let Some(live_kit) = session.live_kit_client.as_ref() {
5343        live_kit
5344            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5345            .await
5346            .trace_err();
5347
5348        if delete_live_kit_room {
5349            live_kit.delete_room(live_kit_room).await.trace_err();
5350        }
5351    }
5352
5353    Ok(())
5354}
5355
5356async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5357    let left_channel_buffers = session
5358        .db()
5359        .await
5360        .leave_channel_buffers(session.connection_id)
5361        .await?;
5362
5363    for left_buffer in left_channel_buffers {
5364        channel_buffer_updated(
5365            session.connection_id,
5366            left_buffer.connections,
5367            &proto::UpdateChannelBufferCollaborators {
5368                channel_id: left_buffer.channel_id.to_proto(),
5369                collaborators: left_buffer.collaborators,
5370            },
5371            &session.peer,
5372        );
5373    }
5374
5375    Ok(())
5376}
5377
5378fn project_left(project: &db::LeftProject, session: &UserSession) {
5379    for connection_id in &project.connection_ids {
5380        if project.should_unshare {
5381            session
5382                .peer
5383                .send(
5384                    *connection_id,
5385                    proto::UnshareProject {
5386                        project_id: project.id.to_proto(),
5387                    },
5388                )
5389                .trace_err();
5390        } else {
5391            session
5392                .peer
5393                .send(
5394                    *connection_id,
5395                    proto::RemoveProjectCollaborator {
5396                        project_id: project.id.to_proto(),
5397                        peer_id: Some(session.connection_id.into()),
5398                    },
5399                )
5400                .trace_err();
5401        }
5402    }
5403}
5404
5405pub trait ResultExt {
5406    type Ok;
5407
5408    fn trace_err(self) -> Option<Self::Ok>;
5409}
5410
5411impl<T, E> ResultExt for Result<T, E>
5412where
5413    E: std::fmt::Debug,
5414{
5415    type Ok = T;
5416
5417    #[track_caller]
5418    fn trace_err(self) -> Option<T> {
5419        match self {
5420            Ok(value) => Some(value),
5421            Err(error) => {
5422                tracing::error!("{:?}", error);
5423                None
5424            }
5425        }
5426    }
5427}