rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use prometheus::{register_int_gauge, IntGauge};
  46use rpc::{
  47    proto::{
  48        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  49        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  50    },
  51    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  52};
  53use semantic_version::SemanticVersion;
  54use serde::{Serialize, Serializer};
  55use std::{
  56    any::TypeId,
  57    future::Future,
  58    marker::PhantomData,
  59    mem,
  60    net::SocketAddr,
  61    ops::{Deref, DerefMut},
  62    rc::Rc,
  63    sync::{
  64        atomic::{AtomicBool, Ordering::SeqCst},
  65        Arc, OnceLock,
  66    },
  67    time::{Duration, Instant},
  68};
  69use time::OffsetDateTime;
  70use tokio::sync::{watch, Semaphore};
  71use tower::ServiceBuilder;
  72use tracing::{
  73    field::{self},
  74    info_span, instrument, Instrument,
  75};
  76use util::http::IsahcHttpClient;
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(delete_dev_server_project))
 435            .add_request_handler(user_handler(create_dev_server))
 436            .add_request_handler(user_handler(delete_dev_server))
 437            .add_request_handler(dev_server_handler(share_dev_server_project))
 438            .add_request_handler(dev_server_handler(shutdown_dev_server))
 439            .add_request_handler(dev_server_handler(reconnect_dev_server))
 440            .add_message_handler(user_message_handler(leave_project))
 441            .add_request_handler(update_project)
 442            .add_request_handler(update_worktree)
 443            .add_message_handler(start_language_server)
 444            .add_message_handler(update_language_server)
 445            .add_message_handler(update_diagnostic_summary)
 446            .add_message_handler(update_worktree_settings)
 447            .add_request_handler(user_handler(
 448                forward_read_only_project_request::<proto::GetHover>,
 449            ))
 450            .add_request_handler(user_handler(
 451                forward_read_only_project_request::<proto::GetDefinition>,
 452            ))
 453            .add_request_handler(user_handler(
 454                forward_read_only_project_request::<proto::GetTypeDefinition>,
 455            ))
 456            .add_request_handler(user_handler(
 457                forward_read_only_project_request::<proto::GetReferences>,
 458            ))
 459            .add_request_handler(user_handler(
 460                forward_read_only_project_request::<proto::SearchProject>,
 461            ))
 462            .add_request_handler(user_handler(
 463                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 464            ))
 465            .add_request_handler(user_handler(
 466                forward_read_only_project_request::<proto::GetProjectSymbols>,
 467            ))
 468            .add_request_handler(user_handler(
 469                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 470            ))
 471            .add_request_handler(user_handler(
 472                forward_read_only_project_request::<proto::OpenBufferById>,
 473            ))
 474            .add_request_handler(user_handler(
 475                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 476            ))
 477            .add_request_handler(user_handler(
 478                forward_read_only_project_request::<proto::InlayHints>,
 479            ))
 480            .add_request_handler(user_handler(
 481                forward_read_only_project_request::<proto::OpenBufferByPath>,
 482            ))
 483            .add_request_handler(user_handler(
 484                forward_mutating_project_request::<proto::GetCompletions>,
 485            ))
 486            .add_request_handler(user_handler(
 487                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 488            ))
 489            .add_request_handler(user_handler(
 490                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 491            ))
 492            .add_request_handler(user_handler(
 493                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 494            ))
 495            .add_request_handler(user_handler(
 496                forward_mutating_project_request::<proto::GetCodeActions>,
 497            ))
 498            .add_request_handler(user_handler(
 499                forward_mutating_project_request::<proto::ApplyCodeAction>,
 500            ))
 501            .add_request_handler(user_handler(
 502                forward_mutating_project_request::<proto::PrepareRename>,
 503            ))
 504            .add_request_handler(user_handler(
 505                forward_mutating_project_request::<proto::PerformRename>,
 506            ))
 507            .add_request_handler(user_handler(
 508                forward_mutating_project_request::<proto::ReloadBuffers>,
 509            ))
 510            .add_request_handler(user_handler(
 511                forward_mutating_project_request::<proto::FormatBuffers>,
 512            ))
 513            .add_request_handler(user_handler(
 514                forward_mutating_project_request::<proto::CreateProjectEntry>,
 515            ))
 516            .add_request_handler(user_handler(
 517                forward_mutating_project_request::<proto::RenameProjectEntry>,
 518            ))
 519            .add_request_handler(user_handler(
 520                forward_mutating_project_request::<proto::CopyProjectEntry>,
 521            ))
 522            .add_request_handler(user_handler(
 523                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 524            ))
 525            .add_request_handler(user_handler(
 526                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 527            ))
 528            .add_request_handler(user_handler(
 529                forward_mutating_project_request::<proto::OnTypeFormatting>,
 530            ))
 531            .add_request_handler(user_handler(
 532                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 533            ))
 534            .add_request_handler(user_handler(
 535                forward_mutating_project_request::<proto::BlameBuffer>,
 536            ))
 537            .add_request_handler(user_handler(
 538                forward_mutating_project_request::<proto::MultiLspQuery>,
 539            ))
 540            .add_message_handler(create_buffer_for_peer)
 541            .add_request_handler(update_buffer)
 542            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 543            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 544            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 545            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 546            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 547            .add_request_handler(get_users)
 548            .add_request_handler(user_handler(fuzzy_search_users))
 549            .add_request_handler(user_handler(request_contact))
 550            .add_request_handler(user_handler(remove_contact))
 551            .add_request_handler(user_handler(respond_to_contact_request))
 552            .add_request_handler(user_handler(create_channel))
 553            .add_request_handler(user_handler(delete_channel))
 554            .add_request_handler(user_handler(invite_channel_member))
 555            .add_request_handler(user_handler(remove_channel_member))
 556            .add_request_handler(user_handler(set_channel_member_role))
 557            .add_request_handler(user_handler(set_channel_visibility))
 558            .add_request_handler(user_handler(rename_channel))
 559            .add_request_handler(user_handler(join_channel_buffer))
 560            .add_request_handler(user_handler(leave_channel_buffer))
 561            .add_message_handler(user_message_handler(update_channel_buffer))
 562            .add_request_handler(user_handler(rejoin_channel_buffers))
 563            .add_request_handler(user_handler(get_channel_members))
 564            .add_request_handler(user_handler(respond_to_channel_invite))
 565            .add_request_handler(user_handler(join_channel))
 566            .add_request_handler(user_handler(join_channel_chat))
 567            .add_message_handler(user_message_handler(leave_channel_chat))
 568            .add_request_handler(user_handler(send_channel_message))
 569            .add_request_handler(user_handler(remove_channel_message))
 570            .add_request_handler(user_handler(update_channel_message))
 571            .add_request_handler(user_handler(get_channel_messages))
 572            .add_request_handler(user_handler(get_channel_messages_by_id))
 573            .add_request_handler(user_handler(get_notifications))
 574            .add_request_handler(user_handler(mark_notification_as_read))
 575            .add_request_handler(user_handler(move_channel))
 576            .add_request_handler(user_handler(follow))
 577            .add_message_handler(user_message_handler(unfollow))
 578            .add_message_handler(user_message_handler(update_followers))
 579            .add_request_handler(user_handler(get_private_user_info))
 580            .add_message_handler(user_message_handler(acknowledge_channel_message))
 581            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 582            .add_request_handler(user_handler(get_supermaven_api_key))
 583            .add_streaming_request_handler({
 584                let app_state = app_state.clone();
 585                move |request, response, session| {
 586                    complete_with_language_model(
 587                        request,
 588                        response,
 589                        session,
 590                        app_state.config.openai_api_key.clone(),
 591                        app_state.config.google_ai_api_key.clone(),
 592                        app_state.config.anthropic_api_key.clone(),
 593                    )
 594                }
 595            })
 596            .add_request_handler({
 597                let app_state = app_state.clone();
 598                user_handler(move |request, response, session| {
 599                    count_tokens_with_language_model(
 600                        request,
 601                        response,
 602                        session,
 603                        app_state.config.google_ai_api_key.clone(),
 604                    )
 605                })
 606            })
 607            .add_request_handler({
 608                user_handler(move |request, response, session| {
 609                    get_cached_embeddings(request, response, session)
 610                })
 611            })
 612            .add_request_handler({
 613                let app_state = app_state.clone();
 614                user_handler(move |request, response, session| {
 615                    compute_embeddings(
 616                        request,
 617                        response,
 618                        session,
 619                        app_state.config.openai_api_key.clone(),
 620                    )
 621                })
 622            });
 623
 624        Arc::new(server)
 625    }
 626
 627    pub async fn start(&self) -> Result<()> {
 628        let server_id = *self.id.lock();
 629        let app_state = self.app_state.clone();
 630        let peer = self.peer.clone();
 631        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 632        let pool = self.connection_pool.clone();
 633        let live_kit_client = self.app_state.live_kit_client.clone();
 634
 635        let span = info_span!("start server");
 636        self.app_state.executor.spawn_detached(
 637            async move {
 638                tracing::info!("waiting for cleanup timeout");
 639                timeout.await;
 640                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 641                if let Some((room_ids, channel_ids)) = app_state
 642                    .db
 643                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 644                    .await
 645                    .trace_err()
 646                {
 647                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 648                    tracing::info!(
 649                        stale_channel_buffer_count = channel_ids.len(),
 650                        "retrieved stale channel buffers"
 651                    );
 652
 653                    for channel_id in channel_ids {
 654                        if let Some(refreshed_channel_buffer) = app_state
 655                            .db
 656                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 657                            .await
 658                            .trace_err()
 659                        {
 660                            for connection_id in refreshed_channel_buffer.connection_ids {
 661                                peer.send(
 662                                    connection_id,
 663                                    proto::UpdateChannelBufferCollaborators {
 664                                        channel_id: channel_id.to_proto(),
 665                                        collaborators: refreshed_channel_buffer
 666                                            .collaborators
 667                                            .clone(),
 668                                    },
 669                                )
 670                                .trace_err();
 671                            }
 672                        }
 673                    }
 674
 675                    for room_id in room_ids {
 676                        let mut contacts_to_update = HashSet::default();
 677                        let mut canceled_calls_to_user_ids = Vec::new();
 678                        let mut live_kit_room = String::new();
 679                        let mut delete_live_kit_room = false;
 680
 681                        if let Some(mut refreshed_room) = app_state
 682                            .db
 683                            .clear_stale_room_participants(room_id, server_id)
 684                            .await
 685                            .trace_err()
 686                        {
 687                            tracing::info!(
 688                                room_id = room_id.0,
 689                                new_participant_count = refreshed_room.room.participants.len(),
 690                                "refreshed room"
 691                            );
 692                            room_updated(&refreshed_room.room, &peer);
 693                            if let Some(channel) = refreshed_room.channel.as_ref() {
 694                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 695                            }
 696                            contacts_to_update
 697                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 698                            contacts_to_update
 699                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 700                            canceled_calls_to_user_ids =
 701                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 702                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 703                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 704                        }
 705
 706                        {
 707                            let pool = pool.lock();
 708                            for canceled_user_id in canceled_calls_to_user_ids {
 709                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 710                                    peer.send(
 711                                        connection_id,
 712                                        proto::CallCanceled {
 713                                            room_id: room_id.to_proto(),
 714                                        },
 715                                    )
 716                                    .trace_err();
 717                                }
 718                            }
 719                        }
 720
 721                        for user_id in contacts_to_update {
 722                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 723                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 724                            if let Some((busy, contacts)) = busy.zip(contacts) {
 725                                let pool = pool.lock();
 726                                let updated_contact = contact_for_user(user_id, busy, &pool);
 727                                for contact in contacts {
 728                                    if let db::Contact::Accepted {
 729                                        user_id: contact_user_id,
 730                                        ..
 731                                    } = contact
 732                                    {
 733                                        for contact_conn_id in
 734                                            pool.user_connection_ids(contact_user_id)
 735                                        {
 736                                            peer.send(
 737                                                contact_conn_id,
 738                                                proto::UpdateContacts {
 739                                                    contacts: vec![updated_contact.clone()],
 740                                                    remove_contacts: Default::default(),
 741                                                    incoming_requests: Default::default(),
 742                                                    remove_incoming_requests: Default::default(),
 743                                                    outgoing_requests: Default::default(),
 744                                                    remove_outgoing_requests: Default::default(),
 745                                                },
 746                                            )
 747                                            .trace_err();
 748                                        }
 749                                    }
 750                                }
 751                            }
 752                        }
 753
 754                        if let Some(live_kit) = live_kit_client.as_ref() {
 755                            if delete_live_kit_room {
 756                                live_kit.delete_room(live_kit_room).await.trace_err();
 757                            }
 758                        }
 759                    }
 760                }
 761
 762                app_state
 763                    .db
 764                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 765                    .await
 766                    .trace_err();
 767            }
 768            .instrument(span),
 769        );
 770        Ok(())
 771    }
 772
 773    pub fn teardown(&self) {
 774        self.peer.teardown();
 775        self.connection_pool.lock().reset();
 776        let _ = self.teardown.send(true);
 777    }
 778
 779    #[cfg(test)]
 780    pub fn reset(&self, id: ServerId) {
 781        self.teardown();
 782        *self.id.lock() = id;
 783        self.peer.reset(id.0 as u32);
 784        let _ = self.teardown.send(false);
 785    }
 786
 787    #[cfg(test)]
 788    pub fn id(&self) -> ServerId {
 789        *self.id.lock()
 790    }
 791
 792    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 793    where
 794        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 795        Fut: 'static + Send + Future<Output = Result<()>>,
 796        M: EnvelopedMessage,
 797    {
 798        let prev_handler = self.handlers.insert(
 799            TypeId::of::<M>(),
 800            Box::new(move |envelope, session| {
 801                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 802                let received_at = envelope.received_at;
 803                tracing::info!("message received");
 804                let start_time = Instant::now();
 805                let future = (handler)(*envelope, session);
 806                async move {
 807                    let result = future.await;
 808                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 809                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 810                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 811                    let payload_type = M::NAME;
 812
 813                    match result {
 814                        Err(error) => {
 815                            tracing::error!(
 816                                ?error,
 817                                total_duration_ms,
 818                                processing_duration_ms,
 819                                queue_duration_ms,
 820                                payload_type,
 821                                "error handling message"
 822                            )
 823                        }
 824                        Ok(()) => tracing::info!(
 825                            total_duration_ms,
 826                            processing_duration_ms,
 827                            queue_duration_ms,
 828                            "finished handling message"
 829                        ),
 830                    }
 831                }
 832                .boxed()
 833            }),
 834        );
 835        if prev_handler.is_some() {
 836            panic!("registered a handler for the same message twice");
 837        }
 838        self
 839    }
 840
 841    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 842    where
 843        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 844        Fut: 'static + Send + Future<Output = Result<()>>,
 845        M: EnvelopedMessage,
 846    {
 847        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 848        self
 849    }
 850
 851    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 852    where
 853        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 854        Fut: Send + Future<Output = Result<()>>,
 855        M: RequestMessage,
 856    {
 857        let handler = Arc::new(handler);
 858        self.add_handler(move |envelope, session| {
 859            let receipt = envelope.receipt();
 860            let handler = handler.clone();
 861            async move {
 862                let peer = session.peer.clone();
 863                let responded = Arc::new(AtomicBool::default());
 864                let response = Response {
 865                    peer: peer.clone(),
 866                    responded: responded.clone(),
 867                    receipt,
 868                };
 869                match (handler)(envelope.payload, response, session).await {
 870                    Ok(()) => {
 871                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 872                            Ok(())
 873                        } else {
 874                            Err(anyhow!("handler did not send a response"))?
 875                        }
 876                    }
 877                    Err(error) => {
 878                        let proto_err = match &error {
 879                            Error::Internal(err) => err.to_proto(),
 880                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 881                        };
 882                        peer.respond_with_error(receipt, proto_err)?;
 883                        Err(error)
 884                    }
 885                }
 886            }
 887        })
 888    }
 889
 890    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 891    where
 892        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 893        Fut: Send + Future<Output = Result<()>>,
 894        M: RequestMessage,
 895    {
 896        let handler = Arc::new(handler);
 897        self.add_handler(move |envelope, session| {
 898            let receipt = envelope.receipt();
 899            let handler = handler.clone();
 900            async move {
 901                let peer = session.peer.clone();
 902                let response = StreamingResponse {
 903                    peer: peer.clone(),
 904                    receipt,
 905                };
 906                match (handler)(envelope.payload, response, session).await {
 907                    Ok(()) => {
 908                        peer.end_stream(receipt)?;
 909                        Ok(())
 910                    }
 911                    Err(error) => {
 912                        let proto_err = match &error {
 913                            Error::Internal(err) => err.to_proto(),
 914                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 915                        };
 916                        peer.respond_with_error(receipt, proto_err)?;
 917                        Err(error)
 918                    }
 919                }
 920            }
 921        })
 922    }
 923
 924    #[allow(clippy::too_many_arguments)]
 925    pub fn handle_connection(
 926        self: &Arc<Self>,
 927        connection: Connection,
 928        address: String,
 929        principal: Principal,
 930        zed_version: ZedVersion,
 931        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 932        executor: Executor,
 933    ) -> impl Future<Output = ()> {
 934        let this = self.clone();
 935        let span = info_span!("handle connection", %address,
 936            connection_id=field::Empty,
 937            user_id=field::Empty,
 938            login=field::Empty,
 939            impersonator=field::Empty,
 940            dev_server_id=field::Empty
 941        );
 942        principal.update_span(&span);
 943
 944        let mut teardown = self.teardown.subscribe();
 945        async move {
 946            if *teardown.borrow() {
 947                tracing::error!("server is tearing down");
 948                return
 949            }
 950            let (connection_id, handle_io, mut incoming_rx) = this
 951                .peer
 952                .add_connection(connection, {
 953                    let executor = executor.clone();
 954                    move |duration| executor.sleep(duration)
 955                });
 956            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 957            tracing::info!("connection opened");
 958
 959            let http_client = match IsahcHttpClient::new() {
 960                Ok(http_client) => Arc::new(http_client),
 961                Err(error) => {
 962                    tracing::error!(?error, "failed to create HTTP client");
 963                    return;
 964                }
 965            };
 966
 967            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 968                Some(Arc::new(SupermavenAdminApi::new(
 969                    supermaven_admin_api_key.to_string(),
 970                    http_client.clone(),
 971                )))
 972            } else {
 973                None
 974            };
 975
 976            let session = Session {
 977                principal: principal.clone(),
 978                connection_id,
 979                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 980                peer: this.peer.clone(),
 981                connection_pool: this.connection_pool.clone(),
 982                live_kit_client: this.app_state.live_kit_client.clone(),
 983                http_client,
 984                rate_limiter: this.app_state.rate_limiter.clone(),
 985                _executor: executor.clone(),
 986                supermaven_client,
 987            };
 988
 989            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 990                tracing::error!(?error, "failed to send initial client update");
 991                return;
 992            }
 993
 994            let handle_io = handle_io.fuse();
 995            futures::pin_mut!(handle_io);
 996
 997            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 998            // This prevents deadlocks when e.g., client A performs a request to client B and
 999            // client B performs a request to client A. If both clients stop processing further
1000            // messages until their respective request completes, they won't have a chance to
1001            // respond to the other client's request and cause a deadlock.
1002            //
1003            // This arrangement ensures we will attempt to process earlier messages first, but fall
1004            // back to processing messages arrived later in the spirit of making progress.
1005            let mut foreground_message_handlers = FuturesUnordered::new();
1006            let concurrent_handlers = Arc::new(Semaphore::new(256));
1007            loop {
1008                let next_message = async {
1009                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1010                    let message = incoming_rx.next().await;
1011                    (permit, message)
1012                }.fuse();
1013                futures::pin_mut!(next_message);
1014                futures::select_biased! {
1015                    _ = teardown.changed().fuse() => return,
1016                    result = handle_io => {
1017                        if let Err(error) = result {
1018                            tracing::error!(?error, "error handling I/O");
1019                        }
1020                        break;
1021                    }
1022                    _ = foreground_message_handlers.next() => {}
1023                    next_message = next_message => {
1024                        let (permit, message) = next_message;
1025                        if let Some(message) = message {
1026                            let type_name = message.payload_type_name();
1027                            // note: we copy all the fields from the parent span so we can query them in the logs.
1028                            // (https://github.com/tokio-rs/tracing/issues/2670).
1029                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1030                                user_id=field::Empty,
1031                                login=field::Empty,
1032                                impersonator=field::Empty,
1033                                dev_server_id=field::Empty
1034                            );
1035                            principal.update_span(&span);
1036                            let span_enter = span.enter();
1037                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1038                                let is_background = message.is_background();
1039                                let handle_message = (handler)(message, session.clone());
1040                                drop(span_enter);
1041
1042                                let handle_message = async move {
1043                                    handle_message.await;
1044                                    drop(permit);
1045                                }.instrument(span);
1046                                if is_background {
1047                                    executor.spawn_detached(handle_message);
1048                                } else {
1049                                    foreground_message_handlers.push(handle_message);
1050                                }
1051                            } else {
1052                                tracing::error!("no message handler");
1053                            }
1054                        } else {
1055                            tracing::info!("connection closed");
1056                            break;
1057                        }
1058                    }
1059                }
1060            }
1061
1062            drop(foreground_message_handlers);
1063            tracing::info!("signing out");
1064            if let Err(error) = connection_lost(session, teardown, executor).await {
1065                tracing::error!(?error, "error signing out");
1066            }
1067
1068        }.instrument(span)
1069    }
1070
1071    async fn send_initial_client_update(
1072        &self,
1073        connection_id: ConnectionId,
1074        principal: &Principal,
1075        zed_version: ZedVersion,
1076        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1077        session: &Session,
1078    ) -> Result<()> {
1079        self.peer.send(
1080            connection_id,
1081            proto::Hello {
1082                peer_id: Some(connection_id.into()),
1083            },
1084        )?;
1085        tracing::info!("sent hello message");
1086        if let Some(send_connection_id) = send_connection_id.take() {
1087            let _ = send_connection_id.send(connection_id);
1088        }
1089
1090        match principal {
1091            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1092                if !user.connected_once {
1093                    self.peer.send(connection_id, proto::ShowContacts {})?;
1094                    self.app_state
1095                        .db
1096                        .set_user_connected_once(user.id, true)
1097                        .await?;
1098                }
1099
1100                let (contacts, channels_for_user, channel_invites, dev_server_projects) =
1101                    future::try_join4(
1102                        self.app_state.db.get_contacts(user.id),
1103                        self.app_state.db.get_channels_for_user(user.id),
1104                        self.app_state.db.get_channel_invites_for_user(user.id),
1105                        self.app_state.db.dev_server_projects_update(user.id),
1106                    )
1107                    .await?;
1108
1109                {
1110                    let mut pool = self.connection_pool.lock();
1111                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1112                    for membership in &channels_for_user.channel_memberships {
1113                        pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1114                    }
1115                    self.peer.send(
1116                        connection_id,
1117                        build_initial_contacts_update(contacts, &pool),
1118                    )?;
1119                    self.peer.send(
1120                        connection_id,
1121                        build_update_user_channels(&channels_for_user),
1122                    )?;
1123                    self.peer.send(
1124                        connection_id,
1125                        build_channels_update(channels_for_user, channel_invites),
1126                    )?;
1127                }
1128                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1129
1130                if let Some(incoming_call) =
1131                    self.app_state.db.incoming_call_for_user(user.id).await?
1132                {
1133                    self.peer.send(connection_id, incoming_call)?;
1134                }
1135
1136                update_user_contacts(user.id, &session).await?;
1137            }
1138            Principal::DevServer(dev_server) => {
1139                {
1140                    let mut pool = self.connection_pool.lock();
1141                    if pool.dev_server_connection_id(dev_server.id).is_some() {
1142                        return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1143                    };
1144                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1145                }
1146
1147                let projects = self
1148                    .app_state
1149                    .db
1150                    .get_projects_for_dev_server(dev_server.id)
1151                    .await?;
1152                self.peer
1153                    .send(connection_id, proto::DevServerInstructions { projects })?;
1154
1155                let status = self
1156                    .app_state
1157                    .db
1158                    .dev_server_projects_update(dev_server.user_id)
1159                    .await?;
1160                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1161            }
1162        }
1163
1164        Ok(())
1165    }
1166
1167    pub async fn invite_code_redeemed(
1168        self: &Arc<Self>,
1169        inviter_id: UserId,
1170        invitee_id: UserId,
1171    ) -> Result<()> {
1172        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1173            if let Some(code) = &user.invite_code {
1174                let pool = self.connection_pool.lock();
1175                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1176                for connection_id in pool.user_connection_ids(inviter_id) {
1177                    self.peer.send(
1178                        connection_id,
1179                        proto::UpdateContacts {
1180                            contacts: vec![invitee_contact.clone()],
1181                            ..Default::default()
1182                        },
1183                    )?;
1184                    self.peer.send(
1185                        connection_id,
1186                        proto::UpdateInviteInfo {
1187                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1188                            count: user.invite_count as u32,
1189                        },
1190                    )?;
1191                }
1192            }
1193        }
1194        Ok(())
1195    }
1196
1197    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1198        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1199            if let Some(invite_code) = &user.invite_code {
1200                let pool = self.connection_pool.lock();
1201                for connection_id in pool.user_connection_ids(user_id) {
1202                    self.peer.send(
1203                        connection_id,
1204                        proto::UpdateInviteInfo {
1205                            url: format!(
1206                                "{}{}",
1207                                self.app_state.config.invite_link_prefix, invite_code
1208                            ),
1209                            count: user.invite_count as u32,
1210                        },
1211                    )?;
1212                }
1213            }
1214        }
1215        Ok(())
1216    }
1217
1218    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1219        ServerSnapshot {
1220            connection_pool: ConnectionPoolGuard {
1221                guard: self.connection_pool.lock(),
1222                _not_send: PhantomData,
1223            },
1224            peer: &self.peer,
1225        }
1226    }
1227}
1228
1229impl<'a> Deref for ConnectionPoolGuard<'a> {
1230    type Target = ConnectionPool;
1231
1232    fn deref(&self) -> &Self::Target {
1233        &self.guard
1234    }
1235}
1236
1237impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1238    fn deref_mut(&mut self) -> &mut Self::Target {
1239        &mut self.guard
1240    }
1241}
1242
1243impl<'a> Drop for ConnectionPoolGuard<'a> {
1244    fn drop(&mut self) {
1245        #[cfg(test)]
1246        self.check_invariants();
1247    }
1248}
1249
1250fn broadcast<F>(
1251    sender_id: Option<ConnectionId>,
1252    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1253    mut f: F,
1254) where
1255    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1256{
1257    for receiver_id in receiver_ids {
1258        if Some(receiver_id) != sender_id {
1259            if let Err(error) = f(receiver_id) {
1260                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1261            }
1262        }
1263    }
1264}
1265
1266pub struct ProtocolVersion(u32);
1267
1268impl Header for ProtocolVersion {
1269    fn name() -> &'static HeaderName {
1270        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1271        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1272    }
1273
1274    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1275    where
1276        Self: Sized,
1277        I: Iterator<Item = &'i axum::http::HeaderValue>,
1278    {
1279        let version = values
1280            .next()
1281            .ok_or_else(axum::headers::Error::invalid)?
1282            .to_str()
1283            .map_err(|_| axum::headers::Error::invalid())?
1284            .parse()
1285            .map_err(|_| axum::headers::Error::invalid())?;
1286        Ok(Self(version))
1287    }
1288
1289    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1290        values.extend([self.0.to_string().parse().unwrap()]);
1291    }
1292}
1293
1294pub struct AppVersionHeader(SemanticVersion);
1295impl Header for AppVersionHeader {
1296    fn name() -> &'static HeaderName {
1297        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1298        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1299    }
1300
1301    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1302    where
1303        Self: Sized,
1304        I: Iterator<Item = &'i axum::http::HeaderValue>,
1305    {
1306        let version = values
1307            .next()
1308            .ok_or_else(axum::headers::Error::invalid)?
1309            .to_str()
1310            .map_err(|_| axum::headers::Error::invalid())?
1311            .parse()
1312            .map_err(|_| axum::headers::Error::invalid())?;
1313        Ok(Self(version))
1314    }
1315
1316    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1317        values.extend([self.0.to_string().parse().unwrap()]);
1318    }
1319}
1320
1321pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1322    Router::new()
1323        .route("/rpc", get(handle_websocket_request))
1324        .layer(
1325            ServiceBuilder::new()
1326                .layer(Extension(server.app_state.clone()))
1327                .layer(middleware::from_fn(auth::validate_header)),
1328        )
1329        .route("/metrics", get(handle_metrics))
1330        .layer(Extension(server))
1331}
1332
1333pub async fn handle_websocket_request(
1334    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1335    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1336    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1337    Extension(server): Extension<Arc<Server>>,
1338    Extension(principal): Extension<Principal>,
1339    ws: WebSocketUpgrade,
1340) -> axum::response::Response {
1341    if protocol_version != rpc::PROTOCOL_VERSION {
1342        return (
1343            StatusCode::UPGRADE_REQUIRED,
1344            "client must be upgraded".to_string(),
1345        )
1346            .into_response();
1347    }
1348
1349    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1350        return (
1351            StatusCode::UPGRADE_REQUIRED,
1352            "no version header found".to_string(),
1353        )
1354            .into_response();
1355    };
1356
1357    if !version.can_collaborate() {
1358        return (
1359            StatusCode::UPGRADE_REQUIRED,
1360            "client must be upgraded".to_string(),
1361        )
1362            .into_response();
1363    }
1364
1365    let socket_address = socket_address.to_string();
1366    ws.on_upgrade(move |socket| {
1367        let socket = socket
1368            .map_ok(to_tungstenite_message)
1369            .err_into()
1370            .with(|message| async move { Ok(to_axum_message(message)) });
1371        let connection = Connection::new(Box::pin(socket));
1372        async move {
1373            server
1374                .handle_connection(
1375                    connection,
1376                    socket_address,
1377                    principal,
1378                    version,
1379                    None,
1380                    Executor::Production,
1381                )
1382                .await;
1383        }
1384    })
1385}
1386
1387pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1388    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1389    let connections_metric = CONNECTIONS_METRIC
1390        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1391
1392    let connections = server
1393        .connection_pool
1394        .lock()
1395        .connections()
1396        .filter(|connection| !connection.admin)
1397        .count();
1398    connections_metric.set(connections as _);
1399
1400    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1401    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1402        register_int_gauge!(
1403            "shared_projects",
1404            "number of open projects with one or more guests"
1405        )
1406        .unwrap()
1407    });
1408
1409    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1410    shared_projects_metric.set(shared_projects as _);
1411
1412    let encoder = prometheus::TextEncoder::new();
1413    let metric_families = prometheus::gather();
1414    let encoded_metrics = encoder
1415        .encode_to_string(&metric_families)
1416        .map_err(|err| anyhow!("{}", err))?;
1417    Ok(encoded_metrics)
1418}
1419
1420#[instrument(err, skip(executor))]
1421async fn connection_lost(
1422    session: Session,
1423    mut teardown: watch::Receiver<bool>,
1424    executor: Executor,
1425) -> Result<()> {
1426    session.peer.disconnect(session.connection_id);
1427    session
1428        .connection_pool()
1429        .await
1430        .remove_connection(session.connection_id)?;
1431
1432    session
1433        .db()
1434        .await
1435        .connection_lost(session.connection_id)
1436        .await
1437        .trace_err();
1438
1439    futures::select_biased! {
1440        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1441            match &session.principal {
1442                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1443                    let session = session.for_user().unwrap();
1444
1445                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1446                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1447                    leave_channel_buffers_for_session(&session)
1448                        .await
1449                        .trace_err();
1450
1451                    if !session
1452                        .connection_pool()
1453                        .await
1454                        .is_user_online(session.user_id())
1455                    {
1456                        let db = session.db().await;
1457                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1458                            room_updated(&room, &session.peer);
1459                        }
1460                    }
1461
1462                    update_user_contacts(session.user_id(), &session).await?;
1463                },
1464            Principal::DevServer(_) => {
1465                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1466            },
1467        }
1468        },
1469        _ = teardown.changed().fuse() => {}
1470    }
1471
1472    Ok(())
1473}
1474
1475/// Acknowledges a ping from a client, used to keep the connection alive.
1476async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1477    response.send(proto::Ack {})?;
1478    Ok(())
1479}
1480
1481/// Creates a new room for calling (outside of channels)
1482async fn create_room(
1483    _request: proto::CreateRoom,
1484    response: Response<proto::CreateRoom>,
1485    session: UserSession,
1486) -> Result<()> {
1487    let live_kit_room = nanoid::nanoid!(30);
1488
1489    let live_kit_connection_info = util::maybe!(async {
1490        let live_kit = session.live_kit_client.as_ref();
1491        let live_kit = live_kit?;
1492        let user_id = session.user_id().to_string();
1493
1494        let token = live_kit
1495            .room_token(&live_kit_room, &user_id.to_string())
1496            .trace_err()?;
1497
1498        Some(proto::LiveKitConnectionInfo {
1499            server_url: live_kit.url().into(),
1500            token,
1501            can_publish: true,
1502        })
1503    })
1504    .await;
1505
1506    let room = session
1507        .db()
1508        .await
1509        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1510        .await?;
1511
1512    response.send(proto::CreateRoomResponse {
1513        room: Some(room.clone()),
1514        live_kit_connection_info,
1515    })?;
1516
1517    update_user_contacts(session.user_id(), &session).await?;
1518    Ok(())
1519}
1520
1521/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1522async fn join_room(
1523    request: proto::JoinRoom,
1524    response: Response<proto::JoinRoom>,
1525    session: UserSession,
1526) -> Result<()> {
1527    let room_id = RoomId::from_proto(request.id);
1528
1529    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1530
1531    if let Some(channel_id) = channel_id {
1532        return join_channel_internal(channel_id, Box::new(response), session).await;
1533    }
1534
1535    let joined_room = {
1536        let room = session
1537            .db()
1538            .await
1539            .join_room(room_id, session.user_id(), session.connection_id)
1540            .await?;
1541        room_updated(&room.room, &session.peer);
1542        room.into_inner()
1543    };
1544
1545    for connection_id in session
1546        .connection_pool()
1547        .await
1548        .user_connection_ids(session.user_id())
1549    {
1550        session
1551            .peer
1552            .send(
1553                connection_id,
1554                proto::CallCanceled {
1555                    room_id: room_id.to_proto(),
1556                },
1557            )
1558            .trace_err();
1559    }
1560
1561    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1562        if let Some(token) = live_kit
1563            .room_token(
1564                &joined_room.room.live_kit_room,
1565                &session.user_id().to_string(),
1566            )
1567            .trace_err()
1568        {
1569            Some(proto::LiveKitConnectionInfo {
1570                server_url: live_kit.url().into(),
1571                token,
1572                can_publish: true,
1573            })
1574        } else {
1575            None
1576        }
1577    } else {
1578        None
1579    };
1580
1581    response.send(proto::JoinRoomResponse {
1582        room: Some(joined_room.room),
1583        channel_id: None,
1584        live_kit_connection_info,
1585    })?;
1586
1587    update_user_contacts(session.user_id(), &session).await?;
1588    Ok(())
1589}
1590
1591/// Rejoin room is used to reconnect to a room after connection errors.
1592async fn rejoin_room(
1593    request: proto::RejoinRoom,
1594    response: Response<proto::RejoinRoom>,
1595    session: UserSession,
1596) -> Result<()> {
1597    let room;
1598    let channel;
1599    {
1600        let mut rejoined_room = session
1601            .db()
1602            .await
1603            .rejoin_room(request, session.user_id(), session.connection_id)
1604            .await?;
1605
1606        response.send(proto::RejoinRoomResponse {
1607            room: Some(rejoined_room.room.clone()),
1608            reshared_projects: rejoined_room
1609                .reshared_projects
1610                .iter()
1611                .map(|project| proto::ResharedProject {
1612                    id: project.id.to_proto(),
1613                    collaborators: project
1614                        .collaborators
1615                        .iter()
1616                        .map(|collaborator| collaborator.to_proto())
1617                        .collect(),
1618                })
1619                .collect(),
1620            rejoined_projects: rejoined_room
1621                .rejoined_projects
1622                .iter()
1623                .map(|rejoined_project| rejoined_project.to_proto())
1624                .collect(),
1625        })?;
1626        room_updated(&rejoined_room.room, &session.peer);
1627
1628        for project in &rejoined_room.reshared_projects {
1629            for collaborator in &project.collaborators {
1630                session
1631                    .peer
1632                    .send(
1633                        collaborator.connection_id,
1634                        proto::UpdateProjectCollaborator {
1635                            project_id: project.id.to_proto(),
1636                            old_peer_id: Some(project.old_connection_id.into()),
1637                            new_peer_id: Some(session.connection_id.into()),
1638                        },
1639                    )
1640                    .trace_err();
1641            }
1642
1643            broadcast(
1644                Some(session.connection_id),
1645                project
1646                    .collaborators
1647                    .iter()
1648                    .map(|collaborator| collaborator.connection_id),
1649                |connection_id| {
1650                    session.peer.forward_send(
1651                        session.connection_id,
1652                        connection_id,
1653                        proto::UpdateProject {
1654                            project_id: project.id.to_proto(),
1655                            worktrees: project.worktrees.clone(),
1656                        },
1657                    )
1658                },
1659            );
1660        }
1661
1662        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1663
1664        let rejoined_room = rejoined_room.into_inner();
1665
1666        room = rejoined_room.room;
1667        channel = rejoined_room.channel;
1668    }
1669
1670    if let Some(channel) = channel {
1671        channel_updated(
1672            &channel,
1673            &room,
1674            &session.peer,
1675            &*session.connection_pool().await,
1676        );
1677    }
1678
1679    update_user_contacts(session.user_id(), &session).await?;
1680    Ok(())
1681}
1682
1683fn notify_rejoined_projects(
1684    rejoined_projects: &mut Vec<RejoinedProject>,
1685    session: &UserSession,
1686) -> Result<()> {
1687    for project in rejoined_projects.iter() {
1688        for collaborator in &project.collaborators {
1689            session
1690                .peer
1691                .send(
1692                    collaborator.connection_id,
1693                    proto::UpdateProjectCollaborator {
1694                        project_id: project.id.to_proto(),
1695                        old_peer_id: Some(project.old_connection_id.into()),
1696                        new_peer_id: Some(session.connection_id.into()),
1697                    },
1698                )
1699                .trace_err();
1700        }
1701    }
1702
1703    for project in rejoined_projects {
1704        for worktree in mem::take(&mut project.worktrees) {
1705            #[cfg(any(test, feature = "test-support"))]
1706            const MAX_CHUNK_SIZE: usize = 2;
1707            #[cfg(not(any(test, feature = "test-support")))]
1708            const MAX_CHUNK_SIZE: usize = 256;
1709
1710            // Stream this worktree's entries.
1711            let message = proto::UpdateWorktree {
1712                project_id: project.id.to_proto(),
1713                worktree_id: worktree.id,
1714                abs_path: worktree.abs_path.clone(),
1715                root_name: worktree.root_name,
1716                updated_entries: worktree.updated_entries,
1717                removed_entries: worktree.removed_entries,
1718                scan_id: worktree.scan_id,
1719                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1720                updated_repositories: worktree.updated_repositories,
1721                removed_repositories: worktree.removed_repositories,
1722            };
1723            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1724                session.peer.send(session.connection_id, update.clone())?;
1725            }
1726
1727            // Stream this worktree's diagnostics.
1728            for summary in worktree.diagnostic_summaries {
1729                session.peer.send(
1730                    session.connection_id,
1731                    proto::UpdateDiagnosticSummary {
1732                        project_id: project.id.to_proto(),
1733                        worktree_id: worktree.id,
1734                        summary: Some(summary),
1735                    },
1736                )?;
1737            }
1738
1739            for settings_file in worktree.settings_files {
1740                session.peer.send(
1741                    session.connection_id,
1742                    proto::UpdateWorktreeSettings {
1743                        project_id: project.id.to_proto(),
1744                        worktree_id: worktree.id,
1745                        path: settings_file.path,
1746                        content: Some(settings_file.content),
1747                    },
1748                )?;
1749            }
1750        }
1751
1752        for language_server in &project.language_servers {
1753            session.peer.send(
1754                session.connection_id,
1755                proto::UpdateLanguageServer {
1756                    project_id: project.id.to_proto(),
1757                    language_server_id: language_server.id,
1758                    variant: Some(
1759                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1760                            proto::LspDiskBasedDiagnosticsUpdated {},
1761                        ),
1762                    ),
1763                },
1764            )?;
1765        }
1766    }
1767    Ok(())
1768}
1769
1770/// leave room disconnects from the room.
1771async fn leave_room(
1772    _: proto::LeaveRoom,
1773    response: Response<proto::LeaveRoom>,
1774    session: UserSession,
1775) -> Result<()> {
1776    leave_room_for_session(&session, session.connection_id).await?;
1777    response.send(proto::Ack {})?;
1778    Ok(())
1779}
1780
1781/// Updates the permissions of someone else in the room.
1782async fn set_room_participant_role(
1783    request: proto::SetRoomParticipantRole,
1784    response: Response<proto::SetRoomParticipantRole>,
1785    session: UserSession,
1786) -> Result<()> {
1787    let user_id = UserId::from_proto(request.user_id);
1788    let role = ChannelRole::from(request.role());
1789
1790    let (live_kit_room, can_publish) = {
1791        let room = session
1792            .db()
1793            .await
1794            .set_room_participant_role(
1795                session.user_id(),
1796                RoomId::from_proto(request.room_id),
1797                user_id,
1798                role,
1799            )
1800            .await?;
1801
1802        let live_kit_room = room.live_kit_room.clone();
1803        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1804        room_updated(&room, &session.peer);
1805        (live_kit_room, can_publish)
1806    };
1807
1808    if let Some(live_kit) = session.live_kit_client.as_ref() {
1809        live_kit
1810            .update_participant(
1811                live_kit_room.clone(),
1812                request.user_id.to_string(),
1813                live_kit_server::proto::ParticipantPermission {
1814                    can_subscribe: true,
1815                    can_publish,
1816                    can_publish_data: can_publish,
1817                    hidden: false,
1818                    recorder: false,
1819                },
1820            )
1821            .await
1822            .trace_err();
1823    }
1824
1825    response.send(proto::Ack {})?;
1826    Ok(())
1827}
1828
1829/// Call someone else into the current room
1830async fn call(
1831    request: proto::Call,
1832    response: Response<proto::Call>,
1833    session: UserSession,
1834) -> Result<()> {
1835    let room_id = RoomId::from_proto(request.room_id);
1836    let calling_user_id = session.user_id();
1837    let calling_connection_id = session.connection_id;
1838    let called_user_id = UserId::from_proto(request.called_user_id);
1839    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1840    if !session
1841        .db()
1842        .await
1843        .has_contact(calling_user_id, called_user_id)
1844        .await?
1845    {
1846        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1847    }
1848
1849    let incoming_call = {
1850        let (room, incoming_call) = &mut *session
1851            .db()
1852            .await
1853            .call(
1854                room_id,
1855                calling_user_id,
1856                calling_connection_id,
1857                called_user_id,
1858                initial_project_id,
1859            )
1860            .await?;
1861        room_updated(&room, &session.peer);
1862        mem::take(incoming_call)
1863    };
1864    update_user_contacts(called_user_id, &session).await?;
1865
1866    let mut calls = session
1867        .connection_pool()
1868        .await
1869        .user_connection_ids(called_user_id)
1870        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1871        .collect::<FuturesUnordered<_>>();
1872
1873    while let Some(call_response) = calls.next().await {
1874        match call_response.as_ref() {
1875            Ok(_) => {
1876                response.send(proto::Ack {})?;
1877                return Ok(());
1878            }
1879            Err(_) => {
1880                call_response.trace_err();
1881            }
1882        }
1883    }
1884
1885    {
1886        let room = session
1887            .db()
1888            .await
1889            .call_failed(room_id, called_user_id)
1890            .await?;
1891        room_updated(&room, &session.peer);
1892    }
1893    update_user_contacts(called_user_id, &session).await?;
1894
1895    Err(anyhow!("failed to ring user"))?
1896}
1897
1898/// Cancel an outgoing call.
1899async fn cancel_call(
1900    request: proto::CancelCall,
1901    response: Response<proto::CancelCall>,
1902    session: UserSession,
1903) -> Result<()> {
1904    let called_user_id = UserId::from_proto(request.called_user_id);
1905    let room_id = RoomId::from_proto(request.room_id);
1906    {
1907        let room = session
1908            .db()
1909            .await
1910            .cancel_call(room_id, session.connection_id, called_user_id)
1911            .await?;
1912        room_updated(&room, &session.peer);
1913    }
1914
1915    for connection_id in session
1916        .connection_pool()
1917        .await
1918        .user_connection_ids(called_user_id)
1919    {
1920        session
1921            .peer
1922            .send(
1923                connection_id,
1924                proto::CallCanceled {
1925                    room_id: room_id.to_proto(),
1926                },
1927            )
1928            .trace_err();
1929    }
1930    response.send(proto::Ack {})?;
1931
1932    update_user_contacts(called_user_id, &session).await?;
1933    Ok(())
1934}
1935
1936/// Decline an incoming call.
1937async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1938    let room_id = RoomId::from_proto(message.room_id);
1939    {
1940        let room = session
1941            .db()
1942            .await
1943            .decline_call(Some(room_id), session.user_id())
1944            .await?
1945            .ok_or_else(|| anyhow!("failed to decline call"))?;
1946        room_updated(&room, &session.peer);
1947    }
1948
1949    for connection_id in session
1950        .connection_pool()
1951        .await
1952        .user_connection_ids(session.user_id())
1953    {
1954        session
1955            .peer
1956            .send(
1957                connection_id,
1958                proto::CallCanceled {
1959                    room_id: room_id.to_proto(),
1960                },
1961            )
1962            .trace_err();
1963    }
1964    update_user_contacts(session.user_id(), &session).await?;
1965    Ok(())
1966}
1967
1968/// Updates other participants in the room with your current location.
1969async fn update_participant_location(
1970    request: proto::UpdateParticipantLocation,
1971    response: Response<proto::UpdateParticipantLocation>,
1972    session: UserSession,
1973) -> Result<()> {
1974    let room_id = RoomId::from_proto(request.room_id);
1975    let location = request
1976        .location
1977        .ok_or_else(|| anyhow!("invalid location"))?;
1978
1979    let db = session.db().await;
1980    let room = db
1981        .update_room_participant_location(room_id, session.connection_id, location)
1982        .await?;
1983
1984    room_updated(&room, &session.peer);
1985    response.send(proto::Ack {})?;
1986    Ok(())
1987}
1988
1989/// Share a project into the room.
1990async fn share_project(
1991    request: proto::ShareProject,
1992    response: Response<proto::ShareProject>,
1993    session: UserSession,
1994) -> Result<()> {
1995    let (project_id, room) = &*session
1996        .db()
1997        .await
1998        .share_project(
1999            RoomId::from_proto(request.room_id),
2000            session.connection_id,
2001            &request.worktrees,
2002            request
2003                .dev_server_project_id
2004                .map(|id| DevServerProjectId::from_proto(id)),
2005        )
2006        .await?;
2007    response.send(proto::ShareProjectResponse {
2008        project_id: project_id.to_proto(),
2009    })?;
2010    room_updated(&room, &session.peer);
2011
2012    Ok(())
2013}
2014
2015/// Unshare a project from the room.
2016async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2017    let project_id = ProjectId::from_proto(message.project_id);
2018    unshare_project_internal(
2019        project_id,
2020        session.connection_id,
2021        session.user_id(),
2022        &session,
2023    )
2024    .await
2025}
2026
2027async fn unshare_project_internal(
2028    project_id: ProjectId,
2029    connection_id: ConnectionId,
2030    user_id: Option<UserId>,
2031    session: &Session,
2032) -> Result<()> {
2033    let (room, guest_connection_ids) = &*session
2034        .db()
2035        .await
2036        .unshare_project(project_id, connection_id, user_id)
2037        .await?;
2038
2039    let message = proto::UnshareProject {
2040        project_id: project_id.to_proto(),
2041    };
2042
2043    broadcast(
2044        Some(connection_id),
2045        guest_connection_ids.iter().copied(),
2046        |conn_id| session.peer.send(conn_id, message.clone()),
2047    );
2048    if let Some(room) = room {
2049        room_updated(room, &session.peer);
2050    }
2051
2052    Ok(())
2053}
2054
2055/// DevServer makes a project available online
2056async fn share_dev_server_project(
2057    request: proto::ShareDevServerProject,
2058    response: Response<proto::ShareDevServerProject>,
2059    session: DevServerSession,
2060) -> Result<()> {
2061    let (dev_server_project, user_id, status) = session
2062        .db()
2063        .await
2064        .share_dev_server_project(
2065            DevServerProjectId::from_proto(request.dev_server_project_id),
2066            session.dev_server_id(),
2067            session.connection_id,
2068            &request.worktrees,
2069        )
2070        .await?;
2071    let Some(project_id) = dev_server_project.project_id else {
2072        return Err(anyhow!("failed to share remote project"))?;
2073    };
2074
2075    send_dev_server_projects_update(user_id, status, &session).await;
2076
2077    response.send(proto::ShareProjectResponse { project_id })?;
2078
2079    Ok(())
2080}
2081
2082/// Join someone elses shared project.
2083async fn join_project(
2084    request: proto::JoinProject,
2085    response: Response<proto::JoinProject>,
2086    session: UserSession,
2087) -> Result<()> {
2088    let project_id = ProjectId::from_proto(request.project_id);
2089
2090    tracing::info!(%project_id, "join project");
2091
2092    let db = session.db().await;
2093    let (project, replica_id) = &mut *db
2094        .join_project(project_id, session.connection_id, session.user_id())
2095        .await?;
2096    drop(db);
2097    tracing::info!(%project_id, "join remote project");
2098    join_project_internal(response, session, project, replica_id)
2099}
2100
2101trait JoinProjectInternalResponse {
2102    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2103}
2104impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2105    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2106        Response::<proto::JoinProject>::send(self, result)
2107    }
2108}
2109impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2110    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2111        Response::<proto::JoinHostedProject>::send(self, result)
2112    }
2113}
2114
2115fn join_project_internal(
2116    response: impl JoinProjectInternalResponse,
2117    session: UserSession,
2118    project: &mut Project,
2119    replica_id: &ReplicaId,
2120) -> Result<()> {
2121    let collaborators = project
2122        .collaborators
2123        .iter()
2124        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2125        .map(|collaborator| collaborator.to_proto())
2126        .collect::<Vec<_>>();
2127    let project_id = project.id;
2128    let guest_user_id = session.user_id();
2129
2130    let worktrees = project
2131        .worktrees
2132        .iter()
2133        .map(|(id, worktree)| proto::WorktreeMetadata {
2134            id: *id,
2135            root_name: worktree.root_name.clone(),
2136            visible: worktree.visible,
2137            abs_path: worktree.abs_path.clone(),
2138        })
2139        .collect::<Vec<_>>();
2140
2141    let add_project_collaborator = proto::AddProjectCollaborator {
2142        project_id: project_id.to_proto(),
2143        collaborator: Some(proto::Collaborator {
2144            peer_id: Some(session.connection_id.into()),
2145            replica_id: replica_id.0 as u32,
2146            user_id: guest_user_id.to_proto(),
2147        }),
2148    };
2149
2150    for collaborator in &collaborators {
2151        session
2152            .peer
2153            .send(
2154                collaborator.peer_id.unwrap().into(),
2155                add_project_collaborator.clone(),
2156            )
2157            .trace_err();
2158    }
2159
2160    // First, we send the metadata associated with each worktree.
2161    response.send(proto::JoinProjectResponse {
2162        project_id: project.id.0 as u64,
2163        worktrees: worktrees.clone(),
2164        replica_id: replica_id.0 as u32,
2165        collaborators: collaborators.clone(),
2166        language_servers: project.language_servers.clone(),
2167        role: project.role.into(),
2168        dev_server_project_id: project
2169            .dev_server_project_id
2170            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2171    })?;
2172
2173    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2174        #[cfg(any(test, feature = "test-support"))]
2175        const MAX_CHUNK_SIZE: usize = 2;
2176        #[cfg(not(any(test, feature = "test-support")))]
2177        const MAX_CHUNK_SIZE: usize = 256;
2178
2179        // Stream this worktree's entries.
2180        let message = proto::UpdateWorktree {
2181            project_id: project_id.to_proto(),
2182            worktree_id,
2183            abs_path: worktree.abs_path.clone(),
2184            root_name: worktree.root_name,
2185            updated_entries: worktree.entries,
2186            removed_entries: Default::default(),
2187            scan_id: worktree.scan_id,
2188            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2189            updated_repositories: worktree.repository_entries.into_values().collect(),
2190            removed_repositories: Default::default(),
2191        };
2192        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2193            session.peer.send(session.connection_id, update.clone())?;
2194        }
2195
2196        // Stream this worktree's diagnostics.
2197        for summary in worktree.diagnostic_summaries {
2198            session.peer.send(
2199                session.connection_id,
2200                proto::UpdateDiagnosticSummary {
2201                    project_id: project_id.to_proto(),
2202                    worktree_id: worktree.id,
2203                    summary: Some(summary),
2204                },
2205            )?;
2206        }
2207
2208        for settings_file in worktree.settings_files {
2209            session.peer.send(
2210                session.connection_id,
2211                proto::UpdateWorktreeSettings {
2212                    project_id: project_id.to_proto(),
2213                    worktree_id: worktree.id,
2214                    path: settings_file.path,
2215                    content: Some(settings_file.content),
2216                },
2217            )?;
2218        }
2219    }
2220
2221    for language_server in &project.language_servers {
2222        session.peer.send(
2223            session.connection_id,
2224            proto::UpdateLanguageServer {
2225                project_id: project_id.to_proto(),
2226                language_server_id: language_server.id,
2227                variant: Some(
2228                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2229                        proto::LspDiskBasedDiagnosticsUpdated {},
2230                    ),
2231                ),
2232            },
2233        )?;
2234    }
2235
2236    Ok(())
2237}
2238
2239/// Leave someone elses shared project.
2240async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2241    let sender_id = session.connection_id;
2242    let project_id = ProjectId::from_proto(request.project_id);
2243    let db = session.db().await;
2244    if db.is_hosted_project(project_id).await? {
2245        let project = db.leave_hosted_project(project_id, sender_id).await?;
2246        project_left(&project, &session);
2247        return Ok(());
2248    }
2249
2250    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2251    tracing::info!(
2252        %project_id,
2253        "leave project"
2254    );
2255
2256    project_left(&project, &session);
2257    if let Some(room) = room {
2258        room_updated(&room, &session.peer);
2259    }
2260
2261    Ok(())
2262}
2263
2264async fn join_hosted_project(
2265    request: proto::JoinHostedProject,
2266    response: Response<proto::JoinHostedProject>,
2267    session: UserSession,
2268) -> Result<()> {
2269    let (mut project, replica_id) = session
2270        .db()
2271        .await
2272        .join_hosted_project(
2273            ProjectId(request.project_id as i32),
2274            session.user_id(),
2275            session.connection_id,
2276        )
2277        .await?;
2278
2279    join_project_internal(response, session, &mut project, &replica_id)
2280}
2281
2282async fn create_dev_server_project(
2283    request: proto::CreateDevServerProject,
2284    response: Response<proto::CreateDevServerProject>,
2285    session: UserSession,
2286) -> Result<()> {
2287    let dev_server_id = DevServerId(request.dev_server_id as i32);
2288    let dev_server_connection_id = session
2289        .connection_pool()
2290        .await
2291        .dev_server_connection_id(dev_server_id);
2292    let Some(dev_server_connection_id) = dev_server_connection_id else {
2293        Err(ErrorCode::DevServerOffline
2294            .message("Cannot create a remote project when the dev server is offline".to_string())
2295            .anyhow())?
2296    };
2297
2298    let path = request.path.clone();
2299    //Check that the path exists on the dev server
2300    session
2301        .peer
2302        .forward_request(
2303            session.connection_id,
2304            dev_server_connection_id,
2305            proto::ValidateDevServerProjectRequest { path: path.clone() },
2306        )
2307        .await?;
2308
2309    let (dev_server_project, update) = session
2310        .db()
2311        .await
2312        .create_dev_server_project(
2313            DevServerId(request.dev_server_id as i32),
2314            &request.path,
2315            session.user_id(),
2316        )
2317        .await?;
2318
2319    let projects = session
2320        .db()
2321        .await
2322        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2323        .await?;
2324
2325    session.peer.send(
2326        dev_server_connection_id,
2327        proto::DevServerInstructions { projects },
2328    )?;
2329
2330    send_dev_server_projects_update(session.user_id(), update, &session).await;
2331
2332    response.send(proto::CreateDevServerProjectResponse {
2333        dev_server_project: Some(dev_server_project.to_proto(None)),
2334    })?;
2335    Ok(())
2336}
2337
2338async fn create_dev_server(
2339    request: proto::CreateDevServer,
2340    response: Response<proto::CreateDevServer>,
2341    session: UserSession,
2342) -> Result<()> {
2343    let access_token = auth::random_token();
2344    let hashed_access_token = auth::hash_access_token(&access_token);
2345
2346    let (dev_server, status) = session
2347        .db()
2348        .await
2349        .create_dev_server(&request.name, &hashed_access_token, session.user_id())
2350        .await?;
2351
2352    send_dev_server_projects_update(session.user_id(), status, &session).await;
2353
2354    response.send(proto::CreateDevServerResponse {
2355        dev_server_id: dev_server.id.0 as u64,
2356        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2357        name: request.name.clone(),
2358    })?;
2359    Ok(())
2360}
2361
2362async fn delete_dev_server(
2363    request: proto::DeleteDevServer,
2364    response: Response<proto::DeleteDevServer>,
2365    session: UserSession,
2366) -> Result<()> {
2367    let dev_server_id = DevServerId(request.dev_server_id as i32);
2368    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2369    if dev_server.user_id != session.user_id() {
2370        return Err(anyhow!(ErrorCode::Forbidden))?;
2371    }
2372
2373    let connection_id = session
2374        .connection_pool()
2375        .await
2376        .dev_server_connection_id(dev_server_id);
2377    if let Some(connection_id) = connection_id {
2378        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2379        session
2380            .peer
2381            .send(connection_id, proto::ShutdownDevServer {})?;
2382    }
2383
2384    let status = session
2385        .db()
2386        .await
2387        .delete_dev_server(dev_server_id, session.user_id())
2388        .await?;
2389
2390    send_dev_server_projects_update(session.user_id(), status, &session).await;
2391
2392    response.send(proto::Ack {})?;
2393    Ok(())
2394}
2395
2396async fn delete_dev_server_project(
2397    request: proto::DeleteDevServerProject,
2398    response: Response<proto::DeleteDevServerProject>,
2399    session: UserSession,
2400) -> Result<()> {
2401    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2402    let dev_server_project = session
2403        .db()
2404        .await
2405        .get_dev_server_project(dev_server_project_id)
2406        .await?;
2407
2408    let dev_server = session
2409        .db()
2410        .await
2411        .get_dev_server(dev_server_project.dev_server_id)
2412        .await?;
2413    if dev_server.user_id != session.user_id() {
2414        return Err(anyhow!(ErrorCode::Forbidden))?;
2415    }
2416
2417    let dev_server_connection_id = session
2418        .connection_pool()
2419        .await
2420        .dev_server_connection_id(dev_server.id);
2421
2422    if let Some(dev_server_connection_id) = dev_server_connection_id {
2423        let project = session
2424            .db()
2425            .await
2426            .find_dev_server_project(dev_server_project_id)
2427            .await;
2428        if let Ok(project) = project {
2429            unshare_project_internal(
2430                project.id,
2431                dev_server_connection_id,
2432                Some(session.user_id()),
2433                &session,
2434            )
2435            .await?;
2436        }
2437    }
2438
2439    let (projects, status) = session
2440        .db()
2441        .await
2442        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2443        .await?;
2444
2445    if let Some(dev_server_connection_id) = dev_server_connection_id {
2446        session.peer.send(
2447            dev_server_connection_id,
2448            proto::DevServerInstructions { projects },
2449        )?;
2450    }
2451
2452    send_dev_server_projects_update(session.user_id(), status, &session).await;
2453
2454    response.send(proto::Ack {})?;
2455    Ok(())
2456}
2457
2458async fn rejoin_dev_server_projects(
2459    request: proto::RejoinRemoteProjects,
2460    response: Response<proto::RejoinRemoteProjects>,
2461    session: UserSession,
2462) -> Result<()> {
2463    let mut rejoined_projects = {
2464        let db = session.db().await;
2465        db.rejoin_dev_server_projects(
2466            &request.rejoined_projects,
2467            session.user_id(),
2468            session.0.connection_id,
2469        )
2470        .await?
2471    };
2472    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2473
2474    response.send(proto::RejoinRemoteProjectsResponse {
2475        rejoined_projects: rejoined_projects
2476            .into_iter()
2477            .map(|project| project.to_proto())
2478            .collect(),
2479    })
2480}
2481
2482async fn reconnect_dev_server(
2483    request: proto::ReconnectDevServer,
2484    response: Response<proto::ReconnectDevServer>,
2485    session: DevServerSession,
2486) -> Result<()> {
2487    let reshared_projects = {
2488        let db = session.db().await;
2489        db.reshare_dev_server_projects(
2490            &request.reshared_projects,
2491            session.dev_server_id(),
2492            session.0.connection_id,
2493        )
2494        .await?
2495    };
2496
2497    for project in &reshared_projects {
2498        for collaborator in &project.collaborators {
2499            session
2500                .peer
2501                .send(
2502                    collaborator.connection_id,
2503                    proto::UpdateProjectCollaborator {
2504                        project_id: project.id.to_proto(),
2505                        old_peer_id: Some(project.old_connection_id.into()),
2506                        new_peer_id: Some(session.connection_id.into()),
2507                    },
2508                )
2509                .trace_err();
2510        }
2511
2512        broadcast(
2513            Some(session.connection_id),
2514            project
2515                .collaborators
2516                .iter()
2517                .map(|collaborator| collaborator.connection_id),
2518            |connection_id| {
2519                session.peer.forward_send(
2520                    session.connection_id,
2521                    connection_id,
2522                    proto::UpdateProject {
2523                        project_id: project.id.to_proto(),
2524                        worktrees: project.worktrees.clone(),
2525                    },
2526                )
2527            },
2528        );
2529    }
2530
2531    response.send(proto::ReconnectDevServerResponse {
2532        reshared_projects: reshared_projects
2533            .iter()
2534            .map(|project| proto::ResharedProject {
2535                id: project.id.to_proto(),
2536                collaborators: project
2537                    .collaborators
2538                    .iter()
2539                    .map(|collaborator| collaborator.to_proto())
2540                    .collect(),
2541            })
2542            .collect(),
2543    })?;
2544
2545    Ok(())
2546}
2547
2548async fn shutdown_dev_server(
2549    _: proto::ShutdownDevServer,
2550    response: Response<proto::ShutdownDevServer>,
2551    session: DevServerSession,
2552) -> Result<()> {
2553    response.send(proto::Ack {})?;
2554    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await
2555}
2556
2557async fn shutdown_dev_server_internal(
2558    dev_server_id: DevServerId,
2559    connection_id: ConnectionId,
2560    session: &Session,
2561) -> Result<()> {
2562    let (dev_server_projects, dev_server) = {
2563        let db = session.db().await;
2564        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2565        let dev_server = db.get_dev_server(dev_server_id).await?;
2566        (dev_server_projects, dev_server)
2567    };
2568
2569    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2570        unshare_project_internal(
2571            ProjectId::from_proto(project_id),
2572            connection_id,
2573            None,
2574            session,
2575        )
2576        .await?;
2577    }
2578
2579    session
2580        .connection_pool()
2581        .await
2582        .set_dev_server_offline(dev_server_id);
2583
2584    let status = session
2585        .db()
2586        .await
2587        .dev_server_projects_update(dev_server.user_id)
2588        .await?;
2589    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2590
2591    Ok(())
2592}
2593
2594/// Updates other participants with changes to the project
2595async fn update_project(
2596    request: proto::UpdateProject,
2597    response: Response<proto::UpdateProject>,
2598    session: Session,
2599) -> Result<()> {
2600    let project_id = ProjectId::from_proto(request.project_id);
2601    let (room, guest_connection_ids) = &*session
2602        .db()
2603        .await
2604        .update_project(project_id, session.connection_id, &request.worktrees)
2605        .await?;
2606    broadcast(
2607        Some(session.connection_id),
2608        guest_connection_ids.iter().copied(),
2609        |connection_id| {
2610            session
2611                .peer
2612                .forward_send(session.connection_id, connection_id, request.clone())
2613        },
2614    );
2615    if let Some(room) = room {
2616        room_updated(&room, &session.peer);
2617    }
2618    response.send(proto::Ack {})?;
2619
2620    Ok(())
2621}
2622
2623/// Updates other participants with changes to the worktree
2624async fn update_worktree(
2625    request: proto::UpdateWorktree,
2626    response: Response<proto::UpdateWorktree>,
2627    session: Session,
2628) -> Result<()> {
2629    let guest_connection_ids = session
2630        .db()
2631        .await
2632        .update_worktree(&request, session.connection_id)
2633        .await?;
2634
2635    broadcast(
2636        Some(session.connection_id),
2637        guest_connection_ids.iter().copied(),
2638        |connection_id| {
2639            session
2640                .peer
2641                .forward_send(session.connection_id, connection_id, request.clone())
2642        },
2643    );
2644    response.send(proto::Ack {})?;
2645    Ok(())
2646}
2647
2648/// Updates other participants with changes to the diagnostics
2649async fn update_diagnostic_summary(
2650    message: proto::UpdateDiagnosticSummary,
2651    session: Session,
2652) -> Result<()> {
2653    let guest_connection_ids = session
2654        .db()
2655        .await
2656        .update_diagnostic_summary(&message, session.connection_id)
2657        .await?;
2658
2659    broadcast(
2660        Some(session.connection_id),
2661        guest_connection_ids.iter().copied(),
2662        |connection_id| {
2663            session
2664                .peer
2665                .forward_send(session.connection_id, connection_id, message.clone())
2666        },
2667    );
2668
2669    Ok(())
2670}
2671
2672/// Updates other participants with changes to the worktree settings
2673async fn update_worktree_settings(
2674    message: proto::UpdateWorktreeSettings,
2675    session: Session,
2676) -> Result<()> {
2677    let guest_connection_ids = session
2678        .db()
2679        .await
2680        .update_worktree_settings(&message, session.connection_id)
2681        .await?;
2682
2683    broadcast(
2684        Some(session.connection_id),
2685        guest_connection_ids.iter().copied(),
2686        |connection_id| {
2687            session
2688                .peer
2689                .forward_send(session.connection_id, connection_id, message.clone())
2690        },
2691    );
2692
2693    Ok(())
2694}
2695
2696/// Notify other participants that a  language server has started.
2697async fn start_language_server(
2698    request: proto::StartLanguageServer,
2699    session: Session,
2700) -> Result<()> {
2701    let guest_connection_ids = session
2702        .db()
2703        .await
2704        .start_language_server(&request, session.connection_id)
2705        .await?;
2706
2707    broadcast(
2708        Some(session.connection_id),
2709        guest_connection_ids.iter().copied(),
2710        |connection_id| {
2711            session
2712                .peer
2713                .forward_send(session.connection_id, connection_id, request.clone())
2714        },
2715    );
2716    Ok(())
2717}
2718
2719/// Notify other participants that a language server has changed.
2720async fn update_language_server(
2721    request: proto::UpdateLanguageServer,
2722    session: Session,
2723) -> Result<()> {
2724    let project_id = ProjectId::from_proto(request.project_id);
2725    let project_connection_ids = session
2726        .db()
2727        .await
2728        .project_connection_ids(project_id, session.connection_id, true)
2729        .await?;
2730    broadcast(
2731        Some(session.connection_id),
2732        project_connection_ids.iter().copied(),
2733        |connection_id| {
2734            session
2735                .peer
2736                .forward_send(session.connection_id, connection_id, request.clone())
2737        },
2738    );
2739    Ok(())
2740}
2741
2742/// forward a project request to the host. These requests should be read only
2743/// as guests are allowed to send them.
2744async fn forward_read_only_project_request<T>(
2745    request: T,
2746    response: Response<T>,
2747    session: UserSession,
2748) -> Result<()>
2749where
2750    T: EntityMessage + RequestMessage,
2751{
2752    let project_id = ProjectId::from_proto(request.remote_entity_id());
2753    let host_connection_id = session
2754        .db()
2755        .await
2756        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2757        .await?;
2758    let payload = session
2759        .peer
2760        .forward_request(session.connection_id, host_connection_id, request)
2761        .await?;
2762    response.send(payload)?;
2763    Ok(())
2764}
2765
2766/// forward a project request to the host. These requests are disallowed
2767/// for guests.
2768async fn forward_mutating_project_request<T>(
2769    request: T,
2770    response: Response<T>,
2771    session: UserSession,
2772) -> Result<()>
2773where
2774    T: EntityMessage + RequestMessage,
2775{
2776    let project_id = ProjectId::from_proto(request.remote_entity_id());
2777
2778    let host_connection_id = session
2779        .db()
2780        .await
2781        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2782        .await?;
2783    let payload = session
2784        .peer
2785        .forward_request(session.connection_id, host_connection_id, request)
2786        .await?;
2787    response.send(payload)?;
2788    Ok(())
2789}
2790
2791/// forward a project request to the host. These requests are disallowed
2792/// for guests.
2793async fn forward_versioned_mutating_project_request<T>(
2794    request: T,
2795    response: Response<T>,
2796    session: UserSession,
2797) -> Result<()>
2798where
2799    T: EntityMessage + RequestMessage + VersionedMessage,
2800{
2801    let project_id = ProjectId::from_proto(request.remote_entity_id());
2802
2803    let host_connection_id = session
2804        .db()
2805        .await
2806        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2807        .await?;
2808    if let Some(host_version) = session
2809        .connection_pool()
2810        .await
2811        .connection(host_connection_id)
2812        .map(|c| c.zed_version)
2813    {
2814        if let Some(min_required_version) = request.required_host_version() {
2815            if min_required_version > host_version {
2816                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2817                    .with_tag("required", &min_required_version.to_string())))?;
2818            }
2819        }
2820    }
2821
2822    let payload = session
2823        .peer
2824        .forward_request(session.connection_id, host_connection_id, request)
2825        .await?;
2826    response.send(payload)?;
2827    Ok(())
2828}
2829
2830/// Notify other participants that a new buffer has been created
2831async fn create_buffer_for_peer(
2832    request: proto::CreateBufferForPeer,
2833    session: Session,
2834) -> Result<()> {
2835    session
2836        .db()
2837        .await
2838        .check_user_is_project_host(
2839            ProjectId::from_proto(request.project_id),
2840            session.connection_id,
2841        )
2842        .await?;
2843    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2844    session
2845        .peer
2846        .forward_send(session.connection_id, peer_id.into(), request)?;
2847    Ok(())
2848}
2849
2850/// Notify other participants that a buffer has been updated. This is
2851/// allowed for guests as long as the update is limited to selections.
2852async fn update_buffer(
2853    request: proto::UpdateBuffer,
2854    response: Response<proto::UpdateBuffer>,
2855    session: Session,
2856) -> Result<()> {
2857    let project_id = ProjectId::from_proto(request.project_id);
2858    let mut capability = Capability::ReadOnly;
2859
2860    for op in request.operations.iter() {
2861        match op.variant {
2862            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2863            Some(_) => capability = Capability::ReadWrite,
2864        }
2865    }
2866
2867    let host = {
2868        let guard = session
2869            .db()
2870            .await
2871            .connections_for_buffer_update(
2872                project_id,
2873                session.principal_id(),
2874                session.connection_id,
2875                capability,
2876            )
2877            .await?;
2878
2879        let (host, guests) = &*guard;
2880
2881        broadcast(
2882            Some(session.connection_id),
2883            guests.clone(),
2884            |connection_id| {
2885                session
2886                    .peer
2887                    .forward_send(session.connection_id, connection_id, request.clone())
2888            },
2889        );
2890
2891        *host
2892    };
2893
2894    if host != session.connection_id {
2895        session
2896            .peer
2897            .forward_request(session.connection_id, host, request.clone())
2898            .await?;
2899    }
2900
2901    response.send(proto::Ack {})?;
2902    Ok(())
2903}
2904
2905/// Notify other participants that a project has been updated.
2906async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2907    request: T,
2908    session: Session,
2909) -> Result<()> {
2910    let project_id = ProjectId::from_proto(request.remote_entity_id());
2911    let project_connection_ids = session
2912        .db()
2913        .await
2914        .project_connection_ids(project_id, session.connection_id, false)
2915        .await?;
2916
2917    broadcast(
2918        Some(session.connection_id),
2919        project_connection_ids.iter().copied(),
2920        |connection_id| {
2921            session
2922                .peer
2923                .forward_send(session.connection_id, connection_id, request.clone())
2924        },
2925    );
2926    Ok(())
2927}
2928
2929/// Start following another user in a call.
2930async fn follow(
2931    request: proto::Follow,
2932    response: Response<proto::Follow>,
2933    session: UserSession,
2934) -> Result<()> {
2935    let room_id = RoomId::from_proto(request.room_id);
2936    let project_id = request.project_id.map(ProjectId::from_proto);
2937    let leader_id = request
2938        .leader_id
2939        .ok_or_else(|| anyhow!("invalid leader id"))?
2940        .into();
2941    let follower_id = session.connection_id;
2942
2943    session
2944        .db()
2945        .await
2946        .check_room_participants(room_id, leader_id, session.connection_id)
2947        .await?;
2948
2949    let response_payload = session
2950        .peer
2951        .forward_request(session.connection_id, leader_id, request)
2952        .await?;
2953    response.send(response_payload)?;
2954
2955    if let Some(project_id) = project_id {
2956        let room = session
2957            .db()
2958            .await
2959            .follow(room_id, project_id, leader_id, follower_id)
2960            .await?;
2961        room_updated(&room, &session.peer);
2962    }
2963
2964    Ok(())
2965}
2966
2967/// Stop following another user in a call.
2968async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2969    let room_id = RoomId::from_proto(request.room_id);
2970    let project_id = request.project_id.map(ProjectId::from_proto);
2971    let leader_id = request
2972        .leader_id
2973        .ok_or_else(|| anyhow!("invalid leader id"))?
2974        .into();
2975    let follower_id = session.connection_id;
2976
2977    session
2978        .db()
2979        .await
2980        .check_room_participants(room_id, leader_id, session.connection_id)
2981        .await?;
2982
2983    session
2984        .peer
2985        .forward_send(session.connection_id, leader_id, request)?;
2986
2987    if let Some(project_id) = project_id {
2988        let room = session
2989            .db()
2990            .await
2991            .unfollow(room_id, project_id, leader_id, follower_id)
2992            .await?;
2993        room_updated(&room, &session.peer);
2994    }
2995
2996    Ok(())
2997}
2998
2999/// Notify everyone following you of your current location.
3000async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3001    let room_id = RoomId::from_proto(request.room_id);
3002    let database = session.db.lock().await;
3003
3004    let connection_ids = if let Some(project_id) = request.project_id {
3005        let project_id = ProjectId::from_proto(project_id);
3006        database
3007            .project_connection_ids(project_id, session.connection_id, true)
3008            .await?
3009    } else {
3010        database
3011            .room_connection_ids(room_id, session.connection_id)
3012            .await?
3013    };
3014
3015    // For now, don't send view update messages back to that view's current leader.
3016    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3017        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3018        _ => None,
3019    });
3020
3021    for connection_id in connection_ids.iter().cloned() {
3022        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3023            session
3024                .peer
3025                .forward_send(session.connection_id, connection_id, request.clone())?;
3026        }
3027    }
3028    Ok(())
3029}
3030
3031/// Get public data about users.
3032async fn get_users(
3033    request: proto::GetUsers,
3034    response: Response<proto::GetUsers>,
3035    session: Session,
3036) -> Result<()> {
3037    let user_ids = request
3038        .user_ids
3039        .into_iter()
3040        .map(UserId::from_proto)
3041        .collect();
3042    let users = session
3043        .db()
3044        .await
3045        .get_users_by_ids(user_ids)
3046        .await?
3047        .into_iter()
3048        .map(|user| proto::User {
3049            id: user.id.to_proto(),
3050            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3051            github_login: user.github_login,
3052        })
3053        .collect();
3054    response.send(proto::UsersResponse { users })?;
3055    Ok(())
3056}
3057
3058/// Search for users (to invite) buy Github login
3059async fn fuzzy_search_users(
3060    request: proto::FuzzySearchUsers,
3061    response: Response<proto::FuzzySearchUsers>,
3062    session: UserSession,
3063) -> Result<()> {
3064    let query = request.query;
3065    let users = match query.len() {
3066        0 => vec![],
3067        1 | 2 => session
3068            .db()
3069            .await
3070            .get_user_by_github_login(&query)
3071            .await?
3072            .into_iter()
3073            .collect(),
3074        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3075    };
3076    let users = users
3077        .into_iter()
3078        .filter(|user| user.id != session.user_id())
3079        .map(|user| proto::User {
3080            id: user.id.to_proto(),
3081            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3082            github_login: user.github_login,
3083        })
3084        .collect();
3085    response.send(proto::UsersResponse { users })?;
3086    Ok(())
3087}
3088
3089/// Send a contact request to another user.
3090async fn request_contact(
3091    request: proto::RequestContact,
3092    response: Response<proto::RequestContact>,
3093    session: UserSession,
3094) -> Result<()> {
3095    let requester_id = session.user_id();
3096    let responder_id = UserId::from_proto(request.responder_id);
3097    if requester_id == responder_id {
3098        return Err(anyhow!("cannot add yourself as a contact"))?;
3099    }
3100
3101    let notifications = session
3102        .db()
3103        .await
3104        .send_contact_request(requester_id, responder_id)
3105        .await?;
3106
3107    // Update outgoing contact requests of requester
3108    let mut update = proto::UpdateContacts::default();
3109    update.outgoing_requests.push(responder_id.to_proto());
3110    for connection_id in session
3111        .connection_pool()
3112        .await
3113        .user_connection_ids(requester_id)
3114    {
3115        session.peer.send(connection_id, update.clone())?;
3116    }
3117
3118    // Update incoming contact requests of responder
3119    let mut update = proto::UpdateContacts::default();
3120    update
3121        .incoming_requests
3122        .push(proto::IncomingContactRequest {
3123            requester_id: requester_id.to_proto(),
3124        });
3125    let connection_pool = session.connection_pool().await;
3126    for connection_id in connection_pool.user_connection_ids(responder_id) {
3127        session.peer.send(connection_id, update.clone())?;
3128    }
3129
3130    send_notifications(&connection_pool, &session.peer, notifications);
3131
3132    response.send(proto::Ack {})?;
3133    Ok(())
3134}
3135
3136/// Accept or decline a contact request
3137async fn respond_to_contact_request(
3138    request: proto::RespondToContactRequest,
3139    response: Response<proto::RespondToContactRequest>,
3140    session: UserSession,
3141) -> Result<()> {
3142    let responder_id = session.user_id();
3143    let requester_id = UserId::from_proto(request.requester_id);
3144    let db = session.db().await;
3145    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3146        db.dismiss_contact_notification(responder_id, requester_id)
3147            .await?;
3148    } else {
3149        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3150
3151        let notifications = db
3152            .respond_to_contact_request(responder_id, requester_id, accept)
3153            .await?;
3154        let requester_busy = db.is_user_busy(requester_id).await?;
3155        let responder_busy = db.is_user_busy(responder_id).await?;
3156
3157        let pool = session.connection_pool().await;
3158        // Update responder with new contact
3159        let mut update = proto::UpdateContacts::default();
3160        if accept {
3161            update
3162                .contacts
3163                .push(contact_for_user(requester_id, requester_busy, &pool));
3164        }
3165        update
3166            .remove_incoming_requests
3167            .push(requester_id.to_proto());
3168        for connection_id in pool.user_connection_ids(responder_id) {
3169            session.peer.send(connection_id, update.clone())?;
3170        }
3171
3172        // Update requester with new contact
3173        let mut update = proto::UpdateContacts::default();
3174        if accept {
3175            update
3176                .contacts
3177                .push(contact_for_user(responder_id, responder_busy, &pool));
3178        }
3179        update
3180            .remove_outgoing_requests
3181            .push(responder_id.to_proto());
3182
3183        for connection_id in pool.user_connection_ids(requester_id) {
3184            session.peer.send(connection_id, update.clone())?;
3185        }
3186
3187        send_notifications(&pool, &session.peer, notifications);
3188    }
3189
3190    response.send(proto::Ack {})?;
3191    Ok(())
3192}
3193
3194/// Remove a contact.
3195async fn remove_contact(
3196    request: proto::RemoveContact,
3197    response: Response<proto::RemoveContact>,
3198    session: UserSession,
3199) -> Result<()> {
3200    let requester_id = session.user_id();
3201    let responder_id = UserId::from_proto(request.user_id);
3202    let db = session.db().await;
3203    let (contact_accepted, deleted_notification_id) =
3204        db.remove_contact(requester_id, responder_id).await?;
3205
3206    let pool = session.connection_pool().await;
3207    // Update outgoing contact requests of requester
3208    let mut update = proto::UpdateContacts::default();
3209    if contact_accepted {
3210        update.remove_contacts.push(responder_id.to_proto());
3211    } else {
3212        update
3213            .remove_outgoing_requests
3214            .push(responder_id.to_proto());
3215    }
3216    for connection_id in pool.user_connection_ids(requester_id) {
3217        session.peer.send(connection_id, update.clone())?;
3218    }
3219
3220    // Update incoming contact requests of responder
3221    let mut update = proto::UpdateContacts::default();
3222    if contact_accepted {
3223        update.remove_contacts.push(requester_id.to_proto());
3224    } else {
3225        update
3226            .remove_incoming_requests
3227            .push(requester_id.to_proto());
3228    }
3229    for connection_id in pool.user_connection_ids(responder_id) {
3230        session.peer.send(connection_id, update.clone())?;
3231        if let Some(notification_id) = deleted_notification_id {
3232            session.peer.send(
3233                connection_id,
3234                proto::DeleteNotification {
3235                    notification_id: notification_id.to_proto(),
3236                },
3237            )?;
3238        }
3239    }
3240
3241    response.send(proto::Ack {})?;
3242    Ok(())
3243}
3244
3245/// Creates a new channel.
3246async fn create_channel(
3247    request: proto::CreateChannel,
3248    response: Response<proto::CreateChannel>,
3249    session: UserSession,
3250) -> Result<()> {
3251    let db = session.db().await;
3252
3253    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3254    let (channel, membership) = db
3255        .create_channel(&request.name, parent_id, session.user_id())
3256        .await?;
3257
3258    let root_id = channel.root_id();
3259    let channel = Channel::from_model(channel);
3260
3261    response.send(proto::CreateChannelResponse {
3262        channel: Some(channel.to_proto()),
3263        parent_id: request.parent_id,
3264    })?;
3265
3266    let mut connection_pool = session.connection_pool().await;
3267    if let Some(membership) = membership {
3268        connection_pool.subscribe_to_channel(
3269            membership.user_id,
3270            membership.channel_id,
3271            membership.role,
3272        );
3273        let update = proto::UpdateUserChannels {
3274            channel_memberships: vec![proto::ChannelMembership {
3275                channel_id: membership.channel_id.to_proto(),
3276                role: membership.role.into(),
3277            }],
3278            ..Default::default()
3279        };
3280        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3281            session.peer.send(connection_id, update.clone())?;
3282        }
3283    }
3284
3285    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3286        if !role.can_see_channel(channel.visibility) {
3287            continue;
3288        }
3289
3290        let update = proto::UpdateChannels {
3291            channels: vec![channel.to_proto()],
3292            ..Default::default()
3293        };
3294        session.peer.send(connection_id, update.clone())?;
3295    }
3296
3297    Ok(())
3298}
3299
3300/// Delete a channel
3301async fn delete_channel(
3302    request: proto::DeleteChannel,
3303    response: Response<proto::DeleteChannel>,
3304    session: UserSession,
3305) -> Result<()> {
3306    let db = session.db().await;
3307
3308    let channel_id = request.channel_id;
3309    let (root_channel, removed_channels) = db
3310        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3311        .await?;
3312    response.send(proto::Ack {})?;
3313
3314    // Notify members of removed channels
3315    let mut update = proto::UpdateChannels::default();
3316    update
3317        .delete_channels
3318        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3319
3320    let connection_pool = session.connection_pool().await;
3321    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3322        session.peer.send(connection_id, update.clone())?;
3323    }
3324
3325    Ok(())
3326}
3327
3328/// Invite someone to join a channel.
3329async fn invite_channel_member(
3330    request: proto::InviteChannelMember,
3331    response: Response<proto::InviteChannelMember>,
3332    session: UserSession,
3333) -> Result<()> {
3334    let db = session.db().await;
3335    let channel_id = ChannelId::from_proto(request.channel_id);
3336    let invitee_id = UserId::from_proto(request.user_id);
3337    let InviteMemberResult {
3338        channel,
3339        notifications,
3340    } = db
3341        .invite_channel_member(
3342            channel_id,
3343            invitee_id,
3344            session.user_id(),
3345            request.role().into(),
3346        )
3347        .await?;
3348
3349    let update = proto::UpdateChannels {
3350        channel_invitations: vec![channel.to_proto()],
3351        ..Default::default()
3352    };
3353
3354    let connection_pool = session.connection_pool().await;
3355    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3356        session.peer.send(connection_id, update.clone())?;
3357    }
3358
3359    send_notifications(&connection_pool, &session.peer, notifications);
3360
3361    response.send(proto::Ack {})?;
3362    Ok(())
3363}
3364
3365/// remove someone from a channel
3366async fn remove_channel_member(
3367    request: proto::RemoveChannelMember,
3368    response: Response<proto::RemoveChannelMember>,
3369    session: UserSession,
3370) -> Result<()> {
3371    let db = session.db().await;
3372    let channel_id = ChannelId::from_proto(request.channel_id);
3373    let member_id = UserId::from_proto(request.user_id);
3374
3375    let RemoveChannelMemberResult {
3376        membership_update,
3377        notification_id,
3378    } = db
3379        .remove_channel_member(channel_id, member_id, session.user_id())
3380        .await?;
3381
3382    let mut connection_pool = session.connection_pool().await;
3383    notify_membership_updated(
3384        &mut connection_pool,
3385        membership_update,
3386        member_id,
3387        &session.peer,
3388    );
3389    for connection_id in connection_pool.user_connection_ids(member_id) {
3390        if let Some(notification_id) = notification_id {
3391            session
3392                .peer
3393                .send(
3394                    connection_id,
3395                    proto::DeleteNotification {
3396                        notification_id: notification_id.to_proto(),
3397                    },
3398                )
3399                .trace_err();
3400        }
3401    }
3402
3403    response.send(proto::Ack {})?;
3404    Ok(())
3405}
3406
3407/// Toggle the channel between public and private.
3408/// Care is taken to maintain the invariant that public channels only descend from public channels,
3409/// (though members-only channels can appear at any point in the hierarchy).
3410async fn set_channel_visibility(
3411    request: proto::SetChannelVisibility,
3412    response: Response<proto::SetChannelVisibility>,
3413    session: UserSession,
3414) -> Result<()> {
3415    let db = session.db().await;
3416    let channel_id = ChannelId::from_proto(request.channel_id);
3417    let visibility = request.visibility().into();
3418
3419    let channel_model = db
3420        .set_channel_visibility(channel_id, visibility, session.user_id())
3421        .await?;
3422    let root_id = channel_model.root_id();
3423    let channel = Channel::from_model(channel_model);
3424
3425    let mut connection_pool = session.connection_pool().await;
3426    for (user_id, role) in connection_pool
3427        .channel_user_ids(root_id)
3428        .collect::<Vec<_>>()
3429        .into_iter()
3430    {
3431        let update = if role.can_see_channel(channel.visibility) {
3432            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3433            proto::UpdateChannels {
3434                channels: vec![channel.to_proto()],
3435                ..Default::default()
3436            }
3437        } else {
3438            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3439            proto::UpdateChannels {
3440                delete_channels: vec![channel.id.to_proto()],
3441                ..Default::default()
3442            }
3443        };
3444
3445        for connection_id in connection_pool.user_connection_ids(user_id) {
3446            session.peer.send(connection_id, update.clone())?;
3447        }
3448    }
3449
3450    response.send(proto::Ack {})?;
3451    Ok(())
3452}
3453
3454/// Alter the role for a user in the channel.
3455async fn set_channel_member_role(
3456    request: proto::SetChannelMemberRole,
3457    response: Response<proto::SetChannelMemberRole>,
3458    session: UserSession,
3459) -> Result<()> {
3460    let db = session.db().await;
3461    let channel_id = ChannelId::from_proto(request.channel_id);
3462    let member_id = UserId::from_proto(request.user_id);
3463    let result = db
3464        .set_channel_member_role(
3465            channel_id,
3466            session.user_id(),
3467            member_id,
3468            request.role().into(),
3469        )
3470        .await?;
3471
3472    match result {
3473        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3474            let mut connection_pool = session.connection_pool().await;
3475            notify_membership_updated(
3476                &mut connection_pool,
3477                membership_update,
3478                member_id,
3479                &session.peer,
3480            )
3481        }
3482        db::SetMemberRoleResult::InviteUpdated(channel) => {
3483            let update = proto::UpdateChannels {
3484                channel_invitations: vec![channel.to_proto()],
3485                ..Default::default()
3486            };
3487
3488            for connection_id in session
3489                .connection_pool()
3490                .await
3491                .user_connection_ids(member_id)
3492            {
3493                session.peer.send(connection_id, update.clone())?;
3494            }
3495        }
3496    }
3497
3498    response.send(proto::Ack {})?;
3499    Ok(())
3500}
3501
3502/// Change the name of a channel
3503async fn rename_channel(
3504    request: proto::RenameChannel,
3505    response: Response<proto::RenameChannel>,
3506    session: UserSession,
3507) -> Result<()> {
3508    let db = session.db().await;
3509    let channel_id = ChannelId::from_proto(request.channel_id);
3510    let channel_model = db
3511        .rename_channel(channel_id, session.user_id(), &request.name)
3512        .await?;
3513    let root_id = channel_model.root_id();
3514    let channel = Channel::from_model(channel_model);
3515
3516    response.send(proto::RenameChannelResponse {
3517        channel: Some(channel.to_proto()),
3518    })?;
3519
3520    let connection_pool = session.connection_pool().await;
3521    let update = proto::UpdateChannels {
3522        channels: vec![channel.to_proto()],
3523        ..Default::default()
3524    };
3525    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3526        if role.can_see_channel(channel.visibility) {
3527            session.peer.send(connection_id, update.clone())?;
3528        }
3529    }
3530
3531    Ok(())
3532}
3533
3534/// Move a channel to a new parent.
3535async fn move_channel(
3536    request: proto::MoveChannel,
3537    response: Response<proto::MoveChannel>,
3538    session: UserSession,
3539) -> Result<()> {
3540    let channel_id = ChannelId::from_proto(request.channel_id);
3541    let to = ChannelId::from_proto(request.to);
3542
3543    let (root_id, channels) = session
3544        .db()
3545        .await
3546        .move_channel(channel_id, to, session.user_id())
3547        .await?;
3548
3549    let connection_pool = session.connection_pool().await;
3550    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3551        let channels = channels
3552            .iter()
3553            .filter_map(|channel| {
3554                if role.can_see_channel(channel.visibility) {
3555                    Some(channel.to_proto())
3556                } else {
3557                    None
3558                }
3559            })
3560            .collect::<Vec<_>>();
3561        if channels.is_empty() {
3562            continue;
3563        }
3564
3565        let update = proto::UpdateChannels {
3566            channels,
3567            ..Default::default()
3568        };
3569
3570        session.peer.send(connection_id, update.clone())?;
3571    }
3572
3573    response.send(Ack {})?;
3574    Ok(())
3575}
3576
3577/// Get the list of channel members
3578async fn get_channel_members(
3579    request: proto::GetChannelMembers,
3580    response: Response<proto::GetChannelMembers>,
3581    session: UserSession,
3582) -> Result<()> {
3583    let db = session.db().await;
3584    let channel_id = ChannelId::from_proto(request.channel_id);
3585    let members = db
3586        .get_channel_participant_details(channel_id, session.user_id())
3587        .await?;
3588    response.send(proto::GetChannelMembersResponse { members })?;
3589    Ok(())
3590}
3591
3592/// Accept or decline a channel invitation.
3593async fn respond_to_channel_invite(
3594    request: proto::RespondToChannelInvite,
3595    response: Response<proto::RespondToChannelInvite>,
3596    session: UserSession,
3597) -> Result<()> {
3598    let db = session.db().await;
3599    let channel_id = ChannelId::from_proto(request.channel_id);
3600    let RespondToChannelInvite {
3601        membership_update,
3602        notifications,
3603    } = db
3604        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3605        .await?;
3606
3607    let mut connection_pool = session.connection_pool().await;
3608    if let Some(membership_update) = membership_update {
3609        notify_membership_updated(
3610            &mut connection_pool,
3611            membership_update,
3612            session.user_id(),
3613            &session.peer,
3614        );
3615    } else {
3616        let update = proto::UpdateChannels {
3617            remove_channel_invitations: vec![channel_id.to_proto()],
3618            ..Default::default()
3619        };
3620
3621        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3622            session.peer.send(connection_id, update.clone())?;
3623        }
3624    };
3625
3626    send_notifications(&connection_pool, &session.peer, notifications);
3627
3628    response.send(proto::Ack {})?;
3629
3630    Ok(())
3631}
3632
3633/// Join the channels' room
3634async fn join_channel(
3635    request: proto::JoinChannel,
3636    response: Response<proto::JoinChannel>,
3637    session: UserSession,
3638) -> Result<()> {
3639    let channel_id = ChannelId::from_proto(request.channel_id);
3640    join_channel_internal(channel_id, Box::new(response), session).await
3641}
3642
3643trait JoinChannelInternalResponse {
3644    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3645}
3646impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3647    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3648        Response::<proto::JoinChannel>::send(self, result)
3649    }
3650}
3651impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3652    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3653        Response::<proto::JoinRoom>::send(self, result)
3654    }
3655}
3656
3657async fn join_channel_internal(
3658    channel_id: ChannelId,
3659    response: Box<impl JoinChannelInternalResponse>,
3660    session: UserSession,
3661) -> Result<()> {
3662    let joined_room = {
3663        let mut db = session.db().await;
3664        // If zed quits without leaving the room, and the user re-opens zed before the
3665        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3666        // room they were in.
3667        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3668            tracing::info!(
3669                stale_connection_id = %connection,
3670                "cleaning up stale connection",
3671            );
3672            drop(db);
3673            leave_room_for_session(&session, connection).await?;
3674            db = session.db().await;
3675        }
3676
3677        let (joined_room, membership_updated, role) = db
3678            .join_channel(channel_id, session.user_id(), session.connection_id)
3679            .await?;
3680
3681        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3682            let (can_publish, token) = if role == ChannelRole::Guest {
3683                (
3684                    false,
3685                    live_kit
3686                        .guest_token(
3687                            &joined_room.room.live_kit_room,
3688                            &session.user_id().to_string(),
3689                        )
3690                        .trace_err()?,
3691                )
3692            } else {
3693                (
3694                    true,
3695                    live_kit
3696                        .room_token(
3697                            &joined_room.room.live_kit_room,
3698                            &session.user_id().to_string(),
3699                        )
3700                        .trace_err()?,
3701                )
3702            };
3703
3704            Some(LiveKitConnectionInfo {
3705                server_url: live_kit.url().into(),
3706                token,
3707                can_publish,
3708            })
3709        });
3710
3711        response.send(proto::JoinRoomResponse {
3712            room: Some(joined_room.room.clone()),
3713            channel_id: joined_room
3714                .channel
3715                .as_ref()
3716                .map(|channel| channel.id.to_proto()),
3717            live_kit_connection_info,
3718        })?;
3719
3720        let mut connection_pool = session.connection_pool().await;
3721        if let Some(membership_updated) = membership_updated {
3722            notify_membership_updated(
3723                &mut connection_pool,
3724                membership_updated,
3725                session.user_id(),
3726                &session.peer,
3727            );
3728        }
3729
3730        room_updated(&joined_room.room, &session.peer);
3731
3732        joined_room
3733    };
3734
3735    channel_updated(
3736        &joined_room
3737            .channel
3738            .ok_or_else(|| anyhow!("channel not returned"))?,
3739        &joined_room.room,
3740        &session.peer,
3741        &*session.connection_pool().await,
3742    );
3743
3744    update_user_contacts(session.user_id(), &session).await?;
3745    Ok(())
3746}
3747
3748/// Start editing the channel notes
3749async fn join_channel_buffer(
3750    request: proto::JoinChannelBuffer,
3751    response: Response<proto::JoinChannelBuffer>,
3752    session: UserSession,
3753) -> Result<()> {
3754    let db = session.db().await;
3755    let channel_id = ChannelId::from_proto(request.channel_id);
3756
3757    let open_response = db
3758        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3759        .await?;
3760
3761    let collaborators = open_response.collaborators.clone();
3762    response.send(open_response)?;
3763
3764    let update = UpdateChannelBufferCollaborators {
3765        channel_id: channel_id.to_proto(),
3766        collaborators: collaborators.clone(),
3767    };
3768    channel_buffer_updated(
3769        session.connection_id,
3770        collaborators
3771            .iter()
3772            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3773        &update,
3774        &session.peer,
3775    );
3776
3777    Ok(())
3778}
3779
3780/// Edit the channel notes
3781async fn update_channel_buffer(
3782    request: proto::UpdateChannelBuffer,
3783    session: UserSession,
3784) -> Result<()> {
3785    let db = session.db().await;
3786    let channel_id = ChannelId::from_proto(request.channel_id);
3787
3788    let (collaborators, non_collaborators, epoch, version) = db
3789        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3790        .await?;
3791
3792    channel_buffer_updated(
3793        session.connection_id,
3794        collaborators,
3795        &proto::UpdateChannelBuffer {
3796            channel_id: channel_id.to_proto(),
3797            operations: request.operations,
3798        },
3799        &session.peer,
3800    );
3801
3802    let pool = &*session.connection_pool().await;
3803
3804    broadcast(
3805        None,
3806        non_collaborators
3807            .iter()
3808            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3809        |peer_id| {
3810            session.peer.send(
3811                peer_id,
3812                proto::UpdateChannels {
3813                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3814                        channel_id: channel_id.to_proto(),
3815                        epoch: epoch as u64,
3816                        version: version.clone(),
3817                    }],
3818                    ..Default::default()
3819                },
3820            )
3821        },
3822    );
3823
3824    Ok(())
3825}
3826
3827/// Rejoin the channel notes after a connection blip
3828async fn rejoin_channel_buffers(
3829    request: proto::RejoinChannelBuffers,
3830    response: Response<proto::RejoinChannelBuffers>,
3831    session: UserSession,
3832) -> Result<()> {
3833    let db = session.db().await;
3834    let buffers = db
3835        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3836        .await?;
3837
3838    for rejoined_buffer in &buffers {
3839        let collaborators_to_notify = rejoined_buffer
3840            .buffer
3841            .collaborators
3842            .iter()
3843            .filter_map(|c| Some(c.peer_id?.into()));
3844        channel_buffer_updated(
3845            session.connection_id,
3846            collaborators_to_notify,
3847            &proto::UpdateChannelBufferCollaborators {
3848                channel_id: rejoined_buffer.buffer.channel_id,
3849                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3850            },
3851            &session.peer,
3852        );
3853    }
3854
3855    response.send(proto::RejoinChannelBuffersResponse {
3856        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3857    })?;
3858
3859    Ok(())
3860}
3861
3862/// Stop editing the channel notes
3863async fn leave_channel_buffer(
3864    request: proto::LeaveChannelBuffer,
3865    response: Response<proto::LeaveChannelBuffer>,
3866    session: UserSession,
3867) -> Result<()> {
3868    let db = session.db().await;
3869    let channel_id = ChannelId::from_proto(request.channel_id);
3870
3871    let left_buffer = db
3872        .leave_channel_buffer(channel_id, session.connection_id)
3873        .await?;
3874
3875    response.send(Ack {})?;
3876
3877    channel_buffer_updated(
3878        session.connection_id,
3879        left_buffer.connections,
3880        &proto::UpdateChannelBufferCollaborators {
3881            channel_id: channel_id.to_proto(),
3882            collaborators: left_buffer.collaborators,
3883        },
3884        &session.peer,
3885    );
3886
3887    Ok(())
3888}
3889
3890fn channel_buffer_updated<T: EnvelopedMessage>(
3891    sender_id: ConnectionId,
3892    collaborators: impl IntoIterator<Item = ConnectionId>,
3893    message: &T,
3894    peer: &Peer,
3895) {
3896    broadcast(Some(sender_id), collaborators, |peer_id| {
3897        peer.send(peer_id, message.clone())
3898    });
3899}
3900
3901fn send_notifications(
3902    connection_pool: &ConnectionPool,
3903    peer: &Peer,
3904    notifications: db::NotificationBatch,
3905) {
3906    for (user_id, notification) in notifications {
3907        for connection_id in connection_pool.user_connection_ids(user_id) {
3908            if let Err(error) = peer.send(
3909                connection_id,
3910                proto::AddNotification {
3911                    notification: Some(notification.clone()),
3912                },
3913            ) {
3914                tracing::error!(
3915                    "failed to send notification to {:?} {}",
3916                    connection_id,
3917                    error
3918                );
3919            }
3920        }
3921    }
3922}
3923
3924/// Send a message to the channel
3925async fn send_channel_message(
3926    request: proto::SendChannelMessage,
3927    response: Response<proto::SendChannelMessage>,
3928    session: UserSession,
3929) -> Result<()> {
3930    // Validate the message body.
3931    let body = request.body.trim().to_string();
3932    if body.len() > MAX_MESSAGE_LEN {
3933        return Err(anyhow!("message is too long"))?;
3934    }
3935    if body.is_empty() {
3936        return Err(anyhow!("message can't be blank"))?;
3937    }
3938
3939    // TODO: adjust mentions if body is trimmed
3940
3941    let timestamp = OffsetDateTime::now_utc();
3942    let nonce = request
3943        .nonce
3944        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3945
3946    let channel_id = ChannelId::from_proto(request.channel_id);
3947    let CreatedChannelMessage {
3948        message_id,
3949        participant_connection_ids,
3950        channel_members,
3951        notifications,
3952    } = session
3953        .db()
3954        .await
3955        .create_channel_message(
3956            channel_id,
3957            session.user_id(),
3958            &body,
3959            &request.mentions,
3960            timestamp,
3961            nonce.clone().into(),
3962            match request.reply_to_message_id {
3963                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3964                None => None,
3965            },
3966        )
3967        .await?;
3968
3969    let message = proto::ChannelMessage {
3970        sender_id: session.user_id().to_proto(),
3971        id: message_id.to_proto(),
3972        body,
3973        mentions: request.mentions,
3974        timestamp: timestamp.unix_timestamp() as u64,
3975        nonce: Some(nonce),
3976        reply_to_message_id: request.reply_to_message_id,
3977        edited_at: None,
3978    };
3979    broadcast(
3980        Some(session.connection_id),
3981        participant_connection_ids,
3982        |connection| {
3983            session.peer.send(
3984                connection,
3985                proto::ChannelMessageSent {
3986                    channel_id: channel_id.to_proto(),
3987                    message: Some(message.clone()),
3988                },
3989            )
3990        },
3991    );
3992    response.send(proto::SendChannelMessageResponse {
3993        message: Some(message),
3994    })?;
3995
3996    let pool = &*session.connection_pool().await;
3997    broadcast(
3998        None,
3999        channel_members
4000            .iter()
4001            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
4002        |peer_id| {
4003            session.peer.send(
4004                peer_id,
4005                proto::UpdateChannels {
4006                    latest_channel_message_ids: vec![proto::ChannelMessageId {
4007                        channel_id: channel_id.to_proto(),
4008                        message_id: message_id.to_proto(),
4009                    }],
4010                    ..Default::default()
4011                },
4012            )
4013        },
4014    );
4015    send_notifications(pool, &session.peer, notifications);
4016
4017    Ok(())
4018}
4019
4020/// Delete a channel message
4021async fn remove_channel_message(
4022    request: proto::RemoveChannelMessage,
4023    response: Response<proto::RemoveChannelMessage>,
4024    session: UserSession,
4025) -> Result<()> {
4026    let channel_id = ChannelId::from_proto(request.channel_id);
4027    let message_id = MessageId::from_proto(request.message_id);
4028    let (connection_ids, existing_notification_ids) = session
4029        .db()
4030        .await
4031        .remove_channel_message(channel_id, message_id, session.user_id())
4032        .await?;
4033
4034    broadcast(
4035        Some(session.connection_id),
4036        connection_ids,
4037        move |connection| {
4038            session.peer.send(connection, request.clone())?;
4039
4040            for notification_id in &existing_notification_ids {
4041                session.peer.send(
4042                    connection,
4043                    proto::DeleteNotification {
4044                        notification_id: (*notification_id).to_proto(),
4045                    },
4046                )?;
4047            }
4048
4049            Ok(())
4050        },
4051    );
4052    response.send(proto::Ack {})?;
4053    Ok(())
4054}
4055
4056async fn update_channel_message(
4057    request: proto::UpdateChannelMessage,
4058    response: Response<proto::UpdateChannelMessage>,
4059    session: UserSession,
4060) -> Result<()> {
4061    let channel_id = ChannelId::from_proto(request.channel_id);
4062    let message_id = MessageId::from_proto(request.message_id);
4063    let updated_at = OffsetDateTime::now_utc();
4064    let UpdatedChannelMessage {
4065        message_id,
4066        participant_connection_ids,
4067        notifications,
4068        reply_to_message_id,
4069        timestamp,
4070        deleted_mention_notification_ids,
4071        updated_mention_notifications,
4072    } = session
4073        .db()
4074        .await
4075        .update_channel_message(
4076            channel_id,
4077            message_id,
4078            session.user_id(),
4079            request.body.as_str(),
4080            &request.mentions,
4081            updated_at,
4082        )
4083        .await?;
4084
4085    let nonce = request
4086        .nonce
4087        .clone()
4088        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4089
4090    let message = proto::ChannelMessage {
4091        sender_id: session.user_id().to_proto(),
4092        id: message_id.to_proto(),
4093        body: request.body.clone(),
4094        mentions: request.mentions.clone(),
4095        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4096        nonce: Some(nonce),
4097        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4098        edited_at: Some(updated_at.unix_timestamp() as u64),
4099    };
4100
4101    response.send(proto::Ack {})?;
4102
4103    let pool = &*session.connection_pool().await;
4104    broadcast(
4105        Some(session.connection_id),
4106        participant_connection_ids,
4107        |connection| {
4108            session.peer.send(
4109                connection,
4110                proto::ChannelMessageUpdate {
4111                    channel_id: channel_id.to_proto(),
4112                    message: Some(message.clone()),
4113                },
4114            )?;
4115
4116            for notification_id in &deleted_mention_notification_ids {
4117                session.peer.send(
4118                    connection,
4119                    proto::DeleteNotification {
4120                        notification_id: (*notification_id).to_proto(),
4121                    },
4122                )?;
4123            }
4124
4125            for notification in &updated_mention_notifications {
4126                session.peer.send(
4127                    connection,
4128                    proto::UpdateNotification {
4129                        notification: Some(notification.clone()),
4130                    },
4131                )?;
4132            }
4133
4134            Ok(())
4135        },
4136    );
4137
4138    send_notifications(pool, &session.peer, notifications);
4139
4140    Ok(())
4141}
4142
4143/// Mark a channel message as read
4144async fn acknowledge_channel_message(
4145    request: proto::AckChannelMessage,
4146    session: UserSession,
4147) -> Result<()> {
4148    let channel_id = ChannelId::from_proto(request.channel_id);
4149    let message_id = MessageId::from_proto(request.message_id);
4150    let notifications = session
4151        .db()
4152        .await
4153        .observe_channel_message(channel_id, session.user_id(), message_id)
4154        .await?;
4155    send_notifications(
4156        &*session.connection_pool().await,
4157        &session.peer,
4158        notifications,
4159    );
4160    Ok(())
4161}
4162
4163/// Mark a buffer version as synced
4164async fn acknowledge_buffer_version(
4165    request: proto::AckBufferOperation,
4166    session: UserSession,
4167) -> Result<()> {
4168    let buffer_id = BufferId::from_proto(request.buffer_id);
4169    session
4170        .db()
4171        .await
4172        .observe_buffer_version(
4173            buffer_id,
4174            session.user_id(),
4175            request.epoch as i32,
4176            &request.version,
4177        )
4178        .await?;
4179    Ok(())
4180}
4181
4182struct CompleteWithLanguageModelRateLimit;
4183
4184impl RateLimit for CompleteWithLanguageModelRateLimit {
4185    fn capacity() -> usize {
4186        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4187            .ok()
4188            .and_then(|v| v.parse().ok())
4189            .unwrap_or(120) // Picked arbitrarily
4190    }
4191
4192    fn refill_duration() -> chrono::Duration {
4193        chrono::Duration::hours(1)
4194    }
4195
4196    fn db_name() -> &'static str {
4197        "complete-with-language-model"
4198    }
4199}
4200
4201async fn complete_with_language_model(
4202    request: proto::CompleteWithLanguageModel,
4203    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4204    session: Session,
4205    open_ai_api_key: Option<Arc<str>>,
4206    google_ai_api_key: Option<Arc<str>>,
4207    anthropic_api_key: Option<Arc<str>>,
4208) -> Result<()> {
4209    let Some(session) = session.for_user() else {
4210        return Err(anyhow!("user not found"))?;
4211    };
4212    authorize_access_to_language_models(&session).await?;
4213    session
4214        .rate_limiter
4215        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4216        .await?;
4217
4218    if request.model.starts_with("gpt") {
4219        let api_key =
4220            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4221        complete_with_open_ai(request, response, session, api_key).await?;
4222    } else if request.model.starts_with("gemini") {
4223        let api_key = google_ai_api_key
4224            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4225        complete_with_google_ai(request, response, session, api_key).await?;
4226    } else if request.model.starts_with("claude") {
4227        let api_key = anthropic_api_key
4228            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4229        complete_with_anthropic(request, response, session, api_key).await?;
4230    }
4231
4232    Ok(())
4233}
4234
4235async fn complete_with_open_ai(
4236    request: proto::CompleteWithLanguageModel,
4237    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4238    session: UserSession,
4239    api_key: Arc<str>,
4240) -> Result<()> {
4241    let mut completion_stream = open_ai::stream_completion(
4242        session.http_client.as_ref(),
4243        OPEN_AI_API_URL,
4244        &api_key,
4245        crate::ai::language_model_request_to_open_ai(request)?,
4246    )
4247    .await
4248    .context("open_ai::stream_completion request failed within collab")?;
4249
4250    while let Some(event) = completion_stream.next().await {
4251        let event = event?;
4252        response.send(proto::LanguageModelResponse {
4253            choices: event
4254                .choices
4255                .into_iter()
4256                .map(|choice| proto::LanguageModelChoiceDelta {
4257                    index: choice.index,
4258                    delta: Some(proto::LanguageModelResponseMessage {
4259                        role: choice.delta.role.map(|role| match role {
4260                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4261                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4262                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4263                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4264                        } as i32),
4265                        content: choice.delta.content,
4266                        tool_calls: choice
4267                            .delta
4268                            .tool_calls
4269                            .into_iter()
4270                            .map(|delta| proto::ToolCallDelta {
4271                                index: delta.index as u32,
4272                                id: delta.id,
4273                                variant: match delta.function {
4274                                    Some(function) => {
4275                                        let name = function.name;
4276                                        let arguments = function.arguments;
4277
4278                                        Some(proto::tool_call_delta::Variant::Function(
4279                                            proto::tool_call_delta::FunctionCallDelta {
4280                                                name,
4281                                                arguments,
4282                                            },
4283                                        ))
4284                                    }
4285                                    None => None,
4286                                },
4287                            })
4288                            .collect(),
4289                    }),
4290                    finish_reason: choice.finish_reason,
4291                })
4292                .collect(),
4293        })?;
4294    }
4295
4296    Ok(())
4297}
4298
4299async fn complete_with_google_ai(
4300    request: proto::CompleteWithLanguageModel,
4301    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4302    session: UserSession,
4303    api_key: Arc<str>,
4304) -> Result<()> {
4305    let mut stream = google_ai::stream_generate_content(
4306        session.http_client.clone(),
4307        google_ai::API_URL,
4308        api_key.as_ref(),
4309        crate::ai::language_model_request_to_google_ai(request)?,
4310    )
4311    .await
4312    .context("google_ai::stream_generate_content request failed")?;
4313
4314    while let Some(event) = stream.next().await {
4315        let event = event?;
4316        response.send(proto::LanguageModelResponse {
4317            choices: event
4318                .candidates
4319                .unwrap_or_default()
4320                .into_iter()
4321                .map(|candidate| proto::LanguageModelChoiceDelta {
4322                    index: candidate.index as u32,
4323                    delta: Some(proto::LanguageModelResponseMessage {
4324                        role: Some(match candidate.content.role {
4325                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4326                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4327                        } as i32),
4328                        content: Some(
4329                            candidate
4330                                .content
4331                                .parts
4332                                .into_iter()
4333                                .filter_map(|part| match part {
4334                                    google_ai::Part::TextPart(part) => Some(part.text),
4335                                    google_ai::Part::InlineDataPart(_) => None,
4336                                })
4337                                .collect(),
4338                        ),
4339                        // Tool calls are not supported for Google
4340                        tool_calls: Vec::new(),
4341                    }),
4342                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4343                })
4344                .collect(),
4345        })?;
4346    }
4347
4348    Ok(())
4349}
4350
4351async fn complete_with_anthropic(
4352    request: proto::CompleteWithLanguageModel,
4353    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4354    session: UserSession,
4355    api_key: Arc<str>,
4356) -> Result<()> {
4357    let model = anthropic::Model::from_id(&request.model)?;
4358
4359    let mut system_message = String::new();
4360    let messages = request
4361        .messages
4362        .into_iter()
4363        .filter_map(|message| {
4364            match message.role() {
4365                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4366                    role: anthropic::Role::User,
4367                    content: message.content,
4368                }),
4369                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4370                    role: anthropic::Role::Assistant,
4371                    content: message.content,
4372                }),
4373                // Anthropic's API breaks system instructions out as a separate field rather
4374                // than having a system message role.
4375                LanguageModelRole::LanguageModelSystem => {
4376                    if !system_message.is_empty() {
4377                        system_message.push_str("\n\n");
4378                    }
4379                    system_message.push_str(&message.content);
4380
4381                    None
4382                }
4383                // We don't yet support tool calls for Anthropic
4384                LanguageModelRole::LanguageModelTool => None,
4385            }
4386        })
4387        .collect();
4388
4389    let mut stream = anthropic::stream_completion(
4390        session.http_client.clone(),
4391        "https://api.anthropic.com",
4392        &api_key,
4393        anthropic::Request {
4394            model,
4395            messages,
4396            stream: true,
4397            system: system_message,
4398            max_tokens: 4092,
4399        },
4400    )
4401    .await?;
4402
4403    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4404
4405    while let Some(event) = stream.next().await {
4406        let event = event?;
4407
4408        match event {
4409            anthropic::ResponseEvent::MessageStart { message } => {
4410                if let Some(role) = message.role {
4411                    if role == "assistant" {
4412                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4413                    } else if role == "user" {
4414                        current_role = proto::LanguageModelRole::LanguageModelUser;
4415                    }
4416                }
4417            }
4418            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4419                match content_block {
4420                    anthropic::ContentBlock::Text { text } => {
4421                        if !text.is_empty() {
4422                            response.send(proto::LanguageModelResponse {
4423                                choices: vec![proto::LanguageModelChoiceDelta {
4424                                    index: 0,
4425                                    delta: Some(proto::LanguageModelResponseMessage {
4426                                        role: Some(current_role as i32),
4427                                        content: Some(text),
4428                                        tool_calls: Vec::new(),
4429                                    }),
4430                                    finish_reason: None,
4431                                }],
4432                            })?;
4433                        }
4434                    }
4435                }
4436            }
4437            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4438                anthropic::TextDelta::TextDelta { text } => {
4439                    response.send(proto::LanguageModelResponse {
4440                        choices: vec![proto::LanguageModelChoiceDelta {
4441                            index: 0,
4442                            delta: Some(proto::LanguageModelResponseMessage {
4443                                role: Some(current_role as i32),
4444                                content: Some(text),
4445                                tool_calls: Vec::new(),
4446                            }),
4447                            finish_reason: None,
4448                        }],
4449                    })?;
4450                }
4451            },
4452            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4453                if let Some(stop_reason) = delta.stop_reason {
4454                    response.send(proto::LanguageModelResponse {
4455                        choices: vec![proto::LanguageModelChoiceDelta {
4456                            index: 0,
4457                            delta: None,
4458                            finish_reason: Some(stop_reason),
4459                        }],
4460                    })?;
4461                }
4462            }
4463            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4464            anthropic::ResponseEvent::MessageStop {} => {}
4465            anthropic::ResponseEvent::Ping {} => {}
4466        }
4467    }
4468
4469    Ok(())
4470}
4471
4472struct CountTokensWithLanguageModelRateLimit;
4473
4474impl RateLimit for CountTokensWithLanguageModelRateLimit {
4475    fn capacity() -> usize {
4476        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4477            .ok()
4478            .and_then(|v| v.parse().ok())
4479            .unwrap_or(600) // Picked arbitrarily
4480    }
4481
4482    fn refill_duration() -> chrono::Duration {
4483        chrono::Duration::hours(1)
4484    }
4485
4486    fn db_name() -> &'static str {
4487        "count-tokens-with-language-model"
4488    }
4489}
4490
4491async fn count_tokens_with_language_model(
4492    request: proto::CountTokensWithLanguageModel,
4493    response: Response<proto::CountTokensWithLanguageModel>,
4494    session: UserSession,
4495    google_ai_api_key: Option<Arc<str>>,
4496) -> Result<()> {
4497    authorize_access_to_language_models(&session).await?;
4498
4499    if !request.model.starts_with("gemini") {
4500        return Err(anyhow!(
4501            "counting tokens for model: {:?} is not supported",
4502            request.model
4503        ))?;
4504    }
4505
4506    session
4507        .rate_limiter
4508        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4509        .await?;
4510
4511    let api_key = google_ai_api_key
4512        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4513    let tokens_response = google_ai::count_tokens(
4514        session.http_client.as_ref(),
4515        google_ai::API_URL,
4516        &api_key,
4517        crate::ai::count_tokens_request_to_google_ai(request)?,
4518    )
4519    .await?;
4520    response.send(proto::CountTokensResponse {
4521        token_count: tokens_response.total_tokens as u32,
4522    })?;
4523    Ok(())
4524}
4525
4526struct ComputeEmbeddingsRateLimit;
4527
4528impl RateLimit for ComputeEmbeddingsRateLimit {
4529    fn capacity() -> usize {
4530        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4531            .ok()
4532            .and_then(|v| v.parse().ok())
4533            .unwrap_or(5000) // Picked arbitrarily
4534    }
4535
4536    fn refill_duration() -> chrono::Duration {
4537        chrono::Duration::hours(1)
4538    }
4539
4540    fn db_name() -> &'static str {
4541        "compute-embeddings"
4542    }
4543}
4544
4545async fn compute_embeddings(
4546    request: proto::ComputeEmbeddings,
4547    response: Response<proto::ComputeEmbeddings>,
4548    session: UserSession,
4549    api_key: Option<Arc<str>>,
4550) -> Result<()> {
4551    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4552    authorize_access_to_language_models(&session).await?;
4553
4554    session
4555        .rate_limiter
4556        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4557        .await?;
4558
4559    let embeddings = match request.model.as_str() {
4560        "openai/text-embedding-3-small" => {
4561            open_ai::embed(
4562                session.http_client.as_ref(),
4563                OPEN_AI_API_URL,
4564                &api_key,
4565                OpenAiEmbeddingModel::TextEmbedding3Small,
4566                request.texts.iter().map(|text| text.as_str()),
4567            )
4568            .await?
4569        }
4570        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4571    };
4572
4573    let embeddings = request
4574        .texts
4575        .iter()
4576        .map(|text| {
4577            let mut hasher = sha2::Sha256::new();
4578            hasher.update(text.as_bytes());
4579            let result = hasher.finalize();
4580            result.to_vec()
4581        })
4582        .zip(
4583            embeddings
4584                .data
4585                .into_iter()
4586                .map(|embedding| embedding.embedding),
4587        )
4588        .collect::<HashMap<_, _>>();
4589
4590    let db = session.db().await;
4591    db.save_embeddings(&request.model, &embeddings)
4592        .await
4593        .context("failed to save embeddings")
4594        .trace_err();
4595
4596    response.send(proto::ComputeEmbeddingsResponse {
4597        embeddings: embeddings
4598            .into_iter()
4599            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4600            .collect(),
4601    })?;
4602    Ok(())
4603}
4604
4605async fn get_cached_embeddings(
4606    request: proto::GetCachedEmbeddings,
4607    response: Response<proto::GetCachedEmbeddings>,
4608    session: UserSession,
4609) -> Result<()> {
4610    authorize_access_to_language_models(&session).await?;
4611
4612    let db = session.db().await;
4613    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4614
4615    response.send(proto::GetCachedEmbeddingsResponse {
4616        embeddings: embeddings
4617            .into_iter()
4618            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4619            .collect(),
4620    })?;
4621    Ok(())
4622}
4623
4624async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4625    let db = session.db().await;
4626    let flags = db.get_user_flags(session.user_id()).await?;
4627    if flags.iter().any(|flag| flag == "language-models") {
4628        Ok(())
4629    } else {
4630        Err(anyhow!("permission denied"))?
4631    }
4632}
4633
4634/// Get a Supermaven API key for the user
4635async fn get_supermaven_api_key(
4636    _request: proto::GetSupermavenApiKey,
4637    response: Response<proto::GetSupermavenApiKey>,
4638    session: UserSession,
4639) -> Result<()> {
4640    let user_id: String = session.user_id().to_string();
4641    if !session.is_staff() {
4642        return Err(anyhow!("supermaven not enabled for this account"))?;
4643    }
4644
4645    let email = session
4646        .email()
4647        .ok_or_else(|| anyhow!("user must have an email"))?;
4648
4649    let supermaven_admin_api = session
4650        .supermaven_client
4651        .as_ref()
4652        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4653
4654    let result = supermaven_admin_api
4655        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4656        .await?;
4657
4658    response.send(proto::GetSupermavenApiKeyResponse {
4659        api_key: result.api_key,
4660    })?;
4661
4662    Ok(())
4663}
4664
4665/// Start receiving chat updates for a channel
4666async fn join_channel_chat(
4667    request: proto::JoinChannelChat,
4668    response: Response<proto::JoinChannelChat>,
4669    session: UserSession,
4670) -> Result<()> {
4671    let channel_id = ChannelId::from_proto(request.channel_id);
4672
4673    let db = session.db().await;
4674    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4675        .await?;
4676    let messages = db
4677        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4678        .await?;
4679    response.send(proto::JoinChannelChatResponse {
4680        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4681        messages,
4682    })?;
4683    Ok(())
4684}
4685
4686/// Stop receiving chat updates for a channel
4687async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4688    let channel_id = ChannelId::from_proto(request.channel_id);
4689    session
4690        .db()
4691        .await
4692        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4693        .await?;
4694    Ok(())
4695}
4696
4697/// Retrieve the chat history for a channel
4698async fn get_channel_messages(
4699    request: proto::GetChannelMessages,
4700    response: Response<proto::GetChannelMessages>,
4701    session: UserSession,
4702) -> Result<()> {
4703    let channel_id = ChannelId::from_proto(request.channel_id);
4704    let messages = session
4705        .db()
4706        .await
4707        .get_channel_messages(
4708            channel_id,
4709            session.user_id(),
4710            MESSAGE_COUNT_PER_PAGE,
4711            Some(MessageId::from_proto(request.before_message_id)),
4712        )
4713        .await?;
4714    response.send(proto::GetChannelMessagesResponse {
4715        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4716        messages,
4717    })?;
4718    Ok(())
4719}
4720
4721/// Retrieve specific chat messages
4722async fn get_channel_messages_by_id(
4723    request: proto::GetChannelMessagesById,
4724    response: Response<proto::GetChannelMessagesById>,
4725    session: UserSession,
4726) -> Result<()> {
4727    let message_ids = request
4728        .message_ids
4729        .iter()
4730        .map(|id| MessageId::from_proto(*id))
4731        .collect::<Vec<_>>();
4732    let messages = session
4733        .db()
4734        .await
4735        .get_channel_messages_by_id(session.user_id(), &message_ids)
4736        .await?;
4737    response.send(proto::GetChannelMessagesResponse {
4738        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4739        messages,
4740    })?;
4741    Ok(())
4742}
4743
4744/// Retrieve the current users notifications
4745async fn get_notifications(
4746    request: proto::GetNotifications,
4747    response: Response<proto::GetNotifications>,
4748    session: UserSession,
4749) -> Result<()> {
4750    let notifications = session
4751        .db()
4752        .await
4753        .get_notifications(
4754            session.user_id(),
4755            NOTIFICATION_COUNT_PER_PAGE,
4756            request
4757                .before_id
4758                .map(|id| db::NotificationId::from_proto(id)),
4759        )
4760        .await?;
4761    response.send(proto::GetNotificationsResponse {
4762        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4763        notifications,
4764    })?;
4765    Ok(())
4766}
4767
4768/// Mark notifications as read
4769async fn mark_notification_as_read(
4770    request: proto::MarkNotificationRead,
4771    response: Response<proto::MarkNotificationRead>,
4772    session: UserSession,
4773) -> Result<()> {
4774    let database = &session.db().await;
4775    let notifications = database
4776        .mark_notification_as_read_by_id(
4777            session.user_id(),
4778            NotificationId::from_proto(request.notification_id),
4779        )
4780        .await?;
4781    send_notifications(
4782        &*session.connection_pool().await,
4783        &session.peer,
4784        notifications,
4785    );
4786    response.send(proto::Ack {})?;
4787    Ok(())
4788}
4789
4790/// Get the current users information
4791async fn get_private_user_info(
4792    _request: proto::GetPrivateUserInfo,
4793    response: Response<proto::GetPrivateUserInfo>,
4794    session: UserSession,
4795) -> Result<()> {
4796    let db = session.db().await;
4797
4798    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4799    let user = db
4800        .get_user_by_id(session.user_id())
4801        .await?
4802        .ok_or_else(|| anyhow!("user not found"))?;
4803    let flags = db.get_user_flags(session.user_id()).await?;
4804
4805    response.send(proto::GetPrivateUserInfoResponse {
4806        metrics_id,
4807        staff: user.admin,
4808        flags,
4809    })?;
4810    Ok(())
4811}
4812
4813fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4814    match message {
4815        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4816        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4817        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4818        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4819        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4820            code: frame.code.into(),
4821            reason: frame.reason,
4822        })),
4823    }
4824}
4825
4826fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4827    match message {
4828        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4829        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4830        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4831        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4832        AxumMessage::Close(frame) => {
4833            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4834                code: frame.code.into(),
4835                reason: frame.reason,
4836            }))
4837        }
4838    }
4839}
4840
4841fn notify_membership_updated(
4842    connection_pool: &mut ConnectionPool,
4843    result: MembershipUpdated,
4844    user_id: UserId,
4845    peer: &Peer,
4846) {
4847    for membership in &result.new_channels.channel_memberships {
4848        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4849    }
4850    for channel_id in &result.removed_channels {
4851        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4852    }
4853
4854    let user_channels_update = proto::UpdateUserChannels {
4855        channel_memberships: result
4856            .new_channels
4857            .channel_memberships
4858            .iter()
4859            .map(|cm| proto::ChannelMembership {
4860                channel_id: cm.channel_id.to_proto(),
4861                role: cm.role.into(),
4862            })
4863            .collect(),
4864        ..Default::default()
4865    };
4866
4867    let mut update = build_channels_update(result.new_channels, vec![]);
4868    update.delete_channels = result
4869        .removed_channels
4870        .into_iter()
4871        .map(|id| id.to_proto())
4872        .collect();
4873    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4874
4875    for connection_id in connection_pool.user_connection_ids(user_id) {
4876        peer.send(connection_id, user_channels_update.clone())
4877            .trace_err();
4878        peer.send(connection_id, update.clone()).trace_err();
4879    }
4880}
4881
4882fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4883    proto::UpdateUserChannels {
4884        channel_memberships: channels
4885            .channel_memberships
4886            .iter()
4887            .map(|m| proto::ChannelMembership {
4888                channel_id: m.channel_id.to_proto(),
4889                role: m.role.into(),
4890            })
4891            .collect(),
4892        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4893        observed_channel_message_id: channels.observed_channel_messages.clone(),
4894    }
4895}
4896
4897fn build_channels_update(
4898    channels: ChannelsForUser,
4899    channel_invites: Vec<db::Channel>,
4900) -> proto::UpdateChannels {
4901    let mut update = proto::UpdateChannels::default();
4902
4903    for channel in channels.channels {
4904        update.channels.push(channel.to_proto());
4905    }
4906
4907    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4908    update.latest_channel_message_ids = channels.latest_channel_messages;
4909
4910    for (channel_id, participants) in channels.channel_participants {
4911        update
4912            .channel_participants
4913            .push(proto::ChannelParticipants {
4914                channel_id: channel_id.to_proto(),
4915                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4916            });
4917    }
4918
4919    for channel in channel_invites {
4920        update.channel_invitations.push(channel.to_proto());
4921    }
4922
4923    update.hosted_projects = channels.hosted_projects;
4924    update
4925}
4926
4927fn build_initial_contacts_update(
4928    contacts: Vec<db::Contact>,
4929    pool: &ConnectionPool,
4930) -> proto::UpdateContacts {
4931    let mut update = proto::UpdateContacts::default();
4932
4933    for contact in contacts {
4934        match contact {
4935            db::Contact::Accepted { user_id, busy } => {
4936                update.contacts.push(contact_for_user(user_id, busy, &pool));
4937            }
4938            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4939            db::Contact::Incoming { user_id } => {
4940                update
4941                    .incoming_requests
4942                    .push(proto::IncomingContactRequest {
4943                        requester_id: user_id.to_proto(),
4944                    })
4945            }
4946        }
4947    }
4948
4949    update
4950}
4951
4952fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4953    proto::Contact {
4954        user_id: user_id.to_proto(),
4955        online: pool.is_user_online(user_id),
4956        busy,
4957    }
4958}
4959
4960fn room_updated(room: &proto::Room, peer: &Peer) {
4961    broadcast(
4962        None,
4963        room.participants
4964            .iter()
4965            .filter_map(|participant| Some(participant.peer_id?.into())),
4966        |peer_id| {
4967            peer.send(
4968                peer_id,
4969                proto::RoomUpdated {
4970                    room: Some(room.clone()),
4971                },
4972            )
4973        },
4974    );
4975}
4976
4977fn channel_updated(
4978    channel: &db::channel::Model,
4979    room: &proto::Room,
4980    peer: &Peer,
4981    pool: &ConnectionPool,
4982) {
4983    let participants = room
4984        .participants
4985        .iter()
4986        .map(|p| p.user_id)
4987        .collect::<Vec<_>>();
4988
4989    broadcast(
4990        None,
4991        pool.channel_connection_ids(channel.root_id())
4992            .filter_map(|(channel_id, role)| {
4993                role.can_see_channel(channel.visibility).then(|| channel_id)
4994            }),
4995        |peer_id| {
4996            peer.send(
4997                peer_id,
4998                proto::UpdateChannels {
4999                    channel_participants: vec![proto::ChannelParticipants {
5000                        channel_id: channel.id.to_proto(),
5001                        participant_user_ids: participants.clone(),
5002                    }],
5003                    ..Default::default()
5004                },
5005            )
5006        },
5007    );
5008}
5009
5010async fn send_dev_server_projects_update(
5011    user_id: UserId,
5012    mut status: proto::DevServerProjectsUpdate,
5013    session: &Session,
5014) {
5015    let pool = session.connection_pool().await;
5016    for dev_server in &mut status.dev_servers {
5017        dev_server.status =
5018            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5019    }
5020    let connections = pool.user_connection_ids(user_id);
5021    for connection_id in connections {
5022        session.peer.send(connection_id, status.clone()).trace_err();
5023    }
5024}
5025
5026async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5027    let db = session.db().await;
5028
5029    let contacts = db.get_contacts(user_id).await?;
5030    let busy = db.is_user_busy(user_id).await?;
5031
5032    let pool = session.connection_pool().await;
5033    let updated_contact = contact_for_user(user_id, busy, &pool);
5034    for contact in contacts {
5035        if let db::Contact::Accepted {
5036            user_id: contact_user_id,
5037            ..
5038        } = contact
5039        {
5040            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5041                session
5042                    .peer
5043                    .send(
5044                        contact_conn_id,
5045                        proto::UpdateContacts {
5046                            contacts: vec![updated_contact.clone()],
5047                            remove_contacts: Default::default(),
5048                            incoming_requests: Default::default(),
5049                            remove_incoming_requests: Default::default(),
5050                            outgoing_requests: Default::default(),
5051                            remove_outgoing_requests: Default::default(),
5052                        },
5053                    )
5054                    .trace_err();
5055            }
5056        }
5057    }
5058    Ok(())
5059}
5060
5061async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5062    log::info!("lost dev server connection, unsharing projects");
5063    let project_ids = session
5064        .db()
5065        .await
5066        .get_stale_dev_server_projects(session.connection_id)
5067        .await?;
5068
5069    for project_id in project_ids {
5070        // not unshare re-checks the connection ids match, so we get away with no transaction
5071        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5072    }
5073
5074    let user_id = session.dev_server().user_id;
5075    let update = session
5076        .db()
5077        .await
5078        .dev_server_projects_update(user_id)
5079        .await?;
5080
5081    send_dev_server_projects_update(user_id, update, session).await;
5082
5083    Ok(())
5084}
5085
5086async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5087    let mut contacts_to_update = HashSet::default();
5088
5089    let room_id;
5090    let canceled_calls_to_user_ids;
5091    let live_kit_room;
5092    let delete_live_kit_room;
5093    let room;
5094    let channel;
5095
5096    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5097        contacts_to_update.insert(session.user_id());
5098
5099        for project in left_room.left_projects.values() {
5100            project_left(project, session);
5101        }
5102
5103        room_id = RoomId::from_proto(left_room.room.id);
5104        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5105        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5106        delete_live_kit_room = left_room.deleted;
5107        room = mem::take(&mut left_room.room);
5108        channel = mem::take(&mut left_room.channel);
5109
5110        room_updated(&room, &session.peer);
5111    } else {
5112        return Ok(());
5113    }
5114
5115    if let Some(channel) = channel {
5116        channel_updated(
5117            &channel,
5118            &room,
5119            &session.peer,
5120            &*session.connection_pool().await,
5121        );
5122    }
5123
5124    {
5125        let pool = session.connection_pool().await;
5126        for canceled_user_id in canceled_calls_to_user_ids {
5127            for connection_id in pool.user_connection_ids(canceled_user_id) {
5128                session
5129                    .peer
5130                    .send(
5131                        connection_id,
5132                        proto::CallCanceled {
5133                            room_id: room_id.to_proto(),
5134                        },
5135                    )
5136                    .trace_err();
5137            }
5138            contacts_to_update.insert(canceled_user_id);
5139        }
5140    }
5141
5142    for contact_user_id in contacts_to_update {
5143        update_user_contacts(contact_user_id, &session).await?;
5144    }
5145
5146    if let Some(live_kit) = session.live_kit_client.as_ref() {
5147        live_kit
5148            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5149            .await
5150            .trace_err();
5151
5152        if delete_live_kit_room {
5153            live_kit.delete_room(live_kit_room).await.trace_err();
5154        }
5155    }
5156
5157    Ok(())
5158}
5159
5160async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5161    let left_channel_buffers = session
5162        .db()
5163        .await
5164        .leave_channel_buffers(session.connection_id)
5165        .await?;
5166
5167    for left_buffer in left_channel_buffers {
5168        channel_buffer_updated(
5169            session.connection_id,
5170            left_buffer.connections,
5171            &proto::UpdateChannelBufferCollaborators {
5172                channel_id: left_buffer.channel_id.to_proto(),
5173                collaborators: left_buffer.collaborators,
5174            },
5175            &session.peer,
5176        );
5177    }
5178
5179    Ok(())
5180}
5181
5182fn project_left(project: &db::LeftProject, session: &UserSession) {
5183    for connection_id in &project.connection_ids {
5184        if project.should_unshare {
5185            session
5186                .peer
5187                .send(
5188                    *connection_id,
5189                    proto::UnshareProject {
5190                        project_id: project.id.to_proto(),
5191                    },
5192                )
5193                .trace_err();
5194        } else {
5195            session
5196                .peer
5197                .send(
5198                    *connection_id,
5199                    proto::RemoveProjectCollaborator {
5200                        project_id: project.id.to_proto(),
5201                        peer_id: Some(session.connection_id.into()),
5202                    },
5203                )
5204                .trace_err();
5205        }
5206    }
5207}
5208
5209pub trait ResultExt {
5210    type Ok;
5211
5212    fn trace_err(self) -> Option<Self::Ok>;
5213}
5214
5215impl<T, E> ResultExt for Result<T, E>
5216where
5217    E: std::fmt::Debug,
5218{
5219    type Ok = T;
5220
5221    #[track_caller]
5222    fn trace_err(self) -> Option<T> {
5223        match self {
5224            Ok(value) => Some(value),
5225            Err(error) => {
5226                tracing::error!("{:?}", error);
5227                None
5228            }
5229        }
5230    }
5231}