rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, bail, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http_client::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  50        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(update_dev_server_project))
 435            .add_request_handler(user_handler(delete_dev_server_project))
 436            .add_request_handler(user_handler(create_dev_server))
 437            .add_request_handler(user_handler(regenerate_dev_server_token))
 438            .add_request_handler(user_handler(rename_dev_server))
 439            .add_request_handler(user_handler(delete_dev_server))
 440            .add_request_handler(user_handler(list_remote_directory))
 441            .add_request_handler(dev_server_handler(share_dev_server_project))
 442            .add_request_handler(dev_server_handler(shutdown_dev_server))
 443            .add_request_handler(dev_server_handler(reconnect_dev_server))
 444            .add_message_handler(user_message_handler(leave_project))
 445            .add_request_handler(update_project)
 446            .add_request_handler(update_worktree)
 447            .add_message_handler(start_language_server)
 448            .add_message_handler(update_language_server)
 449            .add_message_handler(update_diagnostic_summary)
 450            .add_message_handler(update_worktree_settings)
 451            .add_request_handler(user_handler(
 452                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 453            ))
 454            .add_request_handler(user_handler(
 455                forward_project_request_for_owner::<proto::TaskTemplates>,
 456            ))
 457            .add_request_handler(user_handler(
 458                forward_read_only_project_request::<proto::GetHover>,
 459            ))
 460            .add_request_handler(user_handler(
 461                forward_read_only_project_request::<proto::GetDefinition>,
 462            ))
 463            .add_request_handler(user_handler(
 464                forward_read_only_project_request::<proto::GetTypeDefinition>,
 465            ))
 466            .add_request_handler(user_handler(
 467                forward_read_only_project_request::<proto::GetReferences>,
 468            ))
 469            .add_request_handler(user_handler(
 470                forward_read_only_project_request::<proto::SearchProject>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetProjectSymbols>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::OpenBufferById>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::InlayHints>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::OpenBufferByPath>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_mutating_project_request::<proto::GetCompletions>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::GetCodeActions>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::ApplyCodeAction>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::PrepareRename>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::PerformRename>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ReloadBuffers>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::FormatBuffers>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::CreateProjectEntry>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::RenameProjectEntry>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::CopyProjectEntry>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::OnTypeFormatting>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::BlameBuffer>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::MultiLspQuery>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::RestartLanguageServers>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::LinkedEditingRange>,
 555            ))
 556            .add_message_handler(create_buffer_for_peer)
 557            .add_request_handler(update_buffer)
 558            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 559            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 560            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 561            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 562            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 563            .add_request_handler(get_users)
 564            .add_request_handler(user_handler(fuzzy_search_users))
 565            .add_request_handler(user_handler(request_contact))
 566            .add_request_handler(user_handler(remove_contact))
 567            .add_request_handler(user_handler(respond_to_contact_request))
 568            .add_message_handler(subscribe_to_channels)
 569            .add_request_handler(user_handler(create_channel))
 570            .add_request_handler(user_handler(delete_channel))
 571            .add_request_handler(user_handler(invite_channel_member))
 572            .add_request_handler(user_handler(remove_channel_member))
 573            .add_request_handler(user_handler(set_channel_member_role))
 574            .add_request_handler(user_handler(set_channel_visibility))
 575            .add_request_handler(user_handler(rename_channel))
 576            .add_request_handler(user_handler(join_channel_buffer))
 577            .add_request_handler(user_handler(leave_channel_buffer))
 578            .add_message_handler(user_message_handler(update_channel_buffer))
 579            .add_request_handler(user_handler(rejoin_channel_buffers))
 580            .add_request_handler(user_handler(get_channel_members))
 581            .add_request_handler(user_handler(respond_to_channel_invite))
 582            .add_request_handler(user_handler(join_channel))
 583            .add_request_handler(user_handler(join_channel_chat))
 584            .add_message_handler(user_message_handler(leave_channel_chat))
 585            .add_request_handler(user_handler(send_channel_message))
 586            .add_request_handler(user_handler(remove_channel_message))
 587            .add_request_handler(user_handler(update_channel_message))
 588            .add_request_handler(user_handler(get_channel_messages))
 589            .add_request_handler(user_handler(get_channel_messages_by_id))
 590            .add_request_handler(user_handler(get_notifications))
 591            .add_request_handler(user_handler(mark_notification_as_read))
 592            .add_request_handler(user_handler(move_channel))
 593            .add_request_handler(user_handler(follow))
 594            .add_message_handler(user_message_handler(unfollow))
 595            .add_message_handler(user_message_handler(update_followers))
 596            .add_request_handler(user_handler(get_private_user_info))
 597            .add_message_handler(user_message_handler(acknowledge_channel_message))
 598            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 599            .add_request_handler(user_handler(get_supermaven_api_key))
 600            .add_request_handler(user_handler(
 601                forward_mutating_project_request::<proto::OpenContext>,
 602            ))
 603            .add_request_handler(user_handler(
 604                forward_mutating_project_request::<proto::SynchronizeContexts>,
 605            ))
 606            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 607            .add_message_handler(update_context)
 608            .add_streaming_request_handler({
 609                let app_state = app_state.clone();
 610                move |request, response, session| {
 611                    complete_with_language_model(
 612                        request,
 613                        response,
 614                        session,
 615                        app_state.config.openai_api_key.clone(),
 616                        app_state.config.google_ai_api_key.clone(),
 617                        app_state.config.anthropic_api_key.clone(),
 618                    )
 619                }
 620            })
 621            .add_request_handler({
 622                user_handler(move |request, response, session| {
 623                    get_cached_embeddings(request, response, session)
 624                })
 625            })
 626            .add_request_handler({
 627                let app_state = app_state.clone();
 628                user_handler(move |request, response, session| {
 629                    compute_embeddings(
 630                        request,
 631                        response,
 632                        session,
 633                        app_state.config.openai_api_key.clone(),
 634                    )
 635                })
 636            });
 637
 638        Arc::new(server)
 639    }
 640
 641    pub async fn start(&self) -> Result<()> {
 642        let server_id = *self.id.lock();
 643        let app_state = self.app_state.clone();
 644        let peer = self.peer.clone();
 645        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 646        let pool = self.connection_pool.clone();
 647        let live_kit_client = self.app_state.live_kit_client.clone();
 648
 649        let span = info_span!("start server");
 650        self.app_state.executor.spawn_detached(
 651            async move {
 652                tracing::info!("waiting for cleanup timeout");
 653                timeout.await;
 654                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 655                if let Some((room_ids, channel_ids)) = app_state
 656                    .db
 657                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 658                    .await
 659                    .trace_err()
 660                {
 661                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 662                    tracing::info!(
 663                        stale_channel_buffer_count = channel_ids.len(),
 664                        "retrieved stale channel buffers"
 665                    );
 666
 667                    for channel_id in channel_ids {
 668                        if let Some(refreshed_channel_buffer) = app_state
 669                            .db
 670                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 671                            .await
 672                            .trace_err()
 673                        {
 674                            for connection_id in refreshed_channel_buffer.connection_ids {
 675                                peer.send(
 676                                    connection_id,
 677                                    proto::UpdateChannelBufferCollaborators {
 678                                        channel_id: channel_id.to_proto(),
 679                                        collaborators: refreshed_channel_buffer
 680                                            .collaborators
 681                                            .clone(),
 682                                    },
 683                                )
 684                                .trace_err();
 685                            }
 686                        }
 687                    }
 688
 689                    for room_id in room_ids {
 690                        let mut contacts_to_update = HashSet::default();
 691                        let mut canceled_calls_to_user_ids = Vec::new();
 692                        let mut live_kit_room = String::new();
 693                        let mut delete_live_kit_room = false;
 694
 695                        if let Some(mut refreshed_room) = app_state
 696                            .db
 697                            .clear_stale_room_participants(room_id, server_id)
 698                            .await
 699                            .trace_err()
 700                        {
 701                            tracing::info!(
 702                                room_id = room_id.0,
 703                                new_participant_count = refreshed_room.room.participants.len(),
 704                                "refreshed room"
 705                            );
 706                            room_updated(&refreshed_room.room, &peer);
 707                            if let Some(channel) = refreshed_room.channel.as_ref() {
 708                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 709                            }
 710                            contacts_to_update
 711                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 712                            contacts_to_update
 713                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 714                            canceled_calls_to_user_ids =
 715                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 716                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 717                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 718                        }
 719
 720                        {
 721                            let pool = pool.lock();
 722                            for canceled_user_id in canceled_calls_to_user_ids {
 723                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 724                                    peer.send(
 725                                        connection_id,
 726                                        proto::CallCanceled {
 727                                            room_id: room_id.to_proto(),
 728                                        },
 729                                    )
 730                                    .trace_err();
 731                                }
 732                            }
 733                        }
 734
 735                        for user_id in contacts_to_update {
 736                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 737                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 738                            if let Some((busy, contacts)) = busy.zip(contacts) {
 739                                let pool = pool.lock();
 740                                let updated_contact = contact_for_user(user_id, busy, &pool);
 741                                for contact in contacts {
 742                                    if let db::Contact::Accepted {
 743                                        user_id: contact_user_id,
 744                                        ..
 745                                    } = contact
 746                                    {
 747                                        for contact_conn_id in
 748                                            pool.user_connection_ids(contact_user_id)
 749                                        {
 750                                            peer.send(
 751                                                contact_conn_id,
 752                                                proto::UpdateContacts {
 753                                                    contacts: vec![updated_contact.clone()],
 754                                                    remove_contacts: Default::default(),
 755                                                    incoming_requests: Default::default(),
 756                                                    remove_incoming_requests: Default::default(),
 757                                                    outgoing_requests: Default::default(),
 758                                                    remove_outgoing_requests: Default::default(),
 759                                                },
 760                                            )
 761                                            .trace_err();
 762                                        }
 763                                    }
 764                                }
 765                            }
 766                        }
 767
 768                        if let Some(live_kit) = live_kit_client.as_ref() {
 769                            if delete_live_kit_room {
 770                                live_kit.delete_room(live_kit_room).await.trace_err();
 771                            }
 772                        }
 773                    }
 774                }
 775
 776                app_state
 777                    .db
 778                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 779                    .await
 780                    .trace_err();
 781            }
 782            .instrument(span),
 783        );
 784        Ok(())
 785    }
 786
 787    pub fn teardown(&self) {
 788        self.peer.teardown();
 789        self.connection_pool.lock().reset();
 790        let _ = self.teardown.send(true);
 791    }
 792
 793    #[cfg(test)]
 794    pub fn reset(&self, id: ServerId) {
 795        self.teardown();
 796        *self.id.lock() = id;
 797        self.peer.reset(id.0 as u32);
 798        let _ = self.teardown.send(false);
 799    }
 800
 801    #[cfg(test)]
 802    pub fn id(&self) -> ServerId {
 803        *self.id.lock()
 804    }
 805
 806    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 807    where
 808        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 809        Fut: 'static + Send + Future<Output = Result<()>>,
 810        M: EnvelopedMessage,
 811    {
 812        let prev_handler = self.handlers.insert(
 813            TypeId::of::<M>(),
 814            Box::new(move |envelope, session| {
 815                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 816                let received_at = envelope.received_at;
 817                tracing::info!("message received");
 818                let start_time = Instant::now();
 819                let future = (handler)(*envelope, session);
 820                async move {
 821                    let result = future.await;
 822                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 823                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 824                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 825                    let payload_type = M::NAME;
 826
 827                    match result {
 828                        Err(error) => {
 829                            tracing::error!(
 830                                ?error,
 831                                total_duration_ms,
 832                                processing_duration_ms,
 833                                queue_duration_ms,
 834                                payload_type,
 835                                "error handling message"
 836                            )
 837                        }
 838                        Ok(()) => tracing::info!(
 839                            total_duration_ms,
 840                            processing_duration_ms,
 841                            queue_duration_ms,
 842                            "finished handling message"
 843                        ),
 844                    }
 845                }
 846                .boxed()
 847            }),
 848        );
 849        if prev_handler.is_some() {
 850            panic!("registered a handler for the same message twice");
 851        }
 852        self
 853    }
 854
 855    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 856    where
 857        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 858        Fut: 'static + Send + Future<Output = Result<()>>,
 859        M: EnvelopedMessage,
 860    {
 861        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 862        self
 863    }
 864
 865    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 866    where
 867        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 868        Fut: Send + Future<Output = Result<()>>,
 869        M: RequestMessage,
 870    {
 871        let handler = Arc::new(handler);
 872        self.add_handler(move |envelope, session| {
 873            let receipt = envelope.receipt();
 874            let handler = handler.clone();
 875            async move {
 876                let peer = session.peer.clone();
 877                let responded = Arc::new(AtomicBool::default());
 878                let response = Response {
 879                    peer: peer.clone(),
 880                    responded: responded.clone(),
 881                    receipt,
 882                };
 883                match (handler)(envelope.payload, response, session).await {
 884                    Ok(()) => {
 885                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 886                            Ok(())
 887                        } else {
 888                            Err(anyhow!("handler did not send a response"))?
 889                        }
 890                    }
 891                    Err(error) => {
 892                        let proto_err = match &error {
 893                            Error::Internal(err) => err.to_proto(),
 894                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 895                        };
 896                        peer.respond_with_error(receipt, proto_err)?;
 897                        Err(error)
 898                    }
 899                }
 900            }
 901        })
 902    }
 903
 904    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 905    where
 906        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 907        Fut: Send + Future<Output = Result<()>>,
 908        M: RequestMessage,
 909    {
 910        let handler = Arc::new(handler);
 911        self.add_handler(move |envelope, session| {
 912            let receipt = envelope.receipt();
 913            let handler = handler.clone();
 914            async move {
 915                let peer = session.peer.clone();
 916                let response = StreamingResponse {
 917                    peer: peer.clone(),
 918                    receipt,
 919                };
 920                match (handler)(envelope.payload, response, session).await {
 921                    Ok(()) => {
 922                        peer.end_stream(receipt)?;
 923                        Ok(())
 924                    }
 925                    Err(error) => {
 926                        let proto_err = match &error {
 927                            Error::Internal(err) => err.to_proto(),
 928                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 929                        };
 930                        peer.respond_with_error(receipt, proto_err)?;
 931                        Err(error)
 932                    }
 933                }
 934            }
 935        })
 936    }
 937
 938    #[allow(clippy::too_many_arguments)]
 939    pub fn handle_connection(
 940        self: &Arc<Self>,
 941        connection: Connection,
 942        address: String,
 943        principal: Principal,
 944        zed_version: ZedVersion,
 945        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 946        executor: Executor,
 947    ) -> impl Future<Output = ()> {
 948        let this = self.clone();
 949        let span = info_span!("handle connection", %address,
 950            connection_id=field::Empty,
 951            user_id=field::Empty,
 952            login=field::Empty,
 953            impersonator=field::Empty,
 954            dev_server_id=field::Empty
 955        );
 956        principal.update_span(&span);
 957
 958        let mut teardown = self.teardown.subscribe();
 959        async move {
 960            if *teardown.borrow() {
 961                tracing::error!("server is tearing down");
 962                return
 963            }
 964            let (connection_id, handle_io, mut incoming_rx) = this
 965                .peer
 966                .add_connection(connection, {
 967                    let executor = executor.clone();
 968                    move |duration| executor.sleep(duration)
 969                });
 970            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 971            tracing::info!("connection opened");
 972
 973            let http_client = match IsahcHttpClient::new() {
 974                Ok(http_client) => Arc::new(http_client),
 975                Err(error) => {
 976                    tracing::error!(?error, "failed to create HTTP client");
 977                    return;
 978                }
 979            };
 980
 981            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 982                Some(Arc::new(SupermavenAdminApi::new(
 983                    supermaven_admin_api_key.to_string(),
 984                    http_client.clone(),
 985                )))
 986            } else {
 987                None
 988            };
 989
 990            let session = Session {
 991                principal: principal.clone(),
 992                connection_id,
 993                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 994                peer: this.peer.clone(),
 995                connection_pool: this.connection_pool.clone(),
 996                live_kit_client: this.app_state.live_kit_client.clone(),
 997                http_client,
 998                rate_limiter: this.app_state.rate_limiter.clone(),
 999                _executor: executor.clone(),
1000                supermaven_client,
1001            };
1002
1003            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1004                tracing::error!(?error, "failed to send initial client update");
1005                return;
1006            }
1007
1008            let handle_io = handle_io.fuse();
1009            futures::pin_mut!(handle_io);
1010
1011            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1012            // This prevents deadlocks when e.g., client A performs a request to client B and
1013            // client B performs a request to client A. If both clients stop processing further
1014            // messages until their respective request completes, they won't have a chance to
1015            // respond to the other client's request and cause a deadlock.
1016            //
1017            // This arrangement ensures we will attempt to process earlier messages first, but fall
1018            // back to processing messages arrived later in the spirit of making progress.
1019            let mut foreground_message_handlers = FuturesUnordered::new();
1020            let concurrent_handlers = Arc::new(Semaphore::new(256));
1021            loop {
1022                let next_message = async {
1023                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1024                    let message = incoming_rx.next().await;
1025                    (permit, message)
1026                }.fuse();
1027                futures::pin_mut!(next_message);
1028                futures::select_biased! {
1029                    _ = teardown.changed().fuse() => return,
1030                    result = handle_io => {
1031                        if let Err(error) = result {
1032                            tracing::error!(?error, "error handling I/O");
1033                        }
1034                        break;
1035                    }
1036                    _ = foreground_message_handlers.next() => {}
1037                    next_message = next_message => {
1038                        let (permit, message) = next_message;
1039                        if let Some(message) = message {
1040                            let type_name = message.payload_type_name();
1041                            // note: we copy all the fields from the parent span so we can query them in the logs.
1042                            // (https://github.com/tokio-rs/tracing/issues/2670).
1043                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1044                                user_id=field::Empty,
1045                                login=field::Empty,
1046                                impersonator=field::Empty,
1047                                dev_server_id=field::Empty
1048                            );
1049                            principal.update_span(&span);
1050                            let span_enter = span.enter();
1051                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1052                                let is_background = message.is_background();
1053                                let handle_message = (handler)(message, session.clone());
1054                                drop(span_enter);
1055
1056                                let handle_message = async move {
1057                                    handle_message.await;
1058                                    drop(permit);
1059                                }.instrument(span);
1060                                if is_background {
1061                                    executor.spawn_detached(handle_message);
1062                                } else {
1063                                    foreground_message_handlers.push(handle_message);
1064                                }
1065                            } else {
1066                                tracing::error!("no message handler");
1067                            }
1068                        } else {
1069                            tracing::info!("connection closed");
1070                            break;
1071                        }
1072                    }
1073                }
1074            }
1075
1076            drop(foreground_message_handlers);
1077            tracing::info!("signing out");
1078            if let Err(error) = connection_lost(session, teardown, executor).await {
1079                tracing::error!(?error, "error signing out");
1080            }
1081
1082        }.instrument(span)
1083    }
1084
1085    async fn send_initial_client_update(
1086        &self,
1087        connection_id: ConnectionId,
1088        principal: &Principal,
1089        zed_version: ZedVersion,
1090        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1091        session: &Session,
1092    ) -> Result<()> {
1093        self.peer.send(
1094            connection_id,
1095            proto::Hello {
1096                peer_id: Some(connection_id.into()),
1097            },
1098        )?;
1099        tracing::info!("sent hello message");
1100        if let Some(send_connection_id) = send_connection_id.take() {
1101            let _ = send_connection_id.send(connection_id);
1102        }
1103
1104        match principal {
1105            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1106                if !user.connected_once {
1107                    self.peer.send(connection_id, proto::ShowContacts {})?;
1108                    self.app_state
1109                        .db
1110                        .set_user_connected_once(user.id, true)
1111                        .await?;
1112                }
1113
1114                let (contacts, dev_server_projects) = future::try_join(
1115                    self.app_state.db.get_contacts(user.id),
1116                    self.app_state.db.dev_server_projects_update(user.id),
1117                )
1118                .await?;
1119
1120                {
1121                    let mut pool = self.connection_pool.lock();
1122                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1123                    self.peer.send(
1124                        connection_id,
1125                        build_initial_contacts_update(contacts, &pool),
1126                    )?;
1127                }
1128
1129                if should_auto_subscribe_to_channels(zed_version) {
1130                    subscribe_user_to_channels(user.id, session).await?;
1131                }
1132
1133                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1134
1135                if let Some(incoming_call) =
1136                    self.app_state.db.incoming_call_for_user(user.id).await?
1137                {
1138                    self.peer.send(connection_id, incoming_call)?;
1139                }
1140
1141                update_user_contacts(user.id, &session).await?;
1142            }
1143            Principal::DevServer(dev_server) => {
1144                {
1145                    let mut pool = self.connection_pool.lock();
1146                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1147                    {
1148                        self.peer.send(
1149                            stale_connection_id,
1150                            proto::ShutdownDevServer {
1151                                reason: Some(
1152                                    "another dev server connected with the same token".to_string(),
1153                                ),
1154                            },
1155                        )?;
1156                        pool.remove_connection(stale_connection_id)?;
1157                    };
1158                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1159                }
1160
1161                let projects = self
1162                    .app_state
1163                    .db
1164                    .get_projects_for_dev_server(dev_server.id)
1165                    .await?;
1166                self.peer
1167                    .send(connection_id, proto::DevServerInstructions { projects })?;
1168
1169                let status = self
1170                    .app_state
1171                    .db
1172                    .dev_server_projects_update(dev_server.user_id)
1173                    .await?;
1174                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1175            }
1176        }
1177
1178        Ok(())
1179    }
1180
1181    pub async fn invite_code_redeemed(
1182        self: &Arc<Self>,
1183        inviter_id: UserId,
1184        invitee_id: UserId,
1185    ) -> Result<()> {
1186        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1187            if let Some(code) = &user.invite_code {
1188                let pool = self.connection_pool.lock();
1189                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1190                for connection_id in pool.user_connection_ids(inviter_id) {
1191                    self.peer.send(
1192                        connection_id,
1193                        proto::UpdateContacts {
1194                            contacts: vec![invitee_contact.clone()],
1195                            ..Default::default()
1196                        },
1197                    )?;
1198                    self.peer.send(
1199                        connection_id,
1200                        proto::UpdateInviteInfo {
1201                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1202                            count: user.invite_count as u32,
1203                        },
1204                    )?;
1205                }
1206            }
1207        }
1208        Ok(())
1209    }
1210
1211    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1212        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1213            if let Some(invite_code) = &user.invite_code {
1214                let pool = self.connection_pool.lock();
1215                for connection_id in pool.user_connection_ids(user_id) {
1216                    self.peer.send(
1217                        connection_id,
1218                        proto::UpdateInviteInfo {
1219                            url: format!(
1220                                "{}{}",
1221                                self.app_state.config.invite_link_prefix, invite_code
1222                            ),
1223                            count: user.invite_count as u32,
1224                        },
1225                    )?;
1226                }
1227            }
1228        }
1229        Ok(())
1230    }
1231
1232    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1233        ServerSnapshot {
1234            connection_pool: ConnectionPoolGuard {
1235                guard: self.connection_pool.lock(),
1236                _not_send: PhantomData,
1237            },
1238            peer: &self.peer,
1239        }
1240    }
1241}
1242
1243impl<'a> Deref for ConnectionPoolGuard<'a> {
1244    type Target = ConnectionPool;
1245
1246    fn deref(&self) -> &Self::Target {
1247        &self.guard
1248    }
1249}
1250
1251impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1252    fn deref_mut(&mut self) -> &mut Self::Target {
1253        &mut self.guard
1254    }
1255}
1256
1257impl<'a> Drop for ConnectionPoolGuard<'a> {
1258    fn drop(&mut self) {
1259        #[cfg(test)]
1260        self.check_invariants();
1261    }
1262}
1263
1264fn broadcast<F>(
1265    sender_id: Option<ConnectionId>,
1266    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1267    mut f: F,
1268) where
1269    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1270{
1271    for receiver_id in receiver_ids {
1272        if Some(receiver_id) != sender_id {
1273            if let Err(error) = f(receiver_id) {
1274                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1275            }
1276        }
1277    }
1278}
1279
1280pub struct ProtocolVersion(u32);
1281
1282impl Header for ProtocolVersion {
1283    fn name() -> &'static HeaderName {
1284        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1285        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1286    }
1287
1288    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1289    where
1290        Self: Sized,
1291        I: Iterator<Item = &'i axum::http::HeaderValue>,
1292    {
1293        let version = values
1294            .next()
1295            .ok_or_else(axum::headers::Error::invalid)?
1296            .to_str()
1297            .map_err(|_| axum::headers::Error::invalid())?
1298            .parse()
1299            .map_err(|_| axum::headers::Error::invalid())?;
1300        Ok(Self(version))
1301    }
1302
1303    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1304        values.extend([self.0.to_string().parse().unwrap()]);
1305    }
1306}
1307
1308pub struct AppVersionHeader(SemanticVersion);
1309impl Header for AppVersionHeader {
1310    fn name() -> &'static HeaderName {
1311        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1312        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1313    }
1314
1315    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1316    where
1317        Self: Sized,
1318        I: Iterator<Item = &'i axum::http::HeaderValue>,
1319    {
1320        let version = values
1321            .next()
1322            .ok_or_else(axum::headers::Error::invalid)?
1323            .to_str()
1324            .map_err(|_| axum::headers::Error::invalid())?
1325            .parse()
1326            .map_err(|_| axum::headers::Error::invalid())?;
1327        Ok(Self(version))
1328    }
1329
1330    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1331        values.extend([self.0.to_string().parse().unwrap()]);
1332    }
1333}
1334
1335pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1336    Router::new()
1337        .route("/rpc", get(handle_websocket_request))
1338        .layer(
1339            ServiceBuilder::new()
1340                .layer(Extension(server.app_state.clone()))
1341                .layer(middleware::from_fn(auth::validate_header)),
1342        )
1343        .route("/metrics", get(handle_metrics))
1344        .layer(Extension(server))
1345}
1346
1347pub async fn handle_websocket_request(
1348    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1349    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1350    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1351    Extension(server): Extension<Arc<Server>>,
1352    Extension(principal): Extension<Principal>,
1353    ws: WebSocketUpgrade,
1354) -> axum::response::Response {
1355    if protocol_version != rpc::PROTOCOL_VERSION {
1356        return (
1357            StatusCode::UPGRADE_REQUIRED,
1358            "client must be upgraded".to_string(),
1359        )
1360            .into_response();
1361    }
1362
1363    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1364        return (
1365            StatusCode::UPGRADE_REQUIRED,
1366            "no version header found".to_string(),
1367        )
1368            .into_response();
1369    };
1370
1371    if !version.can_collaborate() {
1372        return (
1373            StatusCode::UPGRADE_REQUIRED,
1374            "client must be upgraded".to_string(),
1375        )
1376            .into_response();
1377    }
1378
1379    let socket_address = socket_address.to_string();
1380    ws.on_upgrade(move |socket| {
1381        let socket = socket
1382            .map_ok(to_tungstenite_message)
1383            .err_into()
1384            .with(|message| async move { to_axum_message(message) });
1385        let connection = Connection::new(Box::pin(socket));
1386        async move {
1387            server
1388                .handle_connection(
1389                    connection,
1390                    socket_address,
1391                    principal,
1392                    version,
1393                    None,
1394                    Executor::Production,
1395                )
1396                .await;
1397        }
1398    })
1399}
1400
1401pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1402    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403    let connections_metric = CONNECTIONS_METRIC
1404        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1405
1406    let connections = server
1407        .connection_pool
1408        .lock()
1409        .connections()
1410        .filter(|connection| !connection.admin)
1411        .count();
1412    connections_metric.set(connections as _);
1413
1414    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1415    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1416        register_int_gauge!(
1417            "shared_projects",
1418            "number of open projects with one or more guests"
1419        )
1420        .unwrap()
1421    });
1422
1423    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1424    shared_projects_metric.set(shared_projects as _);
1425
1426    let encoder = prometheus::TextEncoder::new();
1427    let metric_families = prometheus::gather();
1428    let encoded_metrics = encoder
1429        .encode_to_string(&metric_families)
1430        .map_err(|err| anyhow!("{}", err))?;
1431    Ok(encoded_metrics)
1432}
1433
1434#[instrument(err, skip(executor))]
1435async fn connection_lost(
1436    session: Session,
1437    mut teardown: watch::Receiver<bool>,
1438    executor: Executor,
1439) -> Result<()> {
1440    session.peer.disconnect(session.connection_id);
1441    session
1442        .connection_pool()
1443        .await
1444        .remove_connection(session.connection_id)?;
1445
1446    session
1447        .db()
1448        .await
1449        .connection_lost(session.connection_id)
1450        .await
1451        .trace_err();
1452
1453    futures::select_biased! {
1454        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1455            match &session.principal {
1456                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1457                    let session = session.for_user().unwrap();
1458
1459                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1460                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1461                    leave_channel_buffers_for_session(&session)
1462                        .await
1463                        .trace_err();
1464
1465                    if !session
1466                        .connection_pool()
1467                        .await
1468                        .is_user_online(session.user_id())
1469                    {
1470                        let db = session.db().await;
1471                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1472                            room_updated(&room, &session.peer);
1473                        }
1474                    }
1475
1476                    update_user_contacts(session.user_id(), &session).await?;
1477                },
1478            Principal::DevServer(_) => {
1479                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1480            },
1481        }
1482        },
1483        _ = teardown.changed().fuse() => {}
1484    }
1485
1486    Ok(())
1487}
1488
1489/// Acknowledges a ping from a client, used to keep the connection alive.
1490async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1491    response.send(proto::Ack {})?;
1492    Ok(())
1493}
1494
1495/// Creates a new room for calling (outside of channels)
1496async fn create_room(
1497    _request: proto::CreateRoom,
1498    response: Response<proto::CreateRoom>,
1499    session: UserSession,
1500) -> Result<()> {
1501    let live_kit_room = nanoid::nanoid!(30);
1502
1503    let live_kit_connection_info = util::maybe!(async {
1504        let live_kit = session.live_kit_client.as_ref();
1505        let live_kit = live_kit?;
1506        let user_id = session.user_id().to_string();
1507
1508        let token = live_kit
1509            .room_token(&live_kit_room, &user_id.to_string())
1510            .trace_err()?;
1511
1512        Some(proto::LiveKitConnectionInfo {
1513            server_url: live_kit.url().into(),
1514            token,
1515            can_publish: true,
1516        })
1517    })
1518    .await;
1519
1520    let room = session
1521        .db()
1522        .await
1523        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1524        .await?;
1525
1526    response.send(proto::CreateRoomResponse {
1527        room: Some(room.clone()),
1528        live_kit_connection_info,
1529    })?;
1530
1531    update_user_contacts(session.user_id(), &session).await?;
1532    Ok(())
1533}
1534
1535/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1536async fn join_room(
1537    request: proto::JoinRoom,
1538    response: Response<proto::JoinRoom>,
1539    session: UserSession,
1540) -> Result<()> {
1541    let room_id = RoomId::from_proto(request.id);
1542
1543    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1544
1545    if let Some(channel_id) = channel_id {
1546        return join_channel_internal(channel_id, Box::new(response), session).await;
1547    }
1548
1549    let joined_room = {
1550        let room = session
1551            .db()
1552            .await
1553            .join_room(room_id, session.user_id(), session.connection_id)
1554            .await?;
1555        room_updated(&room.room, &session.peer);
1556        room.into_inner()
1557    };
1558
1559    for connection_id in session
1560        .connection_pool()
1561        .await
1562        .user_connection_ids(session.user_id())
1563    {
1564        session
1565            .peer
1566            .send(
1567                connection_id,
1568                proto::CallCanceled {
1569                    room_id: room_id.to_proto(),
1570                },
1571            )
1572            .trace_err();
1573    }
1574
1575    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1576        if let Some(token) = live_kit
1577            .room_token(
1578                &joined_room.room.live_kit_room,
1579                &session.user_id().to_string(),
1580            )
1581            .trace_err()
1582        {
1583            Some(proto::LiveKitConnectionInfo {
1584                server_url: live_kit.url().into(),
1585                token,
1586                can_publish: true,
1587            })
1588        } else {
1589            None
1590        }
1591    } else {
1592        None
1593    };
1594
1595    response.send(proto::JoinRoomResponse {
1596        room: Some(joined_room.room),
1597        channel_id: None,
1598        live_kit_connection_info,
1599    })?;
1600
1601    update_user_contacts(session.user_id(), &session).await?;
1602    Ok(())
1603}
1604
1605/// Rejoin room is used to reconnect to a room after connection errors.
1606async fn rejoin_room(
1607    request: proto::RejoinRoom,
1608    response: Response<proto::RejoinRoom>,
1609    session: UserSession,
1610) -> Result<()> {
1611    let room;
1612    let channel;
1613    {
1614        let mut rejoined_room = session
1615            .db()
1616            .await
1617            .rejoin_room(request, session.user_id(), session.connection_id)
1618            .await?;
1619
1620        response.send(proto::RejoinRoomResponse {
1621            room: Some(rejoined_room.room.clone()),
1622            reshared_projects: rejoined_room
1623                .reshared_projects
1624                .iter()
1625                .map(|project| proto::ResharedProject {
1626                    id: project.id.to_proto(),
1627                    collaborators: project
1628                        .collaborators
1629                        .iter()
1630                        .map(|collaborator| collaborator.to_proto())
1631                        .collect(),
1632                })
1633                .collect(),
1634            rejoined_projects: rejoined_room
1635                .rejoined_projects
1636                .iter()
1637                .map(|rejoined_project| rejoined_project.to_proto())
1638                .collect(),
1639        })?;
1640        room_updated(&rejoined_room.room, &session.peer);
1641
1642        for project in &rejoined_room.reshared_projects {
1643            for collaborator in &project.collaborators {
1644                session
1645                    .peer
1646                    .send(
1647                        collaborator.connection_id,
1648                        proto::UpdateProjectCollaborator {
1649                            project_id: project.id.to_proto(),
1650                            old_peer_id: Some(project.old_connection_id.into()),
1651                            new_peer_id: Some(session.connection_id.into()),
1652                        },
1653                    )
1654                    .trace_err();
1655            }
1656
1657            broadcast(
1658                Some(session.connection_id),
1659                project
1660                    .collaborators
1661                    .iter()
1662                    .map(|collaborator| collaborator.connection_id),
1663                |connection_id| {
1664                    session.peer.forward_send(
1665                        session.connection_id,
1666                        connection_id,
1667                        proto::UpdateProject {
1668                            project_id: project.id.to_proto(),
1669                            worktrees: project.worktrees.clone(),
1670                        },
1671                    )
1672                },
1673            );
1674        }
1675
1676        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1677
1678        let rejoined_room = rejoined_room.into_inner();
1679
1680        room = rejoined_room.room;
1681        channel = rejoined_room.channel;
1682    }
1683
1684    if let Some(channel) = channel {
1685        channel_updated(
1686            &channel,
1687            &room,
1688            &session.peer,
1689            &*session.connection_pool().await,
1690        );
1691    }
1692
1693    update_user_contacts(session.user_id(), &session).await?;
1694    Ok(())
1695}
1696
1697fn notify_rejoined_projects(
1698    rejoined_projects: &mut Vec<RejoinedProject>,
1699    session: &UserSession,
1700) -> Result<()> {
1701    for project in rejoined_projects.iter() {
1702        for collaborator in &project.collaborators {
1703            session
1704                .peer
1705                .send(
1706                    collaborator.connection_id,
1707                    proto::UpdateProjectCollaborator {
1708                        project_id: project.id.to_proto(),
1709                        old_peer_id: Some(project.old_connection_id.into()),
1710                        new_peer_id: Some(session.connection_id.into()),
1711                    },
1712                )
1713                .trace_err();
1714        }
1715    }
1716
1717    for project in rejoined_projects {
1718        for worktree in mem::take(&mut project.worktrees) {
1719            #[cfg(any(test, feature = "test-support"))]
1720            const MAX_CHUNK_SIZE: usize = 2;
1721            #[cfg(not(any(test, feature = "test-support")))]
1722            const MAX_CHUNK_SIZE: usize = 256;
1723
1724            // Stream this worktree's entries.
1725            let message = proto::UpdateWorktree {
1726                project_id: project.id.to_proto(),
1727                worktree_id: worktree.id,
1728                abs_path: worktree.abs_path.clone(),
1729                root_name: worktree.root_name,
1730                updated_entries: worktree.updated_entries,
1731                removed_entries: worktree.removed_entries,
1732                scan_id: worktree.scan_id,
1733                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1734                updated_repositories: worktree.updated_repositories,
1735                removed_repositories: worktree.removed_repositories,
1736            };
1737            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1738                session.peer.send(session.connection_id, update.clone())?;
1739            }
1740
1741            // Stream this worktree's diagnostics.
1742            for summary in worktree.diagnostic_summaries {
1743                session.peer.send(
1744                    session.connection_id,
1745                    proto::UpdateDiagnosticSummary {
1746                        project_id: project.id.to_proto(),
1747                        worktree_id: worktree.id,
1748                        summary: Some(summary),
1749                    },
1750                )?;
1751            }
1752
1753            for settings_file in worktree.settings_files {
1754                session.peer.send(
1755                    session.connection_id,
1756                    proto::UpdateWorktreeSettings {
1757                        project_id: project.id.to_proto(),
1758                        worktree_id: worktree.id,
1759                        path: settings_file.path,
1760                        content: Some(settings_file.content),
1761                    },
1762                )?;
1763            }
1764        }
1765
1766        for language_server in &project.language_servers {
1767            session.peer.send(
1768                session.connection_id,
1769                proto::UpdateLanguageServer {
1770                    project_id: project.id.to_proto(),
1771                    language_server_id: language_server.id,
1772                    variant: Some(
1773                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1774                            proto::LspDiskBasedDiagnosticsUpdated {},
1775                        ),
1776                    ),
1777                },
1778            )?;
1779        }
1780    }
1781    Ok(())
1782}
1783
1784/// leave room disconnects from the room.
1785async fn leave_room(
1786    _: proto::LeaveRoom,
1787    response: Response<proto::LeaveRoom>,
1788    session: UserSession,
1789) -> Result<()> {
1790    leave_room_for_session(&session, session.connection_id).await?;
1791    response.send(proto::Ack {})?;
1792    Ok(())
1793}
1794
1795/// Updates the permissions of someone else in the room.
1796async fn set_room_participant_role(
1797    request: proto::SetRoomParticipantRole,
1798    response: Response<proto::SetRoomParticipantRole>,
1799    session: UserSession,
1800) -> Result<()> {
1801    let user_id = UserId::from_proto(request.user_id);
1802    let role = ChannelRole::from(request.role());
1803
1804    let (live_kit_room, can_publish) = {
1805        let room = session
1806            .db()
1807            .await
1808            .set_room_participant_role(
1809                session.user_id(),
1810                RoomId::from_proto(request.room_id),
1811                user_id,
1812                role,
1813            )
1814            .await?;
1815
1816        let live_kit_room = room.live_kit_room.clone();
1817        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1818        room_updated(&room, &session.peer);
1819        (live_kit_room, can_publish)
1820    };
1821
1822    if let Some(live_kit) = session.live_kit_client.as_ref() {
1823        live_kit
1824            .update_participant(
1825                live_kit_room.clone(),
1826                request.user_id.to_string(),
1827                live_kit_server::proto::ParticipantPermission {
1828                    can_subscribe: true,
1829                    can_publish,
1830                    can_publish_data: can_publish,
1831                    hidden: false,
1832                    recorder: false,
1833                },
1834            )
1835            .await
1836            .trace_err();
1837    }
1838
1839    response.send(proto::Ack {})?;
1840    Ok(())
1841}
1842
1843/// Call someone else into the current room
1844async fn call(
1845    request: proto::Call,
1846    response: Response<proto::Call>,
1847    session: UserSession,
1848) -> Result<()> {
1849    let room_id = RoomId::from_proto(request.room_id);
1850    let calling_user_id = session.user_id();
1851    let calling_connection_id = session.connection_id;
1852    let called_user_id = UserId::from_proto(request.called_user_id);
1853    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1854    if !session
1855        .db()
1856        .await
1857        .has_contact(calling_user_id, called_user_id)
1858        .await?
1859    {
1860        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1861    }
1862
1863    let incoming_call = {
1864        let (room, incoming_call) = &mut *session
1865            .db()
1866            .await
1867            .call(
1868                room_id,
1869                calling_user_id,
1870                calling_connection_id,
1871                called_user_id,
1872                initial_project_id,
1873            )
1874            .await?;
1875        room_updated(&room, &session.peer);
1876        mem::take(incoming_call)
1877    };
1878    update_user_contacts(called_user_id, &session).await?;
1879
1880    let mut calls = session
1881        .connection_pool()
1882        .await
1883        .user_connection_ids(called_user_id)
1884        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1885        .collect::<FuturesUnordered<_>>();
1886
1887    while let Some(call_response) = calls.next().await {
1888        match call_response.as_ref() {
1889            Ok(_) => {
1890                response.send(proto::Ack {})?;
1891                return Ok(());
1892            }
1893            Err(_) => {
1894                call_response.trace_err();
1895            }
1896        }
1897    }
1898
1899    {
1900        let room = session
1901            .db()
1902            .await
1903            .call_failed(room_id, called_user_id)
1904            .await?;
1905        room_updated(&room, &session.peer);
1906    }
1907    update_user_contacts(called_user_id, &session).await?;
1908
1909    Err(anyhow!("failed to ring user"))?
1910}
1911
1912/// Cancel an outgoing call.
1913async fn cancel_call(
1914    request: proto::CancelCall,
1915    response: Response<proto::CancelCall>,
1916    session: UserSession,
1917) -> Result<()> {
1918    let called_user_id = UserId::from_proto(request.called_user_id);
1919    let room_id = RoomId::from_proto(request.room_id);
1920    {
1921        let room = session
1922            .db()
1923            .await
1924            .cancel_call(room_id, session.connection_id, called_user_id)
1925            .await?;
1926        room_updated(&room, &session.peer);
1927    }
1928
1929    for connection_id in session
1930        .connection_pool()
1931        .await
1932        .user_connection_ids(called_user_id)
1933    {
1934        session
1935            .peer
1936            .send(
1937                connection_id,
1938                proto::CallCanceled {
1939                    room_id: room_id.to_proto(),
1940                },
1941            )
1942            .trace_err();
1943    }
1944    response.send(proto::Ack {})?;
1945
1946    update_user_contacts(called_user_id, &session).await?;
1947    Ok(())
1948}
1949
1950/// Decline an incoming call.
1951async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1952    let room_id = RoomId::from_proto(message.room_id);
1953    {
1954        let room = session
1955            .db()
1956            .await
1957            .decline_call(Some(room_id), session.user_id())
1958            .await?
1959            .ok_or_else(|| anyhow!("failed to decline call"))?;
1960        room_updated(&room, &session.peer);
1961    }
1962
1963    for connection_id in session
1964        .connection_pool()
1965        .await
1966        .user_connection_ids(session.user_id())
1967    {
1968        session
1969            .peer
1970            .send(
1971                connection_id,
1972                proto::CallCanceled {
1973                    room_id: room_id.to_proto(),
1974                },
1975            )
1976            .trace_err();
1977    }
1978    update_user_contacts(session.user_id(), &session).await?;
1979    Ok(())
1980}
1981
1982/// Updates other participants in the room with your current location.
1983async fn update_participant_location(
1984    request: proto::UpdateParticipantLocation,
1985    response: Response<proto::UpdateParticipantLocation>,
1986    session: UserSession,
1987) -> Result<()> {
1988    let room_id = RoomId::from_proto(request.room_id);
1989    let location = request
1990        .location
1991        .ok_or_else(|| anyhow!("invalid location"))?;
1992
1993    let db = session.db().await;
1994    let room = db
1995        .update_room_participant_location(room_id, session.connection_id, location)
1996        .await?;
1997
1998    room_updated(&room, &session.peer);
1999    response.send(proto::Ack {})?;
2000    Ok(())
2001}
2002
2003/// Share a project into the room.
2004async fn share_project(
2005    request: proto::ShareProject,
2006    response: Response<proto::ShareProject>,
2007    session: UserSession,
2008) -> Result<()> {
2009    let (project_id, room) = &*session
2010        .db()
2011        .await
2012        .share_project(
2013            RoomId::from_proto(request.room_id),
2014            session.connection_id,
2015            &request.worktrees,
2016            request
2017                .dev_server_project_id
2018                .map(|id| DevServerProjectId::from_proto(id)),
2019        )
2020        .await?;
2021    response.send(proto::ShareProjectResponse {
2022        project_id: project_id.to_proto(),
2023    })?;
2024    room_updated(&room, &session.peer);
2025
2026    Ok(())
2027}
2028
2029/// Unshare a project from the room.
2030async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2031    let project_id = ProjectId::from_proto(message.project_id);
2032    unshare_project_internal(
2033        project_id,
2034        session.connection_id,
2035        session.user_id(),
2036        &session,
2037    )
2038    .await
2039}
2040
2041async fn unshare_project_internal(
2042    project_id: ProjectId,
2043    connection_id: ConnectionId,
2044    user_id: Option<UserId>,
2045    session: &Session,
2046) -> Result<()> {
2047    let delete = {
2048        let room_guard = session
2049            .db()
2050            .await
2051            .unshare_project(project_id, connection_id, user_id)
2052            .await?;
2053
2054        let (delete, room, guest_connection_ids) = &*room_guard;
2055
2056        let message = proto::UnshareProject {
2057            project_id: project_id.to_proto(),
2058        };
2059
2060        broadcast(
2061            Some(connection_id),
2062            guest_connection_ids.iter().copied(),
2063            |conn_id| session.peer.send(conn_id, message.clone()),
2064        );
2065        if let Some(room) = room {
2066            room_updated(room, &session.peer);
2067        }
2068
2069        *delete
2070    };
2071
2072    if delete {
2073        let db = session.db().await;
2074        db.delete_project(project_id).await?;
2075    }
2076
2077    Ok(())
2078}
2079
2080/// DevServer makes a project available online
2081async fn share_dev_server_project(
2082    request: proto::ShareDevServerProject,
2083    response: Response<proto::ShareDevServerProject>,
2084    session: DevServerSession,
2085) -> Result<()> {
2086    let (dev_server_project, user_id, status) = session
2087        .db()
2088        .await
2089        .share_dev_server_project(
2090            DevServerProjectId::from_proto(request.dev_server_project_id),
2091            session.dev_server_id(),
2092            session.connection_id,
2093            &request.worktrees,
2094        )
2095        .await?;
2096    let Some(project_id) = dev_server_project.project_id else {
2097        return Err(anyhow!("failed to share remote project"))?;
2098    };
2099
2100    send_dev_server_projects_update(user_id, status, &session).await;
2101
2102    response.send(proto::ShareProjectResponse { project_id })?;
2103
2104    Ok(())
2105}
2106
2107/// Join someone elses shared project.
2108async fn join_project(
2109    request: proto::JoinProject,
2110    response: Response<proto::JoinProject>,
2111    session: UserSession,
2112) -> Result<()> {
2113    let project_id = ProjectId::from_proto(request.project_id);
2114
2115    tracing::info!(%project_id, "join project");
2116
2117    let db = session.db().await;
2118    let (project, replica_id) = &mut *db
2119        .join_project(project_id, session.connection_id, session.user_id())
2120        .await?;
2121    drop(db);
2122    tracing::info!(%project_id, "join remote project");
2123    join_project_internal(response, session, project, replica_id)
2124}
2125
2126trait JoinProjectInternalResponse {
2127    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2128}
2129impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2130    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2131        Response::<proto::JoinProject>::send(self, result)
2132    }
2133}
2134impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2135    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2136        Response::<proto::JoinHostedProject>::send(self, result)
2137    }
2138}
2139
2140fn join_project_internal(
2141    response: impl JoinProjectInternalResponse,
2142    session: UserSession,
2143    project: &mut Project,
2144    replica_id: &ReplicaId,
2145) -> Result<()> {
2146    let collaborators = project
2147        .collaborators
2148        .iter()
2149        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2150        .map(|collaborator| collaborator.to_proto())
2151        .collect::<Vec<_>>();
2152    let project_id = project.id;
2153    let guest_user_id = session.user_id();
2154
2155    let worktrees = project
2156        .worktrees
2157        .iter()
2158        .map(|(id, worktree)| proto::WorktreeMetadata {
2159            id: *id,
2160            root_name: worktree.root_name.clone(),
2161            visible: worktree.visible,
2162            abs_path: worktree.abs_path.clone(),
2163        })
2164        .collect::<Vec<_>>();
2165
2166    let add_project_collaborator = proto::AddProjectCollaborator {
2167        project_id: project_id.to_proto(),
2168        collaborator: Some(proto::Collaborator {
2169            peer_id: Some(session.connection_id.into()),
2170            replica_id: replica_id.0 as u32,
2171            user_id: guest_user_id.to_proto(),
2172        }),
2173    };
2174
2175    for collaborator in &collaborators {
2176        session
2177            .peer
2178            .send(
2179                collaborator.peer_id.unwrap().into(),
2180                add_project_collaborator.clone(),
2181            )
2182            .trace_err();
2183    }
2184
2185    // First, we send the metadata associated with each worktree.
2186    response.send(proto::JoinProjectResponse {
2187        project_id: project.id.0 as u64,
2188        worktrees: worktrees.clone(),
2189        replica_id: replica_id.0 as u32,
2190        collaborators: collaborators.clone(),
2191        language_servers: project.language_servers.clone(),
2192        role: project.role.into(),
2193        dev_server_project_id: project
2194            .dev_server_project_id
2195            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2196    })?;
2197
2198    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2199        #[cfg(any(test, feature = "test-support"))]
2200        const MAX_CHUNK_SIZE: usize = 2;
2201        #[cfg(not(any(test, feature = "test-support")))]
2202        const MAX_CHUNK_SIZE: usize = 256;
2203
2204        // Stream this worktree's entries.
2205        let message = proto::UpdateWorktree {
2206            project_id: project_id.to_proto(),
2207            worktree_id,
2208            abs_path: worktree.abs_path.clone(),
2209            root_name: worktree.root_name,
2210            updated_entries: worktree.entries,
2211            removed_entries: Default::default(),
2212            scan_id: worktree.scan_id,
2213            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2214            updated_repositories: worktree.repository_entries.into_values().collect(),
2215            removed_repositories: Default::default(),
2216        };
2217        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2218            session.peer.send(session.connection_id, update.clone())?;
2219        }
2220
2221        // Stream this worktree's diagnostics.
2222        for summary in worktree.diagnostic_summaries {
2223            session.peer.send(
2224                session.connection_id,
2225                proto::UpdateDiagnosticSummary {
2226                    project_id: project_id.to_proto(),
2227                    worktree_id: worktree.id,
2228                    summary: Some(summary),
2229                },
2230            )?;
2231        }
2232
2233        for settings_file in worktree.settings_files {
2234            session.peer.send(
2235                session.connection_id,
2236                proto::UpdateWorktreeSettings {
2237                    project_id: project_id.to_proto(),
2238                    worktree_id: worktree.id,
2239                    path: settings_file.path,
2240                    content: Some(settings_file.content),
2241                },
2242            )?;
2243        }
2244    }
2245
2246    for language_server in &project.language_servers {
2247        session.peer.send(
2248            session.connection_id,
2249            proto::UpdateLanguageServer {
2250                project_id: project_id.to_proto(),
2251                language_server_id: language_server.id,
2252                variant: Some(
2253                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2254                        proto::LspDiskBasedDiagnosticsUpdated {},
2255                    ),
2256                ),
2257            },
2258        )?;
2259    }
2260
2261    Ok(())
2262}
2263
2264/// Leave someone elses shared project.
2265async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2266    let sender_id = session.connection_id;
2267    let project_id = ProjectId::from_proto(request.project_id);
2268    let db = session.db().await;
2269    if db.is_hosted_project(project_id).await? {
2270        let project = db.leave_hosted_project(project_id, sender_id).await?;
2271        project_left(&project, &session);
2272        return Ok(());
2273    }
2274
2275    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2276    tracing::info!(
2277        %project_id,
2278        "leave project"
2279    );
2280
2281    project_left(&project, &session);
2282    if let Some(room) = room {
2283        room_updated(&room, &session.peer);
2284    }
2285
2286    Ok(())
2287}
2288
2289async fn join_hosted_project(
2290    request: proto::JoinHostedProject,
2291    response: Response<proto::JoinHostedProject>,
2292    session: UserSession,
2293) -> Result<()> {
2294    let (mut project, replica_id) = session
2295        .db()
2296        .await
2297        .join_hosted_project(
2298            ProjectId(request.project_id as i32),
2299            session.user_id(),
2300            session.connection_id,
2301        )
2302        .await?;
2303
2304    join_project_internal(response, session, &mut project, &replica_id)
2305}
2306
2307async fn list_remote_directory(
2308    request: proto::ListRemoteDirectory,
2309    response: Response<proto::ListRemoteDirectory>,
2310    session: UserSession,
2311) -> Result<()> {
2312    let dev_server_id = DevServerId(request.dev_server_id as i32);
2313    let dev_server_connection_id = session
2314        .connection_pool()
2315        .await
2316        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2317
2318    session
2319        .db()
2320        .await
2321        .get_dev_server_for_user(dev_server_id, session.user_id())
2322        .await?;
2323
2324    response.send(
2325        session
2326            .peer
2327            .forward_request(session.connection_id, dev_server_connection_id, request)
2328            .await?,
2329    )?;
2330    Ok(())
2331}
2332
2333async fn update_dev_server_project(
2334    request: proto::UpdateDevServerProject,
2335    response: Response<proto::UpdateDevServerProject>,
2336    session: UserSession,
2337) -> Result<()> {
2338    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2339
2340    let (dev_server_project, update) = session
2341        .db()
2342        .await
2343        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2344        .await?;
2345
2346    let projects = session
2347        .db()
2348        .await
2349        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2350        .await?;
2351
2352    let dev_server_connection_id = session
2353        .connection_pool()
2354        .await
2355        .dev_server_connection_id_supporting(
2356            dev_server_project.dev_server_id,
2357            ZedVersion::with_list_directory(),
2358        )?;
2359
2360    session.peer.send(
2361        dev_server_connection_id,
2362        proto::DevServerInstructions { projects },
2363    )?;
2364
2365    send_dev_server_projects_update(session.user_id(), update, &session).await;
2366
2367    response.send(proto::Ack {})
2368}
2369
2370async fn create_dev_server_project(
2371    request: proto::CreateDevServerProject,
2372    response: Response<proto::CreateDevServerProject>,
2373    session: UserSession,
2374) -> Result<()> {
2375    let dev_server_id = DevServerId(request.dev_server_id as i32);
2376    let dev_server_connection_id = session
2377        .connection_pool()
2378        .await
2379        .dev_server_connection_id(dev_server_id);
2380    let Some(dev_server_connection_id) = dev_server_connection_id else {
2381        Err(ErrorCode::DevServerOffline
2382            .message("Cannot create a remote project when the dev server is offline".to_string())
2383            .anyhow())?
2384    };
2385
2386    let path = request.path.clone();
2387    //Check that the path exists on the dev server
2388    session
2389        .peer
2390        .forward_request(
2391            session.connection_id,
2392            dev_server_connection_id,
2393            proto::ValidateDevServerProjectRequest { path: path.clone() },
2394        )
2395        .await?;
2396
2397    let (dev_server_project, update) = session
2398        .db()
2399        .await
2400        .create_dev_server_project(
2401            DevServerId(request.dev_server_id as i32),
2402            &request.path,
2403            session.user_id(),
2404        )
2405        .await?;
2406
2407    let projects = session
2408        .db()
2409        .await
2410        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2411        .await?;
2412
2413    session.peer.send(
2414        dev_server_connection_id,
2415        proto::DevServerInstructions { projects },
2416    )?;
2417
2418    send_dev_server_projects_update(session.user_id(), update, &session).await;
2419
2420    response.send(proto::CreateDevServerProjectResponse {
2421        dev_server_project: Some(dev_server_project.to_proto(None)),
2422    })?;
2423    Ok(())
2424}
2425
2426async fn create_dev_server(
2427    request: proto::CreateDevServer,
2428    response: Response<proto::CreateDevServer>,
2429    session: UserSession,
2430) -> Result<()> {
2431    let access_token = auth::random_token();
2432    let hashed_access_token = auth::hash_access_token(&access_token);
2433
2434    if request.name.is_empty() {
2435        return Err(proto::ErrorCode::Forbidden
2436            .message("Dev server name cannot be empty".to_string())
2437            .anyhow())?;
2438    }
2439
2440    let (dev_server, status) = session
2441        .db()
2442        .await
2443        .create_dev_server(
2444            &request.name,
2445            request.ssh_connection_string.as_deref(),
2446            &hashed_access_token,
2447            session.user_id(),
2448        )
2449        .await?;
2450
2451    send_dev_server_projects_update(session.user_id(), status, &session).await;
2452
2453    response.send(proto::CreateDevServerResponse {
2454        dev_server_id: dev_server.id.0 as u64,
2455        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2456        name: request.name,
2457    })?;
2458    Ok(())
2459}
2460
2461async fn regenerate_dev_server_token(
2462    request: proto::RegenerateDevServerToken,
2463    response: Response<proto::RegenerateDevServerToken>,
2464    session: UserSession,
2465) -> Result<()> {
2466    let dev_server_id = DevServerId(request.dev_server_id as i32);
2467    let access_token = auth::random_token();
2468    let hashed_access_token = auth::hash_access_token(&access_token);
2469
2470    let connection_id = session
2471        .connection_pool()
2472        .await
2473        .dev_server_connection_id(dev_server_id);
2474    if let Some(connection_id) = connection_id {
2475        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2476        session.peer.send(
2477            connection_id,
2478            proto::ShutdownDevServer {
2479                reason: Some("dev server token was regenerated".to_string()),
2480            },
2481        )?;
2482        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2483    }
2484
2485    let status = session
2486        .db()
2487        .await
2488        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2489        .await?;
2490
2491    send_dev_server_projects_update(session.user_id(), status, &session).await;
2492
2493    response.send(proto::RegenerateDevServerTokenResponse {
2494        dev_server_id: dev_server_id.to_proto(),
2495        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2496    })?;
2497    Ok(())
2498}
2499
2500async fn rename_dev_server(
2501    request: proto::RenameDevServer,
2502    response: Response<proto::RenameDevServer>,
2503    session: UserSession,
2504) -> Result<()> {
2505    if request.name.trim().is_empty() {
2506        return Err(proto::ErrorCode::Forbidden
2507            .message("Dev server name cannot be empty".to_string())
2508            .anyhow())?;
2509    }
2510
2511    let dev_server_id = DevServerId(request.dev_server_id as i32);
2512    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2513    if dev_server.user_id != session.user_id() {
2514        return Err(anyhow!(ErrorCode::Forbidden))?;
2515    }
2516
2517    let status = session
2518        .db()
2519        .await
2520        .rename_dev_server(
2521            dev_server_id,
2522            &request.name,
2523            request.ssh_connection_string.as_deref(),
2524            session.user_id(),
2525        )
2526        .await?;
2527
2528    send_dev_server_projects_update(session.user_id(), status, &session).await;
2529
2530    response.send(proto::Ack {})?;
2531    Ok(())
2532}
2533
2534async fn delete_dev_server(
2535    request: proto::DeleteDevServer,
2536    response: Response<proto::DeleteDevServer>,
2537    session: UserSession,
2538) -> Result<()> {
2539    let dev_server_id = DevServerId(request.dev_server_id as i32);
2540    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2541    if dev_server.user_id != session.user_id() {
2542        return Err(anyhow!(ErrorCode::Forbidden))?;
2543    }
2544
2545    let connection_id = session
2546        .connection_pool()
2547        .await
2548        .dev_server_connection_id(dev_server_id);
2549    if let Some(connection_id) = connection_id {
2550        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2551        session.peer.send(
2552            connection_id,
2553            proto::ShutdownDevServer {
2554                reason: Some("dev server was deleted".to_string()),
2555            },
2556        )?;
2557        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2558    }
2559
2560    let status = session
2561        .db()
2562        .await
2563        .delete_dev_server(dev_server_id, session.user_id())
2564        .await?;
2565
2566    send_dev_server_projects_update(session.user_id(), status, &session).await;
2567
2568    response.send(proto::Ack {})?;
2569    Ok(())
2570}
2571
2572async fn delete_dev_server_project(
2573    request: proto::DeleteDevServerProject,
2574    response: Response<proto::DeleteDevServerProject>,
2575    session: UserSession,
2576) -> Result<()> {
2577    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2578    let dev_server_project = session
2579        .db()
2580        .await
2581        .get_dev_server_project(dev_server_project_id)
2582        .await?;
2583
2584    let dev_server = session
2585        .db()
2586        .await
2587        .get_dev_server(dev_server_project.dev_server_id)
2588        .await?;
2589    if dev_server.user_id != session.user_id() {
2590        return Err(anyhow!(ErrorCode::Forbidden))?;
2591    }
2592
2593    let dev_server_connection_id = session
2594        .connection_pool()
2595        .await
2596        .dev_server_connection_id(dev_server.id);
2597
2598    if let Some(dev_server_connection_id) = dev_server_connection_id {
2599        let project = session
2600            .db()
2601            .await
2602            .find_dev_server_project(dev_server_project_id)
2603            .await;
2604        if let Ok(project) = project {
2605            unshare_project_internal(
2606                project.id,
2607                dev_server_connection_id,
2608                Some(session.user_id()),
2609                &session,
2610            )
2611            .await?;
2612        }
2613    }
2614
2615    let (projects, status) = session
2616        .db()
2617        .await
2618        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2619        .await?;
2620
2621    if let Some(dev_server_connection_id) = dev_server_connection_id {
2622        session.peer.send(
2623            dev_server_connection_id,
2624            proto::DevServerInstructions { projects },
2625        )?;
2626    }
2627
2628    send_dev_server_projects_update(session.user_id(), status, &session).await;
2629
2630    response.send(proto::Ack {})?;
2631    Ok(())
2632}
2633
2634async fn rejoin_dev_server_projects(
2635    request: proto::RejoinRemoteProjects,
2636    response: Response<proto::RejoinRemoteProjects>,
2637    session: UserSession,
2638) -> Result<()> {
2639    let mut rejoined_projects = {
2640        let db = session.db().await;
2641        db.rejoin_dev_server_projects(
2642            &request.rejoined_projects,
2643            session.user_id(),
2644            session.0.connection_id,
2645        )
2646        .await?
2647    };
2648    response.send(proto::RejoinRemoteProjectsResponse {
2649        rejoined_projects: rejoined_projects
2650            .iter()
2651            .map(|project| project.to_proto())
2652            .collect(),
2653    })?;
2654    notify_rejoined_projects(&mut rejoined_projects, &session)
2655}
2656
2657async fn reconnect_dev_server(
2658    request: proto::ReconnectDevServer,
2659    response: Response<proto::ReconnectDevServer>,
2660    session: DevServerSession,
2661) -> Result<()> {
2662    let reshared_projects = {
2663        let db = session.db().await;
2664        db.reshare_dev_server_projects(
2665            &request.reshared_projects,
2666            session.dev_server_id(),
2667            session.0.connection_id,
2668        )
2669        .await?
2670    };
2671
2672    for project in &reshared_projects {
2673        for collaborator in &project.collaborators {
2674            session
2675                .peer
2676                .send(
2677                    collaborator.connection_id,
2678                    proto::UpdateProjectCollaborator {
2679                        project_id: project.id.to_proto(),
2680                        old_peer_id: Some(project.old_connection_id.into()),
2681                        new_peer_id: Some(session.connection_id.into()),
2682                    },
2683                )
2684                .trace_err();
2685        }
2686
2687        broadcast(
2688            Some(session.connection_id),
2689            project
2690                .collaborators
2691                .iter()
2692                .map(|collaborator| collaborator.connection_id),
2693            |connection_id| {
2694                session.peer.forward_send(
2695                    session.connection_id,
2696                    connection_id,
2697                    proto::UpdateProject {
2698                        project_id: project.id.to_proto(),
2699                        worktrees: project.worktrees.clone(),
2700                    },
2701                )
2702            },
2703        );
2704    }
2705
2706    response.send(proto::ReconnectDevServerResponse {
2707        reshared_projects: reshared_projects
2708            .iter()
2709            .map(|project| proto::ResharedProject {
2710                id: project.id.to_proto(),
2711                collaborators: project
2712                    .collaborators
2713                    .iter()
2714                    .map(|collaborator| collaborator.to_proto())
2715                    .collect(),
2716            })
2717            .collect(),
2718    })?;
2719
2720    Ok(())
2721}
2722
2723async fn shutdown_dev_server(
2724    _: proto::ShutdownDevServer,
2725    response: Response<proto::ShutdownDevServer>,
2726    session: DevServerSession,
2727) -> Result<()> {
2728    response.send(proto::Ack {})?;
2729    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2730    remove_dev_server_connection(session.dev_server_id(), &session).await
2731}
2732
2733async fn shutdown_dev_server_internal(
2734    dev_server_id: DevServerId,
2735    connection_id: ConnectionId,
2736    session: &Session,
2737) -> Result<()> {
2738    let (dev_server_projects, dev_server) = {
2739        let db = session.db().await;
2740        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2741        let dev_server = db.get_dev_server(dev_server_id).await?;
2742        (dev_server_projects, dev_server)
2743    };
2744
2745    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2746        unshare_project_internal(
2747            ProjectId::from_proto(project_id),
2748            connection_id,
2749            None,
2750            session,
2751        )
2752        .await?;
2753    }
2754
2755    session
2756        .connection_pool()
2757        .await
2758        .set_dev_server_offline(dev_server_id);
2759
2760    let status = session
2761        .db()
2762        .await
2763        .dev_server_projects_update(dev_server.user_id)
2764        .await?;
2765    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2766
2767    Ok(())
2768}
2769
2770async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2771    let dev_server_connection = session
2772        .connection_pool()
2773        .await
2774        .dev_server_connection_id(dev_server_id);
2775
2776    if let Some(dev_server_connection) = dev_server_connection {
2777        session
2778            .connection_pool()
2779            .await
2780            .remove_connection(dev_server_connection)?;
2781    }
2782    Ok(())
2783}
2784
2785/// Updates other participants with changes to the project
2786async fn update_project(
2787    request: proto::UpdateProject,
2788    response: Response<proto::UpdateProject>,
2789    session: Session,
2790) -> Result<()> {
2791    let project_id = ProjectId::from_proto(request.project_id);
2792    let (room, guest_connection_ids) = &*session
2793        .db()
2794        .await
2795        .update_project(project_id, session.connection_id, &request.worktrees)
2796        .await?;
2797    broadcast(
2798        Some(session.connection_id),
2799        guest_connection_ids.iter().copied(),
2800        |connection_id| {
2801            session
2802                .peer
2803                .forward_send(session.connection_id, connection_id, request.clone())
2804        },
2805    );
2806    if let Some(room) = room {
2807        room_updated(&room, &session.peer);
2808    }
2809    response.send(proto::Ack {})?;
2810
2811    Ok(())
2812}
2813
2814/// Updates other participants with changes to the worktree
2815async fn update_worktree(
2816    request: proto::UpdateWorktree,
2817    response: Response<proto::UpdateWorktree>,
2818    session: Session,
2819) -> Result<()> {
2820    let guest_connection_ids = session
2821        .db()
2822        .await
2823        .update_worktree(&request, session.connection_id)
2824        .await?;
2825
2826    broadcast(
2827        Some(session.connection_id),
2828        guest_connection_ids.iter().copied(),
2829        |connection_id| {
2830            session
2831                .peer
2832                .forward_send(session.connection_id, connection_id, request.clone())
2833        },
2834    );
2835    response.send(proto::Ack {})?;
2836    Ok(())
2837}
2838
2839/// Updates other participants with changes to the diagnostics
2840async fn update_diagnostic_summary(
2841    message: proto::UpdateDiagnosticSummary,
2842    session: Session,
2843) -> Result<()> {
2844    let guest_connection_ids = session
2845        .db()
2846        .await
2847        .update_diagnostic_summary(&message, session.connection_id)
2848        .await?;
2849
2850    broadcast(
2851        Some(session.connection_id),
2852        guest_connection_ids.iter().copied(),
2853        |connection_id| {
2854            session
2855                .peer
2856                .forward_send(session.connection_id, connection_id, message.clone())
2857        },
2858    );
2859
2860    Ok(())
2861}
2862
2863/// Updates other participants with changes to the worktree settings
2864async fn update_worktree_settings(
2865    message: proto::UpdateWorktreeSettings,
2866    session: Session,
2867) -> Result<()> {
2868    let guest_connection_ids = session
2869        .db()
2870        .await
2871        .update_worktree_settings(&message, session.connection_id)
2872        .await?;
2873
2874    broadcast(
2875        Some(session.connection_id),
2876        guest_connection_ids.iter().copied(),
2877        |connection_id| {
2878            session
2879                .peer
2880                .forward_send(session.connection_id, connection_id, message.clone())
2881        },
2882    );
2883
2884    Ok(())
2885}
2886
2887/// Notify other participants that a  language server has started.
2888async fn start_language_server(
2889    request: proto::StartLanguageServer,
2890    session: Session,
2891) -> Result<()> {
2892    let guest_connection_ids = session
2893        .db()
2894        .await
2895        .start_language_server(&request, session.connection_id)
2896        .await?;
2897
2898    broadcast(
2899        Some(session.connection_id),
2900        guest_connection_ids.iter().copied(),
2901        |connection_id| {
2902            session
2903                .peer
2904                .forward_send(session.connection_id, connection_id, request.clone())
2905        },
2906    );
2907    Ok(())
2908}
2909
2910/// Notify other participants that a language server has changed.
2911async fn update_language_server(
2912    request: proto::UpdateLanguageServer,
2913    session: Session,
2914) -> Result<()> {
2915    let project_id = ProjectId::from_proto(request.project_id);
2916    let project_connection_ids = session
2917        .db()
2918        .await
2919        .project_connection_ids(project_id, session.connection_id, true)
2920        .await?;
2921    broadcast(
2922        Some(session.connection_id),
2923        project_connection_ids.iter().copied(),
2924        |connection_id| {
2925            session
2926                .peer
2927                .forward_send(session.connection_id, connection_id, request.clone())
2928        },
2929    );
2930    Ok(())
2931}
2932
2933/// forward a project request to the host. These requests should be read only
2934/// as guests are allowed to send them.
2935async fn forward_read_only_project_request<T>(
2936    request: T,
2937    response: Response<T>,
2938    session: UserSession,
2939) -> Result<()>
2940where
2941    T: EntityMessage + RequestMessage,
2942{
2943    let project_id = ProjectId::from_proto(request.remote_entity_id());
2944    let host_connection_id = session
2945        .db()
2946        .await
2947        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2948        .await?;
2949    let payload = session
2950        .peer
2951        .forward_request(session.connection_id, host_connection_id, request)
2952        .await?;
2953    response.send(payload)?;
2954    Ok(())
2955}
2956
2957/// forward a project request to the dev server. Only allowed
2958/// if it's your dev server.
2959async fn forward_project_request_for_owner<T>(
2960    request: T,
2961    response: Response<T>,
2962    session: UserSession,
2963) -> Result<()>
2964where
2965    T: EntityMessage + RequestMessage,
2966{
2967    let project_id = ProjectId::from_proto(request.remote_entity_id());
2968
2969    let host_connection_id = session
2970        .db()
2971        .await
2972        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2973        .await?;
2974    let payload = session
2975        .peer
2976        .forward_request(session.connection_id, host_connection_id, request)
2977        .await?;
2978    response.send(payload)?;
2979    Ok(())
2980}
2981
2982/// forward a project request to the host. These requests are disallowed
2983/// for guests.
2984async fn forward_mutating_project_request<T>(
2985    request: T,
2986    response: Response<T>,
2987    session: UserSession,
2988) -> Result<()>
2989where
2990    T: EntityMessage + RequestMessage,
2991{
2992    let project_id = ProjectId::from_proto(request.remote_entity_id());
2993
2994    let host_connection_id = session
2995        .db()
2996        .await
2997        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2998        .await?;
2999    let payload = session
3000        .peer
3001        .forward_request(session.connection_id, host_connection_id, request)
3002        .await?;
3003    response.send(payload)?;
3004    Ok(())
3005}
3006
3007/// forward a project request to the host. These requests are disallowed
3008/// for guests.
3009async fn forward_versioned_mutating_project_request<T>(
3010    request: T,
3011    response: Response<T>,
3012    session: UserSession,
3013) -> Result<()>
3014where
3015    T: EntityMessage + RequestMessage + VersionedMessage,
3016{
3017    let project_id = ProjectId::from_proto(request.remote_entity_id());
3018
3019    let host_connection_id = session
3020        .db()
3021        .await
3022        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3023        .await?;
3024    if let Some(host_version) = session
3025        .connection_pool()
3026        .await
3027        .connection(host_connection_id)
3028        .map(|c| c.zed_version)
3029    {
3030        if let Some(min_required_version) = request.required_host_version() {
3031            if min_required_version > host_version {
3032                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3033                    .with_tag("required", &min_required_version.to_string())))?;
3034            }
3035        }
3036    }
3037
3038    let payload = session
3039        .peer
3040        .forward_request(session.connection_id, host_connection_id, request)
3041        .await?;
3042    response.send(payload)?;
3043    Ok(())
3044}
3045
3046/// Notify other participants that a new buffer has been created
3047async fn create_buffer_for_peer(
3048    request: proto::CreateBufferForPeer,
3049    session: Session,
3050) -> Result<()> {
3051    session
3052        .db()
3053        .await
3054        .check_user_is_project_host(
3055            ProjectId::from_proto(request.project_id),
3056            session.connection_id,
3057        )
3058        .await?;
3059    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3060    session
3061        .peer
3062        .forward_send(session.connection_id, peer_id.into(), request)?;
3063    Ok(())
3064}
3065
3066/// Notify other participants that a buffer has been updated. This is
3067/// allowed for guests as long as the update is limited to selections.
3068async fn update_buffer(
3069    request: proto::UpdateBuffer,
3070    response: Response<proto::UpdateBuffer>,
3071    session: Session,
3072) -> Result<()> {
3073    let project_id = ProjectId::from_proto(request.project_id);
3074    let mut capability = Capability::ReadOnly;
3075
3076    for op in request.operations.iter() {
3077        match op.variant {
3078            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3079            Some(_) => capability = Capability::ReadWrite,
3080        }
3081    }
3082
3083    let host = {
3084        let guard = session
3085            .db()
3086            .await
3087            .connections_for_buffer_update(
3088                project_id,
3089                session.principal_id(),
3090                session.connection_id,
3091                capability,
3092            )
3093            .await?;
3094
3095        let (host, guests) = &*guard;
3096
3097        broadcast(
3098            Some(session.connection_id),
3099            guests.clone(),
3100            |connection_id| {
3101                session
3102                    .peer
3103                    .forward_send(session.connection_id, connection_id, request.clone())
3104            },
3105        );
3106
3107        *host
3108    };
3109
3110    if host != session.connection_id {
3111        session
3112            .peer
3113            .forward_request(session.connection_id, host, request.clone())
3114            .await?;
3115    }
3116
3117    response.send(proto::Ack {})?;
3118    Ok(())
3119}
3120
3121async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3122    let project_id = ProjectId::from_proto(message.project_id);
3123
3124    let operation = message.operation.as_ref().context("invalid operation")?;
3125    let capability = match operation.variant.as_ref() {
3126        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3127            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3128                match buffer_op.variant {
3129                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3130                        Capability::ReadOnly
3131                    }
3132                    _ => Capability::ReadWrite,
3133                }
3134            } else {
3135                Capability::ReadWrite
3136            }
3137        }
3138        Some(_) => Capability::ReadWrite,
3139        None => Capability::ReadOnly,
3140    };
3141
3142    let guard = session
3143        .db()
3144        .await
3145        .connections_for_buffer_update(
3146            project_id,
3147            session.principal_id(),
3148            session.connection_id,
3149            capability,
3150        )
3151        .await?;
3152
3153    let (host, guests) = &*guard;
3154
3155    broadcast(
3156        Some(session.connection_id),
3157        guests.iter().chain([host]).copied(),
3158        |connection_id| {
3159            session
3160                .peer
3161                .forward_send(session.connection_id, connection_id, message.clone())
3162        },
3163    );
3164
3165    Ok(())
3166}
3167
3168/// Notify other participants that a project has been updated.
3169async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3170    request: T,
3171    session: Session,
3172) -> Result<()> {
3173    let project_id = ProjectId::from_proto(request.remote_entity_id());
3174    let project_connection_ids = session
3175        .db()
3176        .await
3177        .project_connection_ids(project_id, session.connection_id, false)
3178        .await?;
3179
3180    broadcast(
3181        Some(session.connection_id),
3182        project_connection_ids.iter().copied(),
3183        |connection_id| {
3184            session
3185                .peer
3186                .forward_send(session.connection_id, connection_id, request.clone())
3187        },
3188    );
3189    Ok(())
3190}
3191
3192/// Start following another user in a call.
3193async fn follow(
3194    request: proto::Follow,
3195    response: Response<proto::Follow>,
3196    session: UserSession,
3197) -> Result<()> {
3198    let room_id = RoomId::from_proto(request.room_id);
3199    let project_id = request.project_id.map(ProjectId::from_proto);
3200    let leader_id = request
3201        .leader_id
3202        .ok_or_else(|| anyhow!("invalid leader id"))?
3203        .into();
3204    let follower_id = session.connection_id;
3205
3206    session
3207        .db()
3208        .await
3209        .check_room_participants(room_id, leader_id, session.connection_id)
3210        .await?;
3211
3212    let response_payload = session
3213        .peer
3214        .forward_request(session.connection_id, leader_id, request)
3215        .await?;
3216    response.send(response_payload)?;
3217
3218    if let Some(project_id) = project_id {
3219        let room = session
3220            .db()
3221            .await
3222            .follow(room_id, project_id, leader_id, follower_id)
3223            .await?;
3224        room_updated(&room, &session.peer);
3225    }
3226
3227    Ok(())
3228}
3229
3230/// Stop following another user in a call.
3231async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3232    let room_id = RoomId::from_proto(request.room_id);
3233    let project_id = request.project_id.map(ProjectId::from_proto);
3234    let leader_id = request
3235        .leader_id
3236        .ok_or_else(|| anyhow!("invalid leader id"))?
3237        .into();
3238    let follower_id = session.connection_id;
3239
3240    session
3241        .db()
3242        .await
3243        .check_room_participants(room_id, leader_id, session.connection_id)
3244        .await?;
3245
3246    session
3247        .peer
3248        .forward_send(session.connection_id, leader_id, request)?;
3249
3250    if let Some(project_id) = project_id {
3251        let room = session
3252            .db()
3253            .await
3254            .unfollow(room_id, project_id, leader_id, follower_id)
3255            .await?;
3256        room_updated(&room, &session.peer);
3257    }
3258
3259    Ok(())
3260}
3261
3262/// Notify everyone following you of your current location.
3263async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3264    let room_id = RoomId::from_proto(request.room_id);
3265    let database = session.db.lock().await;
3266
3267    let connection_ids = if let Some(project_id) = request.project_id {
3268        let project_id = ProjectId::from_proto(project_id);
3269        database
3270            .project_connection_ids(project_id, session.connection_id, true)
3271            .await?
3272    } else {
3273        database
3274            .room_connection_ids(room_id, session.connection_id)
3275            .await?
3276    };
3277
3278    // For now, don't send view update messages back to that view's current leader.
3279    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3280        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3281        _ => None,
3282    });
3283
3284    for connection_id in connection_ids.iter().cloned() {
3285        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3286            session
3287                .peer
3288                .forward_send(session.connection_id, connection_id, request.clone())?;
3289        }
3290    }
3291    Ok(())
3292}
3293
3294/// Get public data about users.
3295async fn get_users(
3296    request: proto::GetUsers,
3297    response: Response<proto::GetUsers>,
3298    session: Session,
3299) -> Result<()> {
3300    let user_ids = request
3301        .user_ids
3302        .into_iter()
3303        .map(UserId::from_proto)
3304        .collect();
3305    let users = session
3306        .db()
3307        .await
3308        .get_users_by_ids(user_ids)
3309        .await?
3310        .into_iter()
3311        .map(|user| proto::User {
3312            id: user.id.to_proto(),
3313            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3314            github_login: user.github_login,
3315        })
3316        .collect();
3317    response.send(proto::UsersResponse { users })?;
3318    Ok(())
3319}
3320
3321/// Search for users (to invite) buy Github login
3322async fn fuzzy_search_users(
3323    request: proto::FuzzySearchUsers,
3324    response: Response<proto::FuzzySearchUsers>,
3325    session: UserSession,
3326) -> Result<()> {
3327    let query = request.query;
3328    let users = match query.len() {
3329        0 => vec![],
3330        1 | 2 => session
3331            .db()
3332            .await
3333            .get_user_by_github_login(&query)
3334            .await?
3335            .into_iter()
3336            .collect(),
3337        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3338    };
3339    let users = users
3340        .into_iter()
3341        .filter(|user| user.id != session.user_id())
3342        .map(|user| proto::User {
3343            id: user.id.to_proto(),
3344            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3345            github_login: user.github_login,
3346        })
3347        .collect();
3348    response.send(proto::UsersResponse { users })?;
3349    Ok(())
3350}
3351
3352/// Send a contact request to another user.
3353async fn request_contact(
3354    request: proto::RequestContact,
3355    response: Response<proto::RequestContact>,
3356    session: UserSession,
3357) -> Result<()> {
3358    let requester_id = session.user_id();
3359    let responder_id = UserId::from_proto(request.responder_id);
3360    if requester_id == responder_id {
3361        return Err(anyhow!("cannot add yourself as a contact"))?;
3362    }
3363
3364    let notifications = session
3365        .db()
3366        .await
3367        .send_contact_request(requester_id, responder_id)
3368        .await?;
3369
3370    // Update outgoing contact requests of requester
3371    let mut update = proto::UpdateContacts::default();
3372    update.outgoing_requests.push(responder_id.to_proto());
3373    for connection_id in session
3374        .connection_pool()
3375        .await
3376        .user_connection_ids(requester_id)
3377    {
3378        session.peer.send(connection_id, update.clone())?;
3379    }
3380
3381    // Update incoming contact requests of responder
3382    let mut update = proto::UpdateContacts::default();
3383    update
3384        .incoming_requests
3385        .push(proto::IncomingContactRequest {
3386            requester_id: requester_id.to_proto(),
3387        });
3388    let connection_pool = session.connection_pool().await;
3389    for connection_id in connection_pool.user_connection_ids(responder_id) {
3390        session.peer.send(connection_id, update.clone())?;
3391    }
3392
3393    send_notifications(&connection_pool, &session.peer, notifications);
3394
3395    response.send(proto::Ack {})?;
3396    Ok(())
3397}
3398
3399/// Accept or decline a contact request
3400async fn respond_to_contact_request(
3401    request: proto::RespondToContactRequest,
3402    response: Response<proto::RespondToContactRequest>,
3403    session: UserSession,
3404) -> Result<()> {
3405    let responder_id = session.user_id();
3406    let requester_id = UserId::from_proto(request.requester_id);
3407    let db = session.db().await;
3408    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3409        db.dismiss_contact_notification(responder_id, requester_id)
3410            .await?;
3411    } else {
3412        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3413
3414        let notifications = db
3415            .respond_to_contact_request(responder_id, requester_id, accept)
3416            .await?;
3417        let requester_busy = db.is_user_busy(requester_id).await?;
3418        let responder_busy = db.is_user_busy(responder_id).await?;
3419
3420        let pool = session.connection_pool().await;
3421        // Update responder with new contact
3422        let mut update = proto::UpdateContacts::default();
3423        if accept {
3424            update
3425                .contacts
3426                .push(contact_for_user(requester_id, requester_busy, &pool));
3427        }
3428        update
3429            .remove_incoming_requests
3430            .push(requester_id.to_proto());
3431        for connection_id in pool.user_connection_ids(responder_id) {
3432            session.peer.send(connection_id, update.clone())?;
3433        }
3434
3435        // Update requester with new contact
3436        let mut update = proto::UpdateContacts::default();
3437        if accept {
3438            update
3439                .contacts
3440                .push(contact_for_user(responder_id, responder_busy, &pool));
3441        }
3442        update
3443            .remove_outgoing_requests
3444            .push(responder_id.to_proto());
3445
3446        for connection_id in pool.user_connection_ids(requester_id) {
3447            session.peer.send(connection_id, update.clone())?;
3448        }
3449
3450        send_notifications(&pool, &session.peer, notifications);
3451    }
3452
3453    response.send(proto::Ack {})?;
3454    Ok(())
3455}
3456
3457/// Remove a contact.
3458async fn remove_contact(
3459    request: proto::RemoveContact,
3460    response: Response<proto::RemoveContact>,
3461    session: UserSession,
3462) -> Result<()> {
3463    let requester_id = session.user_id();
3464    let responder_id = UserId::from_proto(request.user_id);
3465    let db = session.db().await;
3466    let (contact_accepted, deleted_notification_id) =
3467        db.remove_contact(requester_id, responder_id).await?;
3468
3469    let pool = session.connection_pool().await;
3470    // Update outgoing contact requests of requester
3471    let mut update = proto::UpdateContacts::default();
3472    if contact_accepted {
3473        update.remove_contacts.push(responder_id.to_proto());
3474    } else {
3475        update
3476            .remove_outgoing_requests
3477            .push(responder_id.to_proto());
3478    }
3479    for connection_id in pool.user_connection_ids(requester_id) {
3480        session.peer.send(connection_id, update.clone())?;
3481    }
3482
3483    // Update incoming contact requests of responder
3484    let mut update = proto::UpdateContacts::default();
3485    if contact_accepted {
3486        update.remove_contacts.push(requester_id.to_proto());
3487    } else {
3488        update
3489            .remove_incoming_requests
3490            .push(requester_id.to_proto());
3491    }
3492    for connection_id in pool.user_connection_ids(responder_id) {
3493        session.peer.send(connection_id, update.clone())?;
3494        if let Some(notification_id) = deleted_notification_id {
3495            session.peer.send(
3496                connection_id,
3497                proto::DeleteNotification {
3498                    notification_id: notification_id.to_proto(),
3499                },
3500            )?;
3501        }
3502    }
3503
3504    response.send(proto::Ack {})?;
3505    Ok(())
3506}
3507
3508fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3509    version.0.minor() < 139
3510}
3511
3512async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3513    subscribe_user_to_channels(
3514        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3515        &session,
3516    )
3517    .await?;
3518    Ok(())
3519}
3520
3521async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3522    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3523    let mut pool = session.connection_pool().await;
3524    for membership in &channels_for_user.channel_memberships {
3525        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3526    }
3527    session.peer.send(
3528        session.connection_id,
3529        build_update_user_channels(&channels_for_user),
3530    )?;
3531    session.peer.send(
3532        session.connection_id,
3533        build_channels_update(channels_for_user),
3534    )?;
3535    Ok(())
3536}
3537
3538/// Creates a new channel.
3539async fn create_channel(
3540    request: proto::CreateChannel,
3541    response: Response<proto::CreateChannel>,
3542    session: UserSession,
3543) -> Result<()> {
3544    let db = session.db().await;
3545
3546    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3547    let (channel, membership) = db
3548        .create_channel(&request.name, parent_id, session.user_id())
3549        .await?;
3550
3551    let root_id = channel.root_id();
3552    let channel = Channel::from_model(channel);
3553
3554    response.send(proto::CreateChannelResponse {
3555        channel: Some(channel.to_proto()),
3556        parent_id: request.parent_id,
3557    })?;
3558
3559    let mut connection_pool = session.connection_pool().await;
3560    if let Some(membership) = membership {
3561        connection_pool.subscribe_to_channel(
3562            membership.user_id,
3563            membership.channel_id,
3564            membership.role,
3565        );
3566        let update = proto::UpdateUserChannels {
3567            channel_memberships: vec![proto::ChannelMembership {
3568                channel_id: membership.channel_id.to_proto(),
3569                role: membership.role.into(),
3570            }],
3571            ..Default::default()
3572        };
3573        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3574            session.peer.send(connection_id, update.clone())?;
3575        }
3576    }
3577
3578    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3579        if !role.can_see_channel(channel.visibility) {
3580            continue;
3581        }
3582
3583        let update = proto::UpdateChannels {
3584            channels: vec![channel.to_proto()],
3585            ..Default::default()
3586        };
3587        session.peer.send(connection_id, update.clone())?;
3588    }
3589
3590    Ok(())
3591}
3592
3593/// Delete a channel
3594async fn delete_channel(
3595    request: proto::DeleteChannel,
3596    response: Response<proto::DeleteChannel>,
3597    session: UserSession,
3598) -> Result<()> {
3599    let db = session.db().await;
3600
3601    let channel_id = request.channel_id;
3602    let (root_channel, removed_channels) = db
3603        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3604        .await?;
3605    response.send(proto::Ack {})?;
3606
3607    // Notify members of removed channels
3608    let mut update = proto::UpdateChannels::default();
3609    update
3610        .delete_channels
3611        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3612
3613    let connection_pool = session.connection_pool().await;
3614    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3615        session.peer.send(connection_id, update.clone())?;
3616    }
3617
3618    Ok(())
3619}
3620
3621/// Invite someone to join a channel.
3622async fn invite_channel_member(
3623    request: proto::InviteChannelMember,
3624    response: Response<proto::InviteChannelMember>,
3625    session: UserSession,
3626) -> Result<()> {
3627    let db = session.db().await;
3628    let channel_id = ChannelId::from_proto(request.channel_id);
3629    let invitee_id = UserId::from_proto(request.user_id);
3630    let InviteMemberResult {
3631        channel,
3632        notifications,
3633    } = db
3634        .invite_channel_member(
3635            channel_id,
3636            invitee_id,
3637            session.user_id(),
3638            request.role().into(),
3639        )
3640        .await?;
3641
3642    let update = proto::UpdateChannels {
3643        channel_invitations: vec![channel.to_proto()],
3644        ..Default::default()
3645    };
3646
3647    let connection_pool = session.connection_pool().await;
3648    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3649        session.peer.send(connection_id, update.clone())?;
3650    }
3651
3652    send_notifications(&connection_pool, &session.peer, notifications);
3653
3654    response.send(proto::Ack {})?;
3655    Ok(())
3656}
3657
3658/// remove someone from a channel
3659async fn remove_channel_member(
3660    request: proto::RemoveChannelMember,
3661    response: Response<proto::RemoveChannelMember>,
3662    session: UserSession,
3663) -> Result<()> {
3664    let db = session.db().await;
3665    let channel_id = ChannelId::from_proto(request.channel_id);
3666    let member_id = UserId::from_proto(request.user_id);
3667
3668    let RemoveChannelMemberResult {
3669        membership_update,
3670        notification_id,
3671    } = db
3672        .remove_channel_member(channel_id, member_id, session.user_id())
3673        .await?;
3674
3675    let mut connection_pool = session.connection_pool().await;
3676    notify_membership_updated(
3677        &mut connection_pool,
3678        membership_update,
3679        member_id,
3680        &session.peer,
3681    );
3682    for connection_id in connection_pool.user_connection_ids(member_id) {
3683        if let Some(notification_id) = notification_id {
3684            session
3685                .peer
3686                .send(
3687                    connection_id,
3688                    proto::DeleteNotification {
3689                        notification_id: notification_id.to_proto(),
3690                    },
3691                )
3692                .trace_err();
3693        }
3694    }
3695
3696    response.send(proto::Ack {})?;
3697    Ok(())
3698}
3699
3700/// Toggle the channel between public and private.
3701/// Care is taken to maintain the invariant that public channels only descend from public channels,
3702/// (though members-only channels can appear at any point in the hierarchy).
3703async fn set_channel_visibility(
3704    request: proto::SetChannelVisibility,
3705    response: Response<proto::SetChannelVisibility>,
3706    session: UserSession,
3707) -> Result<()> {
3708    let db = session.db().await;
3709    let channel_id = ChannelId::from_proto(request.channel_id);
3710    let visibility = request.visibility().into();
3711
3712    let channel_model = db
3713        .set_channel_visibility(channel_id, visibility, session.user_id())
3714        .await?;
3715    let root_id = channel_model.root_id();
3716    let channel = Channel::from_model(channel_model);
3717
3718    let mut connection_pool = session.connection_pool().await;
3719    for (user_id, role) in connection_pool
3720        .channel_user_ids(root_id)
3721        .collect::<Vec<_>>()
3722        .into_iter()
3723    {
3724        let update = if role.can_see_channel(channel.visibility) {
3725            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3726            proto::UpdateChannels {
3727                channels: vec![channel.to_proto()],
3728                ..Default::default()
3729            }
3730        } else {
3731            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3732            proto::UpdateChannels {
3733                delete_channels: vec![channel.id.to_proto()],
3734                ..Default::default()
3735            }
3736        };
3737
3738        for connection_id in connection_pool.user_connection_ids(user_id) {
3739            session.peer.send(connection_id, update.clone())?;
3740        }
3741    }
3742
3743    response.send(proto::Ack {})?;
3744    Ok(())
3745}
3746
3747/// Alter the role for a user in the channel.
3748async fn set_channel_member_role(
3749    request: proto::SetChannelMemberRole,
3750    response: Response<proto::SetChannelMemberRole>,
3751    session: UserSession,
3752) -> Result<()> {
3753    let db = session.db().await;
3754    let channel_id = ChannelId::from_proto(request.channel_id);
3755    let member_id = UserId::from_proto(request.user_id);
3756    let result = db
3757        .set_channel_member_role(
3758            channel_id,
3759            session.user_id(),
3760            member_id,
3761            request.role().into(),
3762        )
3763        .await?;
3764
3765    match result {
3766        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3767            let mut connection_pool = session.connection_pool().await;
3768            notify_membership_updated(
3769                &mut connection_pool,
3770                membership_update,
3771                member_id,
3772                &session.peer,
3773            )
3774        }
3775        db::SetMemberRoleResult::InviteUpdated(channel) => {
3776            let update = proto::UpdateChannels {
3777                channel_invitations: vec![channel.to_proto()],
3778                ..Default::default()
3779            };
3780
3781            for connection_id in session
3782                .connection_pool()
3783                .await
3784                .user_connection_ids(member_id)
3785            {
3786                session.peer.send(connection_id, update.clone())?;
3787            }
3788        }
3789    }
3790
3791    response.send(proto::Ack {})?;
3792    Ok(())
3793}
3794
3795/// Change the name of a channel
3796async fn rename_channel(
3797    request: proto::RenameChannel,
3798    response: Response<proto::RenameChannel>,
3799    session: UserSession,
3800) -> Result<()> {
3801    let db = session.db().await;
3802    let channel_id = ChannelId::from_proto(request.channel_id);
3803    let channel_model = db
3804        .rename_channel(channel_id, session.user_id(), &request.name)
3805        .await?;
3806    let root_id = channel_model.root_id();
3807    let channel = Channel::from_model(channel_model);
3808
3809    response.send(proto::RenameChannelResponse {
3810        channel: Some(channel.to_proto()),
3811    })?;
3812
3813    let connection_pool = session.connection_pool().await;
3814    let update = proto::UpdateChannels {
3815        channels: vec![channel.to_proto()],
3816        ..Default::default()
3817    };
3818    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3819        if role.can_see_channel(channel.visibility) {
3820            session.peer.send(connection_id, update.clone())?;
3821        }
3822    }
3823
3824    Ok(())
3825}
3826
3827/// Move a channel to a new parent.
3828async fn move_channel(
3829    request: proto::MoveChannel,
3830    response: Response<proto::MoveChannel>,
3831    session: UserSession,
3832) -> Result<()> {
3833    let channel_id = ChannelId::from_proto(request.channel_id);
3834    let to = ChannelId::from_proto(request.to);
3835
3836    let (root_id, channels) = session
3837        .db()
3838        .await
3839        .move_channel(channel_id, to, session.user_id())
3840        .await?;
3841
3842    let connection_pool = session.connection_pool().await;
3843    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3844        let channels = channels
3845            .iter()
3846            .filter_map(|channel| {
3847                if role.can_see_channel(channel.visibility) {
3848                    Some(channel.to_proto())
3849                } else {
3850                    None
3851                }
3852            })
3853            .collect::<Vec<_>>();
3854        if channels.is_empty() {
3855            continue;
3856        }
3857
3858        let update = proto::UpdateChannels {
3859            channels,
3860            ..Default::default()
3861        };
3862
3863        session.peer.send(connection_id, update.clone())?;
3864    }
3865
3866    response.send(Ack {})?;
3867    Ok(())
3868}
3869
3870/// Get the list of channel members
3871async fn get_channel_members(
3872    request: proto::GetChannelMembers,
3873    response: Response<proto::GetChannelMembers>,
3874    session: UserSession,
3875) -> Result<()> {
3876    let db = session.db().await;
3877    let channel_id = ChannelId::from_proto(request.channel_id);
3878    let limit = if request.limit == 0 {
3879        u16::MAX as u64
3880    } else {
3881        request.limit
3882    };
3883    let (members, users) = db
3884        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3885        .await?;
3886    response.send(proto::GetChannelMembersResponse { members, users })?;
3887    Ok(())
3888}
3889
3890/// Accept or decline a channel invitation.
3891async fn respond_to_channel_invite(
3892    request: proto::RespondToChannelInvite,
3893    response: Response<proto::RespondToChannelInvite>,
3894    session: UserSession,
3895) -> Result<()> {
3896    let db = session.db().await;
3897    let channel_id = ChannelId::from_proto(request.channel_id);
3898    let RespondToChannelInvite {
3899        membership_update,
3900        notifications,
3901    } = db
3902        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3903        .await?;
3904
3905    let mut connection_pool = session.connection_pool().await;
3906    if let Some(membership_update) = membership_update {
3907        notify_membership_updated(
3908            &mut connection_pool,
3909            membership_update,
3910            session.user_id(),
3911            &session.peer,
3912        );
3913    } else {
3914        let update = proto::UpdateChannels {
3915            remove_channel_invitations: vec![channel_id.to_proto()],
3916            ..Default::default()
3917        };
3918
3919        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3920            session.peer.send(connection_id, update.clone())?;
3921        }
3922    };
3923
3924    send_notifications(&connection_pool, &session.peer, notifications);
3925
3926    response.send(proto::Ack {})?;
3927
3928    Ok(())
3929}
3930
3931/// Join the channels' room
3932async fn join_channel(
3933    request: proto::JoinChannel,
3934    response: Response<proto::JoinChannel>,
3935    session: UserSession,
3936) -> Result<()> {
3937    let channel_id = ChannelId::from_proto(request.channel_id);
3938    join_channel_internal(channel_id, Box::new(response), session).await
3939}
3940
3941trait JoinChannelInternalResponse {
3942    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3943}
3944impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3945    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3946        Response::<proto::JoinChannel>::send(self, result)
3947    }
3948}
3949impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3950    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3951        Response::<proto::JoinRoom>::send(self, result)
3952    }
3953}
3954
3955async fn join_channel_internal(
3956    channel_id: ChannelId,
3957    response: Box<impl JoinChannelInternalResponse>,
3958    session: UserSession,
3959) -> Result<()> {
3960    let joined_room = {
3961        let mut db = session.db().await;
3962        // If zed quits without leaving the room, and the user re-opens zed before the
3963        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3964        // room they were in.
3965        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3966            tracing::info!(
3967                stale_connection_id = %connection,
3968                "cleaning up stale connection",
3969            );
3970            drop(db);
3971            leave_room_for_session(&session, connection).await?;
3972            db = session.db().await;
3973        }
3974
3975        let (joined_room, membership_updated, role) = db
3976            .join_channel(channel_id, session.user_id(), session.connection_id)
3977            .await?;
3978
3979        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3980            let (can_publish, token) = if role == ChannelRole::Guest {
3981                (
3982                    false,
3983                    live_kit
3984                        .guest_token(
3985                            &joined_room.room.live_kit_room,
3986                            &session.user_id().to_string(),
3987                        )
3988                        .trace_err()?,
3989                )
3990            } else {
3991                (
3992                    true,
3993                    live_kit
3994                        .room_token(
3995                            &joined_room.room.live_kit_room,
3996                            &session.user_id().to_string(),
3997                        )
3998                        .trace_err()?,
3999                )
4000            };
4001
4002            Some(LiveKitConnectionInfo {
4003                server_url: live_kit.url().into(),
4004                token,
4005                can_publish,
4006            })
4007        });
4008
4009        response.send(proto::JoinRoomResponse {
4010            room: Some(joined_room.room.clone()),
4011            channel_id: joined_room
4012                .channel
4013                .as_ref()
4014                .map(|channel| channel.id.to_proto()),
4015            live_kit_connection_info,
4016        })?;
4017
4018        let mut connection_pool = session.connection_pool().await;
4019        if let Some(membership_updated) = membership_updated {
4020            notify_membership_updated(
4021                &mut connection_pool,
4022                membership_updated,
4023                session.user_id(),
4024                &session.peer,
4025            );
4026        }
4027
4028        room_updated(&joined_room.room, &session.peer);
4029
4030        joined_room
4031    };
4032
4033    channel_updated(
4034        &joined_room
4035            .channel
4036            .ok_or_else(|| anyhow!("channel not returned"))?,
4037        &joined_room.room,
4038        &session.peer,
4039        &*session.connection_pool().await,
4040    );
4041
4042    update_user_contacts(session.user_id(), &session).await?;
4043    Ok(())
4044}
4045
4046/// Start editing the channel notes
4047async fn join_channel_buffer(
4048    request: proto::JoinChannelBuffer,
4049    response: Response<proto::JoinChannelBuffer>,
4050    session: UserSession,
4051) -> Result<()> {
4052    let db = session.db().await;
4053    let channel_id = ChannelId::from_proto(request.channel_id);
4054
4055    let open_response = db
4056        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4057        .await?;
4058
4059    let collaborators = open_response.collaborators.clone();
4060    response.send(open_response)?;
4061
4062    let update = UpdateChannelBufferCollaborators {
4063        channel_id: channel_id.to_proto(),
4064        collaborators: collaborators.clone(),
4065    };
4066    channel_buffer_updated(
4067        session.connection_id,
4068        collaborators
4069            .iter()
4070            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4071        &update,
4072        &session.peer,
4073    );
4074
4075    Ok(())
4076}
4077
4078/// Edit the channel notes
4079async fn update_channel_buffer(
4080    request: proto::UpdateChannelBuffer,
4081    session: UserSession,
4082) -> Result<()> {
4083    let db = session.db().await;
4084    let channel_id = ChannelId::from_proto(request.channel_id);
4085
4086    let (collaborators, epoch, version) = db
4087        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4088        .await?;
4089
4090    channel_buffer_updated(
4091        session.connection_id,
4092        collaborators.clone(),
4093        &proto::UpdateChannelBuffer {
4094            channel_id: channel_id.to_proto(),
4095            operations: request.operations,
4096        },
4097        &session.peer,
4098    );
4099
4100    let pool = &*session.connection_pool().await;
4101
4102    let non_collaborators =
4103        pool.channel_connection_ids(channel_id)
4104            .filter_map(|(connection_id, _)| {
4105                if collaborators.contains(&connection_id) {
4106                    None
4107                } else {
4108                    Some(connection_id)
4109                }
4110            });
4111
4112    broadcast(None, non_collaborators, |peer_id| {
4113        session.peer.send(
4114            peer_id,
4115            proto::UpdateChannels {
4116                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4117                    channel_id: channel_id.to_proto(),
4118                    epoch: epoch as u64,
4119                    version: version.clone(),
4120                }],
4121                ..Default::default()
4122            },
4123        )
4124    });
4125
4126    Ok(())
4127}
4128
4129/// Rejoin the channel notes after a connection blip
4130async fn rejoin_channel_buffers(
4131    request: proto::RejoinChannelBuffers,
4132    response: Response<proto::RejoinChannelBuffers>,
4133    session: UserSession,
4134) -> Result<()> {
4135    let db = session.db().await;
4136    let buffers = db
4137        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4138        .await?;
4139
4140    for rejoined_buffer in &buffers {
4141        let collaborators_to_notify = rejoined_buffer
4142            .buffer
4143            .collaborators
4144            .iter()
4145            .filter_map(|c| Some(c.peer_id?.into()));
4146        channel_buffer_updated(
4147            session.connection_id,
4148            collaborators_to_notify,
4149            &proto::UpdateChannelBufferCollaborators {
4150                channel_id: rejoined_buffer.buffer.channel_id,
4151                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4152            },
4153            &session.peer,
4154        );
4155    }
4156
4157    response.send(proto::RejoinChannelBuffersResponse {
4158        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4159    })?;
4160
4161    Ok(())
4162}
4163
4164/// Stop editing the channel notes
4165async fn leave_channel_buffer(
4166    request: proto::LeaveChannelBuffer,
4167    response: Response<proto::LeaveChannelBuffer>,
4168    session: UserSession,
4169) -> Result<()> {
4170    let db = session.db().await;
4171    let channel_id = ChannelId::from_proto(request.channel_id);
4172
4173    let left_buffer = db
4174        .leave_channel_buffer(channel_id, session.connection_id)
4175        .await?;
4176
4177    response.send(Ack {})?;
4178
4179    channel_buffer_updated(
4180        session.connection_id,
4181        left_buffer.connections,
4182        &proto::UpdateChannelBufferCollaborators {
4183            channel_id: channel_id.to_proto(),
4184            collaborators: left_buffer.collaborators,
4185        },
4186        &session.peer,
4187    );
4188
4189    Ok(())
4190}
4191
4192fn channel_buffer_updated<T: EnvelopedMessage>(
4193    sender_id: ConnectionId,
4194    collaborators: impl IntoIterator<Item = ConnectionId>,
4195    message: &T,
4196    peer: &Peer,
4197) {
4198    broadcast(Some(sender_id), collaborators, |peer_id| {
4199        peer.send(peer_id, message.clone())
4200    });
4201}
4202
4203fn send_notifications(
4204    connection_pool: &ConnectionPool,
4205    peer: &Peer,
4206    notifications: db::NotificationBatch,
4207) {
4208    for (user_id, notification) in notifications {
4209        for connection_id in connection_pool.user_connection_ids(user_id) {
4210            if let Err(error) = peer.send(
4211                connection_id,
4212                proto::AddNotification {
4213                    notification: Some(notification.clone()),
4214                },
4215            ) {
4216                tracing::error!(
4217                    "failed to send notification to {:?} {}",
4218                    connection_id,
4219                    error
4220                );
4221            }
4222        }
4223    }
4224}
4225
4226/// Send a message to the channel
4227async fn send_channel_message(
4228    request: proto::SendChannelMessage,
4229    response: Response<proto::SendChannelMessage>,
4230    session: UserSession,
4231) -> Result<()> {
4232    // Validate the message body.
4233    let body = request.body.trim().to_string();
4234    if body.len() > MAX_MESSAGE_LEN {
4235        return Err(anyhow!("message is too long"))?;
4236    }
4237    if body.is_empty() {
4238        return Err(anyhow!("message can't be blank"))?;
4239    }
4240
4241    // TODO: adjust mentions if body is trimmed
4242
4243    let timestamp = OffsetDateTime::now_utc();
4244    let nonce = request
4245        .nonce
4246        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4247
4248    let channel_id = ChannelId::from_proto(request.channel_id);
4249    let CreatedChannelMessage {
4250        message_id,
4251        participant_connection_ids,
4252        notifications,
4253    } = session
4254        .db()
4255        .await
4256        .create_channel_message(
4257            channel_id,
4258            session.user_id(),
4259            &body,
4260            &request.mentions,
4261            timestamp,
4262            nonce.clone().into(),
4263            match request.reply_to_message_id {
4264                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4265                None => None,
4266            },
4267        )
4268        .await?;
4269
4270    let message = proto::ChannelMessage {
4271        sender_id: session.user_id().to_proto(),
4272        id: message_id.to_proto(),
4273        body,
4274        mentions: request.mentions,
4275        timestamp: timestamp.unix_timestamp() as u64,
4276        nonce: Some(nonce),
4277        reply_to_message_id: request.reply_to_message_id,
4278        edited_at: None,
4279    };
4280    broadcast(
4281        Some(session.connection_id),
4282        participant_connection_ids.clone(),
4283        |connection| {
4284            session.peer.send(
4285                connection,
4286                proto::ChannelMessageSent {
4287                    channel_id: channel_id.to_proto(),
4288                    message: Some(message.clone()),
4289                },
4290            )
4291        },
4292    );
4293    response.send(proto::SendChannelMessageResponse {
4294        message: Some(message),
4295    })?;
4296
4297    let pool = &*session.connection_pool().await;
4298    let non_participants =
4299        pool.channel_connection_ids(channel_id)
4300            .filter_map(|(connection_id, _)| {
4301                if participant_connection_ids.contains(&connection_id) {
4302                    None
4303                } else {
4304                    Some(connection_id)
4305                }
4306            });
4307    broadcast(None, non_participants, |peer_id| {
4308        session.peer.send(
4309            peer_id,
4310            proto::UpdateChannels {
4311                latest_channel_message_ids: vec![proto::ChannelMessageId {
4312                    channel_id: channel_id.to_proto(),
4313                    message_id: message_id.to_proto(),
4314                }],
4315                ..Default::default()
4316            },
4317        )
4318    });
4319    send_notifications(pool, &session.peer, notifications);
4320
4321    Ok(())
4322}
4323
4324/// Delete a channel message
4325async fn remove_channel_message(
4326    request: proto::RemoveChannelMessage,
4327    response: Response<proto::RemoveChannelMessage>,
4328    session: UserSession,
4329) -> Result<()> {
4330    let channel_id = ChannelId::from_proto(request.channel_id);
4331    let message_id = MessageId::from_proto(request.message_id);
4332    let (connection_ids, existing_notification_ids) = session
4333        .db()
4334        .await
4335        .remove_channel_message(channel_id, message_id, session.user_id())
4336        .await?;
4337
4338    broadcast(
4339        Some(session.connection_id),
4340        connection_ids,
4341        move |connection| {
4342            session.peer.send(connection, request.clone())?;
4343
4344            for notification_id in &existing_notification_ids {
4345                session.peer.send(
4346                    connection,
4347                    proto::DeleteNotification {
4348                        notification_id: (*notification_id).to_proto(),
4349                    },
4350                )?;
4351            }
4352
4353            Ok(())
4354        },
4355    );
4356    response.send(proto::Ack {})?;
4357    Ok(())
4358}
4359
4360async fn update_channel_message(
4361    request: proto::UpdateChannelMessage,
4362    response: Response<proto::UpdateChannelMessage>,
4363    session: UserSession,
4364) -> Result<()> {
4365    let channel_id = ChannelId::from_proto(request.channel_id);
4366    let message_id = MessageId::from_proto(request.message_id);
4367    let updated_at = OffsetDateTime::now_utc();
4368    let UpdatedChannelMessage {
4369        message_id,
4370        participant_connection_ids,
4371        notifications,
4372        reply_to_message_id,
4373        timestamp,
4374        deleted_mention_notification_ids,
4375        updated_mention_notifications,
4376    } = session
4377        .db()
4378        .await
4379        .update_channel_message(
4380            channel_id,
4381            message_id,
4382            session.user_id(),
4383            request.body.as_str(),
4384            &request.mentions,
4385            updated_at,
4386        )
4387        .await?;
4388
4389    let nonce = request
4390        .nonce
4391        .clone()
4392        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4393
4394    let message = proto::ChannelMessage {
4395        sender_id: session.user_id().to_proto(),
4396        id: message_id.to_proto(),
4397        body: request.body.clone(),
4398        mentions: request.mentions.clone(),
4399        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4400        nonce: Some(nonce),
4401        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4402        edited_at: Some(updated_at.unix_timestamp() as u64),
4403    };
4404
4405    response.send(proto::Ack {})?;
4406
4407    let pool = &*session.connection_pool().await;
4408    broadcast(
4409        Some(session.connection_id),
4410        participant_connection_ids,
4411        |connection| {
4412            session.peer.send(
4413                connection,
4414                proto::ChannelMessageUpdate {
4415                    channel_id: channel_id.to_proto(),
4416                    message: Some(message.clone()),
4417                },
4418            )?;
4419
4420            for notification_id in &deleted_mention_notification_ids {
4421                session.peer.send(
4422                    connection,
4423                    proto::DeleteNotification {
4424                        notification_id: (*notification_id).to_proto(),
4425                    },
4426                )?;
4427            }
4428
4429            for notification in &updated_mention_notifications {
4430                session.peer.send(
4431                    connection,
4432                    proto::UpdateNotification {
4433                        notification: Some(notification.clone()),
4434                    },
4435                )?;
4436            }
4437
4438            Ok(())
4439        },
4440    );
4441
4442    send_notifications(pool, &session.peer, notifications);
4443
4444    Ok(())
4445}
4446
4447/// Mark a channel message as read
4448async fn acknowledge_channel_message(
4449    request: proto::AckChannelMessage,
4450    session: UserSession,
4451) -> Result<()> {
4452    let channel_id = ChannelId::from_proto(request.channel_id);
4453    let message_id = MessageId::from_proto(request.message_id);
4454    let notifications = session
4455        .db()
4456        .await
4457        .observe_channel_message(channel_id, session.user_id(), message_id)
4458        .await?;
4459    send_notifications(
4460        &*session.connection_pool().await,
4461        &session.peer,
4462        notifications,
4463    );
4464    Ok(())
4465}
4466
4467/// Mark a buffer version as synced
4468async fn acknowledge_buffer_version(
4469    request: proto::AckBufferOperation,
4470    session: UserSession,
4471) -> Result<()> {
4472    let buffer_id = BufferId::from_proto(request.buffer_id);
4473    session
4474        .db()
4475        .await
4476        .observe_buffer_version(
4477            buffer_id,
4478            session.user_id(),
4479            request.epoch as i32,
4480            &request.version,
4481        )
4482        .await?;
4483    Ok(())
4484}
4485
4486struct CompleteWithLanguageModelRateLimit;
4487
4488impl RateLimit for CompleteWithLanguageModelRateLimit {
4489    fn capacity() -> usize {
4490        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4491            .ok()
4492            .and_then(|v| v.parse().ok())
4493            .unwrap_or(120) // Picked arbitrarily
4494    }
4495
4496    fn refill_duration() -> chrono::Duration {
4497        chrono::Duration::hours(1)
4498    }
4499
4500    fn db_name() -> &'static str {
4501        "complete-with-language-model"
4502    }
4503}
4504
4505async fn complete_with_language_model(
4506    query: proto::QueryLanguageModel,
4507    response: StreamingResponse<proto::QueryLanguageModel>,
4508    session: Session,
4509    open_ai_api_key: Option<Arc<str>>,
4510    google_ai_api_key: Option<Arc<str>>,
4511    anthropic_api_key: Option<Arc<str>>,
4512) -> Result<()> {
4513    let Some(session) = session.for_user() else {
4514        return Err(anyhow!("user not found"))?;
4515    };
4516    authorize_access_to_language_models(&session).await?;
4517
4518    match proto::LanguageModelRequestKind::from_i32(query.kind) {
4519        Some(proto::LanguageModelRequestKind::Complete) => {
4520            session
4521                .rate_limiter
4522                .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4523                .await?;
4524        }
4525        Some(proto::LanguageModelRequestKind::CountTokens) => {
4526            session
4527                .rate_limiter
4528                .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4529                .await?;
4530        }
4531        None => Err(anyhow!("unknown request kind"))?,
4532    }
4533
4534    match proto::LanguageModelProvider::from_i32(query.provider) {
4535        Some(proto::LanguageModelProvider::Anthropic) => {
4536            let api_key =
4537                anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
4538            let mut chunks = anthropic::stream_completion(
4539                session.http_client.as_ref(),
4540                anthropic::ANTHROPIC_API_URL,
4541                &api_key,
4542                serde_json::from_str(&query.request)?,
4543                None,
4544            )
4545            .await?;
4546            while let Some(chunk) = chunks.next().await {
4547                let chunk = chunk?;
4548                response.send(proto::QueryLanguageModelResponse {
4549                    response: serde_json::to_string(&chunk)?,
4550                })?;
4551            }
4552        }
4553        Some(proto::LanguageModelProvider::OpenAi) => {
4554            let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
4555            let mut chunks = open_ai::stream_completion(
4556                session.http_client.as_ref(),
4557                open_ai::OPEN_AI_API_URL,
4558                &api_key,
4559                serde_json::from_str(&query.request)?,
4560                None,
4561            )
4562            .await?;
4563            while let Some(chunk) = chunks.next().await {
4564                let chunk = chunk?;
4565                response.send(proto::QueryLanguageModelResponse {
4566                    response: serde_json::to_string(&chunk)?,
4567                })?;
4568            }
4569        }
4570        Some(proto::LanguageModelProvider::Google) => {
4571            let api_key =
4572                google_ai_api_key.context("no Google AI API key configured on the server")?;
4573
4574            match proto::LanguageModelRequestKind::from_i32(query.kind) {
4575                Some(proto::LanguageModelRequestKind::Complete) => {
4576                    let mut chunks = google_ai::stream_generate_content(
4577                        session.http_client.as_ref(),
4578                        google_ai::API_URL,
4579                        &api_key,
4580                        serde_json::from_str(&query.request)?,
4581                    )
4582                    .await?;
4583                    while let Some(chunk) = chunks.next().await {
4584                        let chunk = chunk?;
4585                        response.send(proto::QueryLanguageModelResponse {
4586                            response: serde_json::to_string(&chunk)?,
4587                        })?;
4588                    }
4589                }
4590                Some(proto::LanguageModelRequestKind::CountTokens) => {
4591                    let tokens_response = google_ai::count_tokens(
4592                        session.http_client.as_ref(),
4593                        google_ai::API_URL,
4594                        &api_key,
4595                        serde_json::from_str(&query.request)?,
4596                    )
4597                    .await?;
4598                    response.send(proto::QueryLanguageModelResponse {
4599                        response: serde_json::to_string(&tokens_response)?,
4600                    })?;
4601                }
4602                None => Err(anyhow!("unknown request kind"))?,
4603            }
4604        }
4605        None => return Err(anyhow!("unknown provider"))?,
4606    }
4607
4608    Ok(())
4609}
4610
4611struct CountTokensWithLanguageModelRateLimit;
4612
4613impl RateLimit for CountTokensWithLanguageModelRateLimit {
4614    fn capacity() -> usize {
4615        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4616            .ok()
4617            .and_then(|v| v.parse().ok())
4618            .unwrap_or(600) // Picked arbitrarily
4619    }
4620
4621    fn refill_duration() -> chrono::Duration {
4622        chrono::Duration::hours(1)
4623    }
4624
4625    fn db_name() -> &'static str {
4626        "count-tokens-with-language-model"
4627    }
4628}
4629
4630struct ComputeEmbeddingsRateLimit;
4631
4632impl RateLimit for ComputeEmbeddingsRateLimit {
4633    fn capacity() -> usize {
4634        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4635            .ok()
4636            .and_then(|v| v.parse().ok())
4637            .unwrap_or(5000) // Picked arbitrarily
4638    }
4639
4640    fn refill_duration() -> chrono::Duration {
4641        chrono::Duration::hours(1)
4642    }
4643
4644    fn db_name() -> &'static str {
4645        "compute-embeddings"
4646    }
4647}
4648
4649async fn compute_embeddings(
4650    request: proto::ComputeEmbeddings,
4651    response: Response<proto::ComputeEmbeddings>,
4652    session: UserSession,
4653    api_key: Option<Arc<str>>,
4654) -> Result<()> {
4655    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4656    authorize_access_to_language_models(&session).await?;
4657
4658    session
4659        .rate_limiter
4660        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4661        .await?;
4662
4663    let embeddings = match request.model.as_str() {
4664        "openai/text-embedding-3-small" => {
4665            open_ai::embed(
4666                session.http_client.as_ref(),
4667                OPEN_AI_API_URL,
4668                &api_key,
4669                OpenAiEmbeddingModel::TextEmbedding3Small,
4670                request.texts.iter().map(|text| text.as_str()),
4671            )
4672            .await?
4673        }
4674        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4675    };
4676
4677    let embeddings = request
4678        .texts
4679        .iter()
4680        .map(|text| {
4681            let mut hasher = sha2::Sha256::new();
4682            hasher.update(text.as_bytes());
4683            let result = hasher.finalize();
4684            result.to_vec()
4685        })
4686        .zip(
4687            embeddings
4688                .data
4689                .into_iter()
4690                .map(|embedding| embedding.embedding),
4691        )
4692        .collect::<HashMap<_, _>>();
4693
4694    let db = session.db().await;
4695    db.save_embeddings(&request.model, &embeddings)
4696        .await
4697        .context("failed to save embeddings")
4698        .trace_err();
4699
4700    response.send(proto::ComputeEmbeddingsResponse {
4701        embeddings: embeddings
4702            .into_iter()
4703            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4704            .collect(),
4705    })?;
4706    Ok(())
4707}
4708
4709async fn get_cached_embeddings(
4710    request: proto::GetCachedEmbeddings,
4711    response: Response<proto::GetCachedEmbeddings>,
4712    session: UserSession,
4713) -> Result<()> {
4714    authorize_access_to_language_models(&session).await?;
4715
4716    let db = session.db().await;
4717    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4718
4719    response.send(proto::GetCachedEmbeddingsResponse {
4720        embeddings: embeddings
4721            .into_iter()
4722            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4723            .collect(),
4724    })?;
4725    Ok(())
4726}
4727
4728async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4729    let db = session.db().await;
4730    let flags = db.get_user_flags(session.user_id()).await?;
4731    if flags.iter().any(|flag| flag == "language-models") {
4732        Ok(())
4733    } else {
4734        Err(anyhow!("permission denied"))?
4735    }
4736}
4737
4738/// Get a Supermaven API key for the user
4739async fn get_supermaven_api_key(
4740    _request: proto::GetSupermavenApiKey,
4741    response: Response<proto::GetSupermavenApiKey>,
4742    session: UserSession,
4743) -> Result<()> {
4744    let user_id: String = session.user_id().to_string();
4745    if !session.is_staff() {
4746        return Err(anyhow!("supermaven not enabled for this account"))?;
4747    }
4748
4749    let email = session
4750        .email()
4751        .ok_or_else(|| anyhow!("user must have an email"))?;
4752
4753    let supermaven_admin_api = session
4754        .supermaven_client
4755        .as_ref()
4756        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4757
4758    let result = supermaven_admin_api
4759        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4760        .await?;
4761
4762    response.send(proto::GetSupermavenApiKeyResponse {
4763        api_key: result.api_key,
4764    })?;
4765
4766    Ok(())
4767}
4768
4769/// Start receiving chat updates for a channel
4770async fn join_channel_chat(
4771    request: proto::JoinChannelChat,
4772    response: Response<proto::JoinChannelChat>,
4773    session: UserSession,
4774) -> Result<()> {
4775    let channel_id = ChannelId::from_proto(request.channel_id);
4776
4777    let db = session.db().await;
4778    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4779        .await?;
4780    let messages = db
4781        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4782        .await?;
4783    response.send(proto::JoinChannelChatResponse {
4784        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4785        messages,
4786    })?;
4787    Ok(())
4788}
4789
4790/// Stop receiving chat updates for a channel
4791async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4792    let channel_id = ChannelId::from_proto(request.channel_id);
4793    session
4794        .db()
4795        .await
4796        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4797        .await?;
4798    Ok(())
4799}
4800
4801/// Retrieve the chat history for a channel
4802async fn get_channel_messages(
4803    request: proto::GetChannelMessages,
4804    response: Response<proto::GetChannelMessages>,
4805    session: UserSession,
4806) -> Result<()> {
4807    let channel_id = ChannelId::from_proto(request.channel_id);
4808    let messages = session
4809        .db()
4810        .await
4811        .get_channel_messages(
4812            channel_id,
4813            session.user_id(),
4814            MESSAGE_COUNT_PER_PAGE,
4815            Some(MessageId::from_proto(request.before_message_id)),
4816        )
4817        .await?;
4818    response.send(proto::GetChannelMessagesResponse {
4819        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4820        messages,
4821    })?;
4822    Ok(())
4823}
4824
4825/// Retrieve specific chat messages
4826async fn get_channel_messages_by_id(
4827    request: proto::GetChannelMessagesById,
4828    response: Response<proto::GetChannelMessagesById>,
4829    session: UserSession,
4830) -> Result<()> {
4831    let message_ids = request
4832        .message_ids
4833        .iter()
4834        .map(|id| MessageId::from_proto(*id))
4835        .collect::<Vec<_>>();
4836    let messages = session
4837        .db()
4838        .await
4839        .get_channel_messages_by_id(session.user_id(), &message_ids)
4840        .await?;
4841    response.send(proto::GetChannelMessagesResponse {
4842        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4843        messages,
4844    })?;
4845    Ok(())
4846}
4847
4848/// Retrieve the current users notifications
4849async fn get_notifications(
4850    request: proto::GetNotifications,
4851    response: Response<proto::GetNotifications>,
4852    session: UserSession,
4853) -> Result<()> {
4854    let notifications = session
4855        .db()
4856        .await
4857        .get_notifications(
4858            session.user_id(),
4859            NOTIFICATION_COUNT_PER_PAGE,
4860            request
4861                .before_id
4862                .map(|id| db::NotificationId::from_proto(id)),
4863        )
4864        .await?;
4865    response.send(proto::GetNotificationsResponse {
4866        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4867        notifications,
4868    })?;
4869    Ok(())
4870}
4871
4872/// Mark notifications as read
4873async fn mark_notification_as_read(
4874    request: proto::MarkNotificationRead,
4875    response: Response<proto::MarkNotificationRead>,
4876    session: UserSession,
4877) -> Result<()> {
4878    let database = &session.db().await;
4879    let notifications = database
4880        .mark_notification_as_read_by_id(
4881            session.user_id(),
4882            NotificationId::from_proto(request.notification_id),
4883        )
4884        .await?;
4885    send_notifications(
4886        &*session.connection_pool().await,
4887        &session.peer,
4888        notifications,
4889    );
4890    response.send(proto::Ack {})?;
4891    Ok(())
4892}
4893
4894/// Get the current users information
4895async fn get_private_user_info(
4896    _request: proto::GetPrivateUserInfo,
4897    response: Response<proto::GetPrivateUserInfo>,
4898    session: UserSession,
4899) -> Result<()> {
4900    let db = session.db().await;
4901
4902    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4903    let user = db
4904        .get_user_by_id(session.user_id())
4905        .await?
4906        .ok_or_else(|| anyhow!("user not found"))?;
4907    let flags = db.get_user_flags(session.user_id()).await?;
4908
4909    response.send(proto::GetPrivateUserInfoResponse {
4910        metrics_id,
4911        staff: user.admin,
4912        flags,
4913    })?;
4914    Ok(())
4915}
4916
4917fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4918    let message = match message {
4919        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4920        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4921        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4922        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4923        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4924            code: frame.code.into(),
4925            reason: frame.reason,
4926        })),
4927        // We should never receive a frame while reading the message, according
4928        // to the `tungstenite` maintainers:
4929        //
4930        // > It cannot occur when you read messages from the WebSocket, but it
4931        // > can be used when you want to send the raw frames (e.g. you want to
4932        // > send the frames to the WebSocket without composing the full message first).
4933        // >
4934        // > — https://github.com/snapview/tungstenite-rs/issues/268
4935        TungsteniteMessage::Frame(_) => {
4936            bail!("received an unexpected frame while reading the message")
4937        }
4938    };
4939
4940    Ok(message)
4941}
4942
4943fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4944    match message {
4945        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4946        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4947        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4948        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4949        AxumMessage::Close(frame) => {
4950            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4951                code: frame.code.into(),
4952                reason: frame.reason,
4953            }))
4954        }
4955    }
4956}
4957
4958fn notify_membership_updated(
4959    connection_pool: &mut ConnectionPool,
4960    result: MembershipUpdated,
4961    user_id: UserId,
4962    peer: &Peer,
4963) {
4964    for membership in &result.new_channels.channel_memberships {
4965        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4966    }
4967    for channel_id in &result.removed_channels {
4968        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4969    }
4970
4971    let user_channels_update = proto::UpdateUserChannels {
4972        channel_memberships: result
4973            .new_channels
4974            .channel_memberships
4975            .iter()
4976            .map(|cm| proto::ChannelMembership {
4977                channel_id: cm.channel_id.to_proto(),
4978                role: cm.role.into(),
4979            })
4980            .collect(),
4981        ..Default::default()
4982    };
4983
4984    let mut update = build_channels_update(result.new_channels);
4985    update.delete_channels = result
4986        .removed_channels
4987        .into_iter()
4988        .map(|id| id.to_proto())
4989        .collect();
4990    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4991
4992    for connection_id in connection_pool.user_connection_ids(user_id) {
4993        peer.send(connection_id, user_channels_update.clone())
4994            .trace_err();
4995        peer.send(connection_id, update.clone()).trace_err();
4996    }
4997}
4998
4999fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5000    proto::UpdateUserChannels {
5001        channel_memberships: channels
5002            .channel_memberships
5003            .iter()
5004            .map(|m| proto::ChannelMembership {
5005                channel_id: m.channel_id.to_proto(),
5006                role: m.role.into(),
5007            })
5008            .collect(),
5009        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5010        observed_channel_message_id: channels.observed_channel_messages.clone(),
5011    }
5012}
5013
5014fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5015    let mut update = proto::UpdateChannels::default();
5016
5017    for channel in channels.channels {
5018        update.channels.push(channel.to_proto());
5019    }
5020
5021    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5022    update.latest_channel_message_ids = channels.latest_channel_messages;
5023
5024    for (channel_id, participants) in channels.channel_participants {
5025        update
5026            .channel_participants
5027            .push(proto::ChannelParticipants {
5028                channel_id: channel_id.to_proto(),
5029                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5030            });
5031    }
5032
5033    for channel in channels.invited_channels {
5034        update.channel_invitations.push(channel.to_proto());
5035    }
5036
5037    update.hosted_projects = channels.hosted_projects;
5038    update
5039}
5040
5041fn build_initial_contacts_update(
5042    contacts: Vec<db::Contact>,
5043    pool: &ConnectionPool,
5044) -> proto::UpdateContacts {
5045    let mut update = proto::UpdateContacts::default();
5046
5047    for contact in contacts {
5048        match contact {
5049            db::Contact::Accepted { user_id, busy } => {
5050                update.contacts.push(contact_for_user(user_id, busy, &pool));
5051            }
5052            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5053            db::Contact::Incoming { user_id } => {
5054                update
5055                    .incoming_requests
5056                    .push(proto::IncomingContactRequest {
5057                        requester_id: user_id.to_proto(),
5058                    })
5059            }
5060        }
5061    }
5062
5063    update
5064}
5065
5066fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5067    proto::Contact {
5068        user_id: user_id.to_proto(),
5069        online: pool.is_user_online(user_id),
5070        busy,
5071    }
5072}
5073
5074fn room_updated(room: &proto::Room, peer: &Peer) {
5075    broadcast(
5076        None,
5077        room.participants
5078            .iter()
5079            .filter_map(|participant| Some(participant.peer_id?.into())),
5080        |peer_id| {
5081            peer.send(
5082                peer_id,
5083                proto::RoomUpdated {
5084                    room: Some(room.clone()),
5085                },
5086            )
5087        },
5088    );
5089}
5090
5091fn channel_updated(
5092    channel: &db::channel::Model,
5093    room: &proto::Room,
5094    peer: &Peer,
5095    pool: &ConnectionPool,
5096) {
5097    let participants = room
5098        .participants
5099        .iter()
5100        .map(|p| p.user_id)
5101        .collect::<Vec<_>>();
5102
5103    broadcast(
5104        None,
5105        pool.channel_connection_ids(channel.root_id())
5106            .filter_map(|(channel_id, role)| {
5107                role.can_see_channel(channel.visibility).then(|| channel_id)
5108            }),
5109        |peer_id| {
5110            peer.send(
5111                peer_id,
5112                proto::UpdateChannels {
5113                    channel_participants: vec![proto::ChannelParticipants {
5114                        channel_id: channel.id.to_proto(),
5115                        participant_user_ids: participants.clone(),
5116                    }],
5117                    ..Default::default()
5118                },
5119            )
5120        },
5121    );
5122}
5123
5124async fn send_dev_server_projects_update(
5125    user_id: UserId,
5126    mut status: proto::DevServerProjectsUpdate,
5127    session: &Session,
5128) {
5129    let pool = session.connection_pool().await;
5130    for dev_server in &mut status.dev_servers {
5131        dev_server.status =
5132            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5133    }
5134    let connections = pool.user_connection_ids(user_id);
5135    for connection_id in connections {
5136        session.peer.send(connection_id, status.clone()).trace_err();
5137    }
5138}
5139
5140async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5141    let db = session.db().await;
5142
5143    let contacts = db.get_contacts(user_id).await?;
5144    let busy = db.is_user_busy(user_id).await?;
5145
5146    let pool = session.connection_pool().await;
5147    let updated_contact = contact_for_user(user_id, busy, &pool);
5148    for contact in contacts {
5149        if let db::Contact::Accepted {
5150            user_id: contact_user_id,
5151            ..
5152        } = contact
5153        {
5154            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5155                session
5156                    .peer
5157                    .send(
5158                        contact_conn_id,
5159                        proto::UpdateContacts {
5160                            contacts: vec![updated_contact.clone()],
5161                            remove_contacts: Default::default(),
5162                            incoming_requests: Default::default(),
5163                            remove_incoming_requests: Default::default(),
5164                            outgoing_requests: Default::default(),
5165                            remove_outgoing_requests: Default::default(),
5166                        },
5167                    )
5168                    .trace_err();
5169            }
5170        }
5171    }
5172    Ok(())
5173}
5174
5175async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5176    log::info!("lost dev server connection, unsharing projects");
5177    let project_ids = session
5178        .db()
5179        .await
5180        .get_stale_dev_server_projects(session.connection_id)
5181        .await?;
5182
5183    for project_id in project_ids {
5184        // not unshare re-checks the connection ids match, so we get away with no transaction
5185        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5186    }
5187
5188    let user_id = session.dev_server().user_id;
5189    let update = session
5190        .db()
5191        .await
5192        .dev_server_projects_update(user_id)
5193        .await?;
5194
5195    send_dev_server_projects_update(user_id, update, session).await;
5196
5197    Ok(())
5198}
5199
5200async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5201    let mut contacts_to_update = HashSet::default();
5202
5203    let room_id;
5204    let canceled_calls_to_user_ids;
5205    let live_kit_room;
5206    let delete_live_kit_room;
5207    let room;
5208    let channel;
5209
5210    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5211        contacts_to_update.insert(session.user_id());
5212
5213        for project in left_room.left_projects.values() {
5214            project_left(project, session);
5215        }
5216
5217        room_id = RoomId::from_proto(left_room.room.id);
5218        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5219        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5220        delete_live_kit_room = left_room.deleted;
5221        room = mem::take(&mut left_room.room);
5222        channel = mem::take(&mut left_room.channel);
5223
5224        room_updated(&room, &session.peer);
5225    } else {
5226        return Ok(());
5227    }
5228
5229    if let Some(channel) = channel {
5230        channel_updated(
5231            &channel,
5232            &room,
5233            &session.peer,
5234            &*session.connection_pool().await,
5235        );
5236    }
5237
5238    {
5239        let pool = session.connection_pool().await;
5240        for canceled_user_id in canceled_calls_to_user_ids {
5241            for connection_id in pool.user_connection_ids(canceled_user_id) {
5242                session
5243                    .peer
5244                    .send(
5245                        connection_id,
5246                        proto::CallCanceled {
5247                            room_id: room_id.to_proto(),
5248                        },
5249                    )
5250                    .trace_err();
5251            }
5252            contacts_to_update.insert(canceled_user_id);
5253        }
5254    }
5255
5256    for contact_user_id in contacts_to_update {
5257        update_user_contacts(contact_user_id, &session).await?;
5258    }
5259
5260    if let Some(live_kit) = session.live_kit_client.as_ref() {
5261        live_kit
5262            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5263            .await
5264            .trace_err();
5265
5266        if delete_live_kit_room {
5267            live_kit.delete_room(live_kit_room).await.trace_err();
5268        }
5269    }
5270
5271    Ok(())
5272}
5273
5274async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5275    let left_channel_buffers = session
5276        .db()
5277        .await
5278        .leave_channel_buffers(session.connection_id)
5279        .await?;
5280
5281    for left_buffer in left_channel_buffers {
5282        channel_buffer_updated(
5283            session.connection_id,
5284            left_buffer.connections,
5285            &proto::UpdateChannelBufferCollaborators {
5286                channel_id: left_buffer.channel_id.to_proto(),
5287                collaborators: left_buffer.collaborators,
5288            },
5289            &session.peer,
5290        );
5291    }
5292
5293    Ok(())
5294}
5295
5296fn project_left(project: &db::LeftProject, session: &UserSession) {
5297    for connection_id in &project.connection_ids {
5298        if project.should_unshare {
5299            session
5300                .peer
5301                .send(
5302                    *connection_id,
5303                    proto::UnshareProject {
5304                        project_id: project.id.to_proto(),
5305                    },
5306                )
5307                .trace_err();
5308        } else {
5309            session
5310                .peer
5311                .send(
5312                    *connection_id,
5313                    proto::RemoveProjectCollaborator {
5314                        project_id: project.id.to_proto(),
5315                        peer_id: Some(session.connection_id.into()),
5316                    },
5317                )
5318                .trace_err();
5319        }
5320    }
5321}
5322
5323pub trait ResultExt {
5324    type Ok;
5325
5326    fn trace_err(self) -> Option<Self::Ok>;
5327}
5328
5329impl<T, E> ResultExt for Result<T, E>
5330where
5331    E: std::fmt::Debug,
5332{
5333    type Ok = T;
5334
5335    #[track_caller]
5336    fn trace_err(self) -> Option<T> {
5337        match self {
5338            Ok(value) => Some(value),
5339            Err(error) => {
5340                tracing::error!("{:?}", error);
5341                None
5342            }
5343        }
5344    }
5345}