rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Config, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, bail, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http_client::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(update_dev_server_project))
 435            .add_request_handler(user_handler(delete_dev_server_project))
 436            .add_request_handler(user_handler(create_dev_server))
 437            .add_request_handler(user_handler(regenerate_dev_server_token))
 438            .add_request_handler(user_handler(rename_dev_server))
 439            .add_request_handler(user_handler(delete_dev_server))
 440            .add_request_handler(user_handler(list_remote_directory))
 441            .add_request_handler(dev_server_handler(share_dev_server_project))
 442            .add_request_handler(dev_server_handler(shutdown_dev_server))
 443            .add_request_handler(dev_server_handler(reconnect_dev_server))
 444            .add_message_handler(user_message_handler(leave_project))
 445            .add_request_handler(update_project)
 446            .add_request_handler(update_worktree)
 447            .add_message_handler(start_language_server)
 448            .add_message_handler(update_language_server)
 449            .add_message_handler(update_diagnostic_summary)
 450            .add_message_handler(update_worktree_settings)
 451            .add_request_handler(user_handler(
 452                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 453            ))
 454            .add_request_handler(user_handler(
 455                forward_project_request_for_owner::<proto::TaskTemplates>,
 456            ))
 457            .add_request_handler(user_handler(
 458                forward_read_only_project_request::<proto::GetHover>,
 459            ))
 460            .add_request_handler(user_handler(
 461                forward_read_only_project_request::<proto::GetDefinition>,
 462            ))
 463            .add_request_handler(user_handler(
 464                forward_read_only_project_request::<proto::GetTypeDefinition>,
 465            ))
 466            .add_request_handler(user_handler(
 467                forward_read_only_project_request::<proto::GetReferences>,
 468            ))
 469            .add_request_handler(user_handler(
 470                forward_read_only_project_request::<proto::SearchProject>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetProjectSymbols>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::OpenBufferById>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::InlayHints>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::OpenBufferByPath>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_mutating_project_request::<proto::GetCompletions>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::GetCodeActions>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::ApplyCodeAction>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::PrepareRename>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::PerformRename>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ReloadBuffers>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::FormatBuffers>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::CreateProjectEntry>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::RenameProjectEntry>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::CopyProjectEntry>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::OnTypeFormatting>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::BlameBuffer>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::MultiLspQuery>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::RestartLanguageServers>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::LinkedEditingRange>,
 555            ))
 556            .add_message_handler(create_buffer_for_peer)
 557            .add_request_handler(update_buffer)
 558            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 559            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 560            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 561            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 562            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 563            .add_request_handler(get_users)
 564            .add_request_handler(user_handler(fuzzy_search_users))
 565            .add_request_handler(user_handler(request_contact))
 566            .add_request_handler(user_handler(remove_contact))
 567            .add_request_handler(user_handler(respond_to_contact_request))
 568            .add_message_handler(subscribe_to_channels)
 569            .add_request_handler(user_handler(create_channel))
 570            .add_request_handler(user_handler(delete_channel))
 571            .add_request_handler(user_handler(invite_channel_member))
 572            .add_request_handler(user_handler(remove_channel_member))
 573            .add_request_handler(user_handler(set_channel_member_role))
 574            .add_request_handler(user_handler(set_channel_visibility))
 575            .add_request_handler(user_handler(rename_channel))
 576            .add_request_handler(user_handler(join_channel_buffer))
 577            .add_request_handler(user_handler(leave_channel_buffer))
 578            .add_message_handler(user_message_handler(update_channel_buffer))
 579            .add_request_handler(user_handler(rejoin_channel_buffers))
 580            .add_request_handler(user_handler(get_channel_members))
 581            .add_request_handler(user_handler(respond_to_channel_invite))
 582            .add_request_handler(user_handler(join_channel))
 583            .add_request_handler(user_handler(join_channel_chat))
 584            .add_message_handler(user_message_handler(leave_channel_chat))
 585            .add_request_handler(user_handler(send_channel_message))
 586            .add_request_handler(user_handler(remove_channel_message))
 587            .add_request_handler(user_handler(update_channel_message))
 588            .add_request_handler(user_handler(get_channel_messages))
 589            .add_request_handler(user_handler(get_channel_messages_by_id))
 590            .add_request_handler(user_handler(get_notifications))
 591            .add_request_handler(user_handler(mark_notification_as_read))
 592            .add_request_handler(user_handler(move_channel))
 593            .add_request_handler(user_handler(follow))
 594            .add_message_handler(user_message_handler(unfollow))
 595            .add_message_handler(user_message_handler(update_followers))
 596            .add_request_handler(user_handler(get_private_user_info))
 597            .add_message_handler(user_message_handler(acknowledge_channel_message))
 598            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 599            .add_request_handler(user_handler(get_supermaven_api_key))
 600            .add_request_handler(user_handler(
 601                forward_mutating_project_request::<proto::OpenContext>,
 602            ))
 603            .add_request_handler(user_handler(
 604                forward_mutating_project_request::<proto::CreateContext>,
 605            ))
 606            .add_request_handler(user_handler(
 607                forward_mutating_project_request::<proto::SynchronizeContexts>,
 608            ))
 609            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 610            .add_message_handler(update_context)
 611            .add_request_handler({
 612                let app_state = app_state.clone();
 613                move |request, response, session| {
 614                    let app_state = app_state.clone();
 615                    async move {
 616                        complete_with_language_model(request, response, session, &app_state.config)
 617                            .await
 618                    }
 619                }
 620            })
 621            .add_streaming_request_handler({
 622                let app_state = app_state.clone();
 623                move |request, response, session| {
 624                    let app_state = app_state.clone();
 625                    async move {
 626                        stream_complete_with_language_model(
 627                            request,
 628                            response,
 629                            session,
 630                            &app_state.config,
 631                        )
 632                        .await
 633                    }
 634                }
 635            })
 636            .add_request_handler({
 637                let app_state = app_state.clone();
 638                move |request, response, session| {
 639                    let app_state = app_state.clone();
 640                    async move {
 641                        count_language_model_tokens(request, response, session, &app_state.config)
 642                            .await
 643                    }
 644                }
 645            })
 646            .add_request_handler({
 647                user_handler(move |request, response, session| {
 648                    get_cached_embeddings(request, response, session)
 649                })
 650            })
 651            .add_request_handler({
 652                let app_state = app_state.clone();
 653                user_handler(move |request, response, session| {
 654                    compute_embeddings(
 655                        request,
 656                        response,
 657                        session,
 658                        app_state.config.openai_api_key.clone(),
 659                    )
 660                })
 661            });
 662
 663        Arc::new(server)
 664    }
 665
 666    pub async fn start(&self) -> Result<()> {
 667        let server_id = *self.id.lock();
 668        let app_state = self.app_state.clone();
 669        let peer = self.peer.clone();
 670        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 671        let pool = self.connection_pool.clone();
 672        let live_kit_client = self.app_state.live_kit_client.clone();
 673
 674        let span = info_span!("start server");
 675        self.app_state.executor.spawn_detached(
 676            async move {
 677                tracing::info!("waiting for cleanup timeout");
 678                timeout.await;
 679                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 680                if let Some((room_ids, channel_ids)) = app_state
 681                    .db
 682                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 683                    .await
 684                    .trace_err()
 685                {
 686                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 687                    tracing::info!(
 688                        stale_channel_buffer_count = channel_ids.len(),
 689                        "retrieved stale channel buffers"
 690                    );
 691
 692                    for channel_id in channel_ids {
 693                        if let Some(refreshed_channel_buffer) = app_state
 694                            .db
 695                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 696                            .await
 697                            .trace_err()
 698                        {
 699                            for connection_id in refreshed_channel_buffer.connection_ids {
 700                                peer.send(
 701                                    connection_id,
 702                                    proto::UpdateChannelBufferCollaborators {
 703                                        channel_id: channel_id.to_proto(),
 704                                        collaborators: refreshed_channel_buffer
 705                                            .collaborators
 706                                            .clone(),
 707                                    },
 708                                )
 709                                .trace_err();
 710                            }
 711                        }
 712                    }
 713
 714                    for room_id in room_ids {
 715                        let mut contacts_to_update = HashSet::default();
 716                        let mut canceled_calls_to_user_ids = Vec::new();
 717                        let mut live_kit_room = String::new();
 718                        let mut delete_live_kit_room = false;
 719
 720                        if let Some(mut refreshed_room) = app_state
 721                            .db
 722                            .clear_stale_room_participants(room_id, server_id)
 723                            .await
 724                            .trace_err()
 725                        {
 726                            tracing::info!(
 727                                room_id = room_id.0,
 728                                new_participant_count = refreshed_room.room.participants.len(),
 729                                "refreshed room"
 730                            );
 731                            room_updated(&refreshed_room.room, &peer);
 732                            if let Some(channel) = refreshed_room.channel.as_ref() {
 733                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 734                            }
 735                            contacts_to_update
 736                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 737                            contacts_to_update
 738                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 739                            canceled_calls_to_user_ids =
 740                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 741                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 742                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 743                        }
 744
 745                        {
 746                            let pool = pool.lock();
 747                            for canceled_user_id in canceled_calls_to_user_ids {
 748                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 749                                    peer.send(
 750                                        connection_id,
 751                                        proto::CallCanceled {
 752                                            room_id: room_id.to_proto(),
 753                                        },
 754                                    )
 755                                    .trace_err();
 756                                }
 757                            }
 758                        }
 759
 760                        for user_id in contacts_to_update {
 761                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 762                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 763                            if let Some((busy, contacts)) = busy.zip(contacts) {
 764                                let pool = pool.lock();
 765                                let updated_contact = contact_for_user(user_id, busy, &pool);
 766                                for contact in contacts {
 767                                    if let db::Contact::Accepted {
 768                                        user_id: contact_user_id,
 769                                        ..
 770                                    } = contact
 771                                    {
 772                                        for contact_conn_id in
 773                                            pool.user_connection_ids(contact_user_id)
 774                                        {
 775                                            peer.send(
 776                                                contact_conn_id,
 777                                                proto::UpdateContacts {
 778                                                    contacts: vec![updated_contact.clone()],
 779                                                    remove_contacts: Default::default(),
 780                                                    incoming_requests: Default::default(),
 781                                                    remove_incoming_requests: Default::default(),
 782                                                    outgoing_requests: Default::default(),
 783                                                    remove_outgoing_requests: Default::default(),
 784                                                },
 785                                            )
 786                                            .trace_err();
 787                                        }
 788                                    }
 789                                }
 790                            }
 791                        }
 792
 793                        if let Some(live_kit) = live_kit_client.as_ref() {
 794                            if delete_live_kit_room {
 795                                live_kit.delete_room(live_kit_room).await.trace_err();
 796                            }
 797                        }
 798                    }
 799                }
 800
 801                app_state
 802                    .db
 803                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 804                    .await
 805                    .trace_err();
 806            }
 807            .instrument(span),
 808        );
 809        Ok(())
 810    }
 811
 812    pub fn teardown(&self) {
 813        self.peer.teardown();
 814        self.connection_pool.lock().reset();
 815        let _ = self.teardown.send(true);
 816    }
 817
 818    #[cfg(test)]
 819    pub fn reset(&self, id: ServerId) {
 820        self.teardown();
 821        *self.id.lock() = id;
 822        self.peer.reset(id.0 as u32);
 823        let _ = self.teardown.send(false);
 824    }
 825
 826    #[cfg(test)]
 827    pub fn id(&self) -> ServerId {
 828        *self.id.lock()
 829    }
 830
 831    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 832    where
 833        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 834        Fut: 'static + Send + Future<Output = Result<()>>,
 835        M: EnvelopedMessage,
 836    {
 837        let prev_handler = self.handlers.insert(
 838            TypeId::of::<M>(),
 839            Box::new(move |envelope, session| {
 840                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 841                let received_at = envelope.received_at;
 842                tracing::info!("message received");
 843                let start_time = Instant::now();
 844                let future = (handler)(*envelope, session);
 845                async move {
 846                    let result = future.await;
 847                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 848                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 849                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 850                    let payload_type = M::NAME;
 851
 852                    match result {
 853                        Err(error) => {
 854                            tracing::error!(
 855                                ?error,
 856                                total_duration_ms,
 857                                processing_duration_ms,
 858                                queue_duration_ms,
 859                                payload_type,
 860                                "error handling message"
 861                            )
 862                        }
 863                        Ok(()) => tracing::info!(
 864                            total_duration_ms,
 865                            processing_duration_ms,
 866                            queue_duration_ms,
 867                            "finished handling message"
 868                        ),
 869                    }
 870                }
 871                .boxed()
 872            }),
 873        );
 874        if prev_handler.is_some() {
 875            panic!("registered a handler for the same message twice");
 876        }
 877        self
 878    }
 879
 880    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 881    where
 882        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 883        Fut: 'static + Send + Future<Output = Result<()>>,
 884        M: EnvelopedMessage,
 885    {
 886        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 887        self
 888    }
 889
 890    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 891    where
 892        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 893        Fut: Send + Future<Output = Result<()>>,
 894        M: RequestMessage,
 895    {
 896        let handler = Arc::new(handler);
 897        self.add_handler(move |envelope, session| {
 898            let receipt = envelope.receipt();
 899            let handler = handler.clone();
 900            async move {
 901                let peer = session.peer.clone();
 902                let responded = Arc::new(AtomicBool::default());
 903                let response = Response {
 904                    peer: peer.clone(),
 905                    responded: responded.clone(),
 906                    receipt,
 907                };
 908                match (handler)(envelope.payload, response, session).await {
 909                    Ok(()) => {
 910                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 911                            Ok(())
 912                        } else {
 913                            Err(anyhow!("handler did not send a response"))?
 914                        }
 915                    }
 916                    Err(error) => {
 917                        let proto_err = match &error {
 918                            Error::Internal(err) => err.to_proto(),
 919                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 920                        };
 921                        peer.respond_with_error(receipt, proto_err)?;
 922                        Err(error)
 923                    }
 924                }
 925            }
 926        })
 927    }
 928
 929    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 930    where
 931        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 932        Fut: Send + Future<Output = Result<()>>,
 933        M: RequestMessage,
 934    {
 935        let handler = Arc::new(handler);
 936        self.add_handler(move |envelope, session| {
 937            let receipt = envelope.receipt();
 938            let handler = handler.clone();
 939            async move {
 940                let peer = session.peer.clone();
 941                let response = StreamingResponse {
 942                    peer: peer.clone(),
 943                    receipt,
 944                };
 945                match (handler)(envelope.payload, response, session).await {
 946                    Ok(()) => {
 947                        peer.end_stream(receipt)?;
 948                        Ok(())
 949                    }
 950                    Err(error) => {
 951                        let proto_err = match &error {
 952                            Error::Internal(err) => err.to_proto(),
 953                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 954                        };
 955                        peer.respond_with_error(receipt, proto_err)?;
 956                        Err(error)
 957                    }
 958                }
 959            }
 960        })
 961    }
 962
 963    #[allow(clippy::too_many_arguments)]
 964    pub fn handle_connection(
 965        self: &Arc<Self>,
 966        connection: Connection,
 967        address: String,
 968        principal: Principal,
 969        zed_version: ZedVersion,
 970        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 971        executor: Executor,
 972    ) -> impl Future<Output = ()> {
 973        let this = self.clone();
 974        let span = info_span!("handle connection", %address,
 975            connection_id=field::Empty,
 976            user_id=field::Empty,
 977            login=field::Empty,
 978            impersonator=field::Empty,
 979            dev_server_id=field::Empty
 980        );
 981        principal.update_span(&span);
 982
 983        let mut teardown = self.teardown.subscribe();
 984        async move {
 985            if *teardown.borrow() {
 986                tracing::error!("server is tearing down");
 987                return
 988            }
 989            let (connection_id, handle_io, mut incoming_rx) = this
 990                .peer
 991                .add_connection(connection, {
 992                    let executor = executor.clone();
 993                    move |duration| executor.sleep(duration)
 994                });
 995            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 996            tracing::info!("connection opened");
 997
 998            let http_client = match IsahcHttpClient::new() {
 999                Ok(http_client) => Arc::new(http_client),
1000                Err(error) => {
1001                    tracing::error!(?error, "failed to create HTTP client");
1002                    return;
1003                }
1004            };
1005
1006            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1007                Some(Arc::new(SupermavenAdminApi::new(
1008                    supermaven_admin_api_key.to_string(),
1009                    http_client.clone(),
1010                )))
1011            } else {
1012                None
1013            };
1014
1015            let session = Session {
1016                principal: principal.clone(),
1017                connection_id,
1018                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1019                peer: this.peer.clone(),
1020                connection_pool: this.connection_pool.clone(),
1021                live_kit_client: this.app_state.live_kit_client.clone(),
1022                http_client,
1023                rate_limiter: this.app_state.rate_limiter.clone(),
1024                _executor: executor.clone(),
1025                supermaven_client,
1026            };
1027
1028            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1029                tracing::error!(?error, "failed to send initial client update");
1030                return;
1031            }
1032
1033            let handle_io = handle_io.fuse();
1034            futures::pin_mut!(handle_io);
1035
1036            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1037            // This prevents deadlocks when e.g., client A performs a request to client B and
1038            // client B performs a request to client A. If both clients stop processing further
1039            // messages until their respective request completes, they won't have a chance to
1040            // respond to the other client's request and cause a deadlock.
1041            //
1042            // This arrangement ensures we will attempt to process earlier messages first, but fall
1043            // back to processing messages arrived later in the spirit of making progress.
1044            let mut foreground_message_handlers = FuturesUnordered::new();
1045            let concurrent_handlers = Arc::new(Semaphore::new(256));
1046            loop {
1047                let next_message = async {
1048                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1049                    let message = incoming_rx.next().await;
1050                    (permit, message)
1051                }.fuse();
1052                futures::pin_mut!(next_message);
1053                futures::select_biased! {
1054                    _ = teardown.changed().fuse() => return,
1055                    result = handle_io => {
1056                        if let Err(error) = result {
1057                            tracing::error!(?error, "error handling I/O");
1058                        }
1059                        break;
1060                    }
1061                    _ = foreground_message_handlers.next() => {}
1062                    next_message = next_message => {
1063                        let (permit, message) = next_message;
1064                        if let Some(message) = message {
1065                            let type_name = message.payload_type_name();
1066                            // note: we copy all the fields from the parent span so we can query them in the logs.
1067                            // (https://github.com/tokio-rs/tracing/issues/2670).
1068                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1069                                user_id=field::Empty,
1070                                login=field::Empty,
1071                                impersonator=field::Empty,
1072                                dev_server_id=field::Empty
1073                            );
1074                            principal.update_span(&span);
1075                            let span_enter = span.enter();
1076                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1077                                let is_background = message.is_background();
1078                                let handle_message = (handler)(message, session.clone());
1079                                drop(span_enter);
1080
1081                                let handle_message = async move {
1082                                    handle_message.await;
1083                                    drop(permit);
1084                                }.instrument(span);
1085                                if is_background {
1086                                    executor.spawn_detached(handle_message);
1087                                } else {
1088                                    foreground_message_handlers.push(handle_message);
1089                                }
1090                            } else {
1091                                tracing::error!("no message handler");
1092                            }
1093                        } else {
1094                            tracing::info!("connection closed");
1095                            break;
1096                        }
1097                    }
1098                }
1099            }
1100
1101            drop(foreground_message_handlers);
1102            tracing::info!("signing out");
1103            if let Err(error) = connection_lost(session, teardown, executor).await {
1104                tracing::error!(?error, "error signing out");
1105            }
1106
1107        }.instrument(span)
1108    }
1109
1110    async fn send_initial_client_update(
1111        &self,
1112        connection_id: ConnectionId,
1113        principal: &Principal,
1114        zed_version: ZedVersion,
1115        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1116        session: &Session,
1117    ) -> Result<()> {
1118        self.peer.send(
1119            connection_id,
1120            proto::Hello {
1121                peer_id: Some(connection_id.into()),
1122            },
1123        )?;
1124        tracing::info!("sent hello message");
1125        if let Some(send_connection_id) = send_connection_id.take() {
1126            let _ = send_connection_id.send(connection_id);
1127        }
1128
1129        match principal {
1130            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1131                if !user.connected_once {
1132                    self.peer.send(connection_id, proto::ShowContacts {})?;
1133                    self.app_state
1134                        .db
1135                        .set_user_connected_once(user.id, true)
1136                        .await?;
1137                }
1138
1139                let (contacts, dev_server_projects) = future::try_join(
1140                    self.app_state.db.get_contacts(user.id),
1141                    self.app_state.db.dev_server_projects_update(user.id),
1142                )
1143                .await?;
1144
1145                {
1146                    let mut pool = self.connection_pool.lock();
1147                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1148                    self.peer.send(
1149                        connection_id,
1150                        build_initial_contacts_update(contacts, &pool),
1151                    )?;
1152                }
1153
1154                if should_auto_subscribe_to_channels(zed_version) {
1155                    subscribe_user_to_channels(user.id, session).await?;
1156                }
1157
1158                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1159
1160                if let Some(incoming_call) =
1161                    self.app_state.db.incoming_call_for_user(user.id).await?
1162                {
1163                    self.peer.send(connection_id, incoming_call)?;
1164                }
1165
1166                update_user_contacts(user.id, &session).await?;
1167            }
1168            Principal::DevServer(dev_server) => {
1169                {
1170                    let mut pool = self.connection_pool.lock();
1171                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1172                    {
1173                        self.peer.send(
1174                            stale_connection_id,
1175                            proto::ShutdownDevServer {
1176                                reason: Some(
1177                                    "another dev server connected with the same token".to_string(),
1178                                ),
1179                            },
1180                        )?;
1181                        pool.remove_connection(stale_connection_id)?;
1182                    };
1183                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1184                }
1185
1186                let projects = self
1187                    .app_state
1188                    .db
1189                    .get_projects_for_dev_server(dev_server.id)
1190                    .await?;
1191                self.peer
1192                    .send(connection_id, proto::DevServerInstructions { projects })?;
1193
1194                let status = self
1195                    .app_state
1196                    .db
1197                    .dev_server_projects_update(dev_server.user_id)
1198                    .await?;
1199                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1200            }
1201        }
1202
1203        Ok(())
1204    }
1205
1206    pub async fn invite_code_redeemed(
1207        self: &Arc<Self>,
1208        inviter_id: UserId,
1209        invitee_id: UserId,
1210    ) -> Result<()> {
1211        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1212            if let Some(code) = &user.invite_code {
1213                let pool = self.connection_pool.lock();
1214                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1215                for connection_id in pool.user_connection_ids(inviter_id) {
1216                    self.peer.send(
1217                        connection_id,
1218                        proto::UpdateContacts {
1219                            contacts: vec![invitee_contact.clone()],
1220                            ..Default::default()
1221                        },
1222                    )?;
1223                    self.peer.send(
1224                        connection_id,
1225                        proto::UpdateInviteInfo {
1226                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1227                            count: user.invite_count as u32,
1228                        },
1229                    )?;
1230                }
1231            }
1232        }
1233        Ok(())
1234    }
1235
1236    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1237        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1238            if let Some(invite_code) = &user.invite_code {
1239                let pool = self.connection_pool.lock();
1240                for connection_id in pool.user_connection_ids(user_id) {
1241                    self.peer.send(
1242                        connection_id,
1243                        proto::UpdateInviteInfo {
1244                            url: format!(
1245                                "{}{}",
1246                                self.app_state.config.invite_link_prefix, invite_code
1247                            ),
1248                            count: user.invite_count as u32,
1249                        },
1250                    )?;
1251                }
1252            }
1253        }
1254        Ok(())
1255    }
1256
1257    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1258        ServerSnapshot {
1259            connection_pool: ConnectionPoolGuard {
1260                guard: self.connection_pool.lock(),
1261                _not_send: PhantomData,
1262            },
1263            peer: &self.peer,
1264        }
1265    }
1266}
1267
1268impl<'a> Deref for ConnectionPoolGuard<'a> {
1269    type Target = ConnectionPool;
1270
1271    fn deref(&self) -> &Self::Target {
1272        &self.guard
1273    }
1274}
1275
1276impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1277    fn deref_mut(&mut self) -> &mut Self::Target {
1278        &mut self.guard
1279    }
1280}
1281
1282impl<'a> Drop for ConnectionPoolGuard<'a> {
1283    fn drop(&mut self) {
1284        #[cfg(test)]
1285        self.check_invariants();
1286    }
1287}
1288
1289fn broadcast<F>(
1290    sender_id: Option<ConnectionId>,
1291    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1292    mut f: F,
1293) where
1294    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1295{
1296    for receiver_id in receiver_ids {
1297        if Some(receiver_id) != sender_id {
1298            if let Err(error) = f(receiver_id) {
1299                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1300            }
1301        }
1302    }
1303}
1304
1305pub struct ProtocolVersion(u32);
1306
1307impl Header for ProtocolVersion {
1308    fn name() -> &'static HeaderName {
1309        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1310        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1311    }
1312
1313    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1314    where
1315        Self: Sized,
1316        I: Iterator<Item = &'i axum::http::HeaderValue>,
1317    {
1318        let version = values
1319            .next()
1320            .ok_or_else(axum::headers::Error::invalid)?
1321            .to_str()
1322            .map_err(|_| axum::headers::Error::invalid())?
1323            .parse()
1324            .map_err(|_| axum::headers::Error::invalid())?;
1325        Ok(Self(version))
1326    }
1327
1328    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1329        values.extend([self.0.to_string().parse().unwrap()]);
1330    }
1331}
1332
1333pub struct AppVersionHeader(SemanticVersion);
1334impl Header for AppVersionHeader {
1335    fn name() -> &'static HeaderName {
1336        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1337        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1338    }
1339
1340    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1341    where
1342        Self: Sized,
1343        I: Iterator<Item = &'i axum::http::HeaderValue>,
1344    {
1345        let version = values
1346            .next()
1347            .ok_or_else(axum::headers::Error::invalid)?
1348            .to_str()
1349            .map_err(|_| axum::headers::Error::invalid())?
1350            .parse()
1351            .map_err(|_| axum::headers::Error::invalid())?;
1352        Ok(Self(version))
1353    }
1354
1355    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1356        values.extend([self.0.to_string().parse().unwrap()]);
1357    }
1358}
1359
1360pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1361    Router::new()
1362        .route("/rpc", get(handle_websocket_request))
1363        .layer(
1364            ServiceBuilder::new()
1365                .layer(Extension(server.app_state.clone()))
1366                .layer(middleware::from_fn(auth::validate_header)),
1367        )
1368        .route("/metrics", get(handle_metrics))
1369        .layer(Extension(server))
1370}
1371
1372pub async fn handle_websocket_request(
1373    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1374    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1375    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1376    Extension(server): Extension<Arc<Server>>,
1377    Extension(principal): Extension<Principal>,
1378    ws: WebSocketUpgrade,
1379) -> axum::response::Response {
1380    if protocol_version != rpc::PROTOCOL_VERSION {
1381        return (
1382            StatusCode::UPGRADE_REQUIRED,
1383            "client must be upgraded".to_string(),
1384        )
1385            .into_response();
1386    }
1387
1388    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1389        return (
1390            StatusCode::UPGRADE_REQUIRED,
1391            "no version header found".to_string(),
1392        )
1393            .into_response();
1394    };
1395
1396    if !version.can_collaborate() {
1397        return (
1398            StatusCode::UPGRADE_REQUIRED,
1399            "client must be upgraded".to_string(),
1400        )
1401            .into_response();
1402    }
1403
1404    let socket_address = socket_address.to_string();
1405    ws.on_upgrade(move |socket| {
1406        let socket = socket
1407            .map_ok(to_tungstenite_message)
1408            .err_into()
1409            .with(|message| async move { to_axum_message(message) });
1410        let connection = Connection::new(Box::pin(socket));
1411        async move {
1412            server
1413                .handle_connection(
1414                    connection,
1415                    socket_address,
1416                    principal,
1417                    version,
1418                    None,
1419                    Executor::Production,
1420                )
1421                .await;
1422        }
1423    })
1424}
1425
1426pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1427    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1428    let connections_metric = CONNECTIONS_METRIC
1429        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1430
1431    let connections = server
1432        .connection_pool
1433        .lock()
1434        .connections()
1435        .filter(|connection| !connection.admin)
1436        .count();
1437    connections_metric.set(connections as _);
1438
1439    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1440    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1441        register_int_gauge!(
1442            "shared_projects",
1443            "number of open projects with one or more guests"
1444        )
1445        .unwrap()
1446    });
1447
1448    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1449    shared_projects_metric.set(shared_projects as _);
1450
1451    let encoder = prometheus::TextEncoder::new();
1452    let metric_families = prometheus::gather();
1453    let encoded_metrics = encoder
1454        .encode_to_string(&metric_families)
1455        .map_err(|err| anyhow!("{}", err))?;
1456    Ok(encoded_metrics)
1457}
1458
1459#[instrument(err, skip(executor))]
1460async fn connection_lost(
1461    session: Session,
1462    mut teardown: watch::Receiver<bool>,
1463    executor: Executor,
1464) -> Result<()> {
1465    session.peer.disconnect(session.connection_id);
1466    session
1467        .connection_pool()
1468        .await
1469        .remove_connection(session.connection_id)?;
1470
1471    session
1472        .db()
1473        .await
1474        .connection_lost(session.connection_id)
1475        .await
1476        .trace_err();
1477
1478    futures::select_biased! {
1479        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1480            match &session.principal {
1481                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1482                    let session = session.for_user().unwrap();
1483
1484                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1485                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1486                    leave_channel_buffers_for_session(&session)
1487                        .await
1488                        .trace_err();
1489
1490                    if !session
1491                        .connection_pool()
1492                        .await
1493                        .is_user_online(session.user_id())
1494                    {
1495                        let db = session.db().await;
1496                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1497                            room_updated(&room, &session.peer);
1498                        }
1499                    }
1500
1501                    update_user_contacts(session.user_id(), &session).await?;
1502                },
1503            Principal::DevServer(_) => {
1504                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1505            },
1506        }
1507        },
1508        _ = teardown.changed().fuse() => {}
1509    }
1510
1511    Ok(())
1512}
1513
1514/// Acknowledges a ping from a client, used to keep the connection alive.
1515async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1516    response.send(proto::Ack {})?;
1517    Ok(())
1518}
1519
1520/// Creates a new room for calling (outside of channels)
1521async fn create_room(
1522    _request: proto::CreateRoom,
1523    response: Response<proto::CreateRoom>,
1524    session: UserSession,
1525) -> Result<()> {
1526    let live_kit_room = nanoid::nanoid!(30);
1527
1528    let live_kit_connection_info = util::maybe!(async {
1529        let live_kit = session.live_kit_client.as_ref();
1530        let live_kit = live_kit?;
1531        let user_id = session.user_id().to_string();
1532
1533        let token = live_kit
1534            .room_token(&live_kit_room, &user_id.to_string())
1535            .trace_err()?;
1536
1537        Some(proto::LiveKitConnectionInfo {
1538            server_url: live_kit.url().into(),
1539            token,
1540            can_publish: true,
1541        })
1542    })
1543    .await;
1544
1545    let room = session
1546        .db()
1547        .await
1548        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1549        .await?;
1550
1551    response.send(proto::CreateRoomResponse {
1552        room: Some(room.clone()),
1553        live_kit_connection_info,
1554    })?;
1555
1556    update_user_contacts(session.user_id(), &session).await?;
1557    Ok(())
1558}
1559
1560/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1561async fn join_room(
1562    request: proto::JoinRoom,
1563    response: Response<proto::JoinRoom>,
1564    session: UserSession,
1565) -> Result<()> {
1566    let room_id = RoomId::from_proto(request.id);
1567
1568    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1569
1570    if let Some(channel_id) = channel_id {
1571        return join_channel_internal(channel_id, Box::new(response), session).await;
1572    }
1573
1574    let joined_room = {
1575        let room = session
1576            .db()
1577            .await
1578            .join_room(room_id, session.user_id(), session.connection_id)
1579            .await?;
1580        room_updated(&room.room, &session.peer);
1581        room.into_inner()
1582    };
1583
1584    for connection_id in session
1585        .connection_pool()
1586        .await
1587        .user_connection_ids(session.user_id())
1588    {
1589        session
1590            .peer
1591            .send(
1592                connection_id,
1593                proto::CallCanceled {
1594                    room_id: room_id.to_proto(),
1595                },
1596            )
1597            .trace_err();
1598    }
1599
1600    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1601        if let Some(token) = live_kit
1602            .room_token(
1603                &joined_room.room.live_kit_room,
1604                &session.user_id().to_string(),
1605            )
1606            .trace_err()
1607        {
1608            Some(proto::LiveKitConnectionInfo {
1609                server_url: live_kit.url().into(),
1610                token,
1611                can_publish: true,
1612            })
1613        } else {
1614            None
1615        }
1616    } else {
1617        None
1618    };
1619
1620    response.send(proto::JoinRoomResponse {
1621        room: Some(joined_room.room),
1622        channel_id: None,
1623        live_kit_connection_info,
1624    })?;
1625
1626    update_user_contacts(session.user_id(), &session).await?;
1627    Ok(())
1628}
1629
1630/// Rejoin room is used to reconnect to a room after connection errors.
1631async fn rejoin_room(
1632    request: proto::RejoinRoom,
1633    response: Response<proto::RejoinRoom>,
1634    session: UserSession,
1635) -> Result<()> {
1636    let room;
1637    let channel;
1638    {
1639        let mut rejoined_room = session
1640            .db()
1641            .await
1642            .rejoin_room(request, session.user_id(), session.connection_id)
1643            .await?;
1644
1645        response.send(proto::RejoinRoomResponse {
1646            room: Some(rejoined_room.room.clone()),
1647            reshared_projects: rejoined_room
1648                .reshared_projects
1649                .iter()
1650                .map(|project| proto::ResharedProject {
1651                    id: project.id.to_proto(),
1652                    collaborators: project
1653                        .collaborators
1654                        .iter()
1655                        .map(|collaborator| collaborator.to_proto())
1656                        .collect(),
1657                })
1658                .collect(),
1659            rejoined_projects: rejoined_room
1660                .rejoined_projects
1661                .iter()
1662                .map(|rejoined_project| rejoined_project.to_proto())
1663                .collect(),
1664        })?;
1665        room_updated(&rejoined_room.room, &session.peer);
1666
1667        for project in &rejoined_room.reshared_projects {
1668            for collaborator in &project.collaborators {
1669                session
1670                    .peer
1671                    .send(
1672                        collaborator.connection_id,
1673                        proto::UpdateProjectCollaborator {
1674                            project_id: project.id.to_proto(),
1675                            old_peer_id: Some(project.old_connection_id.into()),
1676                            new_peer_id: Some(session.connection_id.into()),
1677                        },
1678                    )
1679                    .trace_err();
1680            }
1681
1682            broadcast(
1683                Some(session.connection_id),
1684                project
1685                    .collaborators
1686                    .iter()
1687                    .map(|collaborator| collaborator.connection_id),
1688                |connection_id| {
1689                    session.peer.forward_send(
1690                        session.connection_id,
1691                        connection_id,
1692                        proto::UpdateProject {
1693                            project_id: project.id.to_proto(),
1694                            worktrees: project.worktrees.clone(),
1695                        },
1696                    )
1697                },
1698            );
1699        }
1700
1701        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1702
1703        let rejoined_room = rejoined_room.into_inner();
1704
1705        room = rejoined_room.room;
1706        channel = rejoined_room.channel;
1707    }
1708
1709    if let Some(channel) = channel {
1710        channel_updated(
1711            &channel,
1712            &room,
1713            &session.peer,
1714            &*session.connection_pool().await,
1715        );
1716    }
1717
1718    update_user_contacts(session.user_id(), &session).await?;
1719    Ok(())
1720}
1721
1722fn notify_rejoined_projects(
1723    rejoined_projects: &mut Vec<RejoinedProject>,
1724    session: &UserSession,
1725) -> Result<()> {
1726    for project in rejoined_projects.iter() {
1727        for collaborator in &project.collaborators {
1728            session
1729                .peer
1730                .send(
1731                    collaborator.connection_id,
1732                    proto::UpdateProjectCollaborator {
1733                        project_id: project.id.to_proto(),
1734                        old_peer_id: Some(project.old_connection_id.into()),
1735                        new_peer_id: Some(session.connection_id.into()),
1736                    },
1737                )
1738                .trace_err();
1739        }
1740    }
1741
1742    for project in rejoined_projects {
1743        for worktree in mem::take(&mut project.worktrees) {
1744            #[cfg(any(test, feature = "test-support"))]
1745            const MAX_CHUNK_SIZE: usize = 2;
1746            #[cfg(not(any(test, feature = "test-support")))]
1747            const MAX_CHUNK_SIZE: usize = 256;
1748
1749            // Stream this worktree's entries.
1750            let message = proto::UpdateWorktree {
1751                project_id: project.id.to_proto(),
1752                worktree_id: worktree.id,
1753                abs_path: worktree.abs_path.clone(),
1754                root_name: worktree.root_name,
1755                updated_entries: worktree.updated_entries,
1756                removed_entries: worktree.removed_entries,
1757                scan_id: worktree.scan_id,
1758                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1759                updated_repositories: worktree.updated_repositories,
1760                removed_repositories: worktree.removed_repositories,
1761            };
1762            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1763                session.peer.send(session.connection_id, update.clone())?;
1764            }
1765
1766            // Stream this worktree's diagnostics.
1767            for summary in worktree.diagnostic_summaries {
1768                session.peer.send(
1769                    session.connection_id,
1770                    proto::UpdateDiagnosticSummary {
1771                        project_id: project.id.to_proto(),
1772                        worktree_id: worktree.id,
1773                        summary: Some(summary),
1774                    },
1775                )?;
1776            }
1777
1778            for settings_file in worktree.settings_files {
1779                session.peer.send(
1780                    session.connection_id,
1781                    proto::UpdateWorktreeSettings {
1782                        project_id: project.id.to_proto(),
1783                        worktree_id: worktree.id,
1784                        path: settings_file.path,
1785                        content: Some(settings_file.content),
1786                    },
1787                )?;
1788            }
1789        }
1790
1791        for language_server in &project.language_servers {
1792            session.peer.send(
1793                session.connection_id,
1794                proto::UpdateLanguageServer {
1795                    project_id: project.id.to_proto(),
1796                    language_server_id: language_server.id,
1797                    variant: Some(
1798                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1799                            proto::LspDiskBasedDiagnosticsUpdated {},
1800                        ),
1801                    ),
1802                },
1803            )?;
1804        }
1805    }
1806    Ok(())
1807}
1808
1809/// leave room disconnects from the room.
1810async fn leave_room(
1811    _: proto::LeaveRoom,
1812    response: Response<proto::LeaveRoom>,
1813    session: UserSession,
1814) -> Result<()> {
1815    leave_room_for_session(&session, session.connection_id).await?;
1816    response.send(proto::Ack {})?;
1817    Ok(())
1818}
1819
1820/// Updates the permissions of someone else in the room.
1821async fn set_room_participant_role(
1822    request: proto::SetRoomParticipantRole,
1823    response: Response<proto::SetRoomParticipantRole>,
1824    session: UserSession,
1825) -> Result<()> {
1826    let user_id = UserId::from_proto(request.user_id);
1827    let role = ChannelRole::from(request.role());
1828
1829    let (live_kit_room, can_publish) = {
1830        let room = session
1831            .db()
1832            .await
1833            .set_room_participant_role(
1834                session.user_id(),
1835                RoomId::from_proto(request.room_id),
1836                user_id,
1837                role,
1838            )
1839            .await?;
1840
1841        let live_kit_room = room.live_kit_room.clone();
1842        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1843        room_updated(&room, &session.peer);
1844        (live_kit_room, can_publish)
1845    };
1846
1847    if let Some(live_kit) = session.live_kit_client.as_ref() {
1848        live_kit
1849            .update_participant(
1850                live_kit_room.clone(),
1851                request.user_id.to_string(),
1852                live_kit_server::proto::ParticipantPermission {
1853                    can_subscribe: true,
1854                    can_publish,
1855                    can_publish_data: can_publish,
1856                    hidden: false,
1857                    recorder: false,
1858                },
1859            )
1860            .await
1861            .trace_err();
1862    }
1863
1864    response.send(proto::Ack {})?;
1865    Ok(())
1866}
1867
1868/// Call someone else into the current room
1869async fn call(
1870    request: proto::Call,
1871    response: Response<proto::Call>,
1872    session: UserSession,
1873) -> Result<()> {
1874    let room_id = RoomId::from_proto(request.room_id);
1875    let calling_user_id = session.user_id();
1876    let calling_connection_id = session.connection_id;
1877    let called_user_id = UserId::from_proto(request.called_user_id);
1878    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1879    if !session
1880        .db()
1881        .await
1882        .has_contact(calling_user_id, called_user_id)
1883        .await?
1884    {
1885        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1886    }
1887
1888    let incoming_call = {
1889        let (room, incoming_call) = &mut *session
1890            .db()
1891            .await
1892            .call(
1893                room_id,
1894                calling_user_id,
1895                calling_connection_id,
1896                called_user_id,
1897                initial_project_id,
1898            )
1899            .await?;
1900        room_updated(&room, &session.peer);
1901        mem::take(incoming_call)
1902    };
1903    update_user_contacts(called_user_id, &session).await?;
1904
1905    let mut calls = session
1906        .connection_pool()
1907        .await
1908        .user_connection_ids(called_user_id)
1909        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1910        .collect::<FuturesUnordered<_>>();
1911
1912    while let Some(call_response) = calls.next().await {
1913        match call_response.as_ref() {
1914            Ok(_) => {
1915                response.send(proto::Ack {})?;
1916                return Ok(());
1917            }
1918            Err(_) => {
1919                call_response.trace_err();
1920            }
1921        }
1922    }
1923
1924    {
1925        let room = session
1926            .db()
1927            .await
1928            .call_failed(room_id, called_user_id)
1929            .await?;
1930        room_updated(&room, &session.peer);
1931    }
1932    update_user_contacts(called_user_id, &session).await?;
1933
1934    Err(anyhow!("failed to ring user"))?
1935}
1936
1937/// Cancel an outgoing call.
1938async fn cancel_call(
1939    request: proto::CancelCall,
1940    response: Response<proto::CancelCall>,
1941    session: UserSession,
1942) -> Result<()> {
1943    let called_user_id = UserId::from_proto(request.called_user_id);
1944    let room_id = RoomId::from_proto(request.room_id);
1945    {
1946        let room = session
1947            .db()
1948            .await
1949            .cancel_call(room_id, session.connection_id, called_user_id)
1950            .await?;
1951        room_updated(&room, &session.peer);
1952    }
1953
1954    for connection_id in session
1955        .connection_pool()
1956        .await
1957        .user_connection_ids(called_user_id)
1958    {
1959        session
1960            .peer
1961            .send(
1962                connection_id,
1963                proto::CallCanceled {
1964                    room_id: room_id.to_proto(),
1965                },
1966            )
1967            .trace_err();
1968    }
1969    response.send(proto::Ack {})?;
1970
1971    update_user_contacts(called_user_id, &session).await?;
1972    Ok(())
1973}
1974
1975/// Decline an incoming call.
1976async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1977    let room_id = RoomId::from_proto(message.room_id);
1978    {
1979        let room = session
1980            .db()
1981            .await
1982            .decline_call(Some(room_id), session.user_id())
1983            .await?
1984            .ok_or_else(|| anyhow!("failed to decline call"))?;
1985        room_updated(&room, &session.peer);
1986    }
1987
1988    for connection_id in session
1989        .connection_pool()
1990        .await
1991        .user_connection_ids(session.user_id())
1992    {
1993        session
1994            .peer
1995            .send(
1996                connection_id,
1997                proto::CallCanceled {
1998                    room_id: room_id.to_proto(),
1999                },
2000            )
2001            .trace_err();
2002    }
2003    update_user_contacts(session.user_id(), &session).await?;
2004    Ok(())
2005}
2006
2007/// Updates other participants in the room with your current location.
2008async fn update_participant_location(
2009    request: proto::UpdateParticipantLocation,
2010    response: Response<proto::UpdateParticipantLocation>,
2011    session: UserSession,
2012) -> Result<()> {
2013    let room_id = RoomId::from_proto(request.room_id);
2014    let location = request
2015        .location
2016        .ok_or_else(|| anyhow!("invalid location"))?;
2017
2018    let db = session.db().await;
2019    let room = db
2020        .update_room_participant_location(room_id, session.connection_id, location)
2021        .await?;
2022
2023    room_updated(&room, &session.peer);
2024    response.send(proto::Ack {})?;
2025    Ok(())
2026}
2027
2028/// Share a project into the room.
2029async fn share_project(
2030    request: proto::ShareProject,
2031    response: Response<proto::ShareProject>,
2032    session: UserSession,
2033) -> Result<()> {
2034    let (project_id, room) = &*session
2035        .db()
2036        .await
2037        .share_project(
2038            RoomId::from_proto(request.room_id),
2039            session.connection_id,
2040            &request.worktrees,
2041            request
2042                .dev_server_project_id
2043                .map(|id| DevServerProjectId::from_proto(id)),
2044        )
2045        .await?;
2046    response.send(proto::ShareProjectResponse {
2047        project_id: project_id.to_proto(),
2048    })?;
2049    room_updated(&room, &session.peer);
2050
2051    Ok(())
2052}
2053
2054/// Unshare a project from the room.
2055async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2056    let project_id = ProjectId::from_proto(message.project_id);
2057    unshare_project_internal(
2058        project_id,
2059        session.connection_id,
2060        session.user_id(),
2061        &session,
2062    )
2063    .await
2064}
2065
2066async fn unshare_project_internal(
2067    project_id: ProjectId,
2068    connection_id: ConnectionId,
2069    user_id: Option<UserId>,
2070    session: &Session,
2071) -> Result<()> {
2072    let delete = {
2073        let room_guard = session
2074            .db()
2075            .await
2076            .unshare_project(project_id, connection_id, user_id)
2077            .await?;
2078
2079        let (delete, room, guest_connection_ids) = &*room_guard;
2080
2081        let message = proto::UnshareProject {
2082            project_id: project_id.to_proto(),
2083        };
2084
2085        broadcast(
2086            Some(connection_id),
2087            guest_connection_ids.iter().copied(),
2088            |conn_id| session.peer.send(conn_id, message.clone()),
2089        );
2090        if let Some(room) = room {
2091            room_updated(room, &session.peer);
2092        }
2093
2094        *delete
2095    };
2096
2097    if delete {
2098        let db = session.db().await;
2099        db.delete_project(project_id).await?;
2100    }
2101
2102    Ok(())
2103}
2104
2105/// DevServer makes a project available online
2106async fn share_dev_server_project(
2107    request: proto::ShareDevServerProject,
2108    response: Response<proto::ShareDevServerProject>,
2109    session: DevServerSession,
2110) -> Result<()> {
2111    let (dev_server_project, user_id, status) = session
2112        .db()
2113        .await
2114        .share_dev_server_project(
2115            DevServerProjectId::from_proto(request.dev_server_project_id),
2116            session.dev_server_id(),
2117            session.connection_id,
2118            &request.worktrees,
2119        )
2120        .await?;
2121    let Some(project_id) = dev_server_project.project_id else {
2122        return Err(anyhow!("failed to share remote project"))?;
2123    };
2124
2125    send_dev_server_projects_update(user_id, status, &session).await;
2126
2127    response.send(proto::ShareProjectResponse { project_id })?;
2128
2129    Ok(())
2130}
2131
2132/// Join someone elses shared project.
2133async fn join_project(
2134    request: proto::JoinProject,
2135    response: Response<proto::JoinProject>,
2136    session: UserSession,
2137) -> Result<()> {
2138    let project_id = ProjectId::from_proto(request.project_id);
2139
2140    tracing::info!(%project_id, "join project");
2141
2142    let db = session.db().await;
2143    let (project, replica_id) = &mut *db
2144        .join_project(project_id, session.connection_id, session.user_id())
2145        .await?;
2146    drop(db);
2147    tracing::info!(%project_id, "join remote project");
2148    join_project_internal(response, session, project, replica_id)
2149}
2150
2151trait JoinProjectInternalResponse {
2152    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2153}
2154impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2155    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2156        Response::<proto::JoinProject>::send(self, result)
2157    }
2158}
2159impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2160    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2161        Response::<proto::JoinHostedProject>::send(self, result)
2162    }
2163}
2164
2165fn join_project_internal(
2166    response: impl JoinProjectInternalResponse,
2167    session: UserSession,
2168    project: &mut Project,
2169    replica_id: &ReplicaId,
2170) -> Result<()> {
2171    let collaborators = project
2172        .collaborators
2173        .iter()
2174        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2175        .map(|collaborator| collaborator.to_proto())
2176        .collect::<Vec<_>>();
2177    let project_id = project.id;
2178    let guest_user_id = session.user_id();
2179
2180    let worktrees = project
2181        .worktrees
2182        .iter()
2183        .map(|(id, worktree)| proto::WorktreeMetadata {
2184            id: *id,
2185            root_name: worktree.root_name.clone(),
2186            visible: worktree.visible,
2187            abs_path: worktree.abs_path.clone(),
2188        })
2189        .collect::<Vec<_>>();
2190
2191    let add_project_collaborator = proto::AddProjectCollaborator {
2192        project_id: project_id.to_proto(),
2193        collaborator: Some(proto::Collaborator {
2194            peer_id: Some(session.connection_id.into()),
2195            replica_id: replica_id.0 as u32,
2196            user_id: guest_user_id.to_proto(),
2197        }),
2198    };
2199
2200    for collaborator in &collaborators {
2201        session
2202            .peer
2203            .send(
2204                collaborator.peer_id.unwrap().into(),
2205                add_project_collaborator.clone(),
2206            )
2207            .trace_err();
2208    }
2209
2210    // First, we send the metadata associated with each worktree.
2211    response.send(proto::JoinProjectResponse {
2212        project_id: project.id.0 as u64,
2213        worktrees: worktrees.clone(),
2214        replica_id: replica_id.0 as u32,
2215        collaborators: collaborators.clone(),
2216        language_servers: project.language_servers.clone(),
2217        role: project.role.into(),
2218        dev_server_project_id: project
2219            .dev_server_project_id
2220            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2221    })?;
2222
2223    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2224        #[cfg(any(test, feature = "test-support"))]
2225        const MAX_CHUNK_SIZE: usize = 2;
2226        #[cfg(not(any(test, feature = "test-support")))]
2227        const MAX_CHUNK_SIZE: usize = 256;
2228
2229        // Stream this worktree's entries.
2230        let message = proto::UpdateWorktree {
2231            project_id: project_id.to_proto(),
2232            worktree_id,
2233            abs_path: worktree.abs_path.clone(),
2234            root_name: worktree.root_name,
2235            updated_entries: worktree.entries,
2236            removed_entries: Default::default(),
2237            scan_id: worktree.scan_id,
2238            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2239            updated_repositories: worktree.repository_entries.into_values().collect(),
2240            removed_repositories: Default::default(),
2241        };
2242        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2243            session.peer.send(session.connection_id, update.clone())?;
2244        }
2245
2246        // Stream this worktree's diagnostics.
2247        for summary in worktree.diagnostic_summaries {
2248            session.peer.send(
2249                session.connection_id,
2250                proto::UpdateDiagnosticSummary {
2251                    project_id: project_id.to_proto(),
2252                    worktree_id: worktree.id,
2253                    summary: Some(summary),
2254                },
2255            )?;
2256        }
2257
2258        for settings_file in worktree.settings_files {
2259            session.peer.send(
2260                session.connection_id,
2261                proto::UpdateWorktreeSettings {
2262                    project_id: project_id.to_proto(),
2263                    worktree_id: worktree.id,
2264                    path: settings_file.path,
2265                    content: Some(settings_file.content),
2266                },
2267            )?;
2268        }
2269    }
2270
2271    for language_server in &project.language_servers {
2272        session.peer.send(
2273            session.connection_id,
2274            proto::UpdateLanguageServer {
2275                project_id: project_id.to_proto(),
2276                language_server_id: language_server.id,
2277                variant: Some(
2278                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2279                        proto::LspDiskBasedDiagnosticsUpdated {},
2280                    ),
2281                ),
2282            },
2283        )?;
2284    }
2285
2286    Ok(())
2287}
2288
2289/// Leave someone elses shared project.
2290async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2291    let sender_id = session.connection_id;
2292    let project_id = ProjectId::from_proto(request.project_id);
2293    let db = session.db().await;
2294    if db.is_hosted_project(project_id).await? {
2295        let project = db.leave_hosted_project(project_id, sender_id).await?;
2296        project_left(&project, &session);
2297        return Ok(());
2298    }
2299
2300    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2301    tracing::info!(
2302        %project_id,
2303        "leave project"
2304    );
2305
2306    project_left(&project, &session);
2307    if let Some(room) = room {
2308        room_updated(&room, &session.peer);
2309    }
2310
2311    Ok(())
2312}
2313
2314async fn join_hosted_project(
2315    request: proto::JoinHostedProject,
2316    response: Response<proto::JoinHostedProject>,
2317    session: UserSession,
2318) -> Result<()> {
2319    let (mut project, replica_id) = session
2320        .db()
2321        .await
2322        .join_hosted_project(
2323            ProjectId(request.project_id as i32),
2324            session.user_id(),
2325            session.connection_id,
2326        )
2327        .await?;
2328
2329    join_project_internal(response, session, &mut project, &replica_id)
2330}
2331
2332async fn list_remote_directory(
2333    request: proto::ListRemoteDirectory,
2334    response: Response<proto::ListRemoteDirectory>,
2335    session: UserSession,
2336) -> Result<()> {
2337    let dev_server_id = DevServerId(request.dev_server_id as i32);
2338    let dev_server_connection_id = session
2339        .connection_pool()
2340        .await
2341        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2342
2343    session
2344        .db()
2345        .await
2346        .get_dev_server_for_user(dev_server_id, session.user_id())
2347        .await?;
2348
2349    response.send(
2350        session
2351            .peer
2352            .forward_request(session.connection_id, dev_server_connection_id, request)
2353            .await?,
2354    )?;
2355    Ok(())
2356}
2357
2358async fn update_dev_server_project(
2359    request: proto::UpdateDevServerProject,
2360    response: Response<proto::UpdateDevServerProject>,
2361    session: UserSession,
2362) -> Result<()> {
2363    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2364
2365    let (dev_server_project, update) = session
2366        .db()
2367        .await
2368        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2369        .await?;
2370
2371    let projects = session
2372        .db()
2373        .await
2374        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2375        .await?;
2376
2377    let dev_server_connection_id = session
2378        .connection_pool()
2379        .await
2380        .dev_server_connection_id_supporting(
2381            dev_server_project.dev_server_id,
2382            ZedVersion::with_list_directory(),
2383        )?;
2384
2385    session.peer.send(
2386        dev_server_connection_id,
2387        proto::DevServerInstructions { projects },
2388    )?;
2389
2390    send_dev_server_projects_update(session.user_id(), update, &session).await;
2391
2392    response.send(proto::Ack {})
2393}
2394
2395async fn create_dev_server_project(
2396    request: proto::CreateDevServerProject,
2397    response: Response<proto::CreateDevServerProject>,
2398    session: UserSession,
2399) -> Result<()> {
2400    let dev_server_id = DevServerId(request.dev_server_id as i32);
2401    let dev_server_connection_id = session
2402        .connection_pool()
2403        .await
2404        .dev_server_connection_id(dev_server_id);
2405    let Some(dev_server_connection_id) = dev_server_connection_id else {
2406        Err(ErrorCode::DevServerOffline
2407            .message("Cannot create a remote project when the dev server is offline".to_string())
2408            .anyhow())?
2409    };
2410
2411    let path = request.path.clone();
2412    //Check that the path exists on the dev server
2413    session
2414        .peer
2415        .forward_request(
2416            session.connection_id,
2417            dev_server_connection_id,
2418            proto::ValidateDevServerProjectRequest { path: path.clone() },
2419        )
2420        .await?;
2421
2422    let (dev_server_project, update) = session
2423        .db()
2424        .await
2425        .create_dev_server_project(
2426            DevServerId(request.dev_server_id as i32),
2427            &request.path,
2428            session.user_id(),
2429        )
2430        .await?;
2431
2432    let projects = session
2433        .db()
2434        .await
2435        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2436        .await?;
2437
2438    session.peer.send(
2439        dev_server_connection_id,
2440        proto::DevServerInstructions { projects },
2441    )?;
2442
2443    send_dev_server_projects_update(session.user_id(), update, &session).await;
2444
2445    response.send(proto::CreateDevServerProjectResponse {
2446        dev_server_project: Some(dev_server_project.to_proto(None)),
2447    })?;
2448    Ok(())
2449}
2450
2451async fn create_dev_server(
2452    request: proto::CreateDevServer,
2453    response: Response<proto::CreateDevServer>,
2454    session: UserSession,
2455) -> Result<()> {
2456    let access_token = auth::random_token();
2457    let hashed_access_token = auth::hash_access_token(&access_token);
2458
2459    if request.name.is_empty() {
2460        return Err(proto::ErrorCode::Forbidden
2461            .message("Dev server name cannot be empty".to_string())
2462            .anyhow())?;
2463    }
2464
2465    let (dev_server, status) = session
2466        .db()
2467        .await
2468        .create_dev_server(
2469            &request.name,
2470            request.ssh_connection_string.as_deref(),
2471            &hashed_access_token,
2472            session.user_id(),
2473        )
2474        .await?;
2475
2476    send_dev_server_projects_update(session.user_id(), status, &session).await;
2477
2478    response.send(proto::CreateDevServerResponse {
2479        dev_server_id: dev_server.id.0 as u64,
2480        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2481        name: request.name,
2482    })?;
2483    Ok(())
2484}
2485
2486async fn regenerate_dev_server_token(
2487    request: proto::RegenerateDevServerToken,
2488    response: Response<proto::RegenerateDevServerToken>,
2489    session: UserSession,
2490) -> Result<()> {
2491    let dev_server_id = DevServerId(request.dev_server_id as i32);
2492    let access_token = auth::random_token();
2493    let hashed_access_token = auth::hash_access_token(&access_token);
2494
2495    let connection_id = session
2496        .connection_pool()
2497        .await
2498        .dev_server_connection_id(dev_server_id);
2499    if let Some(connection_id) = connection_id {
2500        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2501        session.peer.send(
2502            connection_id,
2503            proto::ShutdownDevServer {
2504                reason: Some("dev server token was regenerated".to_string()),
2505            },
2506        )?;
2507        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2508    }
2509
2510    let status = session
2511        .db()
2512        .await
2513        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2514        .await?;
2515
2516    send_dev_server_projects_update(session.user_id(), status, &session).await;
2517
2518    response.send(proto::RegenerateDevServerTokenResponse {
2519        dev_server_id: dev_server_id.to_proto(),
2520        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2521    })?;
2522    Ok(())
2523}
2524
2525async fn rename_dev_server(
2526    request: proto::RenameDevServer,
2527    response: Response<proto::RenameDevServer>,
2528    session: UserSession,
2529) -> Result<()> {
2530    if request.name.trim().is_empty() {
2531        return Err(proto::ErrorCode::Forbidden
2532            .message("Dev server name cannot be empty".to_string())
2533            .anyhow())?;
2534    }
2535
2536    let dev_server_id = DevServerId(request.dev_server_id as i32);
2537    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2538    if dev_server.user_id != session.user_id() {
2539        return Err(anyhow!(ErrorCode::Forbidden))?;
2540    }
2541
2542    let status = session
2543        .db()
2544        .await
2545        .rename_dev_server(
2546            dev_server_id,
2547            &request.name,
2548            request.ssh_connection_string.as_deref(),
2549            session.user_id(),
2550        )
2551        .await?;
2552
2553    send_dev_server_projects_update(session.user_id(), status, &session).await;
2554
2555    response.send(proto::Ack {})?;
2556    Ok(())
2557}
2558
2559async fn delete_dev_server(
2560    request: proto::DeleteDevServer,
2561    response: Response<proto::DeleteDevServer>,
2562    session: UserSession,
2563) -> Result<()> {
2564    let dev_server_id = DevServerId(request.dev_server_id as i32);
2565    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2566    if dev_server.user_id != session.user_id() {
2567        return Err(anyhow!(ErrorCode::Forbidden))?;
2568    }
2569
2570    let connection_id = session
2571        .connection_pool()
2572        .await
2573        .dev_server_connection_id(dev_server_id);
2574    if let Some(connection_id) = connection_id {
2575        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2576        session.peer.send(
2577            connection_id,
2578            proto::ShutdownDevServer {
2579                reason: Some("dev server was deleted".to_string()),
2580            },
2581        )?;
2582        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2583    }
2584
2585    let status = session
2586        .db()
2587        .await
2588        .delete_dev_server(dev_server_id, session.user_id())
2589        .await?;
2590
2591    send_dev_server_projects_update(session.user_id(), status, &session).await;
2592
2593    response.send(proto::Ack {})?;
2594    Ok(())
2595}
2596
2597async fn delete_dev_server_project(
2598    request: proto::DeleteDevServerProject,
2599    response: Response<proto::DeleteDevServerProject>,
2600    session: UserSession,
2601) -> Result<()> {
2602    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2603    let dev_server_project = session
2604        .db()
2605        .await
2606        .get_dev_server_project(dev_server_project_id)
2607        .await?;
2608
2609    let dev_server = session
2610        .db()
2611        .await
2612        .get_dev_server(dev_server_project.dev_server_id)
2613        .await?;
2614    if dev_server.user_id != session.user_id() {
2615        return Err(anyhow!(ErrorCode::Forbidden))?;
2616    }
2617
2618    let dev_server_connection_id = session
2619        .connection_pool()
2620        .await
2621        .dev_server_connection_id(dev_server.id);
2622
2623    if let Some(dev_server_connection_id) = dev_server_connection_id {
2624        let project = session
2625            .db()
2626            .await
2627            .find_dev_server_project(dev_server_project_id)
2628            .await;
2629        if let Ok(project) = project {
2630            unshare_project_internal(
2631                project.id,
2632                dev_server_connection_id,
2633                Some(session.user_id()),
2634                &session,
2635            )
2636            .await?;
2637        }
2638    }
2639
2640    let (projects, status) = session
2641        .db()
2642        .await
2643        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2644        .await?;
2645
2646    if let Some(dev_server_connection_id) = dev_server_connection_id {
2647        session.peer.send(
2648            dev_server_connection_id,
2649            proto::DevServerInstructions { projects },
2650        )?;
2651    }
2652
2653    send_dev_server_projects_update(session.user_id(), status, &session).await;
2654
2655    response.send(proto::Ack {})?;
2656    Ok(())
2657}
2658
2659async fn rejoin_dev_server_projects(
2660    request: proto::RejoinRemoteProjects,
2661    response: Response<proto::RejoinRemoteProjects>,
2662    session: UserSession,
2663) -> Result<()> {
2664    let mut rejoined_projects = {
2665        let db = session.db().await;
2666        db.rejoin_dev_server_projects(
2667            &request.rejoined_projects,
2668            session.user_id(),
2669            session.0.connection_id,
2670        )
2671        .await?
2672    };
2673    response.send(proto::RejoinRemoteProjectsResponse {
2674        rejoined_projects: rejoined_projects
2675            .iter()
2676            .map(|project| project.to_proto())
2677            .collect(),
2678    })?;
2679    notify_rejoined_projects(&mut rejoined_projects, &session)
2680}
2681
2682async fn reconnect_dev_server(
2683    request: proto::ReconnectDevServer,
2684    response: Response<proto::ReconnectDevServer>,
2685    session: DevServerSession,
2686) -> Result<()> {
2687    let reshared_projects = {
2688        let db = session.db().await;
2689        db.reshare_dev_server_projects(
2690            &request.reshared_projects,
2691            session.dev_server_id(),
2692            session.0.connection_id,
2693        )
2694        .await?
2695    };
2696
2697    for project in &reshared_projects {
2698        for collaborator in &project.collaborators {
2699            session
2700                .peer
2701                .send(
2702                    collaborator.connection_id,
2703                    proto::UpdateProjectCollaborator {
2704                        project_id: project.id.to_proto(),
2705                        old_peer_id: Some(project.old_connection_id.into()),
2706                        new_peer_id: Some(session.connection_id.into()),
2707                    },
2708                )
2709                .trace_err();
2710        }
2711
2712        broadcast(
2713            Some(session.connection_id),
2714            project
2715                .collaborators
2716                .iter()
2717                .map(|collaborator| collaborator.connection_id),
2718            |connection_id| {
2719                session.peer.forward_send(
2720                    session.connection_id,
2721                    connection_id,
2722                    proto::UpdateProject {
2723                        project_id: project.id.to_proto(),
2724                        worktrees: project.worktrees.clone(),
2725                    },
2726                )
2727            },
2728        );
2729    }
2730
2731    response.send(proto::ReconnectDevServerResponse {
2732        reshared_projects: reshared_projects
2733            .iter()
2734            .map(|project| proto::ResharedProject {
2735                id: project.id.to_proto(),
2736                collaborators: project
2737                    .collaborators
2738                    .iter()
2739                    .map(|collaborator| collaborator.to_proto())
2740                    .collect(),
2741            })
2742            .collect(),
2743    })?;
2744
2745    Ok(())
2746}
2747
2748async fn shutdown_dev_server(
2749    _: proto::ShutdownDevServer,
2750    response: Response<proto::ShutdownDevServer>,
2751    session: DevServerSession,
2752) -> Result<()> {
2753    response.send(proto::Ack {})?;
2754    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2755    remove_dev_server_connection(session.dev_server_id(), &session).await
2756}
2757
2758async fn shutdown_dev_server_internal(
2759    dev_server_id: DevServerId,
2760    connection_id: ConnectionId,
2761    session: &Session,
2762) -> Result<()> {
2763    let (dev_server_projects, dev_server) = {
2764        let db = session.db().await;
2765        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2766        let dev_server = db.get_dev_server(dev_server_id).await?;
2767        (dev_server_projects, dev_server)
2768    };
2769
2770    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2771        unshare_project_internal(
2772            ProjectId::from_proto(project_id),
2773            connection_id,
2774            None,
2775            session,
2776        )
2777        .await?;
2778    }
2779
2780    session
2781        .connection_pool()
2782        .await
2783        .set_dev_server_offline(dev_server_id);
2784
2785    let status = session
2786        .db()
2787        .await
2788        .dev_server_projects_update(dev_server.user_id)
2789        .await?;
2790    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2791
2792    Ok(())
2793}
2794
2795async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2796    let dev_server_connection = session
2797        .connection_pool()
2798        .await
2799        .dev_server_connection_id(dev_server_id);
2800
2801    if let Some(dev_server_connection) = dev_server_connection {
2802        session
2803            .connection_pool()
2804            .await
2805            .remove_connection(dev_server_connection)?;
2806    }
2807    Ok(())
2808}
2809
2810/// Updates other participants with changes to the project
2811async fn update_project(
2812    request: proto::UpdateProject,
2813    response: Response<proto::UpdateProject>,
2814    session: Session,
2815) -> Result<()> {
2816    let project_id = ProjectId::from_proto(request.project_id);
2817    let (room, guest_connection_ids) = &*session
2818        .db()
2819        .await
2820        .update_project(project_id, session.connection_id, &request.worktrees)
2821        .await?;
2822    broadcast(
2823        Some(session.connection_id),
2824        guest_connection_ids.iter().copied(),
2825        |connection_id| {
2826            session
2827                .peer
2828                .forward_send(session.connection_id, connection_id, request.clone())
2829        },
2830    );
2831    if let Some(room) = room {
2832        room_updated(&room, &session.peer);
2833    }
2834    response.send(proto::Ack {})?;
2835
2836    Ok(())
2837}
2838
2839/// Updates other participants with changes to the worktree
2840async fn update_worktree(
2841    request: proto::UpdateWorktree,
2842    response: Response<proto::UpdateWorktree>,
2843    session: Session,
2844) -> Result<()> {
2845    let guest_connection_ids = session
2846        .db()
2847        .await
2848        .update_worktree(&request, session.connection_id)
2849        .await?;
2850
2851    broadcast(
2852        Some(session.connection_id),
2853        guest_connection_ids.iter().copied(),
2854        |connection_id| {
2855            session
2856                .peer
2857                .forward_send(session.connection_id, connection_id, request.clone())
2858        },
2859    );
2860    response.send(proto::Ack {})?;
2861    Ok(())
2862}
2863
2864/// Updates other participants with changes to the diagnostics
2865async fn update_diagnostic_summary(
2866    message: proto::UpdateDiagnosticSummary,
2867    session: Session,
2868) -> Result<()> {
2869    let guest_connection_ids = session
2870        .db()
2871        .await
2872        .update_diagnostic_summary(&message, session.connection_id)
2873        .await?;
2874
2875    broadcast(
2876        Some(session.connection_id),
2877        guest_connection_ids.iter().copied(),
2878        |connection_id| {
2879            session
2880                .peer
2881                .forward_send(session.connection_id, connection_id, message.clone())
2882        },
2883    );
2884
2885    Ok(())
2886}
2887
2888/// Updates other participants with changes to the worktree settings
2889async fn update_worktree_settings(
2890    message: proto::UpdateWorktreeSettings,
2891    session: Session,
2892) -> Result<()> {
2893    let guest_connection_ids = session
2894        .db()
2895        .await
2896        .update_worktree_settings(&message, session.connection_id)
2897        .await?;
2898
2899    broadcast(
2900        Some(session.connection_id),
2901        guest_connection_ids.iter().copied(),
2902        |connection_id| {
2903            session
2904                .peer
2905                .forward_send(session.connection_id, connection_id, message.clone())
2906        },
2907    );
2908
2909    Ok(())
2910}
2911
2912/// Notify other participants that a  language server has started.
2913async fn start_language_server(
2914    request: proto::StartLanguageServer,
2915    session: Session,
2916) -> Result<()> {
2917    let guest_connection_ids = session
2918        .db()
2919        .await
2920        .start_language_server(&request, session.connection_id)
2921        .await?;
2922
2923    broadcast(
2924        Some(session.connection_id),
2925        guest_connection_ids.iter().copied(),
2926        |connection_id| {
2927            session
2928                .peer
2929                .forward_send(session.connection_id, connection_id, request.clone())
2930        },
2931    );
2932    Ok(())
2933}
2934
2935/// Notify other participants that a language server has changed.
2936async fn update_language_server(
2937    request: proto::UpdateLanguageServer,
2938    session: Session,
2939) -> Result<()> {
2940    let project_id = ProjectId::from_proto(request.project_id);
2941    let project_connection_ids = session
2942        .db()
2943        .await
2944        .project_connection_ids(project_id, session.connection_id, true)
2945        .await?;
2946    broadcast(
2947        Some(session.connection_id),
2948        project_connection_ids.iter().copied(),
2949        |connection_id| {
2950            session
2951                .peer
2952                .forward_send(session.connection_id, connection_id, request.clone())
2953        },
2954    );
2955    Ok(())
2956}
2957
2958/// forward a project request to the host. These requests should be read only
2959/// as guests are allowed to send them.
2960async fn forward_read_only_project_request<T>(
2961    request: T,
2962    response: Response<T>,
2963    session: UserSession,
2964) -> Result<()>
2965where
2966    T: EntityMessage + RequestMessage,
2967{
2968    let project_id = ProjectId::from_proto(request.remote_entity_id());
2969    let host_connection_id = session
2970        .db()
2971        .await
2972        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2973        .await?;
2974    let payload = session
2975        .peer
2976        .forward_request(session.connection_id, host_connection_id, request)
2977        .await?;
2978    response.send(payload)?;
2979    Ok(())
2980}
2981
2982/// forward a project request to the dev server. Only allowed
2983/// if it's your dev server.
2984async fn forward_project_request_for_owner<T>(
2985    request: T,
2986    response: Response<T>,
2987    session: UserSession,
2988) -> Result<()>
2989where
2990    T: EntityMessage + RequestMessage,
2991{
2992    let project_id = ProjectId::from_proto(request.remote_entity_id());
2993
2994    let host_connection_id = session
2995        .db()
2996        .await
2997        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2998        .await?;
2999    let payload = session
3000        .peer
3001        .forward_request(session.connection_id, host_connection_id, request)
3002        .await?;
3003    response.send(payload)?;
3004    Ok(())
3005}
3006
3007/// forward a project request to the host. These requests are disallowed
3008/// for guests.
3009async fn forward_mutating_project_request<T>(
3010    request: T,
3011    response: Response<T>,
3012    session: UserSession,
3013) -> Result<()>
3014where
3015    T: EntityMessage + RequestMessage,
3016{
3017    let project_id = ProjectId::from_proto(request.remote_entity_id());
3018
3019    let host_connection_id = session
3020        .db()
3021        .await
3022        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3023        .await?;
3024    let payload = session
3025        .peer
3026        .forward_request(session.connection_id, host_connection_id, request)
3027        .await?;
3028    response.send(payload)?;
3029    Ok(())
3030}
3031
3032/// forward a project request to the host. These requests are disallowed
3033/// for guests.
3034async fn forward_versioned_mutating_project_request<T>(
3035    request: T,
3036    response: Response<T>,
3037    session: UserSession,
3038) -> Result<()>
3039where
3040    T: EntityMessage + RequestMessage + VersionedMessage,
3041{
3042    let project_id = ProjectId::from_proto(request.remote_entity_id());
3043
3044    let host_connection_id = session
3045        .db()
3046        .await
3047        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3048        .await?;
3049    if let Some(host_version) = session
3050        .connection_pool()
3051        .await
3052        .connection(host_connection_id)
3053        .map(|c| c.zed_version)
3054    {
3055        if let Some(min_required_version) = request.required_host_version() {
3056            if min_required_version > host_version {
3057                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3058                    .with_tag("required", &min_required_version.to_string())))?;
3059            }
3060        }
3061    }
3062
3063    let payload = session
3064        .peer
3065        .forward_request(session.connection_id, host_connection_id, request)
3066        .await?;
3067    response.send(payload)?;
3068    Ok(())
3069}
3070
3071/// Notify other participants that a new buffer has been created
3072async fn create_buffer_for_peer(
3073    request: proto::CreateBufferForPeer,
3074    session: Session,
3075) -> Result<()> {
3076    session
3077        .db()
3078        .await
3079        .check_user_is_project_host(
3080            ProjectId::from_proto(request.project_id),
3081            session.connection_id,
3082        )
3083        .await?;
3084    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3085    session
3086        .peer
3087        .forward_send(session.connection_id, peer_id.into(), request)?;
3088    Ok(())
3089}
3090
3091/// Notify other participants that a buffer has been updated. This is
3092/// allowed for guests as long as the update is limited to selections.
3093async fn update_buffer(
3094    request: proto::UpdateBuffer,
3095    response: Response<proto::UpdateBuffer>,
3096    session: Session,
3097) -> Result<()> {
3098    let project_id = ProjectId::from_proto(request.project_id);
3099    let mut capability = Capability::ReadOnly;
3100
3101    for op in request.operations.iter() {
3102        match op.variant {
3103            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3104            Some(_) => capability = Capability::ReadWrite,
3105        }
3106    }
3107
3108    let host = {
3109        let guard = session
3110            .db()
3111            .await
3112            .connections_for_buffer_update(
3113                project_id,
3114                session.principal_id(),
3115                session.connection_id,
3116                capability,
3117            )
3118            .await?;
3119
3120        let (host, guests) = &*guard;
3121
3122        broadcast(
3123            Some(session.connection_id),
3124            guests.clone(),
3125            |connection_id| {
3126                session
3127                    .peer
3128                    .forward_send(session.connection_id, connection_id, request.clone())
3129            },
3130        );
3131
3132        *host
3133    };
3134
3135    if host != session.connection_id {
3136        session
3137            .peer
3138            .forward_request(session.connection_id, host, request.clone())
3139            .await?;
3140    }
3141
3142    response.send(proto::Ack {})?;
3143    Ok(())
3144}
3145
3146async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3147    let project_id = ProjectId::from_proto(message.project_id);
3148
3149    let operation = message.operation.as_ref().context("invalid operation")?;
3150    let capability = match operation.variant.as_ref() {
3151        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3152            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3153                match buffer_op.variant {
3154                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3155                        Capability::ReadOnly
3156                    }
3157                    _ => Capability::ReadWrite,
3158                }
3159            } else {
3160                Capability::ReadWrite
3161            }
3162        }
3163        Some(_) => Capability::ReadWrite,
3164        None => Capability::ReadOnly,
3165    };
3166
3167    let guard = session
3168        .db()
3169        .await
3170        .connections_for_buffer_update(
3171            project_id,
3172            session.principal_id(),
3173            session.connection_id,
3174            capability,
3175        )
3176        .await?;
3177
3178    let (host, guests) = &*guard;
3179
3180    broadcast(
3181        Some(session.connection_id),
3182        guests.iter().chain([host]).copied(),
3183        |connection_id| {
3184            session
3185                .peer
3186                .forward_send(session.connection_id, connection_id, message.clone())
3187        },
3188    );
3189
3190    Ok(())
3191}
3192
3193/// Notify other participants that a project has been updated.
3194async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3195    request: T,
3196    session: Session,
3197) -> Result<()> {
3198    let project_id = ProjectId::from_proto(request.remote_entity_id());
3199    let project_connection_ids = session
3200        .db()
3201        .await
3202        .project_connection_ids(project_id, session.connection_id, false)
3203        .await?;
3204
3205    broadcast(
3206        Some(session.connection_id),
3207        project_connection_ids.iter().copied(),
3208        |connection_id| {
3209            session
3210                .peer
3211                .forward_send(session.connection_id, connection_id, request.clone())
3212        },
3213    );
3214    Ok(())
3215}
3216
3217/// Start following another user in a call.
3218async fn follow(
3219    request: proto::Follow,
3220    response: Response<proto::Follow>,
3221    session: UserSession,
3222) -> Result<()> {
3223    let room_id = RoomId::from_proto(request.room_id);
3224    let project_id = request.project_id.map(ProjectId::from_proto);
3225    let leader_id = request
3226        .leader_id
3227        .ok_or_else(|| anyhow!("invalid leader id"))?
3228        .into();
3229    let follower_id = session.connection_id;
3230
3231    session
3232        .db()
3233        .await
3234        .check_room_participants(room_id, leader_id, session.connection_id)
3235        .await?;
3236
3237    let response_payload = session
3238        .peer
3239        .forward_request(session.connection_id, leader_id, request)
3240        .await?;
3241    response.send(response_payload)?;
3242
3243    if let Some(project_id) = project_id {
3244        let room = session
3245            .db()
3246            .await
3247            .follow(room_id, project_id, leader_id, follower_id)
3248            .await?;
3249        room_updated(&room, &session.peer);
3250    }
3251
3252    Ok(())
3253}
3254
3255/// Stop following another user in a call.
3256async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3257    let room_id = RoomId::from_proto(request.room_id);
3258    let project_id = request.project_id.map(ProjectId::from_proto);
3259    let leader_id = request
3260        .leader_id
3261        .ok_or_else(|| anyhow!("invalid leader id"))?
3262        .into();
3263    let follower_id = session.connection_id;
3264
3265    session
3266        .db()
3267        .await
3268        .check_room_participants(room_id, leader_id, session.connection_id)
3269        .await?;
3270
3271    session
3272        .peer
3273        .forward_send(session.connection_id, leader_id, request)?;
3274
3275    if let Some(project_id) = project_id {
3276        let room = session
3277            .db()
3278            .await
3279            .unfollow(room_id, project_id, leader_id, follower_id)
3280            .await?;
3281        room_updated(&room, &session.peer);
3282    }
3283
3284    Ok(())
3285}
3286
3287/// Notify everyone following you of your current location.
3288async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3289    let room_id = RoomId::from_proto(request.room_id);
3290    let database = session.db.lock().await;
3291
3292    let connection_ids = if let Some(project_id) = request.project_id {
3293        let project_id = ProjectId::from_proto(project_id);
3294        database
3295            .project_connection_ids(project_id, session.connection_id, true)
3296            .await?
3297    } else {
3298        database
3299            .room_connection_ids(room_id, session.connection_id)
3300            .await?
3301    };
3302
3303    // For now, don't send view update messages back to that view's current leader.
3304    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3305        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3306        _ => None,
3307    });
3308
3309    for connection_id in connection_ids.iter().cloned() {
3310        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3311            session
3312                .peer
3313                .forward_send(session.connection_id, connection_id, request.clone())?;
3314        }
3315    }
3316    Ok(())
3317}
3318
3319/// Get public data about users.
3320async fn get_users(
3321    request: proto::GetUsers,
3322    response: Response<proto::GetUsers>,
3323    session: Session,
3324) -> Result<()> {
3325    let user_ids = request
3326        .user_ids
3327        .into_iter()
3328        .map(UserId::from_proto)
3329        .collect();
3330    let users = session
3331        .db()
3332        .await
3333        .get_users_by_ids(user_ids)
3334        .await?
3335        .into_iter()
3336        .map(|user| proto::User {
3337            id: user.id.to_proto(),
3338            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3339            github_login: user.github_login,
3340        })
3341        .collect();
3342    response.send(proto::UsersResponse { users })?;
3343    Ok(())
3344}
3345
3346/// Search for users (to invite) buy Github login
3347async fn fuzzy_search_users(
3348    request: proto::FuzzySearchUsers,
3349    response: Response<proto::FuzzySearchUsers>,
3350    session: UserSession,
3351) -> Result<()> {
3352    let query = request.query;
3353    let users = match query.len() {
3354        0 => vec![],
3355        1 | 2 => session
3356            .db()
3357            .await
3358            .get_user_by_github_login(&query)
3359            .await?
3360            .into_iter()
3361            .collect(),
3362        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3363    };
3364    let users = users
3365        .into_iter()
3366        .filter(|user| user.id != session.user_id())
3367        .map(|user| proto::User {
3368            id: user.id.to_proto(),
3369            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3370            github_login: user.github_login,
3371        })
3372        .collect();
3373    response.send(proto::UsersResponse { users })?;
3374    Ok(())
3375}
3376
3377/// Send a contact request to another user.
3378async fn request_contact(
3379    request: proto::RequestContact,
3380    response: Response<proto::RequestContact>,
3381    session: UserSession,
3382) -> Result<()> {
3383    let requester_id = session.user_id();
3384    let responder_id = UserId::from_proto(request.responder_id);
3385    if requester_id == responder_id {
3386        return Err(anyhow!("cannot add yourself as a contact"))?;
3387    }
3388
3389    let notifications = session
3390        .db()
3391        .await
3392        .send_contact_request(requester_id, responder_id)
3393        .await?;
3394
3395    // Update outgoing contact requests of requester
3396    let mut update = proto::UpdateContacts::default();
3397    update.outgoing_requests.push(responder_id.to_proto());
3398    for connection_id in session
3399        .connection_pool()
3400        .await
3401        .user_connection_ids(requester_id)
3402    {
3403        session.peer.send(connection_id, update.clone())?;
3404    }
3405
3406    // Update incoming contact requests of responder
3407    let mut update = proto::UpdateContacts::default();
3408    update
3409        .incoming_requests
3410        .push(proto::IncomingContactRequest {
3411            requester_id: requester_id.to_proto(),
3412        });
3413    let connection_pool = session.connection_pool().await;
3414    for connection_id in connection_pool.user_connection_ids(responder_id) {
3415        session.peer.send(connection_id, update.clone())?;
3416    }
3417
3418    send_notifications(&connection_pool, &session.peer, notifications);
3419
3420    response.send(proto::Ack {})?;
3421    Ok(())
3422}
3423
3424/// Accept or decline a contact request
3425async fn respond_to_contact_request(
3426    request: proto::RespondToContactRequest,
3427    response: Response<proto::RespondToContactRequest>,
3428    session: UserSession,
3429) -> Result<()> {
3430    let responder_id = session.user_id();
3431    let requester_id = UserId::from_proto(request.requester_id);
3432    let db = session.db().await;
3433    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3434        db.dismiss_contact_notification(responder_id, requester_id)
3435            .await?;
3436    } else {
3437        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3438
3439        let notifications = db
3440            .respond_to_contact_request(responder_id, requester_id, accept)
3441            .await?;
3442        let requester_busy = db.is_user_busy(requester_id).await?;
3443        let responder_busy = db.is_user_busy(responder_id).await?;
3444
3445        let pool = session.connection_pool().await;
3446        // Update responder with new contact
3447        let mut update = proto::UpdateContacts::default();
3448        if accept {
3449            update
3450                .contacts
3451                .push(contact_for_user(requester_id, requester_busy, &pool));
3452        }
3453        update
3454            .remove_incoming_requests
3455            .push(requester_id.to_proto());
3456        for connection_id in pool.user_connection_ids(responder_id) {
3457            session.peer.send(connection_id, update.clone())?;
3458        }
3459
3460        // Update requester with new contact
3461        let mut update = proto::UpdateContacts::default();
3462        if accept {
3463            update
3464                .contacts
3465                .push(contact_for_user(responder_id, responder_busy, &pool));
3466        }
3467        update
3468            .remove_outgoing_requests
3469            .push(responder_id.to_proto());
3470
3471        for connection_id in pool.user_connection_ids(requester_id) {
3472            session.peer.send(connection_id, update.clone())?;
3473        }
3474
3475        send_notifications(&pool, &session.peer, notifications);
3476    }
3477
3478    response.send(proto::Ack {})?;
3479    Ok(())
3480}
3481
3482/// Remove a contact.
3483async fn remove_contact(
3484    request: proto::RemoveContact,
3485    response: Response<proto::RemoveContact>,
3486    session: UserSession,
3487) -> Result<()> {
3488    let requester_id = session.user_id();
3489    let responder_id = UserId::from_proto(request.user_id);
3490    let db = session.db().await;
3491    let (contact_accepted, deleted_notification_id) =
3492        db.remove_contact(requester_id, responder_id).await?;
3493
3494    let pool = session.connection_pool().await;
3495    // Update outgoing contact requests of requester
3496    let mut update = proto::UpdateContacts::default();
3497    if contact_accepted {
3498        update.remove_contacts.push(responder_id.to_proto());
3499    } else {
3500        update
3501            .remove_outgoing_requests
3502            .push(responder_id.to_proto());
3503    }
3504    for connection_id in pool.user_connection_ids(requester_id) {
3505        session.peer.send(connection_id, update.clone())?;
3506    }
3507
3508    // Update incoming contact requests of responder
3509    let mut update = proto::UpdateContacts::default();
3510    if contact_accepted {
3511        update.remove_contacts.push(requester_id.to_proto());
3512    } else {
3513        update
3514            .remove_incoming_requests
3515            .push(requester_id.to_proto());
3516    }
3517    for connection_id in pool.user_connection_ids(responder_id) {
3518        session.peer.send(connection_id, update.clone())?;
3519        if let Some(notification_id) = deleted_notification_id {
3520            session.peer.send(
3521                connection_id,
3522                proto::DeleteNotification {
3523                    notification_id: notification_id.to_proto(),
3524                },
3525            )?;
3526        }
3527    }
3528
3529    response.send(proto::Ack {})?;
3530    Ok(())
3531}
3532
3533fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3534    version.0.minor() < 139
3535}
3536
3537async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3538    subscribe_user_to_channels(
3539        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3540        &session,
3541    )
3542    .await?;
3543    Ok(())
3544}
3545
3546async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3547    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3548    let mut pool = session.connection_pool().await;
3549    for membership in &channels_for_user.channel_memberships {
3550        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3551    }
3552    session.peer.send(
3553        session.connection_id,
3554        build_update_user_channels(&channels_for_user),
3555    )?;
3556    session.peer.send(
3557        session.connection_id,
3558        build_channels_update(channels_for_user),
3559    )?;
3560    Ok(())
3561}
3562
3563/// Creates a new channel.
3564async fn create_channel(
3565    request: proto::CreateChannel,
3566    response: Response<proto::CreateChannel>,
3567    session: UserSession,
3568) -> Result<()> {
3569    let db = session.db().await;
3570
3571    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3572    let (channel, membership) = db
3573        .create_channel(&request.name, parent_id, session.user_id())
3574        .await?;
3575
3576    let root_id = channel.root_id();
3577    let channel = Channel::from_model(channel);
3578
3579    response.send(proto::CreateChannelResponse {
3580        channel: Some(channel.to_proto()),
3581        parent_id: request.parent_id,
3582    })?;
3583
3584    let mut connection_pool = session.connection_pool().await;
3585    if let Some(membership) = membership {
3586        connection_pool.subscribe_to_channel(
3587            membership.user_id,
3588            membership.channel_id,
3589            membership.role,
3590        );
3591        let update = proto::UpdateUserChannels {
3592            channel_memberships: vec![proto::ChannelMembership {
3593                channel_id: membership.channel_id.to_proto(),
3594                role: membership.role.into(),
3595            }],
3596            ..Default::default()
3597        };
3598        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3599            session.peer.send(connection_id, update.clone())?;
3600        }
3601    }
3602
3603    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3604        if !role.can_see_channel(channel.visibility) {
3605            continue;
3606        }
3607
3608        let update = proto::UpdateChannels {
3609            channels: vec![channel.to_proto()],
3610            ..Default::default()
3611        };
3612        session.peer.send(connection_id, update.clone())?;
3613    }
3614
3615    Ok(())
3616}
3617
3618/// Delete a channel
3619async fn delete_channel(
3620    request: proto::DeleteChannel,
3621    response: Response<proto::DeleteChannel>,
3622    session: UserSession,
3623) -> Result<()> {
3624    let db = session.db().await;
3625
3626    let channel_id = request.channel_id;
3627    let (root_channel, removed_channels) = db
3628        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3629        .await?;
3630    response.send(proto::Ack {})?;
3631
3632    // Notify members of removed channels
3633    let mut update = proto::UpdateChannels::default();
3634    update
3635        .delete_channels
3636        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3637
3638    let connection_pool = session.connection_pool().await;
3639    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3640        session.peer.send(connection_id, update.clone())?;
3641    }
3642
3643    Ok(())
3644}
3645
3646/// Invite someone to join a channel.
3647async fn invite_channel_member(
3648    request: proto::InviteChannelMember,
3649    response: Response<proto::InviteChannelMember>,
3650    session: UserSession,
3651) -> Result<()> {
3652    let db = session.db().await;
3653    let channel_id = ChannelId::from_proto(request.channel_id);
3654    let invitee_id = UserId::from_proto(request.user_id);
3655    let InviteMemberResult {
3656        channel,
3657        notifications,
3658    } = db
3659        .invite_channel_member(
3660            channel_id,
3661            invitee_id,
3662            session.user_id(),
3663            request.role().into(),
3664        )
3665        .await?;
3666
3667    let update = proto::UpdateChannels {
3668        channel_invitations: vec![channel.to_proto()],
3669        ..Default::default()
3670    };
3671
3672    let connection_pool = session.connection_pool().await;
3673    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3674        session.peer.send(connection_id, update.clone())?;
3675    }
3676
3677    send_notifications(&connection_pool, &session.peer, notifications);
3678
3679    response.send(proto::Ack {})?;
3680    Ok(())
3681}
3682
3683/// remove someone from a channel
3684async fn remove_channel_member(
3685    request: proto::RemoveChannelMember,
3686    response: Response<proto::RemoveChannelMember>,
3687    session: UserSession,
3688) -> Result<()> {
3689    let db = session.db().await;
3690    let channel_id = ChannelId::from_proto(request.channel_id);
3691    let member_id = UserId::from_proto(request.user_id);
3692
3693    let RemoveChannelMemberResult {
3694        membership_update,
3695        notification_id,
3696    } = db
3697        .remove_channel_member(channel_id, member_id, session.user_id())
3698        .await?;
3699
3700    let mut connection_pool = session.connection_pool().await;
3701    notify_membership_updated(
3702        &mut connection_pool,
3703        membership_update,
3704        member_id,
3705        &session.peer,
3706    );
3707    for connection_id in connection_pool.user_connection_ids(member_id) {
3708        if let Some(notification_id) = notification_id {
3709            session
3710                .peer
3711                .send(
3712                    connection_id,
3713                    proto::DeleteNotification {
3714                        notification_id: notification_id.to_proto(),
3715                    },
3716                )
3717                .trace_err();
3718        }
3719    }
3720
3721    response.send(proto::Ack {})?;
3722    Ok(())
3723}
3724
3725/// Toggle the channel between public and private.
3726/// Care is taken to maintain the invariant that public channels only descend from public channels,
3727/// (though members-only channels can appear at any point in the hierarchy).
3728async fn set_channel_visibility(
3729    request: proto::SetChannelVisibility,
3730    response: Response<proto::SetChannelVisibility>,
3731    session: UserSession,
3732) -> Result<()> {
3733    let db = session.db().await;
3734    let channel_id = ChannelId::from_proto(request.channel_id);
3735    let visibility = request.visibility().into();
3736
3737    let channel_model = db
3738        .set_channel_visibility(channel_id, visibility, session.user_id())
3739        .await?;
3740    let root_id = channel_model.root_id();
3741    let channel = Channel::from_model(channel_model);
3742
3743    let mut connection_pool = session.connection_pool().await;
3744    for (user_id, role) in connection_pool
3745        .channel_user_ids(root_id)
3746        .collect::<Vec<_>>()
3747        .into_iter()
3748    {
3749        let update = if role.can_see_channel(channel.visibility) {
3750            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3751            proto::UpdateChannels {
3752                channels: vec![channel.to_proto()],
3753                ..Default::default()
3754            }
3755        } else {
3756            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3757            proto::UpdateChannels {
3758                delete_channels: vec![channel.id.to_proto()],
3759                ..Default::default()
3760            }
3761        };
3762
3763        for connection_id in connection_pool.user_connection_ids(user_id) {
3764            session.peer.send(connection_id, update.clone())?;
3765        }
3766    }
3767
3768    response.send(proto::Ack {})?;
3769    Ok(())
3770}
3771
3772/// Alter the role for a user in the channel.
3773async fn set_channel_member_role(
3774    request: proto::SetChannelMemberRole,
3775    response: Response<proto::SetChannelMemberRole>,
3776    session: UserSession,
3777) -> Result<()> {
3778    let db = session.db().await;
3779    let channel_id = ChannelId::from_proto(request.channel_id);
3780    let member_id = UserId::from_proto(request.user_id);
3781    let result = db
3782        .set_channel_member_role(
3783            channel_id,
3784            session.user_id(),
3785            member_id,
3786            request.role().into(),
3787        )
3788        .await?;
3789
3790    match result {
3791        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3792            let mut connection_pool = session.connection_pool().await;
3793            notify_membership_updated(
3794                &mut connection_pool,
3795                membership_update,
3796                member_id,
3797                &session.peer,
3798            )
3799        }
3800        db::SetMemberRoleResult::InviteUpdated(channel) => {
3801            let update = proto::UpdateChannels {
3802                channel_invitations: vec![channel.to_proto()],
3803                ..Default::default()
3804            };
3805
3806            for connection_id in session
3807                .connection_pool()
3808                .await
3809                .user_connection_ids(member_id)
3810            {
3811                session.peer.send(connection_id, update.clone())?;
3812            }
3813        }
3814    }
3815
3816    response.send(proto::Ack {})?;
3817    Ok(())
3818}
3819
3820/// Change the name of a channel
3821async fn rename_channel(
3822    request: proto::RenameChannel,
3823    response: Response<proto::RenameChannel>,
3824    session: UserSession,
3825) -> Result<()> {
3826    let db = session.db().await;
3827    let channel_id = ChannelId::from_proto(request.channel_id);
3828    let channel_model = db
3829        .rename_channel(channel_id, session.user_id(), &request.name)
3830        .await?;
3831    let root_id = channel_model.root_id();
3832    let channel = Channel::from_model(channel_model);
3833
3834    response.send(proto::RenameChannelResponse {
3835        channel: Some(channel.to_proto()),
3836    })?;
3837
3838    let connection_pool = session.connection_pool().await;
3839    let update = proto::UpdateChannels {
3840        channels: vec![channel.to_proto()],
3841        ..Default::default()
3842    };
3843    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3844        if role.can_see_channel(channel.visibility) {
3845            session.peer.send(connection_id, update.clone())?;
3846        }
3847    }
3848
3849    Ok(())
3850}
3851
3852/// Move a channel to a new parent.
3853async fn move_channel(
3854    request: proto::MoveChannel,
3855    response: Response<proto::MoveChannel>,
3856    session: UserSession,
3857) -> Result<()> {
3858    let channel_id = ChannelId::from_proto(request.channel_id);
3859    let to = ChannelId::from_proto(request.to);
3860
3861    let (root_id, channels) = session
3862        .db()
3863        .await
3864        .move_channel(channel_id, to, session.user_id())
3865        .await?;
3866
3867    let connection_pool = session.connection_pool().await;
3868    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3869        let channels = channels
3870            .iter()
3871            .filter_map(|channel| {
3872                if role.can_see_channel(channel.visibility) {
3873                    Some(channel.to_proto())
3874                } else {
3875                    None
3876                }
3877            })
3878            .collect::<Vec<_>>();
3879        if channels.is_empty() {
3880            continue;
3881        }
3882
3883        let update = proto::UpdateChannels {
3884            channels,
3885            ..Default::default()
3886        };
3887
3888        session.peer.send(connection_id, update.clone())?;
3889    }
3890
3891    response.send(Ack {})?;
3892    Ok(())
3893}
3894
3895/// Get the list of channel members
3896async fn get_channel_members(
3897    request: proto::GetChannelMembers,
3898    response: Response<proto::GetChannelMembers>,
3899    session: UserSession,
3900) -> Result<()> {
3901    let db = session.db().await;
3902    let channel_id = ChannelId::from_proto(request.channel_id);
3903    let limit = if request.limit == 0 {
3904        u16::MAX as u64
3905    } else {
3906        request.limit
3907    };
3908    let (members, users) = db
3909        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3910        .await?;
3911    response.send(proto::GetChannelMembersResponse { members, users })?;
3912    Ok(())
3913}
3914
3915/// Accept or decline a channel invitation.
3916async fn respond_to_channel_invite(
3917    request: proto::RespondToChannelInvite,
3918    response: Response<proto::RespondToChannelInvite>,
3919    session: UserSession,
3920) -> Result<()> {
3921    let db = session.db().await;
3922    let channel_id = ChannelId::from_proto(request.channel_id);
3923    let RespondToChannelInvite {
3924        membership_update,
3925        notifications,
3926    } = db
3927        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3928        .await?;
3929
3930    let mut connection_pool = session.connection_pool().await;
3931    if let Some(membership_update) = membership_update {
3932        notify_membership_updated(
3933            &mut connection_pool,
3934            membership_update,
3935            session.user_id(),
3936            &session.peer,
3937        );
3938    } else {
3939        let update = proto::UpdateChannels {
3940            remove_channel_invitations: vec![channel_id.to_proto()],
3941            ..Default::default()
3942        };
3943
3944        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3945            session.peer.send(connection_id, update.clone())?;
3946        }
3947    };
3948
3949    send_notifications(&connection_pool, &session.peer, notifications);
3950
3951    response.send(proto::Ack {})?;
3952
3953    Ok(())
3954}
3955
3956/// Join the channels' room
3957async fn join_channel(
3958    request: proto::JoinChannel,
3959    response: Response<proto::JoinChannel>,
3960    session: UserSession,
3961) -> Result<()> {
3962    let channel_id = ChannelId::from_proto(request.channel_id);
3963    join_channel_internal(channel_id, Box::new(response), session).await
3964}
3965
3966trait JoinChannelInternalResponse {
3967    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3968}
3969impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3970    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3971        Response::<proto::JoinChannel>::send(self, result)
3972    }
3973}
3974impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3975    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3976        Response::<proto::JoinRoom>::send(self, result)
3977    }
3978}
3979
3980async fn join_channel_internal(
3981    channel_id: ChannelId,
3982    response: Box<impl JoinChannelInternalResponse>,
3983    session: UserSession,
3984) -> Result<()> {
3985    let joined_room = {
3986        let mut db = session.db().await;
3987        // If zed quits without leaving the room, and the user re-opens zed before the
3988        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3989        // room they were in.
3990        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3991            tracing::info!(
3992                stale_connection_id = %connection,
3993                "cleaning up stale connection",
3994            );
3995            drop(db);
3996            leave_room_for_session(&session, connection).await?;
3997            db = session.db().await;
3998        }
3999
4000        let (joined_room, membership_updated, role) = db
4001            .join_channel(channel_id, session.user_id(), session.connection_id)
4002            .await?;
4003
4004        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4005            let (can_publish, token) = if role == ChannelRole::Guest {
4006                (
4007                    false,
4008                    live_kit
4009                        .guest_token(
4010                            &joined_room.room.live_kit_room,
4011                            &session.user_id().to_string(),
4012                        )
4013                        .trace_err()?,
4014                )
4015            } else {
4016                (
4017                    true,
4018                    live_kit
4019                        .room_token(
4020                            &joined_room.room.live_kit_room,
4021                            &session.user_id().to_string(),
4022                        )
4023                        .trace_err()?,
4024                )
4025            };
4026
4027            Some(LiveKitConnectionInfo {
4028                server_url: live_kit.url().into(),
4029                token,
4030                can_publish,
4031            })
4032        });
4033
4034        response.send(proto::JoinRoomResponse {
4035            room: Some(joined_room.room.clone()),
4036            channel_id: joined_room
4037                .channel
4038                .as_ref()
4039                .map(|channel| channel.id.to_proto()),
4040            live_kit_connection_info,
4041        })?;
4042
4043        let mut connection_pool = session.connection_pool().await;
4044        if let Some(membership_updated) = membership_updated {
4045            notify_membership_updated(
4046                &mut connection_pool,
4047                membership_updated,
4048                session.user_id(),
4049                &session.peer,
4050            );
4051        }
4052
4053        room_updated(&joined_room.room, &session.peer);
4054
4055        joined_room
4056    };
4057
4058    channel_updated(
4059        &joined_room
4060            .channel
4061            .ok_or_else(|| anyhow!("channel not returned"))?,
4062        &joined_room.room,
4063        &session.peer,
4064        &*session.connection_pool().await,
4065    );
4066
4067    update_user_contacts(session.user_id(), &session).await?;
4068    Ok(())
4069}
4070
4071/// Start editing the channel notes
4072async fn join_channel_buffer(
4073    request: proto::JoinChannelBuffer,
4074    response: Response<proto::JoinChannelBuffer>,
4075    session: UserSession,
4076) -> Result<()> {
4077    let db = session.db().await;
4078    let channel_id = ChannelId::from_proto(request.channel_id);
4079
4080    let open_response = db
4081        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4082        .await?;
4083
4084    let collaborators = open_response.collaborators.clone();
4085    response.send(open_response)?;
4086
4087    let update = UpdateChannelBufferCollaborators {
4088        channel_id: channel_id.to_proto(),
4089        collaborators: collaborators.clone(),
4090    };
4091    channel_buffer_updated(
4092        session.connection_id,
4093        collaborators
4094            .iter()
4095            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4096        &update,
4097        &session.peer,
4098    );
4099
4100    Ok(())
4101}
4102
4103/// Edit the channel notes
4104async fn update_channel_buffer(
4105    request: proto::UpdateChannelBuffer,
4106    session: UserSession,
4107) -> Result<()> {
4108    let db = session.db().await;
4109    let channel_id = ChannelId::from_proto(request.channel_id);
4110
4111    let (collaborators, epoch, version) = db
4112        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4113        .await?;
4114
4115    channel_buffer_updated(
4116        session.connection_id,
4117        collaborators.clone(),
4118        &proto::UpdateChannelBuffer {
4119            channel_id: channel_id.to_proto(),
4120            operations: request.operations,
4121        },
4122        &session.peer,
4123    );
4124
4125    let pool = &*session.connection_pool().await;
4126
4127    let non_collaborators =
4128        pool.channel_connection_ids(channel_id)
4129            .filter_map(|(connection_id, _)| {
4130                if collaborators.contains(&connection_id) {
4131                    None
4132                } else {
4133                    Some(connection_id)
4134                }
4135            });
4136
4137    broadcast(None, non_collaborators, |peer_id| {
4138        session.peer.send(
4139            peer_id,
4140            proto::UpdateChannels {
4141                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4142                    channel_id: channel_id.to_proto(),
4143                    epoch: epoch as u64,
4144                    version: version.clone(),
4145                }],
4146                ..Default::default()
4147            },
4148        )
4149    });
4150
4151    Ok(())
4152}
4153
4154/// Rejoin the channel notes after a connection blip
4155async fn rejoin_channel_buffers(
4156    request: proto::RejoinChannelBuffers,
4157    response: Response<proto::RejoinChannelBuffers>,
4158    session: UserSession,
4159) -> Result<()> {
4160    let db = session.db().await;
4161    let buffers = db
4162        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4163        .await?;
4164
4165    for rejoined_buffer in &buffers {
4166        let collaborators_to_notify = rejoined_buffer
4167            .buffer
4168            .collaborators
4169            .iter()
4170            .filter_map(|c| Some(c.peer_id?.into()));
4171        channel_buffer_updated(
4172            session.connection_id,
4173            collaborators_to_notify,
4174            &proto::UpdateChannelBufferCollaborators {
4175                channel_id: rejoined_buffer.buffer.channel_id,
4176                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4177            },
4178            &session.peer,
4179        );
4180    }
4181
4182    response.send(proto::RejoinChannelBuffersResponse {
4183        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4184    })?;
4185
4186    Ok(())
4187}
4188
4189/// Stop editing the channel notes
4190async fn leave_channel_buffer(
4191    request: proto::LeaveChannelBuffer,
4192    response: Response<proto::LeaveChannelBuffer>,
4193    session: UserSession,
4194) -> Result<()> {
4195    let db = session.db().await;
4196    let channel_id = ChannelId::from_proto(request.channel_id);
4197
4198    let left_buffer = db
4199        .leave_channel_buffer(channel_id, session.connection_id)
4200        .await?;
4201
4202    response.send(Ack {})?;
4203
4204    channel_buffer_updated(
4205        session.connection_id,
4206        left_buffer.connections,
4207        &proto::UpdateChannelBufferCollaborators {
4208            channel_id: channel_id.to_proto(),
4209            collaborators: left_buffer.collaborators,
4210        },
4211        &session.peer,
4212    );
4213
4214    Ok(())
4215}
4216
4217fn channel_buffer_updated<T: EnvelopedMessage>(
4218    sender_id: ConnectionId,
4219    collaborators: impl IntoIterator<Item = ConnectionId>,
4220    message: &T,
4221    peer: &Peer,
4222) {
4223    broadcast(Some(sender_id), collaborators, |peer_id| {
4224        peer.send(peer_id, message.clone())
4225    });
4226}
4227
4228fn send_notifications(
4229    connection_pool: &ConnectionPool,
4230    peer: &Peer,
4231    notifications: db::NotificationBatch,
4232) {
4233    for (user_id, notification) in notifications {
4234        for connection_id in connection_pool.user_connection_ids(user_id) {
4235            if let Err(error) = peer.send(
4236                connection_id,
4237                proto::AddNotification {
4238                    notification: Some(notification.clone()),
4239                },
4240            ) {
4241                tracing::error!(
4242                    "failed to send notification to {:?} {}",
4243                    connection_id,
4244                    error
4245                );
4246            }
4247        }
4248    }
4249}
4250
4251/// Send a message to the channel
4252async fn send_channel_message(
4253    request: proto::SendChannelMessage,
4254    response: Response<proto::SendChannelMessage>,
4255    session: UserSession,
4256) -> Result<()> {
4257    // Validate the message body.
4258    let body = request.body.trim().to_string();
4259    if body.len() > MAX_MESSAGE_LEN {
4260        return Err(anyhow!("message is too long"))?;
4261    }
4262    if body.is_empty() {
4263        return Err(anyhow!("message can't be blank"))?;
4264    }
4265
4266    // TODO: adjust mentions if body is trimmed
4267
4268    let timestamp = OffsetDateTime::now_utc();
4269    let nonce = request
4270        .nonce
4271        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4272
4273    let channel_id = ChannelId::from_proto(request.channel_id);
4274    let CreatedChannelMessage {
4275        message_id,
4276        participant_connection_ids,
4277        notifications,
4278    } = session
4279        .db()
4280        .await
4281        .create_channel_message(
4282            channel_id,
4283            session.user_id(),
4284            &body,
4285            &request.mentions,
4286            timestamp,
4287            nonce.clone().into(),
4288            match request.reply_to_message_id {
4289                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4290                None => None,
4291            },
4292        )
4293        .await?;
4294
4295    let message = proto::ChannelMessage {
4296        sender_id: session.user_id().to_proto(),
4297        id: message_id.to_proto(),
4298        body,
4299        mentions: request.mentions,
4300        timestamp: timestamp.unix_timestamp() as u64,
4301        nonce: Some(nonce),
4302        reply_to_message_id: request.reply_to_message_id,
4303        edited_at: None,
4304    };
4305    broadcast(
4306        Some(session.connection_id),
4307        participant_connection_ids.clone(),
4308        |connection| {
4309            session.peer.send(
4310                connection,
4311                proto::ChannelMessageSent {
4312                    channel_id: channel_id.to_proto(),
4313                    message: Some(message.clone()),
4314                },
4315            )
4316        },
4317    );
4318    response.send(proto::SendChannelMessageResponse {
4319        message: Some(message),
4320    })?;
4321
4322    let pool = &*session.connection_pool().await;
4323    let non_participants =
4324        pool.channel_connection_ids(channel_id)
4325            .filter_map(|(connection_id, _)| {
4326                if participant_connection_ids.contains(&connection_id) {
4327                    None
4328                } else {
4329                    Some(connection_id)
4330                }
4331            });
4332    broadcast(None, non_participants, |peer_id| {
4333        session.peer.send(
4334            peer_id,
4335            proto::UpdateChannels {
4336                latest_channel_message_ids: vec![proto::ChannelMessageId {
4337                    channel_id: channel_id.to_proto(),
4338                    message_id: message_id.to_proto(),
4339                }],
4340                ..Default::default()
4341            },
4342        )
4343    });
4344    send_notifications(pool, &session.peer, notifications);
4345
4346    Ok(())
4347}
4348
4349/// Delete a channel message
4350async fn remove_channel_message(
4351    request: proto::RemoveChannelMessage,
4352    response: Response<proto::RemoveChannelMessage>,
4353    session: UserSession,
4354) -> Result<()> {
4355    let channel_id = ChannelId::from_proto(request.channel_id);
4356    let message_id = MessageId::from_proto(request.message_id);
4357    let (connection_ids, existing_notification_ids) = session
4358        .db()
4359        .await
4360        .remove_channel_message(channel_id, message_id, session.user_id())
4361        .await?;
4362
4363    broadcast(
4364        Some(session.connection_id),
4365        connection_ids,
4366        move |connection| {
4367            session.peer.send(connection, request.clone())?;
4368
4369            for notification_id in &existing_notification_ids {
4370                session.peer.send(
4371                    connection,
4372                    proto::DeleteNotification {
4373                        notification_id: (*notification_id).to_proto(),
4374                    },
4375                )?;
4376            }
4377
4378            Ok(())
4379        },
4380    );
4381    response.send(proto::Ack {})?;
4382    Ok(())
4383}
4384
4385async fn update_channel_message(
4386    request: proto::UpdateChannelMessage,
4387    response: Response<proto::UpdateChannelMessage>,
4388    session: UserSession,
4389) -> Result<()> {
4390    let channel_id = ChannelId::from_proto(request.channel_id);
4391    let message_id = MessageId::from_proto(request.message_id);
4392    let updated_at = OffsetDateTime::now_utc();
4393    let UpdatedChannelMessage {
4394        message_id,
4395        participant_connection_ids,
4396        notifications,
4397        reply_to_message_id,
4398        timestamp,
4399        deleted_mention_notification_ids,
4400        updated_mention_notifications,
4401    } = session
4402        .db()
4403        .await
4404        .update_channel_message(
4405            channel_id,
4406            message_id,
4407            session.user_id(),
4408            request.body.as_str(),
4409            &request.mentions,
4410            updated_at,
4411        )
4412        .await?;
4413
4414    let nonce = request
4415        .nonce
4416        .clone()
4417        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4418
4419    let message = proto::ChannelMessage {
4420        sender_id: session.user_id().to_proto(),
4421        id: message_id.to_proto(),
4422        body: request.body.clone(),
4423        mentions: request.mentions.clone(),
4424        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4425        nonce: Some(nonce),
4426        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4427        edited_at: Some(updated_at.unix_timestamp() as u64),
4428    };
4429
4430    response.send(proto::Ack {})?;
4431
4432    let pool = &*session.connection_pool().await;
4433    broadcast(
4434        Some(session.connection_id),
4435        participant_connection_ids,
4436        |connection| {
4437            session.peer.send(
4438                connection,
4439                proto::ChannelMessageUpdate {
4440                    channel_id: channel_id.to_proto(),
4441                    message: Some(message.clone()),
4442                },
4443            )?;
4444
4445            for notification_id in &deleted_mention_notification_ids {
4446                session.peer.send(
4447                    connection,
4448                    proto::DeleteNotification {
4449                        notification_id: (*notification_id).to_proto(),
4450                    },
4451                )?;
4452            }
4453
4454            for notification in &updated_mention_notifications {
4455                session.peer.send(
4456                    connection,
4457                    proto::UpdateNotification {
4458                        notification: Some(notification.clone()),
4459                    },
4460                )?;
4461            }
4462
4463            Ok(())
4464        },
4465    );
4466
4467    send_notifications(pool, &session.peer, notifications);
4468
4469    Ok(())
4470}
4471
4472/// Mark a channel message as read
4473async fn acknowledge_channel_message(
4474    request: proto::AckChannelMessage,
4475    session: UserSession,
4476) -> Result<()> {
4477    let channel_id = ChannelId::from_proto(request.channel_id);
4478    let message_id = MessageId::from_proto(request.message_id);
4479    let notifications = session
4480        .db()
4481        .await
4482        .observe_channel_message(channel_id, session.user_id(), message_id)
4483        .await?;
4484    send_notifications(
4485        &*session.connection_pool().await,
4486        &session.peer,
4487        notifications,
4488    );
4489    Ok(())
4490}
4491
4492/// Mark a buffer version as synced
4493async fn acknowledge_buffer_version(
4494    request: proto::AckBufferOperation,
4495    session: UserSession,
4496) -> Result<()> {
4497    let buffer_id = BufferId::from_proto(request.buffer_id);
4498    session
4499        .db()
4500        .await
4501        .observe_buffer_version(
4502            buffer_id,
4503            session.user_id(),
4504            request.epoch as i32,
4505            &request.version,
4506        )
4507        .await?;
4508    Ok(())
4509}
4510
4511struct CompleteWithLanguageModelRateLimit;
4512
4513impl RateLimit for CompleteWithLanguageModelRateLimit {
4514    fn capacity() -> usize {
4515        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4516            .ok()
4517            .and_then(|v| v.parse().ok())
4518            .unwrap_or(120) // Picked arbitrarily
4519    }
4520
4521    fn refill_duration() -> chrono::Duration {
4522        chrono::Duration::hours(1)
4523    }
4524
4525    fn db_name() -> &'static str {
4526        "complete-with-language-model"
4527    }
4528}
4529
4530async fn complete_with_language_model(
4531    request: proto::CompleteWithLanguageModel,
4532    response: Response<proto::CompleteWithLanguageModel>,
4533    session: Session,
4534    config: &Config,
4535) -> Result<()> {
4536    let Some(session) = session.for_user() else {
4537        return Err(anyhow!("user not found"))?;
4538    };
4539    authorize_access_to_language_models(&session).await?;
4540
4541    session
4542        .rate_limiter
4543        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4544        .await?;
4545
4546    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4547        Some(proto::LanguageModelProvider::Anthropic) => {
4548            let api_key = config
4549                .anthropic_api_key
4550                .as_ref()
4551                .context("no Anthropic AI API key configured on the server")?;
4552            anthropic::complete(
4553                session.http_client.as_ref(),
4554                anthropic::ANTHROPIC_API_URL,
4555                api_key,
4556                serde_json::from_str(&request.request)?,
4557            )
4558            .await?
4559        }
4560        _ => return Err(anyhow!("unsupported provider"))?,
4561    };
4562
4563    response.send(proto::CompleteWithLanguageModelResponse {
4564        completion: serde_json::to_string(&result)?,
4565    })?;
4566
4567    Ok(())
4568}
4569
4570async fn stream_complete_with_language_model(
4571    request: proto::StreamCompleteWithLanguageModel,
4572    response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4573    session: Session,
4574    config: &Config,
4575) -> Result<()> {
4576    let Some(session) = session.for_user() else {
4577        return Err(anyhow!("user not found"))?;
4578    };
4579    authorize_access_to_language_models(&session).await?;
4580
4581    session
4582        .rate_limiter
4583        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4584        .await?;
4585
4586    match proto::LanguageModelProvider::from_i32(request.provider) {
4587        Some(proto::LanguageModelProvider::Anthropic) => {
4588            let api_key = config
4589                .anthropic_api_key
4590                .as_ref()
4591                .context("no Anthropic AI API key configured on the server")?;
4592            let mut chunks = anthropic::stream_completion(
4593                session.http_client.as_ref(),
4594                anthropic::ANTHROPIC_API_URL,
4595                api_key,
4596                serde_json::from_str(&request.request)?,
4597                None,
4598            )
4599            .await?;
4600            while let Some(event) = chunks.next().await {
4601                let chunk = event?;
4602                response.send(proto::StreamCompleteWithLanguageModelResponse {
4603                    event: serde_json::to_string(&chunk)?,
4604                })?;
4605            }
4606        }
4607        Some(proto::LanguageModelProvider::OpenAi) => {
4608            let api_key = config
4609                .openai_api_key
4610                .as_ref()
4611                .context("no OpenAI API key configured on the server")?;
4612            let mut events = open_ai::stream_completion(
4613                session.http_client.as_ref(),
4614                open_ai::OPEN_AI_API_URL,
4615                api_key,
4616                serde_json::from_str(&request.request)?,
4617                None,
4618            )
4619            .await?;
4620            while let Some(event) = events.next().await {
4621                let event = event?;
4622                response.send(proto::StreamCompleteWithLanguageModelResponse {
4623                    event: serde_json::to_string(&event)?,
4624                })?;
4625            }
4626        }
4627        Some(proto::LanguageModelProvider::Google) => {
4628            let api_key = config
4629                .google_ai_api_key
4630                .as_ref()
4631                .context("no Google AI API key configured on the server")?;
4632            let mut events = google_ai::stream_generate_content(
4633                session.http_client.as_ref(),
4634                google_ai::API_URL,
4635                api_key,
4636                serde_json::from_str(&request.request)?,
4637            )
4638            .await?;
4639            while let Some(event) = events.next().await {
4640                let event = event?;
4641                response.send(proto::StreamCompleteWithLanguageModelResponse {
4642                    event: serde_json::to_string(&event)?,
4643                })?;
4644            }
4645        }
4646        None => return Err(anyhow!("unknown provider"))?,
4647    }
4648
4649    Ok(())
4650}
4651
4652async fn count_language_model_tokens(
4653    request: proto::CountLanguageModelTokens,
4654    response: Response<proto::CountLanguageModelTokens>,
4655    session: Session,
4656    config: &Config,
4657) -> Result<()> {
4658    let Some(session) = session.for_user() else {
4659        return Err(anyhow!("user not found"))?;
4660    };
4661    authorize_access_to_language_models(&session).await?;
4662
4663    session
4664        .rate_limiter
4665        .check::<CountLanguageModelTokensRateLimit>(session.user_id())
4666        .await?;
4667
4668    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4669        Some(proto::LanguageModelProvider::Google) => {
4670            let api_key = config
4671                .google_ai_api_key
4672                .as_ref()
4673                .context("no Google AI API key configured on the server")?;
4674            google_ai::count_tokens(
4675                session.http_client.as_ref(),
4676                google_ai::API_URL,
4677                api_key,
4678                serde_json::from_str(&request.request)?,
4679            )
4680            .await?
4681        }
4682        _ => return Err(anyhow!("unsupported provider"))?,
4683    };
4684
4685    response.send(proto::CountLanguageModelTokensResponse {
4686        token_count: result.total_tokens as u32,
4687    })?;
4688
4689    Ok(())
4690}
4691
4692struct CountLanguageModelTokensRateLimit;
4693
4694impl RateLimit for CountLanguageModelTokensRateLimit {
4695    fn capacity() -> usize {
4696        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4697            .ok()
4698            .and_then(|v| v.parse().ok())
4699            .unwrap_or(600) // Picked arbitrarily
4700    }
4701
4702    fn refill_duration() -> chrono::Duration {
4703        chrono::Duration::hours(1)
4704    }
4705
4706    fn db_name() -> &'static str {
4707        "count-language-model-tokens"
4708    }
4709}
4710
4711struct ComputeEmbeddingsRateLimit;
4712
4713impl RateLimit for ComputeEmbeddingsRateLimit {
4714    fn capacity() -> usize {
4715        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4716            .ok()
4717            .and_then(|v| v.parse().ok())
4718            .unwrap_or(5000) // Picked arbitrarily
4719    }
4720
4721    fn refill_duration() -> chrono::Duration {
4722        chrono::Duration::hours(1)
4723    }
4724
4725    fn db_name() -> &'static str {
4726        "compute-embeddings"
4727    }
4728}
4729
4730async fn compute_embeddings(
4731    request: proto::ComputeEmbeddings,
4732    response: Response<proto::ComputeEmbeddings>,
4733    session: UserSession,
4734    api_key: Option<Arc<str>>,
4735) -> Result<()> {
4736    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4737    authorize_access_to_language_models(&session).await?;
4738
4739    session
4740        .rate_limiter
4741        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4742        .await?;
4743
4744    let embeddings = match request.model.as_str() {
4745        "openai/text-embedding-3-small" => {
4746            open_ai::embed(
4747                session.http_client.as_ref(),
4748                OPEN_AI_API_URL,
4749                &api_key,
4750                OpenAiEmbeddingModel::TextEmbedding3Small,
4751                request.texts.iter().map(|text| text.as_str()),
4752            )
4753            .await?
4754        }
4755        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4756    };
4757
4758    let embeddings = request
4759        .texts
4760        .iter()
4761        .map(|text| {
4762            let mut hasher = sha2::Sha256::new();
4763            hasher.update(text.as_bytes());
4764            let result = hasher.finalize();
4765            result.to_vec()
4766        })
4767        .zip(
4768            embeddings
4769                .data
4770                .into_iter()
4771                .map(|embedding| embedding.embedding),
4772        )
4773        .collect::<HashMap<_, _>>();
4774
4775    let db = session.db().await;
4776    db.save_embeddings(&request.model, &embeddings)
4777        .await
4778        .context("failed to save embeddings")
4779        .trace_err();
4780
4781    response.send(proto::ComputeEmbeddingsResponse {
4782        embeddings: embeddings
4783            .into_iter()
4784            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4785            .collect(),
4786    })?;
4787    Ok(())
4788}
4789
4790async fn get_cached_embeddings(
4791    request: proto::GetCachedEmbeddings,
4792    response: Response<proto::GetCachedEmbeddings>,
4793    session: UserSession,
4794) -> Result<()> {
4795    authorize_access_to_language_models(&session).await?;
4796
4797    let db = session.db().await;
4798    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4799
4800    response.send(proto::GetCachedEmbeddingsResponse {
4801        embeddings: embeddings
4802            .into_iter()
4803            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4804            .collect(),
4805    })?;
4806    Ok(())
4807}
4808
4809async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4810    let db = session.db().await;
4811    let flags = db.get_user_flags(session.user_id()).await?;
4812    if flags.iter().any(|flag| flag == "language-models") {
4813        Ok(())
4814    } else {
4815        Err(anyhow!("permission denied"))?
4816    }
4817}
4818
4819/// Get a Supermaven API key for the user
4820async fn get_supermaven_api_key(
4821    _request: proto::GetSupermavenApiKey,
4822    response: Response<proto::GetSupermavenApiKey>,
4823    session: UserSession,
4824) -> Result<()> {
4825    let user_id: String = session.user_id().to_string();
4826    if !session.is_staff() {
4827        return Err(anyhow!("supermaven not enabled for this account"))?;
4828    }
4829
4830    let email = session
4831        .email()
4832        .ok_or_else(|| anyhow!("user must have an email"))?;
4833
4834    let supermaven_admin_api = session
4835        .supermaven_client
4836        .as_ref()
4837        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4838
4839    let result = supermaven_admin_api
4840        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4841        .await?;
4842
4843    response.send(proto::GetSupermavenApiKeyResponse {
4844        api_key: result.api_key,
4845    })?;
4846
4847    Ok(())
4848}
4849
4850/// Start receiving chat updates for a channel
4851async fn join_channel_chat(
4852    request: proto::JoinChannelChat,
4853    response: Response<proto::JoinChannelChat>,
4854    session: UserSession,
4855) -> Result<()> {
4856    let channel_id = ChannelId::from_proto(request.channel_id);
4857
4858    let db = session.db().await;
4859    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4860        .await?;
4861    let messages = db
4862        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4863        .await?;
4864    response.send(proto::JoinChannelChatResponse {
4865        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4866        messages,
4867    })?;
4868    Ok(())
4869}
4870
4871/// Stop receiving chat updates for a channel
4872async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4873    let channel_id = ChannelId::from_proto(request.channel_id);
4874    session
4875        .db()
4876        .await
4877        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4878        .await?;
4879    Ok(())
4880}
4881
4882/// Retrieve the chat history for a channel
4883async fn get_channel_messages(
4884    request: proto::GetChannelMessages,
4885    response: Response<proto::GetChannelMessages>,
4886    session: UserSession,
4887) -> Result<()> {
4888    let channel_id = ChannelId::from_proto(request.channel_id);
4889    let messages = session
4890        .db()
4891        .await
4892        .get_channel_messages(
4893            channel_id,
4894            session.user_id(),
4895            MESSAGE_COUNT_PER_PAGE,
4896            Some(MessageId::from_proto(request.before_message_id)),
4897        )
4898        .await?;
4899    response.send(proto::GetChannelMessagesResponse {
4900        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4901        messages,
4902    })?;
4903    Ok(())
4904}
4905
4906/// Retrieve specific chat messages
4907async fn get_channel_messages_by_id(
4908    request: proto::GetChannelMessagesById,
4909    response: Response<proto::GetChannelMessagesById>,
4910    session: UserSession,
4911) -> Result<()> {
4912    let message_ids = request
4913        .message_ids
4914        .iter()
4915        .map(|id| MessageId::from_proto(*id))
4916        .collect::<Vec<_>>();
4917    let messages = session
4918        .db()
4919        .await
4920        .get_channel_messages_by_id(session.user_id(), &message_ids)
4921        .await?;
4922    response.send(proto::GetChannelMessagesResponse {
4923        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4924        messages,
4925    })?;
4926    Ok(())
4927}
4928
4929/// Retrieve the current users notifications
4930async fn get_notifications(
4931    request: proto::GetNotifications,
4932    response: Response<proto::GetNotifications>,
4933    session: UserSession,
4934) -> Result<()> {
4935    let notifications = session
4936        .db()
4937        .await
4938        .get_notifications(
4939            session.user_id(),
4940            NOTIFICATION_COUNT_PER_PAGE,
4941            request
4942                .before_id
4943                .map(|id| db::NotificationId::from_proto(id)),
4944        )
4945        .await?;
4946    response.send(proto::GetNotificationsResponse {
4947        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4948        notifications,
4949    })?;
4950    Ok(())
4951}
4952
4953/// Mark notifications as read
4954async fn mark_notification_as_read(
4955    request: proto::MarkNotificationRead,
4956    response: Response<proto::MarkNotificationRead>,
4957    session: UserSession,
4958) -> Result<()> {
4959    let database = &session.db().await;
4960    let notifications = database
4961        .mark_notification_as_read_by_id(
4962            session.user_id(),
4963            NotificationId::from_proto(request.notification_id),
4964        )
4965        .await?;
4966    send_notifications(
4967        &*session.connection_pool().await,
4968        &session.peer,
4969        notifications,
4970    );
4971    response.send(proto::Ack {})?;
4972    Ok(())
4973}
4974
4975/// Get the current users information
4976async fn get_private_user_info(
4977    _request: proto::GetPrivateUserInfo,
4978    response: Response<proto::GetPrivateUserInfo>,
4979    session: UserSession,
4980) -> Result<()> {
4981    let db = session.db().await;
4982
4983    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4984    let user = db
4985        .get_user_by_id(session.user_id())
4986        .await?
4987        .ok_or_else(|| anyhow!("user not found"))?;
4988    let flags = db.get_user_flags(session.user_id()).await?;
4989
4990    response.send(proto::GetPrivateUserInfoResponse {
4991        metrics_id,
4992        staff: user.admin,
4993        flags,
4994    })?;
4995    Ok(())
4996}
4997
4998fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4999    let message = match message {
5000        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5001        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5002        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5003        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5004        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5005            code: frame.code.into(),
5006            reason: frame.reason,
5007        })),
5008        // We should never receive a frame while reading the message, according
5009        // to the `tungstenite` maintainers:
5010        //
5011        // > It cannot occur when you read messages from the WebSocket, but it
5012        // > can be used when you want to send the raw frames (e.g. you want to
5013        // > send the frames to the WebSocket without composing the full message first).
5014        // >
5015        // > — https://github.com/snapview/tungstenite-rs/issues/268
5016        TungsteniteMessage::Frame(_) => {
5017            bail!("received an unexpected frame while reading the message")
5018        }
5019    };
5020
5021    Ok(message)
5022}
5023
5024fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5025    match message {
5026        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5027        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5028        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5029        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5030        AxumMessage::Close(frame) => {
5031            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5032                code: frame.code.into(),
5033                reason: frame.reason,
5034            }))
5035        }
5036    }
5037}
5038
5039fn notify_membership_updated(
5040    connection_pool: &mut ConnectionPool,
5041    result: MembershipUpdated,
5042    user_id: UserId,
5043    peer: &Peer,
5044) {
5045    for membership in &result.new_channels.channel_memberships {
5046        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5047    }
5048    for channel_id in &result.removed_channels {
5049        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5050    }
5051
5052    let user_channels_update = proto::UpdateUserChannels {
5053        channel_memberships: result
5054            .new_channels
5055            .channel_memberships
5056            .iter()
5057            .map(|cm| proto::ChannelMembership {
5058                channel_id: cm.channel_id.to_proto(),
5059                role: cm.role.into(),
5060            })
5061            .collect(),
5062        ..Default::default()
5063    };
5064
5065    let mut update = build_channels_update(result.new_channels);
5066    update.delete_channels = result
5067        .removed_channels
5068        .into_iter()
5069        .map(|id| id.to_proto())
5070        .collect();
5071    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5072
5073    for connection_id in connection_pool.user_connection_ids(user_id) {
5074        peer.send(connection_id, user_channels_update.clone())
5075            .trace_err();
5076        peer.send(connection_id, update.clone()).trace_err();
5077    }
5078}
5079
5080fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5081    proto::UpdateUserChannels {
5082        channel_memberships: channels
5083            .channel_memberships
5084            .iter()
5085            .map(|m| proto::ChannelMembership {
5086                channel_id: m.channel_id.to_proto(),
5087                role: m.role.into(),
5088            })
5089            .collect(),
5090        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5091        observed_channel_message_id: channels.observed_channel_messages.clone(),
5092    }
5093}
5094
5095fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5096    let mut update = proto::UpdateChannels::default();
5097
5098    for channel in channels.channels {
5099        update.channels.push(channel.to_proto());
5100    }
5101
5102    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5103    update.latest_channel_message_ids = channels.latest_channel_messages;
5104
5105    for (channel_id, participants) in channels.channel_participants {
5106        update
5107            .channel_participants
5108            .push(proto::ChannelParticipants {
5109                channel_id: channel_id.to_proto(),
5110                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5111            });
5112    }
5113
5114    for channel in channels.invited_channels {
5115        update.channel_invitations.push(channel.to_proto());
5116    }
5117
5118    update.hosted_projects = channels.hosted_projects;
5119    update
5120}
5121
5122fn build_initial_contacts_update(
5123    contacts: Vec<db::Contact>,
5124    pool: &ConnectionPool,
5125) -> proto::UpdateContacts {
5126    let mut update = proto::UpdateContacts::default();
5127
5128    for contact in contacts {
5129        match contact {
5130            db::Contact::Accepted { user_id, busy } => {
5131                update.contacts.push(contact_for_user(user_id, busy, &pool));
5132            }
5133            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5134            db::Contact::Incoming { user_id } => {
5135                update
5136                    .incoming_requests
5137                    .push(proto::IncomingContactRequest {
5138                        requester_id: user_id.to_proto(),
5139                    })
5140            }
5141        }
5142    }
5143
5144    update
5145}
5146
5147fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5148    proto::Contact {
5149        user_id: user_id.to_proto(),
5150        online: pool.is_user_online(user_id),
5151        busy,
5152    }
5153}
5154
5155fn room_updated(room: &proto::Room, peer: &Peer) {
5156    broadcast(
5157        None,
5158        room.participants
5159            .iter()
5160            .filter_map(|participant| Some(participant.peer_id?.into())),
5161        |peer_id| {
5162            peer.send(
5163                peer_id,
5164                proto::RoomUpdated {
5165                    room: Some(room.clone()),
5166                },
5167            )
5168        },
5169    );
5170}
5171
5172fn channel_updated(
5173    channel: &db::channel::Model,
5174    room: &proto::Room,
5175    peer: &Peer,
5176    pool: &ConnectionPool,
5177) {
5178    let participants = room
5179        .participants
5180        .iter()
5181        .map(|p| p.user_id)
5182        .collect::<Vec<_>>();
5183
5184    broadcast(
5185        None,
5186        pool.channel_connection_ids(channel.root_id())
5187            .filter_map(|(channel_id, role)| {
5188                role.can_see_channel(channel.visibility).then(|| channel_id)
5189            }),
5190        |peer_id| {
5191            peer.send(
5192                peer_id,
5193                proto::UpdateChannels {
5194                    channel_participants: vec![proto::ChannelParticipants {
5195                        channel_id: channel.id.to_proto(),
5196                        participant_user_ids: participants.clone(),
5197                    }],
5198                    ..Default::default()
5199                },
5200            )
5201        },
5202    );
5203}
5204
5205async fn send_dev_server_projects_update(
5206    user_id: UserId,
5207    mut status: proto::DevServerProjectsUpdate,
5208    session: &Session,
5209) {
5210    let pool = session.connection_pool().await;
5211    for dev_server in &mut status.dev_servers {
5212        dev_server.status =
5213            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5214    }
5215    let connections = pool.user_connection_ids(user_id);
5216    for connection_id in connections {
5217        session.peer.send(connection_id, status.clone()).trace_err();
5218    }
5219}
5220
5221async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5222    let db = session.db().await;
5223
5224    let contacts = db.get_contacts(user_id).await?;
5225    let busy = db.is_user_busy(user_id).await?;
5226
5227    let pool = session.connection_pool().await;
5228    let updated_contact = contact_for_user(user_id, busy, &pool);
5229    for contact in contacts {
5230        if let db::Contact::Accepted {
5231            user_id: contact_user_id,
5232            ..
5233        } = contact
5234        {
5235            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5236                session
5237                    .peer
5238                    .send(
5239                        contact_conn_id,
5240                        proto::UpdateContacts {
5241                            contacts: vec![updated_contact.clone()],
5242                            remove_contacts: Default::default(),
5243                            incoming_requests: Default::default(),
5244                            remove_incoming_requests: Default::default(),
5245                            outgoing_requests: Default::default(),
5246                            remove_outgoing_requests: Default::default(),
5247                        },
5248                    )
5249                    .trace_err();
5250            }
5251        }
5252    }
5253    Ok(())
5254}
5255
5256async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5257    log::info!("lost dev server connection, unsharing projects");
5258    let project_ids = session
5259        .db()
5260        .await
5261        .get_stale_dev_server_projects(session.connection_id)
5262        .await?;
5263
5264    for project_id in project_ids {
5265        // not unshare re-checks the connection ids match, so we get away with no transaction
5266        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5267    }
5268
5269    let user_id = session.dev_server().user_id;
5270    let update = session
5271        .db()
5272        .await
5273        .dev_server_projects_update(user_id)
5274        .await?;
5275
5276    send_dev_server_projects_update(user_id, update, session).await;
5277
5278    Ok(())
5279}
5280
5281async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5282    let mut contacts_to_update = HashSet::default();
5283
5284    let room_id;
5285    let canceled_calls_to_user_ids;
5286    let live_kit_room;
5287    let delete_live_kit_room;
5288    let room;
5289    let channel;
5290
5291    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5292        contacts_to_update.insert(session.user_id());
5293
5294        for project in left_room.left_projects.values() {
5295            project_left(project, session);
5296        }
5297
5298        room_id = RoomId::from_proto(left_room.room.id);
5299        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5300        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5301        delete_live_kit_room = left_room.deleted;
5302        room = mem::take(&mut left_room.room);
5303        channel = mem::take(&mut left_room.channel);
5304
5305        room_updated(&room, &session.peer);
5306    } else {
5307        return Ok(());
5308    }
5309
5310    if let Some(channel) = channel {
5311        channel_updated(
5312            &channel,
5313            &room,
5314            &session.peer,
5315            &*session.connection_pool().await,
5316        );
5317    }
5318
5319    {
5320        let pool = session.connection_pool().await;
5321        for canceled_user_id in canceled_calls_to_user_ids {
5322            for connection_id in pool.user_connection_ids(canceled_user_id) {
5323                session
5324                    .peer
5325                    .send(
5326                        connection_id,
5327                        proto::CallCanceled {
5328                            room_id: room_id.to_proto(),
5329                        },
5330                    )
5331                    .trace_err();
5332            }
5333            contacts_to_update.insert(canceled_user_id);
5334        }
5335    }
5336
5337    for contact_user_id in contacts_to_update {
5338        update_user_contacts(contact_user_id, &session).await?;
5339    }
5340
5341    if let Some(live_kit) = session.live_kit_client.as_ref() {
5342        live_kit
5343            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5344            .await
5345            .trace_err();
5346
5347        if delete_live_kit_room {
5348            live_kit.delete_room(live_kit_room).await.trace_err();
5349        }
5350    }
5351
5352    Ok(())
5353}
5354
5355async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5356    let left_channel_buffers = session
5357        .db()
5358        .await
5359        .leave_channel_buffers(session.connection_id)
5360        .await?;
5361
5362    for left_buffer in left_channel_buffers {
5363        channel_buffer_updated(
5364            session.connection_id,
5365            left_buffer.connections,
5366            &proto::UpdateChannelBufferCollaborators {
5367                channel_id: left_buffer.channel_id.to_proto(),
5368                collaborators: left_buffer.collaborators,
5369            },
5370            &session.peer,
5371        );
5372    }
5373
5374    Ok(())
5375}
5376
5377fn project_left(project: &db::LeftProject, session: &UserSession) {
5378    for connection_id in &project.connection_ids {
5379        if project.should_unshare {
5380            session
5381                .peer
5382                .send(
5383                    *connection_id,
5384                    proto::UnshareProject {
5385                        project_id: project.id.to_proto(),
5386                    },
5387                )
5388                .trace_err();
5389        } else {
5390            session
5391                .peer
5392                .send(
5393                    *connection_id,
5394                    proto::RemoveProjectCollaborator {
5395                        project_id: project.id.to_proto(),
5396                        peer_id: Some(session.connection_id.into()),
5397                    },
5398                )
5399                .trace_err();
5400        }
5401    }
5402}
5403
5404pub trait ResultExt {
5405    type Ok;
5406
5407    fn trace_err(self) -> Option<Self::Ok>;
5408}
5409
5410impl<T, E> ResultExt for Result<T, E>
5411where
5412    E: std::fmt::Debug,
5413{
5414    type Ok = T;
5415
5416    #[track_caller]
5417    fn trace_err(self) -> Option<T> {
5418        match self {
5419            Ok(value) => Some(value),
5420            Err(error) => {
5421                tracing::error!("{:?}", error);
5422                None
5423            }
5424        }
5425    }
5426}