rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Config, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, bail, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http_client::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(update_dev_server_project))
 435            .add_request_handler(user_handler(delete_dev_server_project))
 436            .add_request_handler(user_handler(create_dev_server))
 437            .add_request_handler(user_handler(regenerate_dev_server_token))
 438            .add_request_handler(user_handler(rename_dev_server))
 439            .add_request_handler(user_handler(delete_dev_server))
 440            .add_request_handler(user_handler(list_remote_directory))
 441            .add_request_handler(dev_server_handler(share_dev_server_project))
 442            .add_request_handler(dev_server_handler(shutdown_dev_server))
 443            .add_request_handler(dev_server_handler(reconnect_dev_server))
 444            .add_message_handler(user_message_handler(leave_project))
 445            .add_request_handler(update_project)
 446            .add_request_handler(update_worktree)
 447            .add_message_handler(start_language_server)
 448            .add_message_handler(update_language_server)
 449            .add_message_handler(update_diagnostic_summary)
 450            .add_message_handler(update_worktree_settings)
 451            .add_request_handler(user_handler(
 452                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 453            ))
 454            .add_request_handler(user_handler(
 455                forward_project_request_for_owner::<proto::TaskTemplates>,
 456            ))
 457            .add_request_handler(user_handler(
 458                forward_read_only_project_request::<proto::GetHover>,
 459            ))
 460            .add_request_handler(user_handler(
 461                forward_read_only_project_request::<proto::GetDefinition>,
 462            ))
 463            .add_request_handler(user_handler(
 464                forward_read_only_project_request::<proto::GetTypeDefinition>,
 465            ))
 466            .add_request_handler(user_handler(
 467                forward_read_only_project_request::<proto::GetReferences>,
 468            ))
 469            .add_request_handler(user_handler(
 470                forward_read_only_project_request::<proto::SearchProject>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetProjectSymbols>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::OpenBufferById>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::InlayHints>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::OpenBufferByPath>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_mutating_project_request::<proto::GetCompletions>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::GetCodeActions>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::ApplyCodeAction>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::PrepareRename>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::PerformRename>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ReloadBuffers>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::FormatBuffers>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::CreateProjectEntry>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::RenameProjectEntry>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::CopyProjectEntry>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::OnTypeFormatting>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::BlameBuffer>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::MultiLspQuery>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::RestartLanguageServers>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::LinkedEditingRange>,
 555            ))
 556            .add_message_handler(create_buffer_for_peer)
 557            .add_request_handler(update_buffer)
 558            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 559            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 560            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 561            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 562            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 563            .add_request_handler(get_users)
 564            .add_request_handler(user_handler(fuzzy_search_users))
 565            .add_request_handler(user_handler(request_contact))
 566            .add_request_handler(user_handler(remove_contact))
 567            .add_request_handler(user_handler(respond_to_contact_request))
 568            .add_message_handler(subscribe_to_channels)
 569            .add_request_handler(user_handler(create_channel))
 570            .add_request_handler(user_handler(delete_channel))
 571            .add_request_handler(user_handler(invite_channel_member))
 572            .add_request_handler(user_handler(remove_channel_member))
 573            .add_request_handler(user_handler(set_channel_member_role))
 574            .add_request_handler(user_handler(set_channel_visibility))
 575            .add_request_handler(user_handler(rename_channel))
 576            .add_request_handler(user_handler(join_channel_buffer))
 577            .add_request_handler(user_handler(leave_channel_buffer))
 578            .add_message_handler(user_message_handler(update_channel_buffer))
 579            .add_request_handler(user_handler(rejoin_channel_buffers))
 580            .add_request_handler(user_handler(get_channel_members))
 581            .add_request_handler(user_handler(respond_to_channel_invite))
 582            .add_request_handler(user_handler(join_channel))
 583            .add_request_handler(user_handler(join_channel_chat))
 584            .add_message_handler(user_message_handler(leave_channel_chat))
 585            .add_request_handler(user_handler(send_channel_message))
 586            .add_request_handler(user_handler(remove_channel_message))
 587            .add_request_handler(user_handler(update_channel_message))
 588            .add_request_handler(user_handler(get_channel_messages))
 589            .add_request_handler(user_handler(get_channel_messages_by_id))
 590            .add_request_handler(user_handler(get_notifications))
 591            .add_request_handler(user_handler(mark_notification_as_read))
 592            .add_request_handler(user_handler(move_channel))
 593            .add_request_handler(user_handler(follow))
 594            .add_message_handler(user_message_handler(unfollow))
 595            .add_message_handler(user_message_handler(update_followers))
 596            .add_request_handler(user_handler(get_private_user_info))
 597            .add_message_handler(user_message_handler(acknowledge_channel_message))
 598            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 599            .add_request_handler(user_handler(get_supermaven_api_key))
 600            .add_request_handler(user_handler(
 601                forward_mutating_project_request::<proto::OpenContext>,
 602            ))
 603            .add_request_handler(user_handler(
 604                forward_mutating_project_request::<proto::SynchronizeContexts>,
 605            ))
 606            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 607            .add_message_handler(update_context)
 608            .add_request_handler({
 609                let app_state = app_state.clone();
 610                move |request, response, session| {
 611                    let app_state = app_state.clone();
 612                    async move {
 613                        complete_with_language_model(request, response, session, &app_state.config)
 614                            .await
 615                    }
 616                }
 617            })
 618            .add_streaming_request_handler({
 619                let app_state = app_state.clone();
 620                move |request, response, session| {
 621                    let app_state = app_state.clone();
 622                    async move {
 623                        stream_complete_with_language_model(
 624                            request,
 625                            response,
 626                            session,
 627                            &app_state.config,
 628                        )
 629                        .await
 630                    }
 631                }
 632            })
 633            .add_request_handler({
 634                let app_state = app_state.clone();
 635                move |request, response, session| {
 636                    let app_state = app_state.clone();
 637                    async move {
 638                        count_language_model_tokens(request, response, session, &app_state.config)
 639                            .await
 640                    }
 641                }
 642            })
 643            .add_request_handler({
 644                user_handler(move |request, response, session| {
 645                    get_cached_embeddings(request, response, session)
 646                })
 647            })
 648            .add_request_handler({
 649                let app_state = app_state.clone();
 650                user_handler(move |request, response, session| {
 651                    compute_embeddings(
 652                        request,
 653                        response,
 654                        session,
 655                        app_state.config.openai_api_key.clone(),
 656                    )
 657                })
 658            });
 659
 660        Arc::new(server)
 661    }
 662
 663    pub async fn start(&self) -> Result<()> {
 664        let server_id = *self.id.lock();
 665        let app_state = self.app_state.clone();
 666        let peer = self.peer.clone();
 667        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 668        let pool = self.connection_pool.clone();
 669        let live_kit_client = self.app_state.live_kit_client.clone();
 670
 671        let span = info_span!("start server");
 672        self.app_state.executor.spawn_detached(
 673            async move {
 674                tracing::info!("waiting for cleanup timeout");
 675                timeout.await;
 676                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 677                if let Some((room_ids, channel_ids)) = app_state
 678                    .db
 679                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 680                    .await
 681                    .trace_err()
 682                {
 683                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 684                    tracing::info!(
 685                        stale_channel_buffer_count = channel_ids.len(),
 686                        "retrieved stale channel buffers"
 687                    );
 688
 689                    for channel_id in channel_ids {
 690                        if let Some(refreshed_channel_buffer) = app_state
 691                            .db
 692                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 693                            .await
 694                            .trace_err()
 695                        {
 696                            for connection_id in refreshed_channel_buffer.connection_ids {
 697                                peer.send(
 698                                    connection_id,
 699                                    proto::UpdateChannelBufferCollaborators {
 700                                        channel_id: channel_id.to_proto(),
 701                                        collaborators: refreshed_channel_buffer
 702                                            .collaborators
 703                                            .clone(),
 704                                    },
 705                                )
 706                                .trace_err();
 707                            }
 708                        }
 709                    }
 710
 711                    for room_id in room_ids {
 712                        let mut contacts_to_update = HashSet::default();
 713                        let mut canceled_calls_to_user_ids = Vec::new();
 714                        let mut live_kit_room = String::new();
 715                        let mut delete_live_kit_room = false;
 716
 717                        if let Some(mut refreshed_room) = app_state
 718                            .db
 719                            .clear_stale_room_participants(room_id, server_id)
 720                            .await
 721                            .trace_err()
 722                        {
 723                            tracing::info!(
 724                                room_id = room_id.0,
 725                                new_participant_count = refreshed_room.room.participants.len(),
 726                                "refreshed room"
 727                            );
 728                            room_updated(&refreshed_room.room, &peer);
 729                            if let Some(channel) = refreshed_room.channel.as_ref() {
 730                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 731                            }
 732                            contacts_to_update
 733                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 734                            contacts_to_update
 735                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 736                            canceled_calls_to_user_ids =
 737                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 738                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 739                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 740                        }
 741
 742                        {
 743                            let pool = pool.lock();
 744                            for canceled_user_id in canceled_calls_to_user_ids {
 745                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 746                                    peer.send(
 747                                        connection_id,
 748                                        proto::CallCanceled {
 749                                            room_id: room_id.to_proto(),
 750                                        },
 751                                    )
 752                                    .trace_err();
 753                                }
 754                            }
 755                        }
 756
 757                        for user_id in contacts_to_update {
 758                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 759                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 760                            if let Some((busy, contacts)) = busy.zip(contacts) {
 761                                let pool = pool.lock();
 762                                let updated_contact = contact_for_user(user_id, busy, &pool);
 763                                for contact in contacts {
 764                                    if let db::Contact::Accepted {
 765                                        user_id: contact_user_id,
 766                                        ..
 767                                    } = contact
 768                                    {
 769                                        for contact_conn_id in
 770                                            pool.user_connection_ids(contact_user_id)
 771                                        {
 772                                            peer.send(
 773                                                contact_conn_id,
 774                                                proto::UpdateContacts {
 775                                                    contacts: vec![updated_contact.clone()],
 776                                                    remove_contacts: Default::default(),
 777                                                    incoming_requests: Default::default(),
 778                                                    remove_incoming_requests: Default::default(),
 779                                                    outgoing_requests: Default::default(),
 780                                                    remove_outgoing_requests: Default::default(),
 781                                                },
 782                                            )
 783                                            .trace_err();
 784                                        }
 785                                    }
 786                                }
 787                            }
 788                        }
 789
 790                        if let Some(live_kit) = live_kit_client.as_ref() {
 791                            if delete_live_kit_room {
 792                                live_kit.delete_room(live_kit_room).await.trace_err();
 793                            }
 794                        }
 795                    }
 796                }
 797
 798                app_state
 799                    .db
 800                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 801                    .await
 802                    .trace_err();
 803            }
 804            .instrument(span),
 805        );
 806        Ok(())
 807    }
 808
 809    pub fn teardown(&self) {
 810        self.peer.teardown();
 811        self.connection_pool.lock().reset();
 812        let _ = self.teardown.send(true);
 813    }
 814
 815    #[cfg(test)]
 816    pub fn reset(&self, id: ServerId) {
 817        self.teardown();
 818        *self.id.lock() = id;
 819        self.peer.reset(id.0 as u32);
 820        let _ = self.teardown.send(false);
 821    }
 822
 823    #[cfg(test)]
 824    pub fn id(&self) -> ServerId {
 825        *self.id.lock()
 826    }
 827
 828    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 829    where
 830        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 831        Fut: 'static + Send + Future<Output = Result<()>>,
 832        M: EnvelopedMessage,
 833    {
 834        let prev_handler = self.handlers.insert(
 835            TypeId::of::<M>(),
 836            Box::new(move |envelope, session| {
 837                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 838                let received_at = envelope.received_at;
 839                tracing::info!("message received");
 840                let start_time = Instant::now();
 841                let future = (handler)(*envelope, session);
 842                async move {
 843                    let result = future.await;
 844                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 845                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 846                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 847                    let payload_type = M::NAME;
 848
 849                    match result {
 850                        Err(error) => {
 851                            tracing::error!(
 852                                ?error,
 853                                total_duration_ms,
 854                                processing_duration_ms,
 855                                queue_duration_ms,
 856                                payload_type,
 857                                "error handling message"
 858                            )
 859                        }
 860                        Ok(()) => tracing::info!(
 861                            total_duration_ms,
 862                            processing_duration_ms,
 863                            queue_duration_ms,
 864                            "finished handling message"
 865                        ),
 866                    }
 867                }
 868                .boxed()
 869            }),
 870        );
 871        if prev_handler.is_some() {
 872            panic!("registered a handler for the same message twice");
 873        }
 874        self
 875    }
 876
 877    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 878    where
 879        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 880        Fut: 'static + Send + Future<Output = Result<()>>,
 881        M: EnvelopedMessage,
 882    {
 883        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 884        self
 885    }
 886
 887    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 888    where
 889        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 890        Fut: Send + Future<Output = Result<()>>,
 891        M: RequestMessage,
 892    {
 893        let handler = Arc::new(handler);
 894        self.add_handler(move |envelope, session| {
 895            let receipt = envelope.receipt();
 896            let handler = handler.clone();
 897            async move {
 898                let peer = session.peer.clone();
 899                let responded = Arc::new(AtomicBool::default());
 900                let response = Response {
 901                    peer: peer.clone(),
 902                    responded: responded.clone(),
 903                    receipt,
 904                };
 905                match (handler)(envelope.payload, response, session).await {
 906                    Ok(()) => {
 907                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 908                            Ok(())
 909                        } else {
 910                            Err(anyhow!("handler did not send a response"))?
 911                        }
 912                    }
 913                    Err(error) => {
 914                        let proto_err = match &error {
 915                            Error::Internal(err) => err.to_proto(),
 916                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 917                        };
 918                        peer.respond_with_error(receipt, proto_err)?;
 919                        Err(error)
 920                    }
 921                }
 922            }
 923        })
 924    }
 925
 926    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 927    where
 928        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 929        Fut: Send + Future<Output = Result<()>>,
 930        M: RequestMessage,
 931    {
 932        let handler = Arc::new(handler);
 933        self.add_handler(move |envelope, session| {
 934            let receipt = envelope.receipt();
 935            let handler = handler.clone();
 936            async move {
 937                let peer = session.peer.clone();
 938                let response = StreamingResponse {
 939                    peer: peer.clone(),
 940                    receipt,
 941                };
 942                match (handler)(envelope.payload, response, session).await {
 943                    Ok(()) => {
 944                        peer.end_stream(receipt)?;
 945                        Ok(())
 946                    }
 947                    Err(error) => {
 948                        let proto_err = match &error {
 949                            Error::Internal(err) => err.to_proto(),
 950                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 951                        };
 952                        peer.respond_with_error(receipt, proto_err)?;
 953                        Err(error)
 954                    }
 955                }
 956            }
 957        })
 958    }
 959
 960    #[allow(clippy::too_many_arguments)]
 961    pub fn handle_connection(
 962        self: &Arc<Self>,
 963        connection: Connection,
 964        address: String,
 965        principal: Principal,
 966        zed_version: ZedVersion,
 967        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 968        executor: Executor,
 969    ) -> impl Future<Output = ()> {
 970        let this = self.clone();
 971        let span = info_span!("handle connection", %address,
 972            connection_id=field::Empty,
 973            user_id=field::Empty,
 974            login=field::Empty,
 975            impersonator=field::Empty,
 976            dev_server_id=field::Empty
 977        );
 978        principal.update_span(&span);
 979
 980        let mut teardown = self.teardown.subscribe();
 981        async move {
 982            if *teardown.borrow() {
 983                tracing::error!("server is tearing down");
 984                return
 985            }
 986            let (connection_id, handle_io, mut incoming_rx) = this
 987                .peer
 988                .add_connection(connection, {
 989                    let executor = executor.clone();
 990                    move |duration| executor.sleep(duration)
 991                });
 992            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 993            tracing::info!("connection opened");
 994
 995            let http_client = match IsahcHttpClient::new() {
 996                Ok(http_client) => Arc::new(http_client),
 997                Err(error) => {
 998                    tracing::error!(?error, "failed to create HTTP client");
 999                    return;
1000                }
1001            };
1002
1003            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1004                Some(Arc::new(SupermavenAdminApi::new(
1005                    supermaven_admin_api_key.to_string(),
1006                    http_client.clone(),
1007                )))
1008            } else {
1009                None
1010            };
1011
1012            let session = Session {
1013                principal: principal.clone(),
1014                connection_id,
1015                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1016                peer: this.peer.clone(),
1017                connection_pool: this.connection_pool.clone(),
1018                live_kit_client: this.app_state.live_kit_client.clone(),
1019                http_client,
1020                rate_limiter: this.app_state.rate_limiter.clone(),
1021                _executor: executor.clone(),
1022                supermaven_client,
1023            };
1024
1025            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1026                tracing::error!(?error, "failed to send initial client update");
1027                return;
1028            }
1029
1030            let handle_io = handle_io.fuse();
1031            futures::pin_mut!(handle_io);
1032
1033            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1034            // This prevents deadlocks when e.g., client A performs a request to client B and
1035            // client B performs a request to client A. If both clients stop processing further
1036            // messages until their respective request completes, they won't have a chance to
1037            // respond to the other client's request and cause a deadlock.
1038            //
1039            // This arrangement ensures we will attempt to process earlier messages first, but fall
1040            // back to processing messages arrived later in the spirit of making progress.
1041            let mut foreground_message_handlers = FuturesUnordered::new();
1042            let concurrent_handlers = Arc::new(Semaphore::new(256));
1043            loop {
1044                let next_message = async {
1045                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1046                    let message = incoming_rx.next().await;
1047                    (permit, message)
1048                }.fuse();
1049                futures::pin_mut!(next_message);
1050                futures::select_biased! {
1051                    _ = teardown.changed().fuse() => return,
1052                    result = handle_io => {
1053                        if let Err(error) = result {
1054                            tracing::error!(?error, "error handling I/O");
1055                        }
1056                        break;
1057                    }
1058                    _ = foreground_message_handlers.next() => {}
1059                    next_message = next_message => {
1060                        let (permit, message) = next_message;
1061                        if let Some(message) = message {
1062                            let type_name = message.payload_type_name();
1063                            // note: we copy all the fields from the parent span so we can query them in the logs.
1064                            // (https://github.com/tokio-rs/tracing/issues/2670).
1065                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1066                                user_id=field::Empty,
1067                                login=field::Empty,
1068                                impersonator=field::Empty,
1069                                dev_server_id=field::Empty
1070                            );
1071                            principal.update_span(&span);
1072                            let span_enter = span.enter();
1073                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1074                                let is_background = message.is_background();
1075                                let handle_message = (handler)(message, session.clone());
1076                                drop(span_enter);
1077
1078                                let handle_message = async move {
1079                                    handle_message.await;
1080                                    drop(permit);
1081                                }.instrument(span);
1082                                if is_background {
1083                                    executor.spawn_detached(handle_message);
1084                                } else {
1085                                    foreground_message_handlers.push(handle_message);
1086                                }
1087                            } else {
1088                                tracing::error!("no message handler");
1089                            }
1090                        } else {
1091                            tracing::info!("connection closed");
1092                            break;
1093                        }
1094                    }
1095                }
1096            }
1097
1098            drop(foreground_message_handlers);
1099            tracing::info!("signing out");
1100            if let Err(error) = connection_lost(session, teardown, executor).await {
1101                tracing::error!(?error, "error signing out");
1102            }
1103
1104        }.instrument(span)
1105    }
1106
1107    async fn send_initial_client_update(
1108        &self,
1109        connection_id: ConnectionId,
1110        principal: &Principal,
1111        zed_version: ZedVersion,
1112        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1113        session: &Session,
1114    ) -> Result<()> {
1115        self.peer.send(
1116            connection_id,
1117            proto::Hello {
1118                peer_id: Some(connection_id.into()),
1119            },
1120        )?;
1121        tracing::info!("sent hello message");
1122        if let Some(send_connection_id) = send_connection_id.take() {
1123            let _ = send_connection_id.send(connection_id);
1124        }
1125
1126        match principal {
1127            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1128                if !user.connected_once {
1129                    self.peer.send(connection_id, proto::ShowContacts {})?;
1130                    self.app_state
1131                        .db
1132                        .set_user_connected_once(user.id, true)
1133                        .await?;
1134                }
1135
1136                let (contacts, dev_server_projects) = future::try_join(
1137                    self.app_state.db.get_contacts(user.id),
1138                    self.app_state.db.dev_server_projects_update(user.id),
1139                )
1140                .await?;
1141
1142                {
1143                    let mut pool = self.connection_pool.lock();
1144                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1145                    self.peer.send(
1146                        connection_id,
1147                        build_initial_contacts_update(contacts, &pool),
1148                    )?;
1149                }
1150
1151                if should_auto_subscribe_to_channels(zed_version) {
1152                    subscribe_user_to_channels(user.id, session).await?;
1153                }
1154
1155                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1156
1157                if let Some(incoming_call) =
1158                    self.app_state.db.incoming_call_for_user(user.id).await?
1159                {
1160                    self.peer.send(connection_id, incoming_call)?;
1161                }
1162
1163                update_user_contacts(user.id, &session).await?;
1164            }
1165            Principal::DevServer(dev_server) => {
1166                {
1167                    let mut pool = self.connection_pool.lock();
1168                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1169                    {
1170                        self.peer.send(
1171                            stale_connection_id,
1172                            proto::ShutdownDevServer {
1173                                reason: Some(
1174                                    "another dev server connected with the same token".to_string(),
1175                                ),
1176                            },
1177                        )?;
1178                        pool.remove_connection(stale_connection_id)?;
1179                    };
1180                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1181                }
1182
1183                let projects = self
1184                    .app_state
1185                    .db
1186                    .get_projects_for_dev_server(dev_server.id)
1187                    .await?;
1188                self.peer
1189                    .send(connection_id, proto::DevServerInstructions { projects })?;
1190
1191                let status = self
1192                    .app_state
1193                    .db
1194                    .dev_server_projects_update(dev_server.user_id)
1195                    .await?;
1196                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1197            }
1198        }
1199
1200        Ok(())
1201    }
1202
1203    pub async fn invite_code_redeemed(
1204        self: &Arc<Self>,
1205        inviter_id: UserId,
1206        invitee_id: UserId,
1207    ) -> Result<()> {
1208        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1209            if let Some(code) = &user.invite_code {
1210                let pool = self.connection_pool.lock();
1211                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1212                for connection_id in pool.user_connection_ids(inviter_id) {
1213                    self.peer.send(
1214                        connection_id,
1215                        proto::UpdateContacts {
1216                            contacts: vec![invitee_contact.clone()],
1217                            ..Default::default()
1218                        },
1219                    )?;
1220                    self.peer.send(
1221                        connection_id,
1222                        proto::UpdateInviteInfo {
1223                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1224                            count: user.invite_count as u32,
1225                        },
1226                    )?;
1227                }
1228            }
1229        }
1230        Ok(())
1231    }
1232
1233    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1234        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1235            if let Some(invite_code) = &user.invite_code {
1236                let pool = self.connection_pool.lock();
1237                for connection_id in pool.user_connection_ids(user_id) {
1238                    self.peer.send(
1239                        connection_id,
1240                        proto::UpdateInviteInfo {
1241                            url: format!(
1242                                "{}{}",
1243                                self.app_state.config.invite_link_prefix, invite_code
1244                            ),
1245                            count: user.invite_count as u32,
1246                        },
1247                    )?;
1248                }
1249            }
1250        }
1251        Ok(())
1252    }
1253
1254    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1255        ServerSnapshot {
1256            connection_pool: ConnectionPoolGuard {
1257                guard: self.connection_pool.lock(),
1258                _not_send: PhantomData,
1259            },
1260            peer: &self.peer,
1261        }
1262    }
1263}
1264
1265impl<'a> Deref for ConnectionPoolGuard<'a> {
1266    type Target = ConnectionPool;
1267
1268    fn deref(&self) -> &Self::Target {
1269        &self.guard
1270    }
1271}
1272
1273impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1274    fn deref_mut(&mut self) -> &mut Self::Target {
1275        &mut self.guard
1276    }
1277}
1278
1279impl<'a> Drop for ConnectionPoolGuard<'a> {
1280    fn drop(&mut self) {
1281        #[cfg(test)]
1282        self.check_invariants();
1283    }
1284}
1285
1286fn broadcast<F>(
1287    sender_id: Option<ConnectionId>,
1288    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1289    mut f: F,
1290) where
1291    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1292{
1293    for receiver_id in receiver_ids {
1294        if Some(receiver_id) != sender_id {
1295            if let Err(error) = f(receiver_id) {
1296                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1297            }
1298        }
1299    }
1300}
1301
1302pub struct ProtocolVersion(u32);
1303
1304impl Header for ProtocolVersion {
1305    fn name() -> &'static HeaderName {
1306        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1307        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1308    }
1309
1310    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1311    where
1312        Self: Sized,
1313        I: Iterator<Item = &'i axum::http::HeaderValue>,
1314    {
1315        let version = values
1316            .next()
1317            .ok_or_else(axum::headers::Error::invalid)?
1318            .to_str()
1319            .map_err(|_| axum::headers::Error::invalid())?
1320            .parse()
1321            .map_err(|_| axum::headers::Error::invalid())?;
1322        Ok(Self(version))
1323    }
1324
1325    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1326        values.extend([self.0.to_string().parse().unwrap()]);
1327    }
1328}
1329
1330pub struct AppVersionHeader(SemanticVersion);
1331impl Header for AppVersionHeader {
1332    fn name() -> &'static HeaderName {
1333        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1334        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1335    }
1336
1337    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1338    where
1339        Self: Sized,
1340        I: Iterator<Item = &'i axum::http::HeaderValue>,
1341    {
1342        let version = values
1343            .next()
1344            .ok_or_else(axum::headers::Error::invalid)?
1345            .to_str()
1346            .map_err(|_| axum::headers::Error::invalid())?
1347            .parse()
1348            .map_err(|_| axum::headers::Error::invalid())?;
1349        Ok(Self(version))
1350    }
1351
1352    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1353        values.extend([self.0.to_string().parse().unwrap()]);
1354    }
1355}
1356
1357pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1358    Router::new()
1359        .route("/rpc", get(handle_websocket_request))
1360        .layer(
1361            ServiceBuilder::new()
1362                .layer(Extension(server.app_state.clone()))
1363                .layer(middleware::from_fn(auth::validate_header)),
1364        )
1365        .route("/metrics", get(handle_metrics))
1366        .layer(Extension(server))
1367}
1368
1369pub async fn handle_websocket_request(
1370    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1371    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1372    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1373    Extension(server): Extension<Arc<Server>>,
1374    Extension(principal): Extension<Principal>,
1375    ws: WebSocketUpgrade,
1376) -> axum::response::Response {
1377    if protocol_version != rpc::PROTOCOL_VERSION {
1378        return (
1379            StatusCode::UPGRADE_REQUIRED,
1380            "client must be upgraded".to_string(),
1381        )
1382            .into_response();
1383    }
1384
1385    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1386        return (
1387            StatusCode::UPGRADE_REQUIRED,
1388            "no version header found".to_string(),
1389        )
1390            .into_response();
1391    };
1392
1393    if !version.can_collaborate() {
1394        return (
1395            StatusCode::UPGRADE_REQUIRED,
1396            "client must be upgraded".to_string(),
1397        )
1398            .into_response();
1399    }
1400
1401    let socket_address = socket_address.to_string();
1402    ws.on_upgrade(move |socket| {
1403        let socket = socket
1404            .map_ok(to_tungstenite_message)
1405            .err_into()
1406            .with(|message| async move { to_axum_message(message) });
1407        let connection = Connection::new(Box::pin(socket));
1408        async move {
1409            server
1410                .handle_connection(
1411                    connection,
1412                    socket_address,
1413                    principal,
1414                    version,
1415                    None,
1416                    Executor::Production,
1417                )
1418                .await;
1419        }
1420    })
1421}
1422
1423pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1424    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1425    let connections_metric = CONNECTIONS_METRIC
1426        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1427
1428    let connections = server
1429        .connection_pool
1430        .lock()
1431        .connections()
1432        .filter(|connection| !connection.admin)
1433        .count();
1434    connections_metric.set(connections as _);
1435
1436    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1437    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1438        register_int_gauge!(
1439            "shared_projects",
1440            "number of open projects with one or more guests"
1441        )
1442        .unwrap()
1443    });
1444
1445    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1446    shared_projects_metric.set(shared_projects as _);
1447
1448    let encoder = prometheus::TextEncoder::new();
1449    let metric_families = prometheus::gather();
1450    let encoded_metrics = encoder
1451        .encode_to_string(&metric_families)
1452        .map_err(|err| anyhow!("{}", err))?;
1453    Ok(encoded_metrics)
1454}
1455
1456#[instrument(err, skip(executor))]
1457async fn connection_lost(
1458    session: Session,
1459    mut teardown: watch::Receiver<bool>,
1460    executor: Executor,
1461) -> Result<()> {
1462    session.peer.disconnect(session.connection_id);
1463    session
1464        .connection_pool()
1465        .await
1466        .remove_connection(session.connection_id)?;
1467
1468    session
1469        .db()
1470        .await
1471        .connection_lost(session.connection_id)
1472        .await
1473        .trace_err();
1474
1475    futures::select_biased! {
1476        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1477            match &session.principal {
1478                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1479                    let session = session.for_user().unwrap();
1480
1481                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1482                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1483                    leave_channel_buffers_for_session(&session)
1484                        .await
1485                        .trace_err();
1486
1487                    if !session
1488                        .connection_pool()
1489                        .await
1490                        .is_user_online(session.user_id())
1491                    {
1492                        let db = session.db().await;
1493                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1494                            room_updated(&room, &session.peer);
1495                        }
1496                    }
1497
1498                    update_user_contacts(session.user_id(), &session).await?;
1499                },
1500            Principal::DevServer(_) => {
1501                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1502            },
1503        }
1504        },
1505        _ = teardown.changed().fuse() => {}
1506    }
1507
1508    Ok(())
1509}
1510
1511/// Acknowledges a ping from a client, used to keep the connection alive.
1512async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1513    response.send(proto::Ack {})?;
1514    Ok(())
1515}
1516
1517/// Creates a new room for calling (outside of channels)
1518async fn create_room(
1519    _request: proto::CreateRoom,
1520    response: Response<proto::CreateRoom>,
1521    session: UserSession,
1522) -> Result<()> {
1523    let live_kit_room = nanoid::nanoid!(30);
1524
1525    let live_kit_connection_info = util::maybe!(async {
1526        let live_kit = session.live_kit_client.as_ref();
1527        let live_kit = live_kit?;
1528        let user_id = session.user_id().to_string();
1529
1530        let token = live_kit
1531            .room_token(&live_kit_room, &user_id.to_string())
1532            .trace_err()?;
1533
1534        Some(proto::LiveKitConnectionInfo {
1535            server_url: live_kit.url().into(),
1536            token,
1537            can_publish: true,
1538        })
1539    })
1540    .await;
1541
1542    let room = session
1543        .db()
1544        .await
1545        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1546        .await?;
1547
1548    response.send(proto::CreateRoomResponse {
1549        room: Some(room.clone()),
1550        live_kit_connection_info,
1551    })?;
1552
1553    update_user_contacts(session.user_id(), &session).await?;
1554    Ok(())
1555}
1556
1557/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1558async fn join_room(
1559    request: proto::JoinRoom,
1560    response: Response<proto::JoinRoom>,
1561    session: UserSession,
1562) -> Result<()> {
1563    let room_id = RoomId::from_proto(request.id);
1564
1565    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1566
1567    if let Some(channel_id) = channel_id {
1568        return join_channel_internal(channel_id, Box::new(response), session).await;
1569    }
1570
1571    let joined_room = {
1572        let room = session
1573            .db()
1574            .await
1575            .join_room(room_id, session.user_id(), session.connection_id)
1576            .await?;
1577        room_updated(&room.room, &session.peer);
1578        room.into_inner()
1579    };
1580
1581    for connection_id in session
1582        .connection_pool()
1583        .await
1584        .user_connection_ids(session.user_id())
1585    {
1586        session
1587            .peer
1588            .send(
1589                connection_id,
1590                proto::CallCanceled {
1591                    room_id: room_id.to_proto(),
1592                },
1593            )
1594            .trace_err();
1595    }
1596
1597    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1598        if let Some(token) = live_kit
1599            .room_token(
1600                &joined_room.room.live_kit_room,
1601                &session.user_id().to_string(),
1602            )
1603            .trace_err()
1604        {
1605            Some(proto::LiveKitConnectionInfo {
1606                server_url: live_kit.url().into(),
1607                token,
1608                can_publish: true,
1609            })
1610        } else {
1611            None
1612        }
1613    } else {
1614        None
1615    };
1616
1617    response.send(proto::JoinRoomResponse {
1618        room: Some(joined_room.room),
1619        channel_id: None,
1620        live_kit_connection_info,
1621    })?;
1622
1623    update_user_contacts(session.user_id(), &session).await?;
1624    Ok(())
1625}
1626
1627/// Rejoin room is used to reconnect to a room after connection errors.
1628async fn rejoin_room(
1629    request: proto::RejoinRoom,
1630    response: Response<proto::RejoinRoom>,
1631    session: UserSession,
1632) -> Result<()> {
1633    let room;
1634    let channel;
1635    {
1636        let mut rejoined_room = session
1637            .db()
1638            .await
1639            .rejoin_room(request, session.user_id(), session.connection_id)
1640            .await?;
1641
1642        response.send(proto::RejoinRoomResponse {
1643            room: Some(rejoined_room.room.clone()),
1644            reshared_projects: rejoined_room
1645                .reshared_projects
1646                .iter()
1647                .map(|project| proto::ResharedProject {
1648                    id: project.id.to_proto(),
1649                    collaborators: project
1650                        .collaborators
1651                        .iter()
1652                        .map(|collaborator| collaborator.to_proto())
1653                        .collect(),
1654                })
1655                .collect(),
1656            rejoined_projects: rejoined_room
1657                .rejoined_projects
1658                .iter()
1659                .map(|rejoined_project| rejoined_project.to_proto())
1660                .collect(),
1661        })?;
1662        room_updated(&rejoined_room.room, &session.peer);
1663
1664        for project in &rejoined_room.reshared_projects {
1665            for collaborator in &project.collaborators {
1666                session
1667                    .peer
1668                    .send(
1669                        collaborator.connection_id,
1670                        proto::UpdateProjectCollaborator {
1671                            project_id: project.id.to_proto(),
1672                            old_peer_id: Some(project.old_connection_id.into()),
1673                            new_peer_id: Some(session.connection_id.into()),
1674                        },
1675                    )
1676                    .trace_err();
1677            }
1678
1679            broadcast(
1680                Some(session.connection_id),
1681                project
1682                    .collaborators
1683                    .iter()
1684                    .map(|collaborator| collaborator.connection_id),
1685                |connection_id| {
1686                    session.peer.forward_send(
1687                        session.connection_id,
1688                        connection_id,
1689                        proto::UpdateProject {
1690                            project_id: project.id.to_proto(),
1691                            worktrees: project.worktrees.clone(),
1692                        },
1693                    )
1694                },
1695            );
1696        }
1697
1698        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1699
1700        let rejoined_room = rejoined_room.into_inner();
1701
1702        room = rejoined_room.room;
1703        channel = rejoined_room.channel;
1704    }
1705
1706    if let Some(channel) = channel {
1707        channel_updated(
1708            &channel,
1709            &room,
1710            &session.peer,
1711            &*session.connection_pool().await,
1712        );
1713    }
1714
1715    update_user_contacts(session.user_id(), &session).await?;
1716    Ok(())
1717}
1718
1719fn notify_rejoined_projects(
1720    rejoined_projects: &mut Vec<RejoinedProject>,
1721    session: &UserSession,
1722) -> Result<()> {
1723    for project in rejoined_projects.iter() {
1724        for collaborator in &project.collaborators {
1725            session
1726                .peer
1727                .send(
1728                    collaborator.connection_id,
1729                    proto::UpdateProjectCollaborator {
1730                        project_id: project.id.to_proto(),
1731                        old_peer_id: Some(project.old_connection_id.into()),
1732                        new_peer_id: Some(session.connection_id.into()),
1733                    },
1734                )
1735                .trace_err();
1736        }
1737    }
1738
1739    for project in rejoined_projects {
1740        for worktree in mem::take(&mut project.worktrees) {
1741            #[cfg(any(test, feature = "test-support"))]
1742            const MAX_CHUNK_SIZE: usize = 2;
1743            #[cfg(not(any(test, feature = "test-support")))]
1744            const MAX_CHUNK_SIZE: usize = 256;
1745
1746            // Stream this worktree's entries.
1747            let message = proto::UpdateWorktree {
1748                project_id: project.id.to_proto(),
1749                worktree_id: worktree.id,
1750                abs_path: worktree.abs_path.clone(),
1751                root_name: worktree.root_name,
1752                updated_entries: worktree.updated_entries,
1753                removed_entries: worktree.removed_entries,
1754                scan_id: worktree.scan_id,
1755                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1756                updated_repositories: worktree.updated_repositories,
1757                removed_repositories: worktree.removed_repositories,
1758            };
1759            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1760                session.peer.send(session.connection_id, update.clone())?;
1761            }
1762
1763            // Stream this worktree's diagnostics.
1764            for summary in worktree.diagnostic_summaries {
1765                session.peer.send(
1766                    session.connection_id,
1767                    proto::UpdateDiagnosticSummary {
1768                        project_id: project.id.to_proto(),
1769                        worktree_id: worktree.id,
1770                        summary: Some(summary),
1771                    },
1772                )?;
1773            }
1774
1775            for settings_file in worktree.settings_files {
1776                session.peer.send(
1777                    session.connection_id,
1778                    proto::UpdateWorktreeSettings {
1779                        project_id: project.id.to_proto(),
1780                        worktree_id: worktree.id,
1781                        path: settings_file.path,
1782                        content: Some(settings_file.content),
1783                    },
1784                )?;
1785            }
1786        }
1787
1788        for language_server in &project.language_servers {
1789            session.peer.send(
1790                session.connection_id,
1791                proto::UpdateLanguageServer {
1792                    project_id: project.id.to_proto(),
1793                    language_server_id: language_server.id,
1794                    variant: Some(
1795                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1796                            proto::LspDiskBasedDiagnosticsUpdated {},
1797                        ),
1798                    ),
1799                },
1800            )?;
1801        }
1802    }
1803    Ok(())
1804}
1805
1806/// leave room disconnects from the room.
1807async fn leave_room(
1808    _: proto::LeaveRoom,
1809    response: Response<proto::LeaveRoom>,
1810    session: UserSession,
1811) -> Result<()> {
1812    leave_room_for_session(&session, session.connection_id).await?;
1813    response.send(proto::Ack {})?;
1814    Ok(())
1815}
1816
1817/// Updates the permissions of someone else in the room.
1818async fn set_room_participant_role(
1819    request: proto::SetRoomParticipantRole,
1820    response: Response<proto::SetRoomParticipantRole>,
1821    session: UserSession,
1822) -> Result<()> {
1823    let user_id = UserId::from_proto(request.user_id);
1824    let role = ChannelRole::from(request.role());
1825
1826    let (live_kit_room, can_publish) = {
1827        let room = session
1828            .db()
1829            .await
1830            .set_room_participant_role(
1831                session.user_id(),
1832                RoomId::from_proto(request.room_id),
1833                user_id,
1834                role,
1835            )
1836            .await?;
1837
1838        let live_kit_room = room.live_kit_room.clone();
1839        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1840        room_updated(&room, &session.peer);
1841        (live_kit_room, can_publish)
1842    };
1843
1844    if let Some(live_kit) = session.live_kit_client.as_ref() {
1845        live_kit
1846            .update_participant(
1847                live_kit_room.clone(),
1848                request.user_id.to_string(),
1849                live_kit_server::proto::ParticipantPermission {
1850                    can_subscribe: true,
1851                    can_publish,
1852                    can_publish_data: can_publish,
1853                    hidden: false,
1854                    recorder: false,
1855                },
1856            )
1857            .await
1858            .trace_err();
1859    }
1860
1861    response.send(proto::Ack {})?;
1862    Ok(())
1863}
1864
1865/// Call someone else into the current room
1866async fn call(
1867    request: proto::Call,
1868    response: Response<proto::Call>,
1869    session: UserSession,
1870) -> Result<()> {
1871    let room_id = RoomId::from_proto(request.room_id);
1872    let calling_user_id = session.user_id();
1873    let calling_connection_id = session.connection_id;
1874    let called_user_id = UserId::from_proto(request.called_user_id);
1875    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1876    if !session
1877        .db()
1878        .await
1879        .has_contact(calling_user_id, called_user_id)
1880        .await?
1881    {
1882        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1883    }
1884
1885    let incoming_call = {
1886        let (room, incoming_call) = &mut *session
1887            .db()
1888            .await
1889            .call(
1890                room_id,
1891                calling_user_id,
1892                calling_connection_id,
1893                called_user_id,
1894                initial_project_id,
1895            )
1896            .await?;
1897        room_updated(&room, &session.peer);
1898        mem::take(incoming_call)
1899    };
1900    update_user_contacts(called_user_id, &session).await?;
1901
1902    let mut calls = session
1903        .connection_pool()
1904        .await
1905        .user_connection_ids(called_user_id)
1906        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1907        .collect::<FuturesUnordered<_>>();
1908
1909    while let Some(call_response) = calls.next().await {
1910        match call_response.as_ref() {
1911            Ok(_) => {
1912                response.send(proto::Ack {})?;
1913                return Ok(());
1914            }
1915            Err(_) => {
1916                call_response.trace_err();
1917            }
1918        }
1919    }
1920
1921    {
1922        let room = session
1923            .db()
1924            .await
1925            .call_failed(room_id, called_user_id)
1926            .await?;
1927        room_updated(&room, &session.peer);
1928    }
1929    update_user_contacts(called_user_id, &session).await?;
1930
1931    Err(anyhow!("failed to ring user"))?
1932}
1933
1934/// Cancel an outgoing call.
1935async fn cancel_call(
1936    request: proto::CancelCall,
1937    response: Response<proto::CancelCall>,
1938    session: UserSession,
1939) -> Result<()> {
1940    let called_user_id = UserId::from_proto(request.called_user_id);
1941    let room_id = RoomId::from_proto(request.room_id);
1942    {
1943        let room = session
1944            .db()
1945            .await
1946            .cancel_call(room_id, session.connection_id, called_user_id)
1947            .await?;
1948        room_updated(&room, &session.peer);
1949    }
1950
1951    for connection_id in session
1952        .connection_pool()
1953        .await
1954        .user_connection_ids(called_user_id)
1955    {
1956        session
1957            .peer
1958            .send(
1959                connection_id,
1960                proto::CallCanceled {
1961                    room_id: room_id.to_proto(),
1962                },
1963            )
1964            .trace_err();
1965    }
1966    response.send(proto::Ack {})?;
1967
1968    update_user_contacts(called_user_id, &session).await?;
1969    Ok(())
1970}
1971
1972/// Decline an incoming call.
1973async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1974    let room_id = RoomId::from_proto(message.room_id);
1975    {
1976        let room = session
1977            .db()
1978            .await
1979            .decline_call(Some(room_id), session.user_id())
1980            .await?
1981            .ok_or_else(|| anyhow!("failed to decline call"))?;
1982        room_updated(&room, &session.peer);
1983    }
1984
1985    for connection_id in session
1986        .connection_pool()
1987        .await
1988        .user_connection_ids(session.user_id())
1989    {
1990        session
1991            .peer
1992            .send(
1993                connection_id,
1994                proto::CallCanceled {
1995                    room_id: room_id.to_proto(),
1996                },
1997            )
1998            .trace_err();
1999    }
2000    update_user_contacts(session.user_id(), &session).await?;
2001    Ok(())
2002}
2003
2004/// Updates other participants in the room with your current location.
2005async fn update_participant_location(
2006    request: proto::UpdateParticipantLocation,
2007    response: Response<proto::UpdateParticipantLocation>,
2008    session: UserSession,
2009) -> Result<()> {
2010    let room_id = RoomId::from_proto(request.room_id);
2011    let location = request
2012        .location
2013        .ok_or_else(|| anyhow!("invalid location"))?;
2014
2015    let db = session.db().await;
2016    let room = db
2017        .update_room_participant_location(room_id, session.connection_id, location)
2018        .await?;
2019
2020    room_updated(&room, &session.peer);
2021    response.send(proto::Ack {})?;
2022    Ok(())
2023}
2024
2025/// Share a project into the room.
2026async fn share_project(
2027    request: proto::ShareProject,
2028    response: Response<proto::ShareProject>,
2029    session: UserSession,
2030) -> Result<()> {
2031    let (project_id, room) = &*session
2032        .db()
2033        .await
2034        .share_project(
2035            RoomId::from_proto(request.room_id),
2036            session.connection_id,
2037            &request.worktrees,
2038            request
2039                .dev_server_project_id
2040                .map(|id| DevServerProjectId::from_proto(id)),
2041        )
2042        .await?;
2043    response.send(proto::ShareProjectResponse {
2044        project_id: project_id.to_proto(),
2045    })?;
2046    room_updated(&room, &session.peer);
2047
2048    Ok(())
2049}
2050
2051/// Unshare a project from the room.
2052async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2053    let project_id = ProjectId::from_proto(message.project_id);
2054    unshare_project_internal(
2055        project_id,
2056        session.connection_id,
2057        session.user_id(),
2058        &session,
2059    )
2060    .await
2061}
2062
2063async fn unshare_project_internal(
2064    project_id: ProjectId,
2065    connection_id: ConnectionId,
2066    user_id: Option<UserId>,
2067    session: &Session,
2068) -> Result<()> {
2069    let delete = {
2070        let room_guard = session
2071            .db()
2072            .await
2073            .unshare_project(project_id, connection_id, user_id)
2074            .await?;
2075
2076        let (delete, room, guest_connection_ids) = &*room_guard;
2077
2078        let message = proto::UnshareProject {
2079            project_id: project_id.to_proto(),
2080        };
2081
2082        broadcast(
2083            Some(connection_id),
2084            guest_connection_ids.iter().copied(),
2085            |conn_id| session.peer.send(conn_id, message.clone()),
2086        );
2087        if let Some(room) = room {
2088            room_updated(room, &session.peer);
2089        }
2090
2091        *delete
2092    };
2093
2094    if delete {
2095        let db = session.db().await;
2096        db.delete_project(project_id).await?;
2097    }
2098
2099    Ok(())
2100}
2101
2102/// DevServer makes a project available online
2103async fn share_dev_server_project(
2104    request: proto::ShareDevServerProject,
2105    response: Response<proto::ShareDevServerProject>,
2106    session: DevServerSession,
2107) -> Result<()> {
2108    let (dev_server_project, user_id, status) = session
2109        .db()
2110        .await
2111        .share_dev_server_project(
2112            DevServerProjectId::from_proto(request.dev_server_project_id),
2113            session.dev_server_id(),
2114            session.connection_id,
2115            &request.worktrees,
2116        )
2117        .await?;
2118    let Some(project_id) = dev_server_project.project_id else {
2119        return Err(anyhow!("failed to share remote project"))?;
2120    };
2121
2122    send_dev_server_projects_update(user_id, status, &session).await;
2123
2124    response.send(proto::ShareProjectResponse { project_id })?;
2125
2126    Ok(())
2127}
2128
2129/// Join someone elses shared project.
2130async fn join_project(
2131    request: proto::JoinProject,
2132    response: Response<proto::JoinProject>,
2133    session: UserSession,
2134) -> Result<()> {
2135    let project_id = ProjectId::from_proto(request.project_id);
2136
2137    tracing::info!(%project_id, "join project");
2138
2139    let db = session.db().await;
2140    let (project, replica_id) = &mut *db
2141        .join_project(project_id, session.connection_id, session.user_id())
2142        .await?;
2143    drop(db);
2144    tracing::info!(%project_id, "join remote project");
2145    join_project_internal(response, session, project, replica_id)
2146}
2147
2148trait JoinProjectInternalResponse {
2149    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2150}
2151impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2152    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2153        Response::<proto::JoinProject>::send(self, result)
2154    }
2155}
2156impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2157    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2158        Response::<proto::JoinHostedProject>::send(self, result)
2159    }
2160}
2161
2162fn join_project_internal(
2163    response: impl JoinProjectInternalResponse,
2164    session: UserSession,
2165    project: &mut Project,
2166    replica_id: &ReplicaId,
2167) -> Result<()> {
2168    let collaborators = project
2169        .collaborators
2170        .iter()
2171        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2172        .map(|collaborator| collaborator.to_proto())
2173        .collect::<Vec<_>>();
2174    let project_id = project.id;
2175    let guest_user_id = session.user_id();
2176
2177    let worktrees = project
2178        .worktrees
2179        .iter()
2180        .map(|(id, worktree)| proto::WorktreeMetadata {
2181            id: *id,
2182            root_name: worktree.root_name.clone(),
2183            visible: worktree.visible,
2184            abs_path: worktree.abs_path.clone(),
2185        })
2186        .collect::<Vec<_>>();
2187
2188    let add_project_collaborator = proto::AddProjectCollaborator {
2189        project_id: project_id.to_proto(),
2190        collaborator: Some(proto::Collaborator {
2191            peer_id: Some(session.connection_id.into()),
2192            replica_id: replica_id.0 as u32,
2193            user_id: guest_user_id.to_proto(),
2194        }),
2195    };
2196
2197    for collaborator in &collaborators {
2198        session
2199            .peer
2200            .send(
2201                collaborator.peer_id.unwrap().into(),
2202                add_project_collaborator.clone(),
2203            )
2204            .trace_err();
2205    }
2206
2207    // First, we send the metadata associated with each worktree.
2208    response.send(proto::JoinProjectResponse {
2209        project_id: project.id.0 as u64,
2210        worktrees: worktrees.clone(),
2211        replica_id: replica_id.0 as u32,
2212        collaborators: collaborators.clone(),
2213        language_servers: project.language_servers.clone(),
2214        role: project.role.into(),
2215        dev_server_project_id: project
2216            .dev_server_project_id
2217            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2218    })?;
2219
2220    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2221        #[cfg(any(test, feature = "test-support"))]
2222        const MAX_CHUNK_SIZE: usize = 2;
2223        #[cfg(not(any(test, feature = "test-support")))]
2224        const MAX_CHUNK_SIZE: usize = 256;
2225
2226        // Stream this worktree's entries.
2227        let message = proto::UpdateWorktree {
2228            project_id: project_id.to_proto(),
2229            worktree_id,
2230            abs_path: worktree.abs_path.clone(),
2231            root_name: worktree.root_name,
2232            updated_entries: worktree.entries,
2233            removed_entries: Default::default(),
2234            scan_id: worktree.scan_id,
2235            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2236            updated_repositories: worktree.repository_entries.into_values().collect(),
2237            removed_repositories: Default::default(),
2238        };
2239        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2240            session.peer.send(session.connection_id, update.clone())?;
2241        }
2242
2243        // Stream this worktree's diagnostics.
2244        for summary in worktree.diagnostic_summaries {
2245            session.peer.send(
2246                session.connection_id,
2247                proto::UpdateDiagnosticSummary {
2248                    project_id: project_id.to_proto(),
2249                    worktree_id: worktree.id,
2250                    summary: Some(summary),
2251                },
2252            )?;
2253        }
2254
2255        for settings_file in worktree.settings_files {
2256            session.peer.send(
2257                session.connection_id,
2258                proto::UpdateWorktreeSettings {
2259                    project_id: project_id.to_proto(),
2260                    worktree_id: worktree.id,
2261                    path: settings_file.path,
2262                    content: Some(settings_file.content),
2263                },
2264            )?;
2265        }
2266    }
2267
2268    for language_server in &project.language_servers {
2269        session.peer.send(
2270            session.connection_id,
2271            proto::UpdateLanguageServer {
2272                project_id: project_id.to_proto(),
2273                language_server_id: language_server.id,
2274                variant: Some(
2275                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2276                        proto::LspDiskBasedDiagnosticsUpdated {},
2277                    ),
2278                ),
2279            },
2280        )?;
2281    }
2282
2283    Ok(())
2284}
2285
2286/// Leave someone elses shared project.
2287async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2288    let sender_id = session.connection_id;
2289    let project_id = ProjectId::from_proto(request.project_id);
2290    let db = session.db().await;
2291    if db.is_hosted_project(project_id).await? {
2292        let project = db.leave_hosted_project(project_id, sender_id).await?;
2293        project_left(&project, &session);
2294        return Ok(());
2295    }
2296
2297    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2298    tracing::info!(
2299        %project_id,
2300        "leave project"
2301    );
2302
2303    project_left(&project, &session);
2304    if let Some(room) = room {
2305        room_updated(&room, &session.peer);
2306    }
2307
2308    Ok(())
2309}
2310
2311async fn join_hosted_project(
2312    request: proto::JoinHostedProject,
2313    response: Response<proto::JoinHostedProject>,
2314    session: UserSession,
2315) -> Result<()> {
2316    let (mut project, replica_id) = session
2317        .db()
2318        .await
2319        .join_hosted_project(
2320            ProjectId(request.project_id as i32),
2321            session.user_id(),
2322            session.connection_id,
2323        )
2324        .await?;
2325
2326    join_project_internal(response, session, &mut project, &replica_id)
2327}
2328
2329async fn list_remote_directory(
2330    request: proto::ListRemoteDirectory,
2331    response: Response<proto::ListRemoteDirectory>,
2332    session: UserSession,
2333) -> Result<()> {
2334    let dev_server_id = DevServerId(request.dev_server_id as i32);
2335    let dev_server_connection_id = session
2336        .connection_pool()
2337        .await
2338        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2339
2340    session
2341        .db()
2342        .await
2343        .get_dev_server_for_user(dev_server_id, session.user_id())
2344        .await?;
2345
2346    response.send(
2347        session
2348            .peer
2349            .forward_request(session.connection_id, dev_server_connection_id, request)
2350            .await?,
2351    )?;
2352    Ok(())
2353}
2354
2355async fn update_dev_server_project(
2356    request: proto::UpdateDevServerProject,
2357    response: Response<proto::UpdateDevServerProject>,
2358    session: UserSession,
2359) -> Result<()> {
2360    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2361
2362    let (dev_server_project, update) = session
2363        .db()
2364        .await
2365        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2366        .await?;
2367
2368    let projects = session
2369        .db()
2370        .await
2371        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2372        .await?;
2373
2374    let dev_server_connection_id = session
2375        .connection_pool()
2376        .await
2377        .dev_server_connection_id_supporting(
2378            dev_server_project.dev_server_id,
2379            ZedVersion::with_list_directory(),
2380        )?;
2381
2382    session.peer.send(
2383        dev_server_connection_id,
2384        proto::DevServerInstructions { projects },
2385    )?;
2386
2387    send_dev_server_projects_update(session.user_id(), update, &session).await;
2388
2389    response.send(proto::Ack {})
2390}
2391
2392async fn create_dev_server_project(
2393    request: proto::CreateDevServerProject,
2394    response: Response<proto::CreateDevServerProject>,
2395    session: UserSession,
2396) -> Result<()> {
2397    let dev_server_id = DevServerId(request.dev_server_id as i32);
2398    let dev_server_connection_id = session
2399        .connection_pool()
2400        .await
2401        .dev_server_connection_id(dev_server_id);
2402    let Some(dev_server_connection_id) = dev_server_connection_id else {
2403        Err(ErrorCode::DevServerOffline
2404            .message("Cannot create a remote project when the dev server is offline".to_string())
2405            .anyhow())?
2406    };
2407
2408    let path = request.path.clone();
2409    //Check that the path exists on the dev server
2410    session
2411        .peer
2412        .forward_request(
2413            session.connection_id,
2414            dev_server_connection_id,
2415            proto::ValidateDevServerProjectRequest { path: path.clone() },
2416        )
2417        .await?;
2418
2419    let (dev_server_project, update) = session
2420        .db()
2421        .await
2422        .create_dev_server_project(
2423            DevServerId(request.dev_server_id as i32),
2424            &request.path,
2425            session.user_id(),
2426        )
2427        .await?;
2428
2429    let projects = session
2430        .db()
2431        .await
2432        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2433        .await?;
2434
2435    session.peer.send(
2436        dev_server_connection_id,
2437        proto::DevServerInstructions { projects },
2438    )?;
2439
2440    send_dev_server_projects_update(session.user_id(), update, &session).await;
2441
2442    response.send(proto::CreateDevServerProjectResponse {
2443        dev_server_project: Some(dev_server_project.to_proto(None)),
2444    })?;
2445    Ok(())
2446}
2447
2448async fn create_dev_server(
2449    request: proto::CreateDevServer,
2450    response: Response<proto::CreateDevServer>,
2451    session: UserSession,
2452) -> Result<()> {
2453    let access_token = auth::random_token();
2454    let hashed_access_token = auth::hash_access_token(&access_token);
2455
2456    if request.name.is_empty() {
2457        return Err(proto::ErrorCode::Forbidden
2458            .message("Dev server name cannot be empty".to_string())
2459            .anyhow())?;
2460    }
2461
2462    let (dev_server, status) = session
2463        .db()
2464        .await
2465        .create_dev_server(
2466            &request.name,
2467            request.ssh_connection_string.as_deref(),
2468            &hashed_access_token,
2469            session.user_id(),
2470        )
2471        .await?;
2472
2473    send_dev_server_projects_update(session.user_id(), status, &session).await;
2474
2475    response.send(proto::CreateDevServerResponse {
2476        dev_server_id: dev_server.id.0 as u64,
2477        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2478        name: request.name,
2479    })?;
2480    Ok(())
2481}
2482
2483async fn regenerate_dev_server_token(
2484    request: proto::RegenerateDevServerToken,
2485    response: Response<proto::RegenerateDevServerToken>,
2486    session: UserSession,
2487) -> Result<()> {
2488    let dev_server_id = DevServerId(request.dev_server_id as i32);
2489    let access_token = auth::random_token();
2490    let hashed_access_token = auth::hash_access_token(&access_token);
2491
2492    let connection_id = session
2493        .connection_pool()
2494        .await
2495        .dev_server_connection_id(dev_server_id);
2496    if let Some(connection_id) = connection_id {
2497        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2498        session.peer.send(
2499            connection_id,
2500            proto::ShutdownDevServer {
2501                reason: Some("dev server token was regenerated".to_string()),
2502            },
2503        )?;
2504        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2505    }
2506
2507    let status = session
2508        .db()
2509        .await
2510        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2511        .await?;
2512
2513    send_dev_server_projects_update(session.user_id(), status, &session).await;
2514
2515    response.send(proto::RegenerateDevServerTokenResponse {
2516        dev_server_id: dev_server_id.to_proto(),
2517        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2518    })?;
2519    Ok(())
2520}
2521
2522async fn rename_dev_server(
2523    request: proto::RenameDevServer,
2524    response: Response<proto::RenameDevServer>,
2525    session: UserSession,
2526) -> Result<()> {
2527    if request.name.trim().is_empty() {
2528        return Err(proto::ErrorCode::Forbidden
2529            .message("Dev server name cannot be empty".to_string())
2530            .anyhow())?;
2531    }
2532
2533    let dev_server_id = DevServerId(request.dev_server_id as i32);
2534    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2535    if dev_server.user_id != session.user_id() {
2536        return Err(anyhow!(ErrorCode::Forbidden))?;
2537    }
2538
2539    let status = session
2540        .db()
2541        .await
2542        .rename_dev_server(
2543            dev_server_id,
2544            &request.name,
2545            request.ssh_connection_string.as_deref(),
2546            session.user_id(),
2547        )
2548        .await?;
2549
2550    send_dev_server_projects_update(session.user_id(), status, &session).await;
2551
2552    response.send(proto::Ack {})?;
2553    Ok(())
2554}
2555
2556async fn delete_dev_server(
2557    request: proto::DeleteDevServer,
2558    response: Response<proto::DeleteDevServer>,
2559    session: UserSession,
2560) -> Result<()> {
2561    let dev_server_id = DevServerId(request.dev_server_id as i32);
2562    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2563    if dev_server.user_id != session.user_id() {
2564        return Err(anyhow!(ErrorCode::Forbidden))?;
2565    }
2566
2567    let connection_id = session
2568        .connection_pool()
2569        .await
2570        .dev_server_connection_id(dev_server_id);
2571    if let Some(connection_id) = connection_id {
2572        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2573        session.peer.send(
2574            connection_id,
2575            proto::ShutdownDevServer {
2576                reason: Some("dev server was deleted".to_string()),
2577            },
2578        )?;
2579        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2580    }
2581
2582    let status = session
2583        .db()
2584        .await
2585        .delete_dev_server(dev_server_id, session.user_id())
2586        .await?;
2587
2588    send_dev_server_projects_update(session.user_id(), status, &session).await;
2589
2590    response.send(proto::Ack {})?;
2591    Ok(())
2592}
2593
2594async fn delete_dev_server_project(
2595    request: proto::DeleteDevServerProject,
2596    response: Response<proto::DeleteDevServerProject>,
2597    session: UserSession,
2598) -> Result<()> {
2599    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2600    let dev_server_project = session
2601        .db()
2602        .await
2603        .get_dev_server_project(dev_server_project_id)
2604        .await?;
2605
2606    let dev_server = session
2607        .db()
2608        .await
2609        .get_dev_server(dev_server_project.dev_server_id)
2610        .await?;
2611    if dev_server.user_id != session.user_id() {
2612        return Err(anyhow!(ErrorCode::Forbidden))?;
2613    }
2614
2615    let dev_server_connection_id = session
2616        .connection_pool()
2617        .await
2618        .dev_server_connection_id(dev_server.id);
2619
2620    if let Some(dev_server_connection_id) = dev_server_connection_id {
2621        let project = session
2622            .db()
2623            .await
2624            .find_dev_server_project(dev_server_project_id)
2625            .await;
2626        if let Ok(project) = project {
2627            unshare_project_internal(
2628                project.id,
2629                dev_server_connection_id,
2630                Some(session.user_id()),
2631                &session,
2632            )
2633            .await?;
2634        }
2635    }
2636
2637    let (projects, status) = session
2638        .db()
2639        .await
2640        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2641        .await?;
2642
2643    if let Some(dev_server_connection_id) = dev_server_connection_id {
2644        session.peer.send(
2645            dev_server_connection_id,
2646            proto::DevServerInstructions { projects },
2647        )?;
2648    }
2649
2650    send_dev_server_projects_update(session.user_id(), status, &session).await;
2651
2652    response.send(proto::Ack {})?;
2653    Ok(())
2654}
2655
2656async fn rejoin_dev_server_projects(
2657    request: proto::RejoinRemoteProjects,
2658    response: Response<proto::RejoinRemoteProjects>,
2659    session: UserSession,
2660) -> Result<()> {
2661    let mut rejoined_projects = {
2662        let db = session.db().await;
2663        db.rejoin_dev_server_projects(
2664            &request.rejoined_projects,
2665            session.user_id(),
2666            session.0.connection_id,
2667        )
2668        .await?
2669    };
2670    response.send(proto::RejoinRemoteProjectsResponse {
2671        rejoined_projects: rejoined_projects
2672            .iter()
2673            .map(|project| project.to_proto())
2674            .collect(),
2675    })?;
2676    notify_rejoined_projects(&mut rejoined_projects, &session)
2677}
2678
2679async fn reconnect_dev_server(
2680    request: proto::ReconnectDevServer,
2681    response: Response<proto::ReconnectDevServer>,
2682    session: DevServerSession,
2683) -> Result<()> {
2684    let reshared_projects = {
2685        let db = session.db().await;
2686        db.reshare_dev_server_projects(
2687            &request.reshared_projects,
2688            session.dev_server_id(),
2689            session.0.connection_id,
2690        )
2691        .await?
2692    };
2693
2694    for project in &reshared_projects {
2695        for collaborator in &project.collaborators {
2696            session
2697                .peer
2698                .send(
2699                    collaborator.connection_id,
2700                    proto::UpdateProjectCollaborator {
2701                        project_id: project.id.to_proto(),
2702                        old_peer_id: Some(project.old_connection_id.into()),
2703                        new_peer_id: Some(session.connection_id.into()),
2704                    },
2705                )
2706                .trace_err();
2707        }
2708
2709        broadcast(
2710            Some(session.connection_id),
2711            project
2712                .collaborators
2713                .iter()
2714                .map(|collaborator| collaborator.connection_id),
2715            |connection_id| {
2716                session.peer.forward_send(
2717                    session.connection_id,
2718                    connection_id,
2719                    proto::UpdateProject {
2720                        project_id: project.id.to_proto(),
2721                        worktrees: project.worktrees.clone(),
2722                    },
2723                )
2724            },
2725        );
2726    }
2727
2728    response.send(proto::ReconnectDevServerResponse {
2729        reshared_projects: reshared_projects
2730            .iter()
2731            .map(|project| proto::ResharedProject {
2732                id: project.id.to_proto(),
2733                collaborators: project
2734                    .collaborators
2735                    .iter()
2736                    .map(|collaborator| collaborator.to_proto())
2737                    .collect(),
2738            })
2739            .collect(),
2740    })?;
2741
2742    Ok(())
2743}
2744
2745async fn shutdown_dev_server(
2746    _: proto::ShutdownDevServer,
2747    response: Response<proto::ShutdownDevServer>,
2748    session: DevServerSession,
2749) -> Result<()> {
2750    response.send(proto::Ack {})?;
2751    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2752    remove_dev_server_connection(session.dev_server_id(), &session).await
2753}
2754
2755async fn shutdown_dev_server_internal(
2756    dev_server_id: DevServerId,
2757    connection_id: ConnectionId,
2758    session: &Session,
2759) -> Result<()> {
2760    let (dev_server_projects, dev_server) = {
2761        let db = session.db().await;
2762        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2763        let dev_server = db.get_dev_server(dev_server_id).await?;
2764        (dev_server_projects, dev_server)
2765    };
2766
2767    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2768        unshare_project_internal(
2769            ProjectId::from_proto(project_id),
2770            connection_id,
2771            None,
2772            session,
2773        )
2774        .await?;
2775    }
2776
2777    session
2778        .connection_pool()
2779        .await
2780        .set_dev_server_offline(dev_server_id);
2781
2782    let status = session
2783        .db()
2784        .await
2785        .dev_server_projects_update(dev_server.user_id)
2786        .await?;
2787    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2788
2789    Ok(())
2790}
2791
2792async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2793    let dev_server_connection = session
2794        .connection_pool()
2795        .await
2796        .dev_server_connection_id(dev_server_id);
2797
2798    if let Some(dev_server_connection) = dev_server_connection {
2799        session
2800            .connection_pool()
2801            .await
2802            .remove_connection(dev_server_connection)?;
2803    }
2804    Ok(())
2805}
2806
2807/// Updates other participants with changes to the project
2808async fn update_project(
2809    request: proto::UpdateProject,
2810    response: Response<proto::UpdateProject>,
2811    session: Session,
2812) -> Result<()> {
2813    let project_id = ProjectId::from_proto(request.project_id);
2814    let (room, guest_connection_ids) = &*session
2815        .db()
2816        .await
2817        .update_project(project_id, session.connection_id, &request.worktrees)
2818        .await?;
2819    broadcast(
2820        Some(session.connection_id),
2821        guest_connection_ids.iter().copied(),
2822        |connection_id| {
2823            session
2824                .peer
2825                .forward_send(session.connection_id, connection_id, request.clone())
2826        },
2827    );
2828    if let Some(room) = room {
2829        room_updated(&room, &session.peer);
2830    }
2831    response.send(proto::Ack {})?;
2832
2833    Ok(())
2834}
2835
2836/// Updates other participants with changes to the worktree
2837async fn update_worktree(
2838    request: proto::UpdateWorktree,
2839    response: Response<proto::UpdateWorktree>,
2840    session: Session,
2841) -> Result<()> {
2842    let guest_connection_ids = session
2843        .db()
2844        .await
2845        .update_worktree(&request, session.connection_id)
2846        .await?;
2847
2848    broadcast(
2849        Some(session.connection_id),
2850        guest_connection_ids.iter().copied(),
2851        |connection_id| {
2852            session
2853                .peer
2854                .forward_send(session.connection_id, connection_id, request.clone())
2855        },
2856    );
2857    response.send(proto::Ack {})?;
2858    Ok(())
2859}
2860
2861/// Updates other participants with changes to the diagnostics
2862async fn update_diagnostic_summary(
2863    message: proto::UpdateDiagnosticSummary,
2864    session: Session,
2865) -> Result<()> {
2866    let guest_connection_ids = session
2867        .db()
2868        .await
2869        .update_diagnostic_summary(&message, session.connection_id)
2870        .await?;
2871
2872    broadcast(
2873        Some(session.connection_id),
2874        guest_connection_ids.iter().copied(),
2875        |connection_id| {
2876            session
2877                .peer
2878                .forward_send(session.connection_id, connection_id, message.clone())
2879        },
2880    );
2881
2882    Ok(())
2883}
2884
2885/// Updates other participants with changes to the worktree settings
2886async fn update_worktree_settings(
2887    message: proto::UpdateWorktreeSettings,
2888    session: Session,
2889) -> Result<()> {
2890    let guest_connection_ids = session
2891        .db()
2892        .await
2893        .update_worktree_settings(&message, session.connection_id)
2894        .await?;
2895
2896    broadcast(
2897        Some(session.connection_id),
2898        guest_connection_ids.iter().copied(),
2899        |connection_id| {
2900            session
2901                .peer
2902                .forward_send(session.connection_id, connection_id, message.clone())
2903        },
2904    );
2905
2906    Ok(())
2907}
2908
2909/// Notify other participants that a  language server has started.
2910async fn start_language_server(
2911    request: proto::StartLanguageServer,
2912    session: Session,
2913) -> Result<()> {
2914    let guest_connection_ids = session
2915        .db()
2916        .await
2917        .start_language_server(&request, session.connection_id)
2918        .await?;
2919
2920    broadcast(
2921        Some(session.connection_id),
2922        guest_connection_ids.iter().copied(),
2923        |connection_id| {
2924            session
2925                .peer
2926                .forward_send(session.connection_id, connection_id, request.clone())
2927        },
2928    );
2929    Ok(())
2930}
2931
2932/// Notify other participants that a language server has changed.
2933async fn update_language_server(
2934    request: proto::UpdateLanguageServer,
2935    session: Session,
2936) -> Result<()> {
2937    let project_id = ProjectId::from_proto(request.project_id);
2938    let project_connection_ids = session
2939        .db()
2940        .await
2941        .project_connection_ids(project_id, session.connection_id, true)
2942        .await?;
2943    broadcast(
2944        Some(session.connection_id),
2945        project_connection_ids.iter().copied(),
2946        |connection_id| {
2947            session
2948                .peer
2949                .forward_send(session.connection_id, connection_id, request.clone())
2950        },
2951    );
2952    Ok(())
2953}
2954
2955/// forward a project request to the host. These requests should be read only
2956/// as guests are allowed to send them.
2957async fn forward_read_only_project_request<T>(
2958    request: T,
2959    response: Response<T>,
2960    session: UserSession,
2961) -> Result<()>
2962where
2963    T: EntityMessage + RequestMessage,
2964{
2965    let project_id = ProjectId::from_proto(request.remote_entity_id());
2966    let host_connection_id = session
2967        .db()
2968        .await
2969        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2970        .await?;
2971    let payload = session
2972        .peer
2973        .forward_request(session.connection_id, host_connection_id, request)
2974        .await?;
2975    response.send(payload)?;
2976    Ok(())
2977}
2978
2979/// forward a project request to the dev server. Only allowed
2980/// if it's your dev server.
2981async fn forward_project_request_for_owner<T>(
2982    request: T,
2983    response: Response<T>,
2984    session: UserSession,
2985) -> Result<()>
2986where
2987    T: EntityMessage + RequestMessage,
2988{
2989    let project_id = ProjectId::from_proto(request.remote_entity_id());
2990
2991    let host_connection_id = session
2992        .db()
2993        .await
2994        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2995        .await?;
2996    let payload = session
2997        .peer
2998        .forward_request(session.connection_id, host_connection_id, request)
2999        .await?;
3000    response.send(payload)?;
3001    Ok(())
3002}
3003
3004/// forward a project request to the host. These requests are disallowed
3005/// for guests.
3006async fn forward_mutating_project_request<T>(
3007    request: T,
3008    response: Response<T>,
3009    session: UserSession,
3010) -> Result<()>
3011where
3012    T: EntityMessage + RequestMessage,
3013{
3014    let project_id = ProjectId::from_proto(request.remote_entity_id());
3015
3016    let host_connection_id = session
3017        .db()
3018        .await
3019        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3020        .await?;
3021    let payload = session
3022        .peer
3023        .forward_request(session.connection_id, host_connection_id, request)
3024        .await?;
3025    response.send(payload)?;
3026    Ok(())
3027}
3028
3029/// forward a project request to the host. These requests are disallowed
3030/// for guests.
3031async fn forward_versioned_mutating_project_request<T>(
3032    request: T,
3033    response: Response<T>,
3034    session: UserSession,
3035) -> Result<()>
3036where
3037    T: EntityMessage + RequestMessage + VersionedMessage,
3038{
3039    let project_id = ProjectId::from_proto(request.remote_entity_id());
3040
3041    let host_connection_id = session
3042        .db()
3043        .await
3044        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3045        .await?;
3046    if let Some(host_version) = session
3047        .connection_pool()
3048        .await
3049        .connection(host_connection_id)
3050        .map(|c| c.zed_version)
3051    {
3052        if let Some(min_required_version) = request.required_host_version() {
3053            if min_required_version > host_version {
3054                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3055                    .with_tag("required", &min_required_version.to_string())))?;
3056            }
3057        }
3058    }
3059
3060    let payload = session
3061        .peer
3062        .forward_request(session.connection_id, host_connection_id, request)
3063        .await?;
3064    response.send(payload)?;
3065    Ok(())
3066}
3067
3068/// Notify other participants that a new buffer has been created
3069async fn create_buffer_for_peer(
3070    request: proto::CreateBufferForPeer,
3071    session: Session,
3072) -> Result<()> {
3073    session
3074        .db()
3075        .await
3076        .check_user_is_project_host(
3077            ProjectId::from_proto(request.project_id),
3078            session.connection_id,
3079        )
3080        .await?;
3081    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3082    session
3083        .peer
3084        .forward_send(session.connection_id, peer_id.into(), request)?;
3085    Ok(())
3086}
3087
3088/// Notify other participants that a buffer has been updated. This is
3089/// allowed for guests as long as the update is limited to selections.
3090async fn update_buffer(
3091    request: proto::UpdateBuffer,
3092    response: Response<proto::UpdateBuffer>,
3093    session: Session,
3094) -> Result<()> {
3095    let project_id = ProjectId::from_proto(request.project_id);
3096    let mut capability = Capability::ReadOnly;
3097
3098    for op in request.operations.iter() {
3099        match op.variant {
3100            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3101            Some(_) => capability = Capability::ReadWrite,
3102        }
3103    }
3104
3105    let host = {
3106        let guard = session
3107            .db()
3108            .await
3109            .connections_for_buffer_update(
3110                project_id,
3111                session.principal_id(),
3112                session.connection_id,
3113                capability,
3114            )
3115            .await?;
3116
3117        let (host, guests) = &*guard;
3118
3119        broadcast(
3120            Some(session.connection_id),
3121            guests.clone(),
3122            |connection_id| {
3123                session
3124                    .peer
3125                    .forward_send(session.connection_id, connection_id, request.clone())
3126            },
3127        );
3128
3129        *host
3130    };
3131
3132    if host != session.connection_id {
3133        session
3134            .peer
3135            .forward_request(session.connection_id, host, request.clone())
3136            .await?;
3137    }
3138
3139    response.send(proto::Ack {})?;
3140    Ok(())
3141}
3142
3143async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3144    let project_id = ProjectId::from_proto(message.project_id);
3145
3146    let operation = message.operation.as_ref().context("invalid operation")?;
3147    let capability = match operation.variant.as_ref() {
3148        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3149            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3150                match buffer_op.variant {
3151                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3152                        Capability::ReadOnly
3153                    }
3154                    _ => Capability::ReadWrite,
3155                }
3156            } else {
3157                Capability::ReadWrite
3158            }
3159        }
3160        Some(_) => Capability::ReadWrite,
3161        None => Capability::ReadOnly,
3162    };
3163
3164    let guard = session
3165        .db()
3166        .await
3167        .connections_for_buffer_update(
3168            project_id,
3169            session.principal_id(),
3170            session.connection_id,
3171            capability,
3172        )
3173        .await?;
3174
3175    let (host, guests) = &*guard;
3176
3177    broadcast(
3178        Some(session.connection_id),
3179        guests.iter().chain([host]).copied(),
3180        |connection_id| {
3181            session
3182                .peer
3183                .forward_send(session.connection_id, connection_id, message.clone())
3184        },
3185    );
3186
3187    Ok(())
3188}
3189
3190/// Notify other participants that a project has been updated.
3191async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3192    request: T,
3193    session: Session,
3194) -> Result<()> {
3195    let project_id = ProjectId::from_proto(request.remote_entity_id());
3196    let project_connection_ids = session
3197        .db()
3198        .await
3199        .project_connection_ids(project_id, session.connection_id, false)
3200        .await?;
3201
3202    broadcast(
3203        Some(session.connection_id),
3204        project_connection_ids.iter().copied(),
3205        |connection_id| {
3206            session
3207                .peer
3208                .forward_send(session.connection_id, connection_id, request.clone())
3209        },
3210    );
3211    Ok(())
3212}
3213
3214/// Start following another user in a call.
3215async fn follow(
3216    request: proto::Follow,
3217    response: Response<proto::Follow>,
3218    session: UserSession,
3219) -> Result<()> {
3220    let room_id = RoomId::from_proto(request.room_id);
3221    let project_id = request.project_id.map(ProjectId::from_proto);
3222    let leader_id = request
3223        .leader_id
3224        .ok_or_else(|| anyhow!("invalid leader id"))?
3225        .into();
3226    let follower_id = session.connection_id;
3227
3228    session
3229        .db()
3230        .await
3231        .check_room_participants(room_id, leader_id, session.connection_id)
3232        .await?;
3233
3234    let response_payload = session
3235        .peer
3236        .forward_request(session.connection_id, leader_id, request)
3237        .await?;
3238    response.send(response_payload)?;
3239
3240    if let Some(project_id) = project_id {
3241        let room = session
3242            .db()
3243            .await
3244            .follow(room_id, project_id, leader_id, follower_id)
3245            .await?;
3246        room_updated(&room, &session.peer);
3247    }
3248
3249    Ok(())
3250}
3251
3252/// Stop following another user in a call.
3253async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3254    let room_id = RoomId::from_proto(request.room_id);
3255    let project_id = request.project_id.map(ProjectId::from_proto);
3256    let leader_id = request
3257        .leader_id
3258        .ok_or_else(|| anyhow!("invalid leader id"))?
3259        .into();
3260    let follower_id = session.connection_id;
3261
3262    session
3263        .db()
3264        .await
3265        .check_room_participants(room_id, leader_id, session.connection_id)
3266        .await?;
3267
3268    session
3269        .peer
3270        .forward_send(session.connection_id, leader_id, request)?;
3271
3272    if let Some(project_id) = project_id {
3273        let room = session
3274            .db()
3275            .await
3276            .unfollow(room_id, project_id, leader_id, follower_id)
3277            .await?;
3278        room_updated(&room, &session.peer);
3279    }
3280
3281    Ok(())
3282}
3283
3284/// Notify everyone following you of your current location.
3285async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3286    let room_id = RoomId::from_proto(request.room_id);
3287    let database = session.db.lock().await;
3288
3289    let connection_ids = if let Some(project_id) = request.project_id {
3290        let project_id = ProjectId::from_proto(project_id);
3291        database
3292            .project_connection_ids(project_id, session.connection_id, true)
3293            .await?
3294    } else {
3295        database
3296            .room_connection_ids(room_id, session.connection_id)
3297            .await?
3298    };
3299
3300    // For now, don't send view update messages back to that view's current leader.
3301    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3302        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3303        _ => None,
3304    });
3305
3306    for connection_id in connection_ids.iter().cloned() {
3307        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3308            session
3309                .peer
3310                .forward_send(session.connection_id, connection_id, request.clone())?;
3311        }
3312    }
3313    Ok(())
3314}
3315
3316/// Get public data about users.
3317async fn get_users(
3318    request: proto::GetUsers,
3319    response: Response<proto::GetUsers>,
3320    session: Session,
3321) -> Result<()> {
3322    let user_ids = request
3323        .user_ids
3324        .into_iter()
3325        .map(UserId::from_proto)
3326        .collect();
3327    let users = session
3328        .db()
3329        .await
3330        .get_users_by_ids(user_ids)
3331        .await?
3332        .into_iter()
3333        .map(|user| proto::User {
3334            id: user.id.to_proto(),
3335            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3336            github_login: user.github_login,
3337        })
3338        .collect();
3339    response.send(proto::UsersResponse { users })?;
3340    Ok(())
3341}
3342
3343/// Search for users (to invite) buy Github login
3344async fn fuzzy_search_users(
3345    request: proto::FuzzySearchUsers,
3346    response: Response<proto::FuzzySearchUsers>,
3347    session: UserSession,
3348) -> Result<()> {
3349    let query = request.query;
3350    let users = match query.len() {
3351        0 => vec![],
3352        1 | 2 => session
3353            .db()
3354            .await
3355            .get_user_by_github_login(&query)
3356            .await?
3357            .into_iter()
3358            .collect(),
3359        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3360    };
3361    let users = users
3362        .into_iter()
3363        .filter(|user| user.id != session.user_id())
3364        .map(|user| proto::User {
3365            id: user.id.to_proto(),
3366            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3367            github_login: user.github_login,
3368        })
3369        .collect();
3370    response.send(proto::UsersResponse { users })?;
3371    Ok(())
3372}
3373
3374/// Send a contact request to another user.
3375async fn request_contact(
3376    request: proto::RequestContact,
3377    response: Response<proto::RequestContact>,
3378    session: UserSession,
3379) -> Result<()> {
3380    let requester_id = session.user_id();
3381    let responder_id = UserId::from_proto(request.responder_id);
3382    if requester_id == responder_id {
3383        return Err(anyhow!("cannot add yourself as a contact"))?;
3384    }
3385
3386    let notifications = session
3387        .db()
3388        .await
3389        .send_contact_request(requester_id, responder_id)
3390        .await?;
3391
3392    // Update outgoing contact requests of requester
3393    let mut update = proto::UpdateContacts::default();
3394    update.outgoing_requests.push(responder_id.to_proto());
3395    for connection_id in session
3396        .connection_pool()
3397        .await
3398        .user_connection_ids(requester_id)
3399    {
3400        session.peer.send(connection_id, update.clone())?;
3401    }
3402
3403    // Update incoming contact requests of responder
3404    let mut update = proto::UpdateContacts::default();
3405    update
3406        .incoming_requests
3407        .push(proto::IncomingContactRequest {
3408            requester_id: requester_id.to_proto(),
3409        });
3410    let connection_pool = session.connection_pool().await;
3411    for connection_id in connection_pool.user_connection_ids(responder_id) {
3412        session.peer.send(connection_id, update.clone())?;
3413    }
3414
3415    send_notifications(&connection_pool, &session.peer, notifications);
3416
3417    response.send(proto::Ack {})?;
3418    Ok(())
3419}
3420
3421/// Accept or decline a contact request
3422async fn respond_to_contact_request(
3423    request: proto::RespondToContactRequest,
3424    response: Response<proto::RespondToContactRequest>,
3425    session: UserSession,
3426) -> Result<()> {
3427    let responder_id = session.user_id();
3428    let requester_id = UserId::from_proto(request.requester_id);
3429    let db = session.db().await;
3430    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3431        db.dismiss_contact_notification(responder_id, requester_id)
3432            .await?;
3433    } else {
3434        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3435
3436        let notifications = db
3437            .respond_to_contact_request(responder_id, requester_id, accept)
3438            .await?;
3439        let requester_busy = db.is_user_busy(requester_id).await?;
3440        let responder_busy = db.is_user_busy(responder_id).await?;
3441
3442        let pool = session.connection_pool().await;
3443        // Update responder with new contact
3444        let mut update = proto::UpdateContacts::default();
3445        if accept {
3446            update
3447                .contacts
3448                .push(contact_for_user(requester_id, requester_busy, &pool));
3449        }
3450        update
3451            .remove_incoming_requests
3452            .push(requester_id.to_proto());
3453        for connection_id in pool.user_connection_ids(responder_id) {
3454            session.peer.send(connection_id, update.clone())?;
3455        }
3456
3457        // Update requester with new contact
3458        let mut update = proto::UpdateContacts::default();
3459        if accept {
3460            update
3461                .contacts
3462                .push(contact_for_user(responder_id, responder_busy, &pool));
3463        }
3464        update
3465            .remove_outgoing_requests
3466            .push(responder_id.to_proto());
3467
3468        for connection_id in pool.user_connection_ids(requester_id) {
3469            session.peer.send(connection_id, update.clone())?;
3470        }
3471
3472        send_notifications(&pool, &session.peer, notifications);
3473    }
3474
3475    response.send(proto::Ack {})?;
3476    Ok(())
3477}
3478
3479/// Remove a contact.
3480async fn remove_contact(
3481    request: proto::RemoveContact,
3482    response: Response<proto::RemoveContact>,
3483    session: UserSession,
3484) -> Result<()> {
3485    let requester_id = session.user_id();
3486    let responder_id = UserId::from_proto(request.user_id);
3487    let db = session.db().await;
3488    let (contact_accepted, deleted_notification_id) =
3489        db.remove_contact(requester_id, responder_id).await?;
3490
3491    let pool = session.connection_pool().await;
3492    // Update outgoing contact requests of requester
3493    let mut update = proto::UpdateContacts::default();
3494    if contact_accepted {
3495        update.remove_contacts.push(responder_id.to_proto());
3496    } else {
3497        update
3498            .remove_outgoing_requests
3499            .push(responder_id.to_proto());
3500    }
3501    for connection_id in pool.user_connection_ids(requester_id) {
3502        session.peer.send(connection_id, update.clone())?;
3503    }
3504
3505    // Update incoming contact requests of responder
3506    let mut update = proto::UpdateContacts::default();
3507    if contact_accepted {
3508        update.remove_contacts.push(requester_id.to_proto());
3509    } else {
3510        update
3511            .remove_incoming_requests
3512            .push(requester_id.to_proto());
3513    }
3514    for connection_id in pool.user_connection_ids(responder_id) {
3515        session.peer.send(connection_id, update.clone())?;
3516        if let Some(notification_id) = deleted_notification_id {
3517            session.peer.send(
3518                connection_id,
3519                proto::DeleteNotification {
3520                    notification_id: notification_id.to_proto(),
3521                },
3522            )?;
3523        }
3524    }
3525
3526    response.send(proto::Ack {})?;
3527    Ok(())
3528}
3529
3530fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3531    version.0.minor() < 139
3532}
3533
3534async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3535    subscribe_user_to_channels(
3536        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3537        &session,
3538    )
3539    .await?;
3540    Ok(())
3541}
3542
3543async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3544    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3545    let mut pool = session.connection_pool().await;
3546    for membership in &channels_for_user.channel_memberships {
3547        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3548    }
3549    session.peer.send(
3550        session.connection_id,
3551        build_update_user_channels(&channels_for_user),
3552    )?;
3553    session.peer.send(
3554        session.connection_id,
3555        build_channels_update(channels_for_user),
3556    )?;
3557    Ok(())
3558}
3559
3560/// Creates a new channel.
3561async fn create_channel(
3562    request: proto::CreateChannel,
3563    response: Response<proto::CreateChannel>,
3564    session: UserSession,
3565) -> Result<()> {
3566    let db = session.db().await;
3567
3568    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3569    let (channel, membership) = db
3570        .create_channel(&request.name, parent_id, session.user_id())
3571        .await?;
3572
3573    let root_id = channel.root_id();
3574    let channel = Channel::from_model(channel);
3575
3576    response.send(proto::CreateChannelResponse {
3577        channel: Some(channel.to_proto()),
3578        parent_id: request.parent_id,
3579    })?;
3580
3581    let mut connection_pool = session.connection_pool().await;
3582    if let Some(membership) = membership {
3583        connection_pool.subscribe_to_channel(
3584            membership.user_id,
3585            membership.channel_id,
3586            membership.role,
3587        );
3588        let update = proto::UpdateUserChannels {
3589            channel_memberships: vec![proto::ChannelMembership {
3590                channel_id: membership.channel_id.to_proto(),
3591                role: membership.role.into(),
3592            }],
3593            ..Default::default()
3594        };
3595        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3596            session.peer.send(connection_id, update.clone())?;
3597        }
3598    }
3599
3600    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3601        if !role.can_see_channel(channel.visibility) {
3602            continue;
3603        }
3604
3605        let update = proto::UpdateChannels {
3606            channels: vec![channel.to_proto()],
3607            ..Default::default()
3608        };
3609        session.peer.send(connection_id, update.clone())?;
3610    }
3611
3612    Ok(())
3613}
3614
3615/// Delete a channel
3616async fn delete_channel(
3617    request: proto::DeleteChannel,
3618    response: Response<proto::DeleteChannel>,
3619    session: UserSession,
3620) -> Result<()> {
3621    let db = session.db().await;
3622
3623    let channel_id = request.channel_id;
3624    let (root_channel, removed_channels) = db
3625        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3626        .await?;
3627    response.send(proto::Ack {})?;
3628
3629    // Notify members of removed channels
3630    let mut update = proto::UpdateChannels::default();
3631    update
3632        .delete_channels
3633        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3634
3635    let connection_pool = session.connection_pool().await;
3636    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3637        session.peer.send(connection_id, update.clone())?;
3638    }
3639
3640    Ok(())
3641}
3642
3643/// Invite someone to join a channel.
3644async fn invite_channel_member(
3645    request: proto::InviteChannelMember,
3646    response: Response<proto::InviteChannelMember>,
3647    session: UserSession,
3648) -> Result<()> {
3649    let db = session.db().await;
3650    let channel_id = ChannelId::from_proto(request.channel_id);
3651    let invitee_id = UserId::from_proto(request.user_id);
3652    let InviteMemberResult {
3653        channel,
3654        notifications,
3655    } = db
3656        .invite_channel_member(
3657            channel_id,
3658            invitee_id,
3659            session.user_id(),
3660            request.role().into(),
3661        )
3662        .await?;
3663
3664    let update = proto::UpdateChannels {
3665        channel_invitations: vec![channel.to_proto()],
3666        ..Default::default()
3667    };
3668
3669    let connection_pool = session.connection_pool().await;
3670    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3671        session.peer.send(connection_id, update.clone())?;
3672    }
3673
3674    send_notifications(&connection_pool, &session.peer, notifications);
3675
3676    response.send(proto::Ack {})?;
3677    Ok(())
3678}
3679
3680/// remove someone from a channel
3681async fn remove_channel_member(
3682    request: proto::RemoveChannelMember,
3683    response: Response<proto::RemoveChannelMember>,
3684    session: UserSession,
3685) -> Result<()> {
3686    let db = session.db().await;
3687    let channel_id = ChannelId::from_proto(request.channel_id);
3688    let member_id = UserId::from_proto(request.user_id);
3689
3690    let RemoveChannelMemberResult {
3691        membership_update,
3692        notification_id,
3693    } = db
3694        .remove_channel_member(channel_id, member_id, session.user_id())
3695        .await?;
3696
3697    let mut connection_pool = session.connection_pool().await;
3698    notify_membership_updated(
3699        &mut connection_pool,
3700        membership_update,
3701        member_id,
3702        &session.peer,
3703    );
3704    for connection_id in connection_pool.user_connection_ids(member_id) {
3705        if let Some(notification_id) = notification_id {
3706            session
3707                .peer
3708                .send(
3709                    connection_id,
3710                    proto::DeleteNotification {
3711                        notification_id: notification_id.to_proto(),
3712                    },
3713                )
3714                .trace_err();
3715        }
3716    }
3717
3718    response.send(proto::Ack {})?;
3719    Ok(())
3720}
3721
3722/// Toggle the channel between public and private.
3723/// Care is taken to maintain the invariant that public channels only descend from public channels,
3724/// (though members-only channels can appear at any point in the hierarchy).
3725async fn set_channel_visibility(
3726    request: proto::SetChannelVisibility,
3727    response: Response<proto::SetChannelVisibility>,
3728    session: UserSession,
3729) -> Result<()> {
3730    let db = session.db().await;
3731    let channel_id = ChannelId::from_proto(request.channel_id);
3732    let visibility = request.visibility().into();
3733
3734    let channel_model = db
3735        .set_channel_visibility(channel_id, visibility, session.user_id())
3736        .await?;
3737    let root_id = channel_model.root_id();
3738    let channel = Channel::from_model(channel_model);
3739
3740    let mut connection_pool = session.connection_pool().await;
3741    for (user_id, role) in connection_pool
3742        .channel_user_ids(root_id)
3743        .collect::<Vec<_>>()
3744        .into_iter()
3745    {
3746        let update = if role.can_see_channel(channel.visibility) {
3747            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3748            proto::UpdateChannels {
3749                channels: vec![channel.to_proto()],
3750                ..Default::default()
3751            }
3752        } else {
3753            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3754            proto::UpdateChannels {
3755                delete_channels: vec![channel.id.to_proto()],
3756                ..Default::default()
3757            }
3758        };
3759
3760        for connection_id in connection_pool.user_connection_ids(user_id) {
3761            session.peer.send(connection_id, update.clone())?;
3762        }
3763    }
3764
3765    response.send(proto::Ack {})?;
3766    Ok(())
3767}
3768
3769/// Alter the role for a user in the channel.
3770async fn set_channel_member_role(
3771    request: proto::SetChannelMemberRole,
3772    response: Response<proto::SetChannelMemberRole>,
3773    session: UserSession,
3774) -> Result<()> {
3775    let db = session.db().await;
3776    let channel_id = ChannelId::from_proto(request.channel_id);
3777    let member_id = UserId::from_proto(request.user_id);
3778    let result = db
3779        .set_channel_member_role(
3780            channel_id,
3781            session.user_id(),
3782            member_id,
3783            request.role().into(),
3784        )
3785        .await?;
3786
3787    match result {
3788        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3789            let mut connection_pool = session.connection_pool().await;
3790            notify_membership_updated(
3791                &mut connection_pool,
3792                membership_update,
3793                member_id,
3794                &session.peer,
3795            )
3796        }
3797        db::SetMemberRoleResult::InviteUpdated(channel) => {
3798            let update = proto::UpdateChannels {
3799                channel_invitations: vec![channel.to_proto()],
3800                ..Default::default()
3801            };
3802
3803            for connection_id in session
3804                .connection_pool()
3805                .await
3806                .user_connection_ids(member_id)
3807            {
3808                session.peer.send(connection_id, update.clone())?;
3809            }
3810        }
3811    }
3812
3813    response.send(proto::Ack {})?;
3814    Ok(())
3815}
3816
3817/// Change the name of a channel
3818async fn rename_channel(
3819    request: proto::RenameChannel,
3820    response: Response<proto::RenameChannel>,
3821    session: UserSession,
3822) -> Result<()> {
3823    let db = session.db().await;
3824    let channel_id = ChannelId::from_proto(request.channel_id);
3825    let channel_model = db
3826        .rename_channel(channel_id, session.user_id(), &request.name)
3827        .await?;
3828    let root_id = channel_model.root_id();
3829    let channel = Channel::from_model(channel_model);
3830
3831    response.send(proto::RenameChannelResponse {
3832        channel: Some(channel.to_proto()),
3833    })?;
3834
3835    let connection_pool = session.connection_pool().await;
3836    let update = proto::UpdateChannels {
3837        channels: vec![channel.to_proto()],
3838        ..Default::default()
3839    };
3840    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3841        if role.can_see_channel(channel.visibility) {
3842            session.peer.send(connection_id, update.clone())?;
3843        }
3844    }
3845
3846    Ok(())
3847}
3848
3849/// Move a channel to a new parent.
3850async fn move_channel(
3851    request: proto::MoveChannel,
3852    response: Response<proto::MoveChannel>,
3853    session: UserSession,
3854) -> Result<()> {
3855    let channel_id = ChannelId::from_proto(request.channel_id);
3856    let to = ChannelId::from_proto(request.to);
3857
3858    let (root_id, channels) = session
3859        .db()
3860        .await
3861        .move_channel(channel_id, to, session.user_id())
3862        .await?;
3863
3864    let connection_pool = session.connection_pool().await;
3865    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3866        let channels = channels
3867            .iter()
3868            .filter_map(|channel| {
3869                if role.can_see_channel(channel.visibility) {
3870                    Some(channel.to_proto())
3871                } else {
3872                    None
3873                }
3874            })
3875            .collect::<Vec<_>>();
3876        if channels.is_empty() {
3877            continue;
3878        }
3879
3880        let update = proto::UpdateChannels {
3881            channels,
3882            ..Default::default()
3883        };
3884
3885        session.peer.send(connection_id, update.clone())?;
3886    }
3887
3888    response.send(Ack {})?;
3889    Ok(())
3890}
3891
3892/// Get the list of channel members
3893async fn get_channel_members(
3894    request: proto::GetChannelMembers,
3895    response: Response<proto::GetChannelMembers>,
3896    session: UserSession,
3897) -> Result<()> {
3898    let db = session.db().await;
3899    let channel_id = ChannelId::from_proto(request.channel_id);
3900    let limit = if request.limit == 0 {
3901        u16::MAX as u64
3902    } else {
3903        request.limit
3904    };
3905    let (members, users) = db
3906        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3907        .await?;
3908    response.send(proto::GetChannelMembersResponse { members, users })?;
3909    Ok(())
3910}
3911
3912/// Accept or decline a channel invitation.
3913async fn respond_to_channel_invite(
3914    request: proto::RespondToChannelInvite,
3915    response: Response<proto::RespondToChannelInvite>,
3916    session: UserSession,
3917) -> Result<()> {
3918    let db = session.db().await;
3919    let channel_id = ChannelId::from_proto(request.channel_id);
3920    let RespondToChannelInvite {
3921        membership_update,
3922        notifications,
3923    } = db
3924        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3925        .await?;
3926
3927    let mut connection_pool = session.connection_pool().await;
3928    if let Some(membership_update) = membership_update {
3929        notify_membership_updated(
3930            &mut connection_pool,
3931            membership_update,
3932            session.user_id(),
3933            &session.peer,
3934        );
3935    } else {
3936        let update = proto::UpdateChannels {
3937            remove_channel_invitations: vec![channel_id.to_proto()],
3938            ..Default::default()
3939        };
3940
3941        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3942            session.peer.send(connection_id, update.clone())?;
3943        }
3944    };
3945
3946    send_notifications(&connection_pool, &session.peer, notifications);
3947
3948    response.send(proto::Ack {})?;
3949
3950    Ok(())
3951}
3952
3953/// Join the channels' room
3954async fn join_channel(
3955    request: proto::JoinChannel,
3956    response: Response<proto::JoinChannel>,
3957    session: UserSession,
3958) -> Result<()> {
3959    let channel_id = ChannelId::from_proto(request.channel_id);
3960    join_channel_internal(channel_id, Box::new(response), session).await
3961}
3962
3963trait JoinChannelInternalResponse {
3964    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3965}
3966impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3967    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3968        Response::<proto::JoinChannel>::send(self, result)
3969    }
3970}
3971impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3972    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3973        Response::<proto::JoinRoom>::send(self, result)
3974    }
3975}
3976
3977async fn join_channel_internal(
3978    channel_id: ChannelId,
3979    response: Box<impl JoinChannelInternalResponse>,
3980    session: UserSession,
3981) -> Result<()> {
3982    let joined_room = {
3983        let mut db = session.db().await;
3984        // If zed quits without leaving the room, and the user re-opens zed before the
3985        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3986        // room they were in.
3987        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3988            tracing::info!(
3989                stale_connection_id = %connection,
3990                "cleaning up stale connection",
3991            );
3992            drop(db);
3993            leave_room_for_session(&session, connection).await?;
3994            db = session.db().await;
3995        }
3996
3997        let (joined_room, membership_updated, role) = db
3998            .join_channel(channel_id, session.user_id(), session.connection_id)
3999            .await?;
4000
4001        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4002            let (can_publish, token) = if role == ChannelRole::Guest {
4003                (
4004                    false,
4005                    live_kit
4006                        .guest_token(
4007                            &joined_room.room.live_kit_room,
4008                            &session.user_id().to_string(),
4009                        )
4010                        .trace_err()?,
4011                )
4012            } else {
4013                (
4014                    true,
4015                    live_kit
4016                        .room_token(
4017                            &joined_room.room.live_kit_room,
4018                            &session.user_id().to_string(),
4019                        )
4020                        .trace_err()?,
4021                )
4022            };
4023
4024            Some(LiveKitConnectionInfo {
4025                server_url: live_kit.url().into(),
4026                token,
4027                can_publish,
4028            })
4029        });
4030
4031        response.send(proto::JoinRoomResponse {
4032            room: Some(joined_room.room.clone()),
4033            channel_id: joined_room
4034                .channel
4035                .as_ref()
4036                .map(|channel| channel.id.to_proto()),
4037            live_kit_connection_info,
4038        })?;
4039
4040        let mut connection_pool = session.connection_pool().await;
4041        if let Some(membership_updated) = membership_updated {
4042            notify_membership_updated(
4043                &mut connection_pool,
4044                membership_updated,
4045                session.user_id(),
4046                &session.peer,
4047            );
4048        }
4049
4050        room_updated(&joined_room.room, &session.peer);
4051
4052        joined_room
4053    };
4054
4055    channel_updated(
4056        &joined_room
4057            .channel
4058            .ok_or_else(|| anyhow!("channel not returned"))?,
4059        &joined_room.room,
4060        &session.peer,
4061        &*session.connection_pool().await,
4062    );
4063
4064    update_user_contacts(session.user_id(), &session).await?;
4065    Ok(())
4066}
4067
4068/// Start editing the channel notes
4069async fn join_channel_buffer(
4070    request: proto::JoinChannelBuffer,
4071    response: Response<proto::JoinChannelBuffer>,
4072    session: UserSession,
4073) -> Result<()> {
4074    let db = session.db().await;
4075    let channel_id = ChannelId::from_proto(request.channel_id);
4076
4077    let open_response = db
4078        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4079        .await?;
4080
4081    let collaborators = open_response.collaborators.clone();
4082    response.send(open_response)?;
4083
4084    let update = UpdateChannelBufferCollaborators {
4085        channel_id: channel_id.to_proto(),
4086        collaborators: collaborators.clone(),
4087    };
4088    channel_buffer_updated(
4089        session.connection_id,
4090        collaborators
4091            .iter()
4092            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4093        &update,
4094        &session.peer,
4095    );
4096
4097    Ok(())
4098}
4099
4100/// Edit the channel notes
4101async fn update_channel_buffer(
4102    request: proto::UpdateChannelBuffer,
4103    session: UserSession,
4104) -> Result<()> {
4105    let db = session.db().await;
4106    let channel_id = ChannelId::from_proto(request.channel_id);
4107
4108    let (collaborators, epoch, version) = db
4109        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4110        .await?;
4111
4112    channel_buffer_updated(
4113        session.connection_id,
4114        collaborators.clone(),
4115        &proto::UpdateChannelBuffer {
4116            channel_id: channel_id.to_proto(),
4117            operations: request.operations,
4118        },
4119        &session.peer,
4120    );
4121
4122    let pool = &*session.connection_pool().await;
4123
4124    let non_collaborators =
4125        pool.channel_connection_ids(channel_id)
4126            .filter_map(|(connection_id, _)| {
4127                if collaborators.contains(&connection_id) {
4128                    None
4129                } else {
4130                    Some(connection_id)
4131                }
4132            });
4133
4134    broadcast(None, non_collaborators, |peer_id| {
4135        session.peer.send(
4136            peer_id,
4137            proto::UpdateChannels {
4138                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4139                    channel_id: channel_id.to_proto(),
4140                    epoch: epoch as u64,
4141                    version: version.clone(),
4142                }],
4143                ..Default::default()
4144            },
4145        )
4146    });
4147
4148    Ok(())
4149}
4150
4151/// Rejoin the channel notes after a connection blip
4152async fn rejoin_channel_buffers(
4153    request: proto::RejoinChannelBuffers,
4154    response: Response<proto::RejoinChannelBuffers>,
4155    session: UserSession,
4156) -> Result<()> {
4157    let db = session.db().await;
4158    let buffers = db
4159        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4160        .await?;
4161
4162    for rejoined_buffer in &buffers {
4163        let collaborators_to_notify = rejoined_buffer
4164            .buffer
4165            .collaborators
4166            .iter()
4167            .filter_map(|c| Some(c.peer_id?.into()));
4168        channel_buffer_updated(
4169            session.connection_id,
4170            collaborators_to_notify,
4171            &proto::UpdateChannelBufferCollaborators {
4172                channel_id: rejoined_buffer.buffer.channel_id,
4173                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4174            },
4175            &session.peer,
4176        );
4177    }
4178
4179    response.send(proto::RejoinChannelBuffersResponse {
4180        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4181    })?;
4182
4183    Ok(())
4184}
4185
4186/// Stop editing the channel notes
4187async fn leave_channel_buffer(
4188    request: proto::LeaveChannelBuffer,
4189    response: Response<proto::LeaveChannelBuffer>,
4190    session: UserSession,
4191) -> Result<()> {
4192    let db = session.db().await;
4193    let channel_id = ChannelId::from_proto(request.channel_id);
4194
4195    let left_buffer = db
4196        .leave_channel_buffer(channel_id, session.connection_id)
4197        .await?;
4198
4199    response.send(Ack {})?;
4200
4201    channel_buffer_updated(
4202        session.connection_id,
4203        left_buffer.connections,
4204        &proto::UpdateChannelBufferCollaborators {
4205            channel_id: channel_id.to_proto(),
4206            collaborators: left_buffer.collaborators,
4207        },
4208        &session.peer,
4209    );
4210
4211    Ok(())
4212}
4213
4214fn channel_buffer_updated<T: EnvelopedMessage>(
4215    sender_id: ConnectionId,
4216    collaborators: impl IntoIterator<Item = ConnectionId>,
4217    message: &T,
4218    peer: &Peer,
4219) {
4220    broadcast(Some(sender_id), collaborators, |peer_id| {
4221        peer.send(peer_id, message.clone())
4222    });
4223}
4224
4225fn send_notifications(
4226    connection_pool: &ConnectionPool,
4227    peer: &Peer,
4228    notifications: db::NotificationBatch,
4229) {
4230    for (user_id, notification) in notifications {
4231        for connection_id in connection_pool.user_connection_ids(user_id) {
4232            if let Err(error) = peer.send(
4233                connection_id,
4234                proto::AddNotification {
4235                    notification: Some(notification.clone()),
4236                },
4237            ) {
4238                tracing::error!(
4239                    "failed to send notification to {:?} {}",
4240                    connection_id,
4241                    error
4242                );
4243            }
4244        }
4245    }
4246}
4247
4248/// Send a message to the channel
4249async fn send_channel_message(
4250    request: proto::SendChannelMessage,
4251    response: Response<proto::SendChannelMessage>,
4252    session: UserSession,
4253) -> Result<()> {
4254    // Validate the message body.
4255    let body = request.body.trim().to_string();
4256    if body.len() > MAX_MESSAGE_LEN {
4257        return Err(anyhow!("message is too long"))?;
4258    }
4259    if body.is_empty() {
4260        return Err(anyhow!("message can't be blank"))?;
4261    }
4262
4263    // TODO: adjust mentions if body is trimmed
4264
4265    let timestamp = OffsetDateTime::now_utc();
4266    let nonce = request
4267        .nonce
4268        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4269
4270    let channel_id = ChannelId::from_proto(request.channel_id);
4271    let CreatedChannelMessage {
4272        message_id,
4273        participant_connection_ids,
4274        notifications,
4275    } = session
4276        .db()
4277        .await
4278        .create_channel_message(
4279            channel_id,
4280            session.user_id(),
4281            &body,
4282            &request.mentions,
4283            timestamp,
4284            nonce.clone().into(),
4285            match request.reply_to_message_id {
4286                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4287                None => None,
4288            },
4289        )
4290        .await?;
4291
4292    let message = proto::ChannelMessage {
4293        sender_id: session.user_id().to_proto(),
4294        id: message_id.to_proto(),
4295        body,
4296        mentions: request.mentions,
4297        timestamp: timestamp.unix_timestamp() as u64,
4298        nonce: Some(nonce),
4299        reply_to_message_id: request.reply_to_message_id,
4300        edited_at: None,
4301    };
4302    broadcast(
4303        Some(session.connection_id),
4304        participant_connection_ids.clone(),
4305        |connection| {
4306            session.peer.send(
4307                connection,
4308                proto::ChannelMessageSent {
4309                    channel_id: channel_id.to_proto(),
4310                    message: Some(message.clone()),
4311                },
4312            )
4313        },
4314    );
4315    response.send(proto::SendChannelMessageResponse {
4316        message: Some(message),
4317    })?;
4318
4319    let pool = &*session.connection_pool().await;
4320    let non_participants =
4321        pool.channel_connection_ids(channel_id)
4322            .filter_map(|(connection_id, _)| {
4323                if participant_connection_ids.contains(&connection_id) {
4324                    None
4325                } else {
4326                    Some(connection_id)
4327                }
4328            });
4329    broadcast(None, non_participants, |peer_id| {
4330        session.peer.send(
4331            peer_id,
4332            proto::UpdateChannels {
4333                latest_channel_message_ids: vec![proto::ChannelMessageId {
4334                    channel_id: channel_id.to_proto(),
4335                    message_id: message_id.to_proto(),
4336                }],
4337                ..Default::default()
4338            },
4339        )
4340    });
4341    send_notifications(pool, &session.peer, notifications);
4342
4343    Ok(())
4344}
4345
4346/// Delete a channel message
4347async fn remove_channel_message(
4348    request: proto::RemoveChannelMessage,
4349    response: Response<proto::RemoveChannelMessage>,
4350    session: UserSession,
4351) -> Result<()> {
4352    let channel_id = ChannelId::from_proto(request.channel_id);
4353    let message_id = MessageId::from_proto(request.message_id);
4354    let (connection_ids, existing_notification_ids) = session
4355        .db()
4356        .await
4357        .remove_channel_message(channel_id, message_id, session.user_id())
4358        .await?;
4359
4360    broadcast(
4361        Some(session.connection_id),
4362        connection_ids,
4363        move |connection| {
4364            session.peer.send(connection, request.clone())?;
4365
4366            for notification_id in &existing_notification_ids {
4367                session.peer.send(
4368                    connection,
4369                    proto::DeleteNotification {
4370                        notification_id: (*notification_id).to_proto(),
4371                    },
4372                )?;
4373            }
4374
4375            Ok(())
4376        },
4377    );
4378    response.send(proto::Ack {})?;
4379    Ok(())
4380}
4381
4382async fn update_channel_message(
4383    request: proto::UpdateChannelMessage,
4384    response: Response<proto::UpdateChannelMessage>,
4385    session: UserSession,
4386) -> Result<()> {
4387    let channel_id = ChannelId::from_proto(request.channel_id);
4388    let message_id = MessageId::from_proto(request.message_id);
4389    let updated_at = OffsetDateTime::now_utc();
4390    let UpdatedChannelMessage {
4391        message_id,
4392        participant_connection_ids,
4393        notifications,
4394        reply_to_message_id,
4395        timestamp,
4396        deleted_mention_notification_ids,
4397        updated_mention_notifications,
4398    } = session
4399        .db()
4400        .await
4401        .update_channel_message(
4402            channel_id,
4403            message_id,
4404            session.user_id(),
4405            request.body.as_str(),
4406            &request.mentions,
4407            updated_at,
4408        )
4409        .await?;
4410
4411    let nonce = request
4412        .nonce
4413        .clone()
4414        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4415
4416    let message = proto::ChannelMessage {
4417        sender_id: session.user_id().to_proto(),
4418        id: message_id.to_proto(),
4419        body: request.body.clone(),
4420        mentions: request.mentions.clone(),
4421        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4422        nonce: Some(nonce),
4423        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4424        edited_at: Some(updated_at.unix_timestamp() as u64),
4425    };
4426
4427    response.send(proto::Ack {})?;
4428
4429    let pool = &*session.connection_pool().await;
4430    broadcast(
4431        Some(session.connection_id),
4432        participant_connection_ids,
4433        |connection| {
4434            session.peer.send(
4435                connection,
4436                proto::ChannelMessageUpdate {
4437                    channel_id: channel_id.to_proto(),
4438                    message: Some(message.clone()),
4439                },
4440            )?;
4441
4442            for notification_id in &deleted_mention_notification_ids {
4443                session.peer.send(
4444                    connection,
4445                    proto::DeleteNotification {
4446                        notification_id: (*notification_id).to_proto(),
4447                    },
4448                )?;
4449            }
4450
4451            for notification in &updated_mention_notifications {
4452                session.peer.send(
4453                    connection,
4454                    proto::UpdateNotification {
4455                        notification: Some(notification.clone()),
4456                    },
4457                )?;
4458            }
4459
4460            Ok(())
4461        },
4462    );
4463
4464    send_notifications(pool, &session.peer, notifications);
4465
4466    Ok(())
4467}
4468
4469/// Mark a channel message as read
4470async fn acknowledge_channel_message(
4471    request: proto::AckChannelMessage,
4472    session: UserSession,
4473) -> Result<()> {
4474    let channel_id = ChannelId::from_proto(request.channel_id);
4475    let message_id = MessageId::from_proto(request.message_id);
4476    let notifications = session
4477        .db()
4478        .await
4479        .observe_channel_message(channel_id, session.user_id(), message_id)
4480        .await?;
4481    send_notifications(
4482        &*session.connection_pool().await,
4483        &session.peer,
4484        notifications,
4485    );
4486    Ok(())
4487}
4488
4489/// Mark a buffer version as synced
4490async fn acknowledge_buffer_version(
4491    request: proto::AckBufferOperation,
4492    session: UserSession,
4493) -> Result<()> {
4494    let buffer_id = BufferId::from_proto(request.buffer_id);
4495    session
4496        .db()
4497        .await
4498        .observe_buffer_version(
4499            buffer_id,
4500            session.user_id(),
4501            request.epoch as i32,
4502            &request.version,
4503        )
4504        .await?;
4505    Ok(())
4506}
4507
4508struct CompleteWithLanguageModelRateLimit;
4509
4510impl RateLimit for CompleteWithLanguageModelRateLimit {
4511    fn capacity() -> usize {
4512        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4513            .ok()
4514            .and_then(|v| v.parse().ok())
4515            .unwrap_or(120) // Picked arbitrarily
4516    }
4517
4518    fn refill_duration() -> chrono::Duration {
4519        chrono::Duration::hours(1)
4520    }
4521
4522    fn db_name() -> &'static str {
4523        "complete-with-language-model"
4524    }
4525}
4526
4527async fn complete_with_language_model(
4528    request: proto::CompleteWithLanguageModel,
4529    response: Response<proto::CompleteWithLanguageModel>,
4530    session: Session,
4531    config: &Config,
4532) -> Result<()> {
4533    let Some(session) = session.for_user() else {
4534        return Err(anyhow!("user not found"))?;
4535    };
4536    authorize_access_to_language_models(&session).await?;
4537
4538    session
4539        .rate_limiter
4540        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4541        .await?;
4542
4543    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4544        Some(proto::LanguageModelProvider::Anthropic) => {
4545            let api_key = config
4546                .anthropic_api_key
4547                .as_ref()
4548                .context("no Anthropic AI API key configured on the server")?;
4549            anthropic::complete(
4550                session.http_client.as_ref(),
4551                anthropic::ANTHROPIC_API_URL,
4552                api_key,
4553                serde_json::from_str(&request.request)?,
4554            )
4555            .await?
4556        }
4557        _ => return Err(anyhow!("unsupported provider"))?,
4558    };
4559
4560    response.send(proto::CompleteWithLanguageModelResponse {
4561        completion: serde_json::to_string(&result)?,
4562    })?;
4563
4564    Ok(())
4565}
4566
4567async fn stream_complete_with_language_model(
4568    request: proto::StreamCompleteWithLanguageModel,
4569    response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4570    session: Session,
4571    config: &Config,
4572) -> Result<()> {
4573    let Some(session) = session.for_user() else {
4574        return Err(anyhow!("user not found"))?;
4575    };
4576    authorize_access_to_language_models(&session).await?;
4577
4578    session
4579        .rate_limiter
4580        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4581        .await?;
4582
4583    match proto::LanguageModelProvider::from_i32(request.provider) {
4584        Some(proto::LanguageModelProvider::Anthropic) => {
4585            let api_key = config
4586                .anthropic_api_key
4587                .as_ref()
4588                .context("no Anthropic AI API key configured on the server")?;
4589            let mut chunks = anthropic::stream_completion(
4590                session.http_client.as_ref(),
4591                anthropic::ANTHROPIC_API_URL,
4592                api_key,
4593                serde_json::from_str(&request.request)?,
4594                None,
4595            )
4596            .await?;
4597            while let Some(event) = chunks.next().await {
4598                let chunk = event?;
4599                response.send(proto::StreamCompleteWithLanguageModelResponse {
4600                    event: serde_json::to_string(&chunk)?,
4601                })?;
4602            }
4603        }
4604        Some(proto::LanguageModelProvider::OpenAi) => {
4605            let api_key = config
4606                .openai_api_key
4607                .as_ref()
4608                .context("no OpenAI API key configured on the server")?;
4609            let mut events = open_ai::stream_completion(
4610                session.http_client.as_ref(),
4611                open_ai::OPEN_AI_API_URL,
4612                api_key,
4613                serde_json::from_str(&request.request)?,
4614                None,
4615            )
4616            .await?;
4617            while let Some(event) = events.next().await {
4618                let event = event?;
4619                response.send(proto::StreamCompleteWithLanguageModelResponse {
4620                    event: serde_json::to_string(&event)?,
4621                })?;
4622            }
4623        }
4624        Some(proto::LanguageModelProvider::Google) => {
4625            let api_key = config
4626                .google_ai_api_key
4627                .as_ref()
4628                .context("no Google AI API key configured on the server")?;
4629            let mut events = google_ai::stream_generate_content(
4630                session.http_client.as_ref(),
4631                google_ai::API_URL,
4632                api_key,
4633                serde_json::from_str(&request.request)?,
4634            )
4635            .await?;
4636            while let Some(event) = events.next().await {
4637                let event = event?;
4638                response.send(proto::StreamCompleteWithLanguageModelResponse {
4639                    event: serde_json::to_string(&event)?,
4640                })?;
4641            }
4642        }
4643        None => return Err(anyhow!("unknown provider"))?,
4644    }
4645
4646    Ok(())
4647}
4648
4649async fn count_language_model_tokens(
4650    request: proto::CountLanguageModelTokens,
4651    response: Response<proto::CountLanguageModelTokens>,
4652    session: Session,
4653    config: &Config,
4654) -> Result<()> {
4655    let Some(session) = session.for_user() else {
4656        return Err(anyhow!("user not found"))?;
4657    };
4658    authorize_access_to_language_models(&session).await?;
4659
4660    session
4661        .rate_limiter
4662        .check::<CountLanguageModelTokensRateLimit>(session.user_id())
4663        .await?;
4664
4665    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4666        Some(proto::LanguageModelProvider::Google) => {
4667            let api_key = config
4668                .google_ai_api_key
4669                .as_ref()
4670                .context("no Google AI API key configured on the server")?;
4671            google_ai::count_tokens(
4672                session.http_client.as_ref(),
4673                google_ai::API_URL,
4674                api_key,
4675                serde_json::from_str(&request.request)?,
4676            )
4677            .await?
4678        }
4679        _ => return Err(anyhow!("unsupported provider"))?,
4680    };
4681
4682    response.send(proto::CountLanguageModelTokensResponse {
4683        token_count: result.total_tokens as u32,
4684    })?;
4685
4686    Ok(())
4687}
4688
4689struct CountLanguageModelTokensRateLimit;
4690
4691impl RateLimit for CountLanguageModelTokensRateLimit {
4692    fn capacity() -> usize {
4693        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4694            .ok()
4695            .and_then(|v| v.parse().ok())
4696            .unwrap_or(600) // Picked arbitrarily
4697    }
4698
4699    fn refill_duration() -> chrono::Duration {
4700        chrono::Duration::hours(1)
4701    }
4702
4703    fn db_name() -> &'static str {
4704        "count-language-model-tokens"
4705    }
4706}
4707
4708struct ComputeEmbeddingsRateLimit;
4709
4710impl RateLimit for ComputeEmbeddingsRateLimit {
4711    fn capacity() -> usize {
4712        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4713            .ok()
4714            .and_then(|v| v.parse().ok())
4715            .unwrap_or(5000) // Picked arbitrarily
4716    }
4717
4718    fn refill_duration() -> chrono::Duration {
4719        chrono::Duration::hours(1)
4720    }
4721
4722    fn db_name() -> &'static str {
4723        "compute-embeddings"
4724    }
4725}
4726
4727async fn compute_embeddings(
4728    request: proto::ComputeEmbeddings,
4729    response: Response<proto::ComputeEmbeddings>,
4730    session: UserSession,
4731    api_key: Option<Arc<str>>,
4732) -> Result<()> {
4733    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4734    authorize_access_to_language_models(&session).await?;
4735
4736    session
4737        .rate_limiter
4738        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4739        .await?;
4740
4741    let embeddings = match request.model.as_str() {
4742        "openai/text-embedding-3-small" => {
4743            open_ai::embed(
4744                session.http_client.as_ref(),
4745                OPEN_AI_API_URL,
4746                &api_key,
4747                OpenAiEmbeddingModel::TextEmbedding3Small,
4748                request.texts.iter().map(|text| text.as_str()),
4749            )
4750            .await?
4751        }
4752        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4753    };
4754
4755    let embeddings = request
4756        .texts
4757        .iter()
4758        .map(|text| {
4759            let mut hasher = sha2::Sha256::new();
4760            hasher.update(text.as_bytes());
4761            let result = hasher.finalize();
4762            result.to_vec()
4763        })
4764        .zip(
4765            embeddings
4766                .data
4767                .into_iter()
4768                .map(|embedding| embedding.embedding),
4769        )
4770        .collect::<HashMap<_, _>>();
4771
4772    let db = session.db().await;
4773    db.save_embeddings(&request.model, &embeddings)
4774        .await
4775        .context("failed to save embeddings")
4776        .trace_err();
4777
4778    response.send(proto::ComputeEmbeddingsResponse {
4779        embeddings: embeddings
4780            .into_iter()
4781            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4782            .collect(),
4783    })?;
4784    Ok(())
4785}
4786
4787async fn get_cached_embeddings(
4788    request: proto::GetCachedEmbeddings,
4789    response: Response<proto::GetCachedEmbeddings>,
4790    session: UserSession,
4791) -> Result<()> {
4792    authorize_access_to_language_models(&session).await?;
4793
4794    let db = session.db().await;
4795    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4796
4797    response.send(proto::GetCachedEmbeddingsResponse {
4798        embeddings: embeddings
4799            .into_iter()
4800            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4801            .collect(),
4802    })?;
4803    Ok(())
4804}
4805
4806async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4807    let db = session.db().await;
4808    let flags = db.get_user_flags(session.user_id()).await?;
4809    if flags.iter().any(|flag| flag == "language-models") {
4810        Ok(())
4811    } else {
4812        Err(anyhow!("permission denied"))?
4813    }
4814}
4815
4816/// Get a Supermaven API key for the user
4817async fn get_supermaven_api_key(
4818    _request: proto::GetSupermavenApiKey,
4819    response: Response<proto::GetSupermavenApiKey>,
4820    session: UserSession,
4821) -> Result<()> {
4822    let user_id: String = session.user_id().to_string();
4823    if !session.is_staff() {
4824        return Err(anyhow!("supermaven not enabled for this account"))?;
4825    }
4826
4827    let email = session
4828        .email()
4829        .ok_or_else(|| anyhow!("user must have an email"))?;
4830
4831    let supermaven_admin_api = session
4832        .supermaven_client
4833        .as_ref()
4834        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4835
4836    let result = supermaven_admin_api
4837        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4838        .await?;
4839
4840    response.send(proto::GetSupermavenApiKeyResponse {
4841        api_key: result.api_key,
4842    })?;
4843
4844    Ok(())
4845}
4846
4847/// Start receiving chat updates for a channel
4848async fn join_channel_chat(
4849    request: proto::JoinChannelChat,
4850    response: Response<proto::JoinChannelChat>,
4851    session: UserSession,
4852) -> Result<()> {
4853    let channel_id = ChannelId::from_proto(request.channel_id);
4854
4855    let db = session.db().await;
4856    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4857        .await?;
4858    let messages = db
4859        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4860        .await?;
4861    response.send(proto::JoinChannelChatResponse {
4862        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4863        messages,
4864    })?;
4865    Ok(())
4866}
4867
4868/// Stop receiving chat updates for a channel
4869async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4870    let channel_id = ChannelId::from_proto(request.channel_id);
4871    session
4872        .db()
4873        .await
4874        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4875        .await?;
4876    Ok(())
4877}
4878
4879/// Retrieve the chat history for a channel
4880async fn get_channel_messages(
4881    request: proto::GetChannelMessages,
4882    response: Response<proto::GetChannelMessages>,
4883    session: UserSession,
4884) -> Result<()> {
4885    let channel_id = ChannelId::from_proto(request.channel_id);
4886    let messages = session
4887        .db()
4888        .await
4889        .get_channel_messages(
4890            channel_id,
4891            session.user_id(),
4892            MESSAGE_COUNT_PER_PAGE,
4893            Some(MessageId::from_proto(request.before_message_id)),
4894        )
4895        .await?;
4896    response.send(proto::GetChannelMessagesResponse {
4897        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4898        messages,
4899    })?;
4900    Ok(())
4901}
4902
4903/// Retrieve specific chat messages
4904async fn get_channel_messages_by_id(
4905    request: proto::GetChannelMessagesById,
4906    response: Response<proto::GetChannelMessagesById>,
4907    session: UserSession,
4908) -> Result<()> {
4909    let message_ids = request
4910        .message_ids
4911        .iter()
4912        .map(|id| MessageId::from_proto(*id))
4913        .collect::<Vec<_>>();
4914    let messages = session
4915        .db()
4916        .await
4917        .get_channel_messages_by_id(session.user_id(), &message_ids)
4918        .await?;
4919    response.send(proto::GetChannelMessagesResponse {
4920        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4921        messages,
4922    })?;
4923    Ok(())
4924}
4925
4926/// Retrieve the current users notifications
4927async fn get_notifications(
4928    request: proto::GetNotifications,
4929    response: Response<proto::GetNotifications>,
4930    session: UserSession,
4931) -> Result<()> {
4932    let notifications = session
4933        .db()
4934        .await
4935        .get_notifications(
4936            session.user_id(),
4937            NOTIFICATION_COUNT_PER_PAGE,
4938            request
4939                .before_id
4940                .map(|id| db::NotificationId::from_proto(id)),
4941        )
4942        .await?;
4943    response.send(proto::GetNotificationsResponse {
4944        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4945        notifications,
4946    })?;
4947    Ok(())
4948}
4949
4950/// Mark notifications as read
4951async fn mark_notification_as_read(
4952    request: proto::MarkNotificationRead,
4953    response: Response<proto::MarkNotificationRead>,
4954    session: UserSession,
4955) -> Result<()> {
4956    let database = &session.db().await;
4957    let notifications = database
4958        .mark_notification_as_read_by_id(
4959            session.user_id(),
4960            NotificationId::from_proto(request.notification_id),
4961        )
4962        .await?;
4963    send_notifications(
4964        &*session.connection_pool().await,
4965        &session.peer,
4966        notifications,
4967    );
4968    response.send(proto::Ack {})?;
4969    Ok(())
4970}
4971
4972/// Get the current users information
4973async fn get_private_user_info(
4974    _request: proto::GetPrivateUserInfo,
4975    response: Response<proto::GetPrivateUserInfo>,
4976    session: UserSession,
4977) -> Result<()> {
4978    let db = session.db().await;
4979
4980    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4981    let user = db
4982        .get_user_by_id(session.user_id())
4983        .await?
4984        .ok_or_else(|| anyhow!("user not found"))?;
4985    let flags = db.get_user_flags(session.user_id()).await?;
4986
4987    response.send(proto::GetPrivateUserInfoResponse {
4988        metrics_id,
4989        staff: user.admin,
4990        flags,
4991    })?;
4992    Ok(())
4993}
4994
4995fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4996    let message = match message {
4997        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4998        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4999        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5000        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5001        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5002            code: frame.code.into(),
5003            reason: frame.reason,
5004        })),
5005        // We should never receive a frame while reading the message, according
5006        // to the `tungstenite` maintainers:
5007        //
5008        // > It cannot occur when you read messages from the WebSocket, but it
5009        // > can be used when you want to send the raw frames (e.g. you want to
5010        // > send the frames to the WebSocket without composing the full message first).
5011        // >
5012        // > — https://github.com/snapview/tungstenite-rs/issues/268
5013        TungsteniteMessage::Frame(_) => {
5014            bail!("received an unexpected frame while reading the message")
5015        }
5016    };
5017
5018    Ok(message)
5019}
5020
5021fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5022    match message {
5023        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5024        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5025        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5026        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5027        AxumMessage::Close(frame) => {
5028            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5029                code: frame.code.into(),
5030                reason: frame.reason,
5031            }))
5032        }
5033    }
5034}
5035
5036fn notify_membership_updated(
5037    connection_pool: &mut ConnectionPool,
5038    result: MembershipUpdated,
5039    user_id: UserId,
5040    peer: &Peer,
5041) {
5042    for membership in &result.new_channels.channel_memberships {
5043        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5044    }
5045    for channel_id in &result.removed_channels {
5046        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5047    }
5048
5049    let user_channels_update = proto::UpdateUserChannels {
5050        channel_memberships: result
5051            .new_channels
5052            .channel_memberships
5053            .iter()
5054            .map(|cm| proto::ChannelMembership {
5055                channel_id: cm.channel_id.to_proto(),
5056                role: cm.role.into(),
5057            })
5058            .collect(),
5059        ..Default::default()
5060    };
5061
5062    let mut update = build_channels_update(result.new_channels);
5063    update.delete_channels = result
5064        .removed_channels
5065        .into_iter()
5066        .map(|id| id.to_proto())
5067        .collect();
5068    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5069
5070    for connection_id in connection_pool.user_connection_ids(user_id) {
5071        peer.send(connection_id, user_channels_update.clone())
5072            .trace_err();
5073        peer.send(connection_id, update.clone()).trace_err();
5074    }
5075}
5076
5077fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5078    proto::UpdateUserChannels {
5079        channel_memberships: channels
5080            .channel_memberships
5081            .iter()
5082            .map(|m| proto::ChannelMembership {
5083                channel_id: m.channel_id.to_proto(),
5084                role: m.role.into(),
5085            })
5086            .collect(),
5087        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5088        observed_channel_message_id: channels.observed_channel_messages.clone(),
5089    }
5090}
5091
5092fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5093    let mut update = proto::UpdateChannels::default();
5094
5095    for channel in channels.channels {
5096        update.channels.push(channel.to_proto());
5097    }
5098
5099    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5100    update.latest_channel_message_ids = channels.latest_channel_messages;
5101
5102    for (channel_id, participants) in channels.channel_participants {
5103        update
5104            .channel_participants
5105            .push(proto::ChannelParticipants {
5106                channel_id: channel_id.to_proto(),
5107                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5108            });
5109    }
5110
5111    for channel in channels.invited_channels {
5112        update.channel_invitations.push(channel.to_proto());
5113    }
5114
5115    update.hosted_projects = channels.hosted_projects;
5116    update
5117}
5118
5119fn build_initial_contacts_update(
5120    contacts: Vec<db::Contact>,
5121    pool: &ConnectionPool,
5122) -> proto::UpdateContacts {
5123    let mut update = proto::UpdateContacts::default();
5124
5125    for contact in contacts {
5126        match contact {
5127            db::Contact::Accepted { user_id, busy } => {
5128                update.contacts.push(contact_for_user(user_id, busy, &pool));
5129            }
5130            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5131            db::Contact::Incoming { user_id } => {
5132                update
5133                    .incoming_requests
5134                    .push(proto::IncomingContactRequest {
5135                        requester_id: user_id.to_proto(),
5136                    })
5137            }
5138        }
5139    }
5140
5141    update
5142}
5143
5144fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5145    proto::Contact {
5146        user_id: user_id.to_proto(),
5147        online: pool.is_user_online(user_id),
5148        busy,
5149    }
5150}
5151
5152fn room_updated(room: &proto::Room, peer: &Peer) {
5153    broadcast(
5154        None,
5155        room.participants
5156            .iter()
5157            .filter_map(|participant| Some(participant.peer_id?.into())),
5158        |peer_id| {
5159            peer.send(
5160                peer_id,
5161                proto::RoomUpdated {
5162                    room: Some(room.clone()),
5163                },
5164            )
5165        },
5166    );
5167}
5168
5169fn channel_updated(
5170    channel: &db::channel::Model,
5171    room: &proto::Room,
5172    peer: &Peer,
5173    pool: &ConnectionPool,
5174) {
5175    let participants = room
5176        .participants
5177        .iter()
5178        .map(|p| p.user_id)
5179        .collect::<Vec<_>>();
5180
5181    broadcast(
5182        None,
5183        pool.channel_connection_ids(channel.root_id())
5184            .filter_map(|(channel_id, role)| {
5185                role.can_see_channel(channel.visibility).then(|| channel_id)
5186            }),
5187        |peer_id| {
5188            peer.send(
5189                peer_id,
5190                proto::UpdateChannels {
5191                    channel_participants: vec![proto::ChannelParticipants {
5192                        channel_id: channel.id.to_proto(),
5193                        participant_user_ids: participants.clone(),
5194                    }],
5195                    ..Default::default()
5196                },
5197            )
5198        },
5199    );
5200}
5201
5202async fn send_dev_server_projects_update(
5203    user_id: UserId,
5204    mut status: proto::DevServerProjectsUpdate,
5205    session: &Session,
5206) {
5207    let pool = session.connection_pool().await;
5208    for dev_server in &mut status.dev_servers {
5209        dev_server.status =
5210            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5211    }
5212    let connections = pool.user_connection_ids(user_id);
5213    for connection_id in connections {
5214        session.peer.send(connection_id, status.clone()).trace_err();
5215    }
5216}
5217
5218async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5219    let db = session.db().await;
5220
5221    let contacts = db.get_contacts(user_id).await?;
5222    let busy = db.is_user_busy(user_id).await?;
5223
5224    let pool = session.connection_pool().await;
5225    let updated_contact = contact_for_user(user_id, busy, &pool);
5226    for contact in contacts {
5227        if let db::Contact::Accepted {
5228            user_id: contact_user_id,
5229            ..
5230        } = contact
5231        {
5232            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5233                session
5234                    .peer
5235                    .send(
5236                        contact_conn_id,
5237                        proto::UpdateContacts {
5238                            contacts: vec![updated_contact.clone()],
5239                            remove_contacts: Default::default(),
5240                            incoming_requests: Default::default(),
5241                            remove_incoming_requests: Default::default(),
5242                            outgoing_requests: Default::default(),
5243                            remove_outgoing_requests: Default::default(),
5244                        },
5245                    )
5246                    .trace_err();
5247            }
5248        }
5249    }
5250    Ok(())
5251}
5252
5253async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5254    log::info!("lost dev server connection, unsharing projects");
5255    let project_ids = session
5256        .db()
5257        .await
5258        .get_stale_dev_server_projects(session.connection_id)
5259        .await?;
5260
5261    for project_id in project_ids {
5262        // not unshare re-checks the connection ids match, so we get away with no transaction
5263        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5264    }
5265
5266    let user_id = session.dev_server().user_id;
5267    let update = session
5268        .db()
5269        .await
5270        .dev_server_projects_update(user_id)
5271        .await?;
5272
5273    send_dev_server_projects_update(user_id, update, session).await;
5274
5275    Ok(())
5276}
5277
5278async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5279    let mut contacts_to_update = HashSet::default();
5280
5281    let room_id;
5282    let canceled_calls_to_user_ids;
5283    let live_kit_room;
5284    let delete_live_kit_room;
5285    let room;
5286    let channel;
5287
5288    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5289        contacts_to_update.insert(session.user_id());
5290
5291        for project in left_room.left_projects.values() {
5292            project_left(project, session);
5293        }
5294
5295        room_id = RoomId::from_proto(left_room.room.id);
5296        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5297        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5298        delete_live_kit_room = left_room.deleted;
5299        room = mem::take(&mut left_room.room);
5300        channel = mem::take(&mut left_room.channel);
5301
5302        room_updated(&room, &session.peer);
5303    } else {
5304        return Ok(());
5305    }
5306
5307    if let Some(channel) = channel {
5308        channel_updated(
5309            &channel,
5310            &room,
5311            &session.peer,
5312            &*session.connection_pool().await,
5313        );
5314    }
5315
5316    {
5317        let pool = session.connection_pool().await;
5318        for canceled_user_id in canceled_calls_to_user_ids {
5319            for connection_id in pool.user_connection_ids(canceled_user_id) {
5320                session
5321                    .peer
5322                    .send(
5323                        connection_id,
5324                        proto::CallCanceled {
5325                            room_id: room_id.to_proto(),
5326                        },
5327                    )
5328                    .trace_err();
5329            }
5330            contacts_to_update.insert(canceled_user_id);
5331        }
5332    }
5333
5334    for contact_user_id in contacts_to_update {
5335        update_user_contacts(contact_user_id, &session).await?;
5336    }
5337
5338    if let Some(live_kit) = session.live_kit_client.as_ref() {
5339        live_kit
5340            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5341            .await
5342            .trace_err();
5343
5344        if delete_live_kit_room {
5345            live_kit.delete_room(live_kit_room).await.trace_err();
5346        }
5347    }
5348
5349    Ok(())
5350}
5351
5352async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5353    let left_channel_buffers = session
5354        .db()
5355        .await
5356        .leave_channel_buffers(session.connection_id)
5357        .await?;
5358
5359    for left_buffer in left_channel_buffers {
5360        channel_buffer_updated(
5361            session.connection_id,
5362            left_buffer.connections,
5363            &proto::UpdateChannelBufferCollaborators {
5364                channel_id: left_buffer.channel_id.to_proto(),
5365                collaborators: left_buffer.collaborators,
5366            },
5367            &session.peer,
5368        );
5369    }
5370
5371    Ok(())
5372}
5373
5374fn project_left(project: &db::LeftProject, session: &UserSession) {
5375    for connection_id in &project.connection_ids {
5376        if project.should_unshare {
5377            session
5378                .peer
5379                .send(
5380                    *connection_id,
5381                    proto::UnshareProject {
5382                        project_id: project.id.to_proto(),
5383                    },
5384                )
5385                .trace_err();
5386        } else {
5387            session
5388                .peer
5389                .send(
5390                    *connection_id,
5391                    proto::RemoveProjectCollaborator {
5392                        project_id: project.id.to_proto(),
5393                        peer_id: Some(session.connection_id.into()),
5394                    },
5395                )
5396                .trace_err();
5397        }
5398    }
5399}
5400
5401pub trait ResultExt {
5402    type Ok;
5403
5404    fn trace_err(self) -> Option<Self::Ok>;
5405}
5406
5407impl<T, E> ResultExt for Result<T, E>
5408where
5409    E: std::fmt::Debug,
5410{
5411    type Ok = T;
5412
5413    #[track_caller]
5414    fn trace_err(self) -> Option<T> {
5415        match self {
5416            Ok(value) => Some(value),
5417            Err(error) => {
5418                tracing::error!("{:?}", error);
5419                None
5420            }
5421        }
5422    }
5423}