rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  50        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(delete_dev_server_project))
 435            .add_request_handler(user_handler(create_dev_server))
 436            .add_request_handler(user_handler(regenerate_dev_server_token))
 437            .add_request_handler(user_handler(rename_dev_server))
 438            .add_request_handler(user_handler(delete_dev_server))
 439            .add_request_handler(dev_server_handler(share_dev_server_project))
 440            .add_request_handler(dev_server_handler(shutdown_dev_server))
 441            .add_request_handler(dev_server_handler(reconnect_dev_server))
 442            .add_message_handler(user_message_handler(leave_project))
 443            .add_request_handler(update_project)
 444            .add_request_handler(update_worktree)
 445            .add_message_handler(start_language_server)
 446            .add_message_handler(update_language_server)
 447            .add_message_handler(update_diagnostic_summary)
 448            .add_message_handler(update_worktree_settings)
 449            .add_request_handler(user_handler(
 450                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 451            ))
 452            .add_request_handler(user_handler(
 453                forward_project_request_for_owner::<proto::TaskTemplates>,
 454            ))
 455            .add_request_handler(user_handler(
 456                forward_read_only_project_request::<proto::GetHover>,
 457            ))
 458            .add_request_handler(user_handler(
 459                forward_read_only_project_request::<proto::GetDefinition>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_read_only_project_request::<proto::GetTypeDefinition>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetReferences>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::SearchProject>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetProjectSymbols>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_read_only_project_request::<proto::OpenBufferById>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::InlayHints>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::OpenBufferByPath>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_mutating_project_request::<proto::GetCompletions>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::GetCodeActions>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::ApplyCodeAction>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::PrepareRename>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::PerformRename>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::ReloadBuffers>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::FormatBuffers>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::CreateProjectEntry>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::RenameProjectEntry>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::CopyProjectEntry>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::OnTypeFormatting>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::BlameBuffer>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::MultiLspQuery>,
 547            ))
 548            .add_message_handler(create_buffer_for_peer)
 549            .add_request_handler(update_buffer)
 550            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 551            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 552            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 553            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 554            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 555            .add_request_handler(get_users)
 556            .add_request_handler(user_handler(fuzzy_search_users))
 557            .add_request_handler(user_handler(request_contact))
 558            .add_request_handler(user_handler(remove_contact))
 559            .add_request_handler(user_handler(respond_to_contact_request))
 560            .add_message_handler(subscribe_to_channels)
 561            .add_request_handler(user_handler(create_channel))
 562            .add_request_handler(user_handler(delete_channel))
 563            .add_request_handler(user_handler(invite_channel_member))
 564            .add_request_handler(user_handler(remove_channel_member))
 565            .add_request_handler(user_handler(set_channel_member_role))
 566            .add_request_handler(user_handler(set_channel_visibility))
 567            .add_request_handler(user_handler(rename_channel))
 568            .add_request_handler(user_handler(join_channel_buffer))
 569            .add_request_handler(user_handler(leave_channel_buffer))
 570            .add_message_handler(user_message_handler(update_channel_buffer))
 571            .add_request_handler(user_handler(rejoin_channel_buffers))
 572            .add_request_handler(user_handler(get_channel_members))
 573            .add_request_handler(user_handler(respond_to_channel_invite))
 574            .add_request_handler(user_handler(join_channel))
 575            .add_request_handler(user_handler(join_channel_chat))
 576            .add_message_handler(user_message_handler(leave_channel_chat))
 577            .add_request_handler(user_handler(send_channel_message))
 578            .add_request_handler(user_handler(remove_channel_message))
 579            .add_request_handler(user_handler(update_channel_message))
 580            .add_request_handler(user_handler(get_channel_messages))
 581            .add_request_handler(user_handler(get_channel_messages_by_id))
 582            .add_request_handler(user_handler(get_notifications))
 583            .add_request_handler(user_handler(mark_notification_as_read))
 584            .add_request_handler(user_handler(move_channel))
 585            .add_request_handler(user_handler(follow))
 586            .add_message_handler(user_message_handler(unfollow))
 587            .add_message_handler(user_message_handler(update_followers))
 588            .add_request_handler(user_handler(get_private_user_info))
 589            .add_message_handler(user_message_handler(acknowledge_channel_message))
 590            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 591            .add_request_handler(user_handler(get_supermaven_api_key))
 592            .add_streaming_request_handler({
 593                let app_state = app_state.clone();
 594                move |request, response, session| {
 595                    complete_with_language_model(
 596                        request,
 597                        response,
 598                        session,
 599                        app_state.config.openai_api_key.clone(),
 600                        app_state.config.google_ai_api_key.clone(),
 601                        app_state.config.anthropic_api_key.clone(),
 602                    )
 603                }
 604            })
 605            .add_request_handler({
 606                let app_state = app_state.clone();
 607                user_handler(move |request, response, session| {
 608                    count_tokens_with_language_model(
 609                        request,
 610                        response,
 611                        session,
 612                        app_state.config.google_ai_api_key.clone(),
 613                    )
 614                })
 615            })
 616            .add_request_handler({
 617                user_handler(move |request, response, session| {
 618                    get_cached_embeddings(request, response, session)
 619                })
 620            })
 621            .add_request_handler({
 622                let app_state = app_state.clone();
 623                user_handler(move |request, response, session| {
 624                    compute_embeddings(
 625                        request,
 626                        response,
 627                        session,
 628                        app_state.config.openai_api_key.clone(),
 629                    )
 630                })
 631            });
 632
 633        Arc::new(server)
 634    }
 635
 636    pub async fn start(&self) -> Result<()> {
 637        let server_id = *self.id.lock();
 638        let app_state = self.app_state.clone();
 639        let peer = self.peer.clone();
 640        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 641        let pool = self.connection_pool.clone();
 642        let live_kit_client = self.app_state.live_kit_client.clone();
 643
 644        let span = info_span!("start server");
 645        self.app_state.executor.spawn_detached(
 646            async move {
 647                tracing::info!("waiting for cleanup timeout");
 648                timeout.await;
 649                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 650                if let Some((room_ids, channel_ids)) = app_state
 651                    .db
 652                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 653                    .await
 654                    .trace_err()
 655                {
 656                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 657                    tracing::info!(
 658                        stale_channel_buffer_count = channel_ids.len(),
 659                        "retrieved stale channel buffers"
 660                    );
 661
 662                    for channel_id in channel_ids {
 663                        if let Some(refreshed_channel_buffer) = app_state
 664                            .db
 665                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 666                            .await
 667                            .trace_err()
 668                        {
 669                            for connection_id in refreshed_channel_buffer.connection_ids {
 670                                peer.send(
 671                                    connection_id,
 672                                    proto::UpdateChannelBufferCollaborators {
 673                                        channel_id: channel_id.to_proto(),
 674                                        collaborators: refreshed_channel_buffer
 675                                            .collaborators
 676                                            .clone(),
 677                                    },
 678                                )
 679                                .trace_err();
 680                            }
 681                        }
 682                    }
 683
 684                    for room_id in room_ids {
 685                        let mut contacts_to_update = HashSet::default();
 686                        let mut canceled_calls_to_user_ids = Vec::new();
 687                        let mut live_kit_room = String::new();
 688                        let mut delete_live_kit_room = false;
 689
 690                        if let Some(mut refreshed_room) = app_state
 691                            .db
 692                            .clear_stale_room_participants(room_id, server_id)
 693                            .await
 694                            .trace_err()
 695                        {
 696                            tracing::info!(
 697                                room_id = room_id.0,
 698                                new_participant_count = refreshed_room.room.participants.len(),
 699                                "refreshed room"
 700                            );
 701                            room_updated(&refreshed_room.room, &peer);
 702                            if let Some(channel) = refreshed_room.channel.as_ref() {
 703                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 704                            }
 705                            contacts_to_update
 706                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 707                            contacts_to_update
 708                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 709                            canceled_calls_to_user_ids =
 710                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 711                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 712                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 713                        }
 714
 715                        {
 716                            let pool = pool.lock();
 717                            for canceled_user_id in canceled_calls_to_user_ids {
 718                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 719                                    peer.send(
 720                                        connection_id,
 721                                        proto::CallCanceled {
 722                                            room_id: room_id.to_proto(),
 723                                        },
 724                                    )
 725                                    .trace_err();
 726                                }
 727                            }
 728                        }
 729
 730                        for user_id in contacts_to_update {
 731                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 732                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 733                            if let Some((busy, contacts)) = busy.zip(contacts) {
 734                                let pool = pool.lock();
 735                                let updated_contact = contact_for_user(user_id, busy, &pool);
 736                                for contact in contacts {
 737                                    if let db::Contact::Accepted {
 738                                        user_id: contact_user_id,
 739                                        ..
 740                                    } = contact
 741                                    {
 742                                        for contact_conn_id in
 743                                            pool.user_connection_ids(contact_user_id)
 744                                        {
 745                                            peer.send(
 746                                                contact_conn_id,
 747                                                proto::UpdateContacts {
 748                                                    contacts: vec![updated_contact.clone()],
 749                                                    remove_contacts: Default::default(),
 750                                                    incoming_requests: Default::default(),
 751                                                    remove_incoming_requests: Default::default(),
 752                                                    outgoing_requests: Default::default(),
 753                                                    remove_outgoing_requests: Default::default(),
 754                                                },
 755                                            )
 756                                            .trace_err();
 757                                        }
 758                                    }
 759                                }
 760                            }
 761                        }
 762
 763                        if let Some(live_kit) = live_kit_client.as_ref() {
 764                            if delete_live_kit_room {
 765                                live_kit.delete_room(live_kit_room).await.trace_err();
 766                            }
 767                        }
 768                    }
 769                }
 770
 771                app_state
 772                    .db
 773                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 774                    .await
 775                    .trace_err();
 776            }
 777            .instrument(span),
 778        );
 779        Ok(())
 780    }
 781
 782    pub fn teardown(&self) {
 783        self.peer.teardown();
 784        self.connection_pool.lock().reset();
 785        let _ = self.teardown.send(true);
 786    }
 787
 788    #[cfg(test)]
 789    pub fn reset(&self, id: ServerId) {
 790        self.teardown();
 791        *self.id.lock() = id;
 792        self.peer.reset(id.0 as u32);
 793        let _ = self.teardown.send(false);
 794    }
 795
 796    #[cfg(test)]
 797    pub fn id(&self) -> ServerId {
 798        *self.id.lock()
 799    }
 800
 801    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 802    where
 803        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 804        Fut: 'static + Send + Future<Output = Result<()>>,
 805        M: EnvelopedMessage,
 806    {
 807        let prev_handler = self.handlers.insert(
 808            TypeId::of::<M>(),
 809            Box::new(move |envelope, session| {
 810                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 811                let received_at = envelope.received_at;
 812                tracing::info!("message received");
 813                let start_time = Instant::now();
 814                let future = (handler)(*envelope, session);
 815                async move {
 816                    let result = future.await;
 817                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 818                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 819                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 820                    let payload_type = M::NAME;
 821
 822                    match result {
 823                        Err(error) => {
 824                            tracing::error!(
 825                                ?error,
 826                                total_duration_ms,
 827                                processing_duration_ms,
 828                                queue_duration_ms,
 829                                payload_type,
 830                                "error handling message"
 831                            )
 832                        }
 833                        Ok(()) => tracing::info!(
 834                            total_duration_ms,
 835                            processing_duration_ms,
 836                            queue_duration_ms,
 837                            "finished handling message"
 838                        ),
 839                    }
 840                }
 841                .boxed()
 842            }),
 843        );
 844        if prev_handler.is_some() {
 845            panic!("registered a handler for the same message twice");
 846        }
 847        self
 848    }
 849
 850    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 851    where
 852        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 853        Fut: 'static + Send + Future<Output = Result<()>>,
 854        M: EnvelopedMessage,
 855    {
 856        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 857        self
 858    }
 859
 860    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 861    where
 862        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 863        Fut: Send + Future<Output = Result<()>>,
 864        M: RequestMessage,
 865    {
 866        let handler = Arc::new(handler);
 867        self.add_handler(move |envelope, session| {
 868            let receipt = envelope.receipt();
 869            let handler = handler.clone();
 870            async move {
 871                let peer = session.peer.clone();
 872                let responded = Arc::new(AtomicBool::default());
 873                let response = Response {
 874                    peer: peer.clone(),
 875                    responded: responded.clone(),
 876                    receipt,
 877                };
 878                match (handler)(envelope.payload, response, session).await {
 879                    Ok(()) => {
 880                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 881                            Ok(())
 882                        } else {
 883                            Err(anyhow!("handler did not send a response"))?
 884                        }
 885                    }
 886                    Err(error) => {
 887                        let proto_err = match &error {
 888                            Error::Internal(err) => err.to_proto(),
 889                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 890                        };
 891                        peer.respond_with_error(receipt, proto_err)?;
 892                        Err(error)
 893                    }
 894                }
 895            }
 896        })
 897    }
 898
 899    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 900    where
 901        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 902        Fut: Send + Future<Output = Result<()>>,
 903        M: RequestMessage,
 904    {
 905        let handler = Arc::new(handler);
 906        self.add_handler(move |envelope, session| {
 907            let receipt = envelope.receipt();
 908            let handler = handler.clone();
 909            async move {
 910                let peer = session.peer.clone();
 911                let response = StreamingResponse {
 912                    peer: peer.clone(),
 913                    receipt,
 914                };
 915                match (handler)(envelope.payload, response, session).await {
 916                    Ok(()) => {
 917                        peer.end_stream(receipt)?;
 918                        Ok(())
 919                    }
 920                    Err(error) => {
 921                        let proto_err = match &error {
 922                            Error::Internal(err) => err.to_proto(),
 923                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 924                        };
 925                        peer.respond_with_error(receipt, proto_err)?;
 926                        Err(error)
 927                    }
 928                }
 929            }
 930        })
 931    }
 932
 933    #[allow(clippy::too_many_arguments)]
 934    pub fn handle_connection(
 935        self: &Arc<Self>,
 936        connection: Connection,
 937        address: String,
 938        principal: Principal,
 939        zed_version: ZedVersion,
 940        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 941        executor: Executor,
 942    ) -> impl Future<Output = ()> {
 943        let this = self.clone();
 944        let span = info_span!("handle connection", %address,
 945            connection_id=field::Empty,
 946            user_id=field::Empty,
 947            login=field::Empty,
 948            impersonator=field::Empty,
 949            dev_server_id=field::Empty
 950        );
 951        principal.update_span(&span);
 952
 953        let mut teardown = self.teardown.subscribe();
 954        async move {
 955            if *teardown.borrow() {
 956                tracing::error!("server is tearing down");
 957                return
 958            }
 959            let (connection_id, handle_io, mut incoming_rx) = this
 960                .peer
 961                .add_connection(connection, {
 962                    let executor = executor.clone();
 963                    move |duration| executor.sleep(duration)
 964                });
 965            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 966            tracing::info!("connection opened");
 967
 968            let http_client = match IsahcHttpClient::new() {
 969                Ok(http_client) => Arc::new(http_client),
 970                Err(error) => {
 971                    tracing::error!(?error, "failed to create HTTP client");
 972                    return;
 973                }
 974            };
 975
 976            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 977                Some(Arc::new(SupermavenAdminApi::new(
 978                    supermaven_admin_api_key.to_string(),
 979                    http_client.clone(),
 980                )))
 981            } else {
 982                None
 983            };
 984
 985            let session = Session {
 986                principal: principal.clone(),
 987                connection_id,
 988                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 989                peer: this.peer.clone(),
 990                connection_pool: this.connection_pool.clone(),
 991                live_kit_client: this.app_state.live_kit_client.clone(),
 992                http_client,
 993                rate_limiter: this.app_state.rate_limiter.clone(),
 994                _executor: executor.clone(),
 995                supermaven_client,
 996            };
 997
 998            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 999                tracing::error!(?error, "failed to send initial client update");
1000                return;
1001            }
1002
1003            let handle_io = handle_io.fuse();
1004            futures::pin_mut!(handle_io);
1005
1006            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1007            // This prevents deadlocks when e.g., client A performs a request to client B and
1008            // client B performs a request to client A. If both clients stop processing further
1009            // messages until their respective request completes, they won't have a chance to
1010            // respond to the other client's request and cause a deadlock.
1011            //
1012            // This arrangement ensures we will attempt to process earlier messages first, but fall
1013            // back to processing messages arrived later in the spirit of making progress.
1014            let mut foreground_message_handlers = FuturesUnordered::new();
1015            let concurrent_handlers = Arc::new(Semaphore::new(256));
1016            loop {
1017                let next_message = async {
1018                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1019                    let message = incoming_rx.next().await;
1020                    (permit, message)
1021                }.fuse();
1022                futures::pin_mut!(next_message);
1023                futures::select_biased! {
1024                    _ = teardown.changed().fuse() => return,
1025                    result = handle_io => {
1026                        if let Err(error) = result {
1027                            tracing::error!(?error, "error handling I/O");
1028                        }
1029                        break;
1030                    }
1031                    _ = foreground_message_handlers.next() => {}
1032                    next_message = next_message => {
1033                        let (permit, message) = next_message;
1034                        if let Some(message) = message {
1035                            let type_name = message.payload_type_name();
1036                            // note: we copy all the fields from the parent span so we can query them in the logs.
1037                            // (https://github.com/tokio-rs/tracing/issues/2670).
1038                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1039                                user_id=field::Empty,
1040                                login=field::Empty,
1041                                impersonator=field::Empty,
1042                                dev_server_id=field::Empty
1043                            );
1044                            principal.update_span(&span);
1045                            let span_enter = span.enter();
1046                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1047                                let is_background = message.is_background();
1048                                let handle_message = (handler)(message, session.clone());
1049                                drop(span_enter);
1050
1051                                let handle_message = async move {
1052                                    handle_message.await;
1053                                    drop(permit);
1054                                }.instrument(span);
1055                                if is_background {
1056                                    executor.spawn_detached(handle_message);
1057                                } else {
1058                                    foreground_message_handlers.push(handle_message);
1059                                }
1060                            } else {
1061                                tracing::error!("no message handler");
1062                            }
1063                        } else {
1064                            tracing::info!("connection closed");
1065                            break;
1066                        }
1067                    }
1068                }
1069            }
1070
1071            drop(foreground_message_handlers);
1072            tracing::info!("signing out");
1073            if let Err(error) = connection_lost(session, teardown, executor).await {
1074                tracing::error!(?error, "error signing out");
1075            }
1076
1077        }.instrument(span)
1078    }
1079
1080    async fn send_initial_client_update(
1081        &self,
1082        connection_id: ConnectionId,
1083        principal: &Principal,
1084        zed_version: ZedVersion,
1085        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1086        session: &Session,
1087    ) -> Result<()> {
1088        self.peer.send(
1089            connection_id,
1090            proto::Hello {
1091                peer_id: Some(connection_id.into()),
1092            },
1093        )?;
1094        tracing::info!("sent hello message");
1095        if let Some(send_connection_id) = send_connection_id.take() {
1096            let _ = send_connection_id.send(connection_id);
1097        }
1098
1099        match principal {
1100            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1101                if !user.connected_once {
1102                    self.peer.send(connection_id, proto::ShowContacts {})?;
1103                    self.app_state
1104                        .db
1105                        .set_user_connected_once(user.id, true)
1106                        .await?;
1107                }
1108
1109                let (contacts, dev_server_projects) = future::try_join(
1110                    self.app_state.db.get_contacts(user.id),
1111                    self.app_state.db.dev_server_projects_update(user.id),
1112                )
1113                .await?;
1114
1115                {
1116                    let mut pool = self.connection_pool.lock();
1117                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1118                    self.peer.send(
1119                        connection_id,
1120                        build_initial_contacts_update(contacts, &pool),
1121                    )?;
1122                }
1123
1124                if should_auto_subscribe_to_channels(zed_version) {
1125                    subscribe_user_to_channels(user.id, session).await?;
1126                }
1127
1128                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1129
1130                if let Some(incoming_call) =
1131                    self.app_state.db.incoming_call_for_user(user.id).await?
1132                {
1133                    self.peer.send(connection_id, incoming_call)?;
1134                }
1135
1136                update_user_contacts(user.id, &session).await?;
1137            }
1138            Principal::DevServer(dev_server) => {
1139                {
1140                    let mut pool = self.connection_pool.lock();
1141                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1142                    {
1143                        self.peer.send(
1144                            stale_connection_id,
1145                            proto::ShutdownDevServer {
1146                                reason: Some(
1147                                    "another dev server connected with the same token".to_string(),
1148                                ),
1149                            },
1150                        )?;
1151                        pool.remove_connection(stale_connection_id)?;
1152                    };
1153                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1154                }
1155
1156                let projects = self
1157                    .app_state
1158                    .db
1159                    .get_projects_for_dev_server(dev_server.id)
1160                    .await?;
1161                self.peer
1162                    .send(connection_id, proto::DevServerInstructions { projects })?;
1163
1164                let status = self
1165                    .app_state
1166                    .db
1167                    .dev_server_projects_update(dev_server.user_id)
1168                    .await?;
1169                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1170            }
1171        }
1172
1173        Ok(())
1174    }
1175
1176    pub async fn invite_code_redeemed(
1177        self: &Arc<Self>,
1178        inviter_id: UserId,
1179        invitee_id: UserId,
1180    ) -> Result<()> {
1181        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1182            if let Some(code) = &user.invite_code {
1183                let pool = self.connection_pool.lock();
1184                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1185                for connection_id in pool.user_connection_ids(inviter_id) {
1186                    self.peer.send(
1187                        connection_id,
1188                        proto::UpdateContacts {
1189                            contacts: vec![invitee_contact.clone()],
1190                            ..Default::default()
1191                        },
1192                    )?;
1193                    self.peer.send(
1194                        connection_id,
1195                        proto::UpdateInviteInfo {
1196                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1197                            count: user.invite_count as u32,
1198                        },
1199                    )?;
1200                }
1201            }
1202        }
1203        Ok(())
1204    }
1205
1206    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1207        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1208            if let Some(invite_code) = &user.invite_code {
1209                let pool = self.connection_pool.lock();
1210                for connection_id in pool.user_connection_ids(user_id) {
1211                    self.peer.send(
1212                        connection_id,
1213                        proto::UpdateInviteInfo {
1214                            url: format!(
1215                                "{}{}",
1216                                self.app_state.config.invite_link_prefix, invite_code
1217                            ),
1218                            count: user.invite_count as u32,
1219                        },
1220                    )?;
1221                }
1222            }
1223        }
1224        Ok(())
1225    }
1226
1227    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1228        ServerSnapshot {
1229            connection_pool: ConnectionPoolGuard {
1230                guard: self.connection_pool.lock(),
1231                _not_send: PhantomData,
1232            },
1233            peer: &self.peer,
1234        }
1235    }
1236}
1237
1238impl<'a> Deref for ConnectionPoolGuard<'a> {
1239    type Target = ConnectionPool;
1240
1241    fn deref(&self) -> &Self::Target {
1242        &self.guard
1243    }
1244}
1245
1246impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1247    fn deref_mut(&mut self) -> &mut Self::Target {
1248        &mut self.guard
1249    }
1250}
1251
1252impl<'a> Drop for ConnectionPoolGuard<'a> {
1253    fn drop(&mut self) {
1254        #[cfg(test)]
1255        self.check_invariants();
1256    }
1257}
1258
1259fn broadcast<F>(
1260    sender_id: Option<ConnectionId>,
1261    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1262    mut f: F,
1263) where
1264    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1265{
1266    for receiver_id in receiver_ids {
1267        if Some(receiver_id) != sender_id {
1268            if let Err(error) = f(receiver_id) {
1269                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1270            }
1271        }
1272    }
1273}
1274
1275pub struct ProtocolVersion(u32);
1276
1277impl Header for ProtocolVersion {
1278    fn name() -> &'static HeaderName {
1279        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1280        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1281    }
1282
1283    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1284    where
1285        Self: Sized,
1286        I: Iterator<Item = &'i axum::http::HeaderValue>,
1287    {
1288        let version = values
1289            .next()
1290            .ok_or_else(axum::headers::Error::invalid)?
1291            .to_str()
1292            .map_err(|_| axum::headers::Error::invalid())?
1293            .parse()
1294            .map_err(|_| axum::headers::Error::invalid())?;
1295        Ok(Self(version))
1296    }
1297
1298    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1299        values.extend([self.0.to_string().parse().unwrap()]);
1300    }
1301}
1302
1303pub struct AppVersionHeader(SemanticVersion);
1304impl Header for AppVersionHeader {
1305    fn name() -> &'static HeaderName {
1306        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1307        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1308    }
1309
1310    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1311    where
1312        Self: Sized,
1313        I: Iterator<Item = &'i axum::http::HeaderValue>,
1314    {
1315        let version = values
1316            .next()
1317            .ok_or_else(axum::headers::Error::invalid)?
1318            .to_str()
1319            .map_err(|_| axum::headers::Error::invalid())?
1320            .parse()
1321            .map_err(|_| axum::headers::Error::invalid())?;
1322        Ok(Self(version))
1323    }
1324
1325    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1326        values.extend([self.0.to_string().parse().unwrap()]);
1327    }
1328}
1329
1330pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1331    Router::new()
1332        .route("/rpc", get(handle_websocket_request))
1333        .layer(
1334            ServiceBuilder::new()
1335                .layer(Extension(server.app_state.clone()))
1336                .layer(middleware::from_fn(auth::validate_header)),
1337        )
1338        .route("/metrics", get(handle_metrics))
1339        .layer(Extension(server))
1340}
1341
1342pub async fn handle_websocket_request(
1343    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1344    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1345    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1346    Extension(server): Extension<Arc<Server>>,
1347    Extension(principal): Extension<Principal>,
1348    ws: WebSocketUpgrade,
1349) -> axum::response::Response {
1350    if protocol_version != rpc::PROTOCOL_VERSION {
1351        return (
1352            StatusCode::UPGRADE_REQUIRED,
1353            "client must be upgraded".to_string(),
1354        )
1355            .into_response();
1356    }
1357
1358    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1359        return (
1360            StatusCode::UPGRADE_REQUIRED,
1361            "no version header found".to_string(),
1362        )
1363            .into_response();
1364    };
1365
1366    if !version.can_collaborate() {
1367        return (
1368            StatusCode::UPGRADE_REQUIRED,
1369            "client must be upgraded".to_string(),
1370        )
1371            .into_response();
1372    }
1373
1374    let socket_address = socket_address.to_string();
1375    ws.on_upgrade(move |socket| {
1376        let socket = socket
1377            .map_ok(to_tungstenite_message)
1378            .err_into()
1379            .with(|message| async move { Ok(to_axum_message(message)) });
1380        let connection = Connection::new(Box::pin(socket));
1381        async move {
1382            server
1383                .handle_connection(
1384                    connection,
1385                    socket_address,
1386                    principal,
1387                    version,
1388                    None,
1389                    Executor::Production,
1390                )
1391                .await;
1392        }
1393    })
1394}
1395
1396pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1397    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1398    let connections_metric = CONNECTIONS_METRIC
1399        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1400
1401    let connections = server
1402        .connection_pool
1403        .lock()
1404        .connections()
1405        .filter(|connection| !connection.admin)
1406        .count();
1407    connections_metric.set(connections as _);
1408
1409    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1410    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1411        register_int_gauge!(
1412            "shared_projects",
1413            "number of open projects with one or more guests"
1414        )
1415        .unwrap()
1416    });
1417
1418    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1419    shared_projects_metric.set(shared_projects as _);
1420
1421    let encoder = prometheus::TextEncoder::new();
1422    let metric_families = prometheus::gather();
1423    let encoded_metrics = encoder
1424        .encode_to_string(&metric_families)
1425        .map_err(|err| anyhow!("{}", err))?;
1426    Ok(encoded_metrics)
1427}
1428
1429#[instrument(err, skip(executor))]
1430async fn connection_lost(
1431    session: Session,
1432    mut teardown: watch::Receiver<bool>,
1433    executor: Executor,
1434) -> Result<()> {
1435    session.peer.disconnect(session.connection_id);
1436    session
1437        .connection_pool()
1438        .await
1439        .remove_connection(session.connection_id)?;
1440
1441    session
1442        .db()
1443        .await
1444        .connection_lost(session.connection_id)
1445        .await
1446        .trace_err();
1447
1448    futures::select_biased! {
1449        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1450            match &session.principal {
1451                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1452                    let session = session.for_user().unwrap();
1453
1454                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1455                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1456                    leave_channel_buffers_for_session(&session)
1457                        .await
1458                        .trace_err();
1459
1460                    if !session
1461                        .connection_pool()
1462                        .await
1463                        .is_user_online(session.user_id())
1464                    {
1465                        let db = session.db().await;
1466                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1467                            room_updated(&room, &session.peer);
1468                        }
1469                    }
1470
1471                    update_user_contacts(session.user_id(), &session).await?;
1472                },
1473            Principal::DevServer(_) => {
1474                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1475            },
1476        }
1477        },
1478        _ = teardown.changed().fuse() => {}
1479    }
1480
1481    Ok(())
1482}
1483
1484/// Acknowledges a ping from a client, used to keep the connection alive.
1485async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1486    response.send(proto::Ack {})?;
1487    Ok(())
1488}
1489
1490/// Creates a new room for calling (outside of channels)
1491async fn create_room(
1492    _request: proto::CreateRoom,
1493    response: Response<proto::CreateRoom>,
1494    session: UserSession,
1495) -> Result<()> {
1496    let live_kit_room = nanoid::nanoid!(30);
1497
1498    let live_kit_connection_info = util::maybe!(async {
1499        let live_kit = session.live_kit_client.as_ref();
1500        let live_kit = live_kit?;
1501        let user_id = session.user_id().to_string();
1502
1503        let token = live_kit
1504            .room_token(&live_kit_room, &user_id.to_string())
1505            .trace_err()?;
1506
1507        Some(proto::LiveKitConnectionInfo {
1508            server_url: live_kit.url().into(),
1509            token,
1510            can_publish: true,
1511        })
1512    })
1513    .await;
1514
1515    let room = session
1516        .db()
1517        .await
1518        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1519        .await?;
1520
1521    response.send(proto::CreateRoomResponse {
1522        room: Some(room.clone()),
1523        live_kit_connection_info,
1524    })?;
1525
1526    update_user_contacts(session.user_id(), &session).await?;
1527    Ok(())
1528}
1529
1530/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1531async fn join_room(
1532    request: proto::JoinRoom,
1533    response: Response<proto::JoinRoom>,
1534    session: UserSession,
1535) -> Result<()> {
1536    let room_id = RoomId::from_proto(request.id);
1537
1538    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1539
1540    if let Some(channel_id) = channel_id {
1541        return join_channel_internal(channel_id, Box::new(response), session).await;
1542    }
1543
1544    let joined_room = {
1545        let room = session
1546            .db()
1547            .await
1548            .join_room(room_id, session.user_id(), session.connection_id)
1549            .await?;
1550        room_updated(&room.room, &session.peer);
1551        room.into_inner()
1552    };
1553
1554    for connection_id in session
1555        .connection_pool()
1556        .await
1557        .user_connection_ids(session.user_id())
1558    {
1559        session
1560            .peer
1561            .send(
1562                connection_id,
1563                proto::CallCanceled {
1564                    room_id: room_id.to_proto(),
1565                },
1566            )
1567            .trace_err();
1568    }
1569
1570    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1571        if let Some(token) = live_kit
1572            .room_token(
1573                &joined_room.room.live_kit_room,
1574                &session.user_id().to_string(),
1575            )
1576            .trace_err()
1577        {
1578            Some(proto::LiveKitConnectionInfo {
1579                server_url: live_kit.url().into(),
1580                token,
1581                can_publish: true,
1582            })
1583        } else {
1584            None
1585        }
1586    } else {
1587        None
1588    };
1589
1590    response.send(proto::JoinRoomResponse {
1591        room: Some(joined_room.room),
1592        channel_id: None,
1593        live_kit_connection_info,
1594    })?;
1595
1596    update_user_contacts(session.user_id(), &session).await?;
1597    Ok(())
1598}
1599
1600/// Rejoin room is used to reconnect to a room after connection errors.
1601async fn rejoin_room(
1602    request: proto::RejoinRoom,
1603    response: Response<proto::RejoinRoom>,
1604    session: UserSession,
1605) -> Result<()> {
1606    let room;
1607    let channel;
1608    {
1609        let mut rejoined_room = session
1610            .db()
1611            .await
1612            .rejoin_room(request, session.user_id(), session.connection_id)
1613            .await?;
1614
1615        response.send(proto::RejoinRoomResponse {
1616            room: Some(rejoined_room.room.clone()),
1617            reshared_projects: rejoined_room
1618                .reshared_projects
1619                .iter()
1620                .map(|project| proto::ResharedProject {
1621                    id: project.id.to_proto(),
1622                    collaborators: project
1623                        .collaborators
1624                        .iter()
1625                        .map(|collaborator| collaborator.to_proto())
1626                        .collect(),
1627                })
1628                .collect(),
1629            rejoined_projects: rejoined_room
1630                .rejoined_projects
1631                .iter()
1632                .map(|rejoined_project| rejoined_project.to_proto())
1633                .collect(),
1634        })?;
1635        room_updated(&rejoined_room.room, &session.peer);
1636
1637        for project in &rejoined_room.reshared_projects {
1638            for collaborator in &project.collaborators {
1639                session
1640                    .peer
1641                    .send(
1642                        collaborator.connection_id,
1643                        proto::UpdateProjectCollaborator {
1644                            project_id: project.id.to_proto(),
1645                            old_peer_id: Some(project.old_connection_id.into()),
1646                            new_peer_id: Some(session.connection_id.into()),
1647                        },
1648                    )
1649                    .trace_err();
1650            }
1651
1652            broadcast(
1653                Some(session.connection_id),
1654                project
1655                    .collaborators
1656                    .iter()
1657                    .map(|collaborator| collaborator.connection_id),
1658                |connection_id| {
1659                    session.peer.forward_send(
1660                        session.connection_id,
1661                        connection_id,
1662                        proto::UpdateProject {
1663                            project_id: project.id.to_proto(),
1664                            worktrees: project.worktrees.clone(),
1665                        },
1666                    )
1667                },
1668            );
1669        }
1670
1671        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1672
1673        let rejoined_room = rejoined_room.into_inner();
1674
1675        room = rejoined_room.room;
1676        channel = rejoined_room.channel;
1677    }
1678
1679    if let Some(channel) = channel {
1680        channel_updated(
1681            &channel,
1682            &room,
1683            &session.peer,
1684            &*session.connection_pool().await,
1685        );
1686    }
1687
1688    update_user_contacts(session.user_id(), &session).await?;
1689    Ok(())
1690}
1691
1692fn notify_rejoined_projects(
1693    rejoined_projects: &mut Vec<RejoinedProject>,
1694    session: &UserSession,
1695) -> Result<()> {
1696    for project in rejoined_projects.iter() {
1697        for collaborator in &project.collaborators {
1698            session
1699                .peer
1700                .send(
1701                    collaborator.connection_id,
1702                    proto::UpdateProjectCollaborator {
1703                        project_id: project.id.to_proto(),
1704                        old_peer_id: Some(project.old_connection_id.into()),
1705                        new_peer_id: Some(session.connection_id.into()),
1706                    },
1707                )
1708                .trace_err();
1709        }
1710    }
1711
1712    for project in rejoined_projects {
1713        for worktree in mem::take(&mut project.worktrees) {
1714            #[cfg(any(test, feature = "test-support"))]
1715            const MAX_CHUNK_SIZE: usize = 2;
1716            #[cfg(not(any(test, feature = "test-support")))]
1717            const MAX_CHUNK_SIZE: usize = 256;
1718
1719            // Stream this worktree's entries.
1720            let message = proto::UpdateWorktree {
1721                project_id: project.id.to_proto(),
1722                worktree_id: worktree.id,
1723                abs_path: worktree.abs_path.clone(),
1724                root_name: worktree.root_name,
1725                updated_entries: worktree.updated_entries,
1726                removed_entries: worktree.removed_entries,
1727                scan_id: worktree.scan_id,
1728                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1729                updated_repositories: worktree.updated_repositories,
1730                removed_repositories: worktree.removed_repositories,
1731            };
1732            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1733                session.peer.send(session.connection_id, update.clone())?;
1734            }
1735
1736            // Stream this worktree's diagnostics.
1737            for summary in worktree.diagnostic_summaries {
1738                session.peer.send(
1739                    session.connection_id,
1740                    proto::UpdateDiagnosticSummary {
1741                        project_id: project.id.to_proto(),
1742                        worktree_id: worktree.id,
1743                        summary: Some(summary),
1744                    },
1745                )?;
1746            }
1747
1748            for settings_file in worktree.settings_files {
1749                session.peer.send(
1750                    session.connection_id,
1751                    proto::UpdateWorktreeSettings {
1752                        project_id: project.id.to_proto(),
1753                        worktree_id: worktree.id,
1754                        path: settings_file.path,
1755                        content: Some(settings_file.content),
1756                    },
1757                )?;
1758            }
1759        }
1760
1761        for language_server in &project.language_servers {
1762            session.peer.send(
1763                session.connection_id,
1764                proto::UpdateLanguageServer {
1765                    project_id: project.id.to_proto(),
1766                    language_server_id: language_server.id,
1767                    variant: Some(
1768                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1769                            proto::LspDiskBasedDiagnosticsUpdated {},
1770                        ),
1771                    ),
1772                },
1773            )?;
1774        }
1775    }
1776    Ok(())
1777}
1778
1779/// leave room disconnects from the room.
1780async fn leave_room(
1781    _: proto::LeaveRoom,
1782    response: Response<proto::LeaveRoom>,
1783    session: UserSession,
1784) -> Result<()> {
1785    leave_room_for_session(&session, session.connection_id).await?;
1786    response.send(proto::Ack {})?;
1787    Ok(())
1788}
1789
1790/// Updates the permissions of someone else in the room.
1791async fn set_room_participant_role(
1792    request: proto::SetRoomParticipantRole,
1793    response: Response<proto::SetRoomParticipantRole>,
1794    session: UserSession,
1795) -> Result<()> {
1796    let user_id = UserId::from_proto(request.user_id);
1797    let role = ChannelRole::from(request.role());
1798
1799    let (live_kit_room, can_publish) = {
1800        let room = session
1801            .db()
1802            .await
1803            .set_room_participant_role(
1804                session.user_id(),
1805                RoomId::from_proto(request.room_id),
1806                user_id,
1807                role,
1808            )
1809            .await?;
1810
1811        let live_kit_room = room.live_kit_room.clone();
1812        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1813        room_updated(&room, &session.peer);
1814        (live_kit_room, can_publish)
1815    };
1816
1817    if let Some(live_kit) = session.live_kit_client.as_ref() {
1818        live_kit
1819            .update_participant(
1820                live_kit_room.clone(),
1821                request.user_id.to_string(),
1822                live_kit_server::proto::ParticipantPermission {
1823                    can_subscribe: true,
1824                    can_publish,
1825                    can_publish_data: can_publish,
1826                    hidden: false,
1827                    recorder: false,
1828                },
1829            )
1830            .await
1831            .trace_err();
1832    }
1833
1834    response.send(proto::Ack {})?;
1835    Ok(())
1836}
1837
1838/// Call someone else into the current room
1839async fn call(
1840    request: proto::Call,
1841    response: Response<proto::Call>,
1842    session: UserSession,
1843) -> Result<()> {
1844    let room_id = RoomId::from_proto(request.room_id);
1845    let calling_user_id = session.user_id();
1846    let calling_connection_id = session.connection_id;
1847    let called_user_id = UserId::from_proto(request.called_user_id);
1848    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1849    if !session
1850        .db()
1851        .await
1852        .has_contact(calling_user_id, called_user_id)
1853        .await?
1854    {
1855        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1856    }
1857
1858    let incoming_call = {
1859        let (room, incoming_call) = &mut *session
1860            .db()
1861            .await
1862            .call(
1863                room_id,
1864                calling_user_id,
1865                calling_connection_id,
1866                called_user_id,
1867                initial_project_id,
1868            )
1869            .await?;
1870        room_updated(&room, &session.peer);
1871        mem::take(incoming_call)
1872    };
1873    update_user_contacts(called_user_id, &session).await?;
1874
1875    let mut calls = session
1876        .connection_pool()
1877        .await
1878        .user_connection_ids(called_user_id)
1879        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1880        .collect::<FuturesUnordered<_>>();
1881
1882    while let Some(call_response) = calls.next().await {
1883        match call_response.as_ref() {
1884            Ok(_) => {
1885                response.send(proto::Ack {})?;
1886                return Ok(());
1887            }
1888            Err(_) => {
1889                call_response.trace_err();
1890            }
1891        }
1892    }
1893
1894    {
1895        let room = session
1896            .db()
1897            .await
1898            .call_failed(room_id, called_user_id)
1899            .await?;
1900        room_updated(&room, &session.peer);
1901    }
1902    update_user_contacts(called_user_id, &session).await?;
1903
1904    Err(anyhow!("failed to ring user"))?
1905}
1906
1907/// Cancel an outgoing call.
1908async fn cancel_call(
1909    request: proto::CancelCall,
1910    response: Response<proto::CancelCall>,
1911    session: UserSession,
1912) -> Result<()> {
1913    let called_user_id = UserId::from_proto(request.called_user_id);
1914    let room_id = RoomId::from_proto(request.room_id);
1915    {
1916        let room = session
1917            .db()
1918            .await
1919            .cancel_call(room_id, session.connection_id, called_user_id)
1920            .await?;
1921        room_updated(&room, &session.peer);
1922    }
1923
1924    for connection_id in session
1925        .connection_pool()
1926        .await
1927        .user_connection_ids(called_user_id)
1928    {
1929        session
1930            .peer
1931            .send(
1932                connection_id,
1933                proto::CallCanceled {
1934                    room_id: room_id.to_proto(),
1935                },
1936            )
1937            .trace_err();
1938    }
1939    response.send(proto::Ack {})?;
1940
1941    update_user_contacts(called_user_id, &session).await?;
1942    Ok(())
1943}
1944
1945/// Decline an incoming call.
1946async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1947    let room_id = RoomId::from_proto(message.room_id);
1948    {
1949        let room = session
1950            .db()
1951            .await
1952            .decline_call(Some(room_id), session.user_id())
1953            .await?
1954            .ok_or_else(|| anyhow!("failed to decline call"))?;
1955        room_updated(&room, &session.peer);
1956    }
1957
1958    for connection_id in session
1959        .connection_pool()
1960        .await
1961        .user_connection_ids(session.user_id())
1962    {
1963        session
1964            .peer
1965            .send(
1966                connection_id,
1967                proto::CallCanceled {
1968                    room_id: room_id.to_proto(),
1969                },
1970            )
1971            .trace_err();
1972    }
1973    update_user_contacts(session.user_id(), &session).await?;
1974    Ok(())
1975}
1976
1977/// Updates other participants in the room with your current location.
1978async fn update_participant_location(
1979    request: proto::UpdateParticipantLocation,
1980    response: Response<proto::UpdateParticipantLocation>,
1981    session: UserSession,
1982) -> Result<()> {
1983    let room_id = RoomId::from_proto(request.room_id);
1984    let location = request
1985        .location
1986        .ok_or_else(|| anyhow!("invalid location"))?;
1987
1988    let db = session.db().await;
1989    let room = db
1990        .update_room_participant_location(room_id, session.connection_id, location)
1991        .await?;
1992
1993    room_updated(&room, &session.peer);
1994    response.send(proto::Ack {})?;
1995    Ok(())
1996}
1997
1998/// Share a project into the room.
1999async fn share_project(
2000    request: proto::ShareProject,
2001    response: Response<proto::ShareProject>,
2002    session: UserSession,
2003) -> Result<()> {
2004    let (project_id, room) = &*session
2005        .db()
2006        .await
2007        .share_project(
2008            RoomId::from_proto(request.room_id),
2009            session.connection_id,
2010            &request.worktrees,
2011            request
2012                .dev_server_project_id
2013                .map(|id| DevServerProjectId::from_proto(id)),
2014        )
2015        .await?;
2016    response.send(proto::ShareProjectResponse {
2017        project_id: project_id.to_proto(),
2018    })?;
2019    room_updated(&room, &session.peer);
2020
2021    Ok(())
2022}
2023
2024/// Unshare a project from the room.
2025async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2026    let project_id = ProjectId::from_proto(message.project_id);
2027    unshare_project_internal(
2028        project_id,
2029        session.connection_id,
2030        session.user_id(),
2031        &session,
2032    )
2033    .await
2034}
2035
2036async fn unshare_project_internal(
2037    project_id: ProjectId,
2038    connection_id: ConnectionId,
2039    user_id: Option<UserId>,
2040    session: &Session,
2041) -> Result<()> {
2042    let delete = {
2043        let room_guard = session
2044            .db()
2045            .await
2046            .unshare_project(project_id, connection_id, user_id)
2047            .await?;
2048
2049        let (delete, room, guest_connection_ids) = &*room_guard;
2050
2051        let message = proto::UnshareProject {
2052            project_id: project_id.to_proto(),
2053        };
2054
2055        broadcast(
2056            Some(connection_id),
2057            guest_connection_ids.iter().copied(),
2058            |conn_id| session.peer.send(conn_id, message.clone()),
2059        );
2060        if let Some(room) = room {
2061            room_updated(room, &session.peer);
2062        }
2063
2064        *delete
2065    };
2066
2067    if delete {
2068        let db = session.db().await;
2069        db.delete_project(project_id).await?;
2070    }
2071
2072    Ok(())
2073}
2074
2075/// DevServer makes a project available online
2076async fn share_dev_server_project(
2077    request: proto::ShareDevServerProject,
2078    response: Response<proto::ShareDevServerProject>,
2079    session: DevServerSession,
2080) -> Result<()> {
2081    let (dev_server_project, user_id, status) = session
2082        .db()
2083        .await
2084        .share_dev_server_project(
2085            DevServerProjectId::from_proto(request.dev_server_project_id),
2086            session.dev_server_id(),
2087            session.connection_id,
2088            &request.worktrees,
2089        )
2090        .await?;
2091    let Some(project_id) = dev_server_project.project_id else {
2092        return Err(anyhow!("failed to share remote project"))?;
2093    };
2094
2095    send_dev_server_projects_update(user_id, status, &session).await;
2096
2097    response.send(proto::ShareProjectResponse { project_id })?;
2098
2099    Ok(())
2100}
2101
2102/// Join someone elses shared project.
2103async fn join_project(
2104    request: proto::JoinProject,
2105    response: Response<proto::JoinProject>,
2106    session: UserSession,
2107) -> Result<()> {
2108    let project_id = ProjectId::from_proto(request.project_id);
2109
2110    tracing::info!(%project_id, "join project");
2111
2112    let db = session.db().await;
2113    let (project, replica_id) = &mut *db
2114        .join_project(project_id, session.connection_id, session.user_id())
2115        .await?;
2116    drop(db);
2117    tracing::info!(%project_id, "join remote project");
2118    join_project_internal(response, session, project, replica_id)
2119}
2120
2121trait JoinProjectInternalResponse {
2122    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2123}
2124impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2125    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2126        Response::<proto::JoinProject>::send(self, result)
2127    }
2128}
2129impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2130    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2131        Response::<proto::JoinHostedProject>::send(self, result)
2132    }
2133}
2134
2135fn join_project_internal(
2136    response: impl JoinProjectInternalResponse,
2137    session: UserSession,
2138    project: &mut Project,
2139    replica_id: &ReplicaId,
2140) -> Result<()> {
2141    let collaborators = project
2142        .collaborators
2143        .iter()
2144        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2145        .map(|collaborator| collaborator.to_proto())
2146        .collect::<Vec<_>>();
2147    let project_id = project.id;
2148    let guest_user_id = session.user_id();
2149
2150    let worktrees = project
2151        .worktrees
2152        .iter()
2153        .map(|(id, worktree)| proto::WorktreeMetadata {
2154            id: *id,
2155            root_name: worktree.root_name.clone(),
2156            visible: worktree.visible,
2157            abs_path: worktree.abs_path.clone(),
2158        })
2159        .collect::<Vec<_>>();
2160
2161    let add_project_collaborator = proto::AddProjectCollaborator {
2162        project_id: project_id.to_proto(),
2163        collaborator: Some(proto::Collaborator {
2164            peer_id: Some(session.connection_id.into()),
2165            replica_id: replica_id.0 as u32,
2166            user_id: guest_user_id.to_proto(),
2167        }),
2168    };
2169
2170    for collaborator in &collaborators {
2171        session
2172            .peer
2173            .send(
2174                collaborator.peer_id.unwrap().into(),
2175                add_project_collaborator.clone(),
2176            )
2177            .trace_err();
2178    }
2179
2180    // First, we send the metadata associated with each worktree.
2181    response.send(proto::JoinProjectResponse {
2182        project_id: project.id.0 as u64,
2183        worktrees: worktrees.clone(),
2184        replica_id: replica_id.0 as u32,
2185        collaborators: collaborators.clone(),
2186        language_servers: project.language_servers.clone(),
2187        role: project.role.into(),
2188        dev_server_project_id: project
2189            .dev_server_project_id
2190            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2191    })?;
2192
2193    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2194        #[cfg(any(test, feature = "test-support"))]
2195        const MAX_CHUNK_SIZE: usize = 2;
2196        #[cfg(not(any(test, feature = "test-support")))]
2197        const MAX_CHUNK_SIZE: usize = 256;
2198
2199        // Stream this worktree's entries.
2200        let message = proto::UpdateWorktree {
2201            project_id: project_id.to_proto(),
2202            worktree_id,
2203            abs_path: worktree.abs_path.clone(),
2204            root_name: worktree.root_name,
2205            updated_entries: worktree.entries,
2206            removed_entries: Default::default(),
2207            scan_id: worktree.scan_id,
2208            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2209            updated_repositories: worktree.repository_entries.into_values().collect(),
2210            removed_repositories: Default::default(),
2211        };
2212        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2213            session.peer.send(session.connection_id, update.clone())?;
2214        }
2215
2216        // Stream this worktree's diagnostics.
2217        for summary in worktree.diagnostic_summaries {
2218            session.peer.send(
2219                session.connection_id,
2220                proto::UpdateDiagnosticSummary {
2221                    project_id: project_id.to_proto(),
2222                    worktree_id: worktree.id,
2223                    summary: Some(summary),
2224                },
2225            )?;
2226        }
2227
2228        for settings_file in worktree.settings_files {
2229            session.peer.send(
2230                session.connection_id,
2231                proto::UpdateWorktreeSettings {
2232                    project_id: project_id.to_proto(),
2233                    worktree_id: worktree.id,
2234                    path: settings_file.path,
2235                    content: Some(settings_file.content),
2236                },
2237            )?;
2238        }
2239    }
2240
2241    for language_server in &project.language_servers {
2242        session.peer.send(
2243            session.connection_id,
2244            proto::UpdateLanguageServer {
2245                project_id: project_id.to_proto(),
2246                language_server_id: language_server.id,
2247                variant: Some(
2248                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2249                        proto::LspDiskBasedDiagnosticsUpdated {},
2250                    ),
2251                ),
2252            },
2253        )?;
2254    }
2255
2256    Ok(())
2257}
2258
2259/// Leave someone elses shared project.
2260async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2261    let sender_id = session.connection_id;
2262    let project_id = ProjectId::from_proto(request.project_id);
2263    let db = session.db().await;
2264    if db.is_hosted_project(project_id).await? {
2265        let project = db.leave_hosted_project(project_id, sender_id).await?;
2266        project_left(&project, &session);
2267        return Ok(());
2268    }
2269
2270    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2271    tracing::info!(
2272        %project_id,
2273        "leave project"
2274    );
2275
2276    project_left(&project, &session);
2277    if let Some(room) = room {
2278        room_updated(&room, &session.peer);
2279    }
2280
2281    Ok(())
2282}
2283
2284async fn join_hosted_project(
2285    request: proto::JoinHostedProject,
2286    response: Response<proto::JoinHostedProject>,
2287    session: UserSession,
2288) -> Result<()> {
2289    let (mut project, replica_id) = session
2290        .db()
2291        .await
2292        .join_hosted_project(
2293            ProjectId(request.project_id as i32),
2294            session.user_id(),
2295            session.connection_id,
2296        )
2297        .await?;
2298
2299    join_project_internal(response, session, &mut project, &replica_id)
2300}
2301
2302async fn create_dev_server_project(
2303    request: proto::CreateDevServerProject,
2304    response: Response<proto::CreateDevServerProject>,
2305    session: UserSession,
2306) -> Result<()> {
2307    let dev_server_id = DevServerId(request.dev_server_id as i32);
2308    let dev_server_connection_id = session
2309        .connection_pool()
2310        .await
2311        .dev_server_connection_id(dev_server_id);
2312    let Some(dev_server_connection_id) = dev_server_connection_id else {
2313        Err(ErrorCode::DevServerOffline
2314            .message("Cannot create a remote project when the dev server is offline".to_string())
2315            .anyhow())?
2316    };
2317
2318    let path = request.path.clone();
2319    //Check that the path exists on the dev server
2320    session
2321        .peer
2322        .forward_request(
2323            session.connection_id,
2324            dev_server_connection_id,
2325            proto::ValidateDevServerProjectRequest { path: path.clone() },
2326        )
2327        .await?;
2328
2329    let (dev_server_project, update) = session
2330        .db()
2331        .await
2332        .create_dev_server_project(
2333            DevServerId(request.dev_server_id as i32),
2334            &request.path,
2335            session.user_id(),
2336        )
2337        .await?;
2338
2339    let projects = session
2340        .db()
2341        .await
2342        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2343        .await?;
2344
2345    session.peer.send(
2346        dev_server_connection_id,
2347        proto::DevServerInstructions { projects },
2348    )?;
2349
2350    send_dev_server_projects_update(session.user_id(), update, &session).await;
2351
2352    response.send(proto::CreateDevServerProjectResponse {
2353        dev_server_project: Some(dev_server_project.to_proto(None)),
2354    })?;
2355    Ok(())
2356}
2357
2358async fn create_dev_server(
2359    request: proto::CreateDevServer,
2360    response: Response<proto::CreateDevServer>,
2361    session: UserSession,
2362) -> Result<()> {
2363    let access_token = auth::random_token();
2364    let hashed_access_token = auth::hash_access_token(&access_token);
2365
2366    if request.name.is_empty() {
2367        return Err(proto::ErrorCode::Forbidden
2368            .message("Dev server name cannot be empty".to_string())
2369            .anyhow())?;
2370    }
2371
2372    let (dev_server, status) = session
2373        .db()
2374        .await
2375        .create_dev_server(
2376            &request.name,
2377            request.ssh_connection_string.as_deref(),
2378            &hashed_access_token,
2379            session.user_id(),
2380        )
2381        .await?;
2382
2383    send_dev_server_projects_update(session.user_id(), status, &session).await;
2384
2385    response.send(proto::CreateDevServerResponse {
2386        dev_server_id: dev_server.id.0 as u64,
2387        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2388        name: request.name,
2389    })?;
2390    Ok(())
2391}
2392
2393async fn regenerate_dev_server_token(
2394    request: proto::RegenerateDevServerToken,
2395    response: Response<proto::RegenerateDevServerToken>,
2396    session: UserSession,
2397) -> Result<()> {
2398    let dev_server_id = DevServerId(request.dev_server_id as i32);
2399    let access_token = auth::random_token();
2400    let hashed_access_token = auth::hash_access_token(&access_token);
2401
2402    let connection_id = session
2403        .connection_pool()
2404        .await
2405        .dev_server_connection_id(dev_server_id);
2406    if let Some(connection_id) = connection_id {
2407        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2408        session.peer.send(
2409            connection_id,
2410            proto::ShutdownDevServer {
2411                reason: Some("dev server token was regenerated".to_string()),
2412            },
2413        )?;
2414        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2415    }
2416
2417    let status = session
2418        .db()
2419        .await
2420        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2421        .await?;
2422
2423    send_dev_server_projects_update(session.user_id(), status, &session).await;
2424
2425    response.send(proto::RegenerateDevServerTokenResponse {
2426        dev_server_id: dev_server_id.to_proto(),
2427        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2428    })?;
2429    Ok(())
2430}
2431
2432async fn rename_dev_server(
2433    request: proto::RenameDevServer,
2434    response: Response<proto::RenameDevServer>,
2435    session: UserSession,
2436) -> Result<()> {
2437    if request.name.trim().is_empty() {
2438        return Err(proto::ErrorCode::Forbidden
2439            .message("Dev server name cannot be empty".to_string())
2440            .anyhow())?;
2441    }
2442
2443    let dev_server_id = DevServerId(request.dev_server_id as i32);
2444    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2445    if dev_server.user_id != session.user_id() {
2446        return Err(anyhow!(ErrorCode::Forbidden))?;
2447    }
2448
2449    let status = session
2450        .db()
2451        .await
2452        .rename_dev_server(
2453            dev_server_id,
2454            &request.name,
2455            request.ssh_connection_string.as_deref(),
2456            session.user_id(),
2457        )
2458        .await?;
2459
2460    send_dev_server_projects_update(session.user_id(), status, &session).await;
2461
2462    response.send(proto::Ack {})?;
2463    Ok(())
2464}
2465
2466async fn delete_dev_server(
2467    request: proto::DeleteDevServer,
2468    response: Response<proto::DeleteDevServer>,
2469    session: UserSession,
2470) -> Result<()> {
2471    let dev_server_id = DevServerId(request.dev_server_id as i32);
2472    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2473    if dev_server.user_id != session.user_id() {
2474        return Err(anyhow!(ErrorCode::Forbidden))?;
2475    }
2476
2477    let connection_id = session
2478        .connection_pool()
2479        .await
2480        .dev_server_connection_id(dev_server_id);
2481    if let Some(connection_id) = connection_id {
2482        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2483        session.peer.send(
2484            connection_id,
2485            proto::ShutdownDevServer {
2486                reason: Some("dev server was deleted".to_string()),
2487            },
2488        )?;
2489        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2490    }
2491
2492    let status = session
2493        .db()
2494        .await
2495        .delete_dev_server(dev_server_id, session.user_id())
2496        .await?;
2497
2498    send_dev_server_projects_update(session.user_id(), status, &session).await;
2499
2500    response.send(proto::Ack {})?;
2501    Ok(())
2502}
2503
2504async fn delete_dev_server_project(
2505    request: proto::DeleteDevServerProject,
2506    response: Response<proto::DeleteDevServerProject>,
2507    session: UserSession,
2508) -> Result<()> {
2509    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2510    let dev_server_project = session
2511        .db()
2512        .await
2513        .get_dev_server_project(dev_server_project_id)
2514        .await?;
2515
2516    let dev_server = session
2517        .db()
2518        .await
2519        .get_dev_server(dev_server_project.dev_server_id)
2520        .await?;
2521    if dev_server.user_id != session.user_id() {
2522        return Err(anyhow!(ErrorCode::Forbidden))?;
2523    }
2524
2525    let dev_server_connection_id = session
2526        .connection_pool()
2527        .await
2528        .dev_server_connection_id(dev_server.id);
2529
2530    if let Some(dev_server_connection_id) = dev_server_connection_id {
2531        let project = session
2532            .db()
2533            .await
2534            .find_dev_server_project(dev_server_project_id)
2535            .await;
2536        if let Ok(project) = project {
2537            unshare_project_internal(
2538                project.id,
2539                dev_server_connection_id,
2540                Some(session.user_id()),
2541                &session,
2542            )
2543            .await?;
2544        }
2545    }
2546
2547    let (projects, status) = session
2548        .db()
2549        .await
2550        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2551        .await?;
2552
2553    if let Some(dev_server_connection_id) = dev_server_connection_id {
2554        session.peer.send(
2555            dev_server_connection_id,
2556            proto::DevServerInstructions { projects },
2557        )?;
2558    }
2559
2560    send_dev_server_projects_update(session.user_id(), status, &session).await;
2561
2562    response.send(proto::Ack {})?;
2563    Ok(())
2564}
2565
2566async fn rejoin_dev_server_projects(
2567    request: proto::RejoinRemoteProjects,
2568    response: Response<proto::RejoinRemoteProjects>,
2569    session: UserSession,
2570) -> Result<()> {
2571    let mut rejoined_projects = {
2572        let db = session.db().await;
2573        db.rejoin_dev_server_projects(
2574            &request.rejoined_projects,
2575            session.user_id(),
2576            session.0.connection_id,
2577        )
2578        .await?
2579    };
2580    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2581
2582    response.send(proto::RejoinRemoteProjectsResponse {
2583        rejoined_projects: rejoined_projects
2584            .into_iter()
2585            .map(|project| project.to_proto())
2586            .collect(),
2587    })
2588}
2589
2590async fn reconnect_dev_server(
2591    request: proto::ReconnectDevServer,
2592    response: Response<proto::ReconnectDevServer>,
2593    session: DevServerSession,
2594) -> Result<()> {
2595    let reshared_projects = {
2596        let db = session.db().await;
2597        db.reshare_dev_server_projects(
2598            &request.reshared_projects,
2599            session.dev_server_id(),
2600            session.0.connection_id,
2601        )
2602        .await?
2603    };
2604
2605    for project in &reshared_projects {
2606        for collaborator in &project.collaborators {
2607            session
2608                .peer
2609                .send(
2610                    collaborator.connection_id,
2611                    proto::UpdateProjectCollaborator {
2612                        project_id: project.id.to_proto(),
2613                        old_peer_id: Some(project.old_connection_id.into()),
2614                        new_peer_id: Some(session.connection_id.into()),
2615                    },
2616                )
2617                .trace_err();
2618        }
2619
2620        broadcast(
2621            Some(session.connection_id),
2622            project
2623                .collaborators
2624                .iter()
2625                .map(|collaborator| collaborator.connection_id),
2626            |connection_id| {
2627                session.peer.forward_send(
2628                    session.connection_id,
2629                    connection_id,
2630                    proto::UpdateProject {
2631                        project_id: project.id.to_proto(),
2632                        worktrees: project.worktrees.clone(),
2633                    },
2634                )
2635            },
2636        );
2637    }
2638
2639    response.send(proto::ReconnectDevServerResponse {
2640        reshared_projects: reshared_projects
2641            .iter()
2642            .map(|project| proto::ResharedProject {
2643                id: project.id.to_proto(),
2644                collaborators: project
2645                    .collaborators
2646                    .iter()
2647                    .map(|collaborator| collaborator.to_proto())
2648                    .collect(),
2649            })
2650            .collect(),
2651    })?;
2652
2653    Ok(())
2654}
2655
2656async fn shutdown_dev_server(
2657    _: proto::ShutdownDevServer,
2658    response: Response<proto::ShutdownDevServer>,
2659    session: DevServerSession,
2660) -> Result<()> {
2661    response.send(proto::Ack {})?;
2662    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2663    remove_dev_server_connection(session.dev_server_id(), &session).await
2664}
2665
2666async fn shutdown_dev_server_internal(
2667    dev_server_id: DevServerId,
2668    connection_id: ConnectionId,
2669    session: &Session,
2670) -> Result<()> {
2671    let (dev_server_projects, dev_server) = {
2672        let db = session.db().await;
2673        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2674        let dev_server = db.get_dev_server(dev_server_id).await?;
2675        (dev_server_projects, dev_server)
2676    };
2677
2678    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2679        unshare_project_internal(
2680            ProjectId::from_proto(project_id),
2681            connection_id,
2682            None,
2683            session,
2684        )
2685        .await?;
2686    }
2687
2688    session
2689        .connection_pool()
2690        .await
2691        .set_dev_server_offline(dev_server_id);
2692
2693    let status = session
2694        .db()
2695        .await
2696        .dev_server_projects_update(dev_server.user_id)
2697        .await?;
2698    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2699
2700    Ok(())
2701}
2702
2703async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2704    let dev_server_connection = session
2705        .connection_pool()
2706        .await
2707        .dev_server_connection_id(dev_server_id);
2708
2709    if let Some(dev_server_connection) = dev_server_connection {
2710        session
2711            .connection_pool()
2712            .await
2713            .remove_connection(dev_server_connection)?;
2714    }
2715    Ok(())
2716}
2717
2718/// Updates other participants with changes to the project
2719async fn update_project(
2720    request: proto::UpdateProject,
2721    response: Response<proto::UpdateProject>,
2722    session: Session,
2723) -> Result<()> {
2724    let project_id = ProjectId::from_proto(request.project_id);
2725    let (room, guest_connection_ids) = &*session
2726        .db()
2727        .await
2728        .update_project(project_id, session.connection_id, &request.worktrees)
2729        .await?;
2730    broadcast(
2731        Some(session.connection_id),
2732        guest_connection_ids.iter().copied(),
2733        |connection_id| {
2734            session
2735                .peer
2736                .forward_send(session.connection_id, connection_id, request.clone())
2737        },
2738    );
2739    if let Some(room) = room {
2740        room_updated(&room, &session.peer);
2741    }
2742    response.send(proto::Ack {})?;
2743
2744    Ok(())
2745}
2746
2747/// Updates other participants with changes to the worktree
2748async fn update_worktree(
2749    request: proto::UpdateWorktree,
2750    response: Response<proto::UpdateWorktree>,
2751    session: Session,
2752) -> Result<()> {
2753    let guest_connection_ids = session
2754        .db()
2755        .await
2756        .update_worktree(&request, session.connection_id)
2757        .await?;
2758
2759    broadcast(
2760        Some(session.connection_id),
2761        guest_connection_ids.iter().copied(),
2762        |connection_id| {
2763            session
2764                .peer
2765                .forward_send(session.connection_id, connection_id, request.clone())
2766        },
2767    );
2768    response.send(proto::Ack {})?;
2769    Ok(())
2770}
2771
2772/// Updates other participants with changes to the diagnostics
2773async fn update_diagnostic_summary(
2774    message: proto::UpdateDiagnosticSummary,
2775    session: Session,
2776) -> Result<()> {
2777    let guest_connection_ids = session
2778        .db()
2779        .await
2780        .update_diagnostic_summary(&message, session.connection_id)
2781        .await?;
2782
2783    broadcast(
2784        Some(session.connection_id),
2785        guest_connection_ids.iter().copied(),
2786        |connection_id| {
2787            session
2788                .peer
2789                .forward_send(session.connection_id, connection_id, message.clone())
2790        },
2791    );
2792
2793    Ok(())
2794}
2795
2796/// Updates other participants with changes to the worktree settings
2797async fn update_worktree_settings(
2798    message: proto::UpdateWorktreeSettings,
2799    session: Session,
2800) -> Result<()> {
2801    let guest_connection_ids = session
2802        .db()
2803        .await
2804        .update_worktree_settings(&message, session.connection_id)
2805        .await?;
2806
2807    broadcast(
2808        Some(session.connection_id),
2809        guest_connection_ids.iter().copied(),
2810        |connection_id| {
2811            session
2812                .peer
2813                .forward_send(session.connection_id, connection_id, message.clone())
2814        },
2815    );
2816
2817    Ok(())
2818}
2819
2820/// Notify other participants that a  language server has started.
2821async fn start_language_server(
2822    request: proto::StartLanguageServer,
2823    session: Session,
2824) -> Result<()> {
2825    let guest_connection_ids = session
2826        .db()
2827        .await
2828        .start_language_server(&request, session.connection_id)
2829        .await?;
2830
2831    broadcast(
2832        Some(session.connection_id),
2833        guest_connection_ids.iter().copied(),
2834        |connection_id| {
2835            session
2836                .peer
2837                .forward_send(session.connection_id, connection_id, request.clone())
2838        },
2839    );
2840    Ok(())
2841}
2842
2843/// Notify other participants that a language server has changed.
2844async fn update_language_server(
2845    request: proto::UpdateLanguageServer,
2846    session: Session,
2847) -> Result<()> {
2848    let project_id = ProjectId::from_proto(request.project_id);
2849    let project_connection_ids = session
2850        .db()
2851        .await
2852        .project_connection_ids(project_id, session.connection_id, true)
2853        .await?;
2854    broadcast(
2855        Some(session.connection_id),
2856        project_connection_ids.iter().copied(),
2857        |connection_id| {
2858            session
2859                .peer
2860                .forward_send(session.connection_id, connection_id, request.clone())
2861        },
2862    );
2863    Ok(())
2864}
2865
2866/// forward a project request to the host. These requests should be read only
2867/// as guests are allowed to send them.
2868async fn forward_read_only_project_request<T>(
2869    request: T,
2870    response: Response<T>,
2871    session: UserSession,
2872) -> Result<()>
2873where
2874    T: EntityMessage + RequestMessage,
2875{
2876    let project_id = ProjectId::from_proto(request.remote_entity_id());
2877    let host_connection_id = session
2878        .db()
2879        .await
2880        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2881        .await?;
2882    let payload = session
2883        .peer
2884        .forward_request(session.connection_id, host_connection_id, request)
2885        .await?;
2886    response.send(payload)?;
2887    Ok(())
2888}
2889
2890/// forward a project request to the dev server. Only allowed
2891/// if it's your dev server.
2892async fn forward_project_request_for_owner<T>(
2893    request: T,
2894    response: Response<T>,
2895    session: UserSession,
2896) -> Result<()>
2897where
2898    T: EntityMessage + RequestMessage,
2899{
2900    let project_id = ProjectId::from_proto(request.remote_entity_id());
2901
2902    let host_connection_id = session
2903        .db()
2904        .await
2905        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2906        .await?;
2907    let payload = session
2908        .peer
2909        .forward_request(session.connection_id, host_connection_id, request)
2910        .await?;
2911    response.send(payload)?;
2912    Ok(())
2913}
2914
2915/// forward a project request to the host. These requests are disallowed
2916/// for guests.
2917async fn forward_mutating_project_request<T>(
2918    request: T,
2919    response: Response<T>,
2920    session: UserSession,
2921) -> Result<()>
2922where
2923    T: EntityMessage + RequestMessage,
2924{
2925    let project_id = ProjectId::from_proto(request.remote_entity_id());
2926
2927    let host_connection_id = session
2928        .db()
2929        .await
2930        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2931        .await?;
2932    let payload = session
2933        .peer
2934        .forward_request(session.connection_id, host_connection_id, request)
2935        .await?;
2936    response.send(payload)?;
2937    Ok(())
2938}
2939
2940/// forward a project request to the host. These requests are disallowed
2941/// for guests.
2942async fn forward_versioned_mutating_project_request<T>(
2943    request: T,
2944    response: Response<T>,
2945    session: UserSession,
2946) -> Result<()>
2947where
2948    T: EntityMessage + RequestMessage + VersionedMessage,
2949{
2950    let project_id = ProjectId::from_proto(request.remote_entity_id());
2951
2952    let host_connection_id = session
2953        .db()
2954        .await
2955        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2956        .await?;
2957    if let Some(host_version) = session
2958        .connection_pool()
2959        .await
2960        .connection(host_connection_id)
2961        .map(|c| c.zed_version)
2962    {
2963        if let Some(min_required_version) = request.required_host_version() {
2964            if min_required_version > host_version {
2965                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2966                    .with_tag("required", &min_required_version.to_string())))?;
2967            }
2968        }
2969    }
2970
2971    let payload = session
2972        .peer
2973        .forward_request(session.connection_id, host_connection_id, request)
2974        .await?;
2975    response.send(payload)?;
2976    Ok(())
2977}
2978
2979/// Notify other participants that a new buffer has been created
2980async fn create_buffer_for_peer(
2981    request: proto::CreateBufferForPeer,
2982    session: Session,
2983) -> Result<()> {
2984    session
2985        .db()
2986        .await
2987        .check_user_is_project_host(
2988            ProjectId::from_proto(request.project_id),
2989            session.connection_id,
2990        )
2991        .await?;
2992    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2993    session
2994        .peer
2995        .forward_send(session.connection_id, peer_id.into(), request)?;
2996    Ok(())
2997}
2998
2999/// Notify other participants that a buffer has been updated. This is
3000/// allowed for guests as long as the update is limited to selections.
3001async fn update_buffer(
3002    request: proto::UpdateBuffer,
3003    response: Response<proto::UpdateBuffer>,
3004    session: Session,
3005) -> Result<()> {
3006    let project_id = ProjectId::from_proto(request.project_id);
3007    let mut capability = Capability::ReadOnly;
3008
3009    for op in request.operations.iter() {
3010        match op.variant {
3011            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3012            Some(_) => capability = Capability::ReadWrite,
3013        }
3014    }
3015
3016    let host = {
3017        let guard = session
3018            .db()
3019            .await
3020            .connections_for_buffer_update(
3021                project_id,
3022                session.principal_id(),
3023                session.connection_id,
3024                capability,
3025            )
3026            .await?;
3027
3028        let (host, guests) = &*guard;
3029
3030        broadcast(
3031            Some(session.connection_id),
3032            guests.clone(),
3033            |connection_id| {
3034                session
3035                    .peer
3036                    .forward_send(session.connection_id, connection_id, request.clone())
3037            },
3038        );
3039
3040        *host
3041    };
3042
3043    if host != session.connection_id {
3044        session
3045            .peer
3046            .forward_request(session.connection_id, host, request.clone())
3047            .await?;
3048    }
3049
3050    response.send(proto::Ack {})?;
3051    Ok(())
3052}
3053
3054/// Notify other participants that a project has been updated.
3055async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3056    request: T,
3057    session: Session,
3058) -> Result<()> {
3059    let project_id = ProjectId::from_proto(request.remote_entity_id());
3060    let project_connection_ids = session
3061        .db()
3062        .await
3063        .project_connection_ids(project_id, session.connection_id, false)
3064        .await?;
3065
3066    broadcast(
3067        Some(session.connection_id),
3068        project_connection_ids.iter().copied(),
3069        |connection_id| {
3070            session
3071                .peer
3072                .forward_send(session.connection_id, connection_id, request.clone())
3073        },
3074    );
3075    Ok(())
3076}
3077
3078/// Start following another user in a call.
3079async fn follow(
3080    request: proto::Follow,
3081    response: Response<proto::Follow>,
3082    session: UserSession,
3083) -> Result<()> {
3084    let room_id = RoomId::from_proto(request.room_id);
3085    let project_id = request.project_id.map(ProjectId::from_proto);
3086    let leader_id = request
3087        .leader_id
3088        .ok_or_else(|| anyhow!("invalid leader id"))?
3089        .into();
3090    let follower_id = session.connection_id;
3091
3092    session
3093        .db()
3094        .await
3095        .check_room_participants(room_id, leader_id, session.connection_id)
3096        .await?;
3097
3098    let response_payload = session
3099        .peer
3100        .forward_request(session.connection_id, leader_id, request)
3101        .await?;
3102    response.send(response_payload)?;
3103
3104    if let Some(project_id) = project_id {
3105        let room = session
3106            .db()
3107            .await
3108            .follow(room_id, project_id, leader_id, follower_id)
3109            .await?;
3110        room_updated(&room, &session.peer);
3111    }
3112
3113    Ok(())
3114}
3115
3116/// Stop following another user in a call.
3117async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3118    let room_id = RoomId::from_proto(request.room_id);
3119    let project_id = request.project_id.map(ProjectId::from_proto);
3120    let leader_id = request
3121        .leader_id
3122        .ok_or_else(|| anyhow!("invalid leader id"))?
3123        .into();
3124    let follower_id = session.connection_id;
3125
3126    session
3127        .db()
3128        .await
3129        .check_room_participants(room_id, leader_id, session.connection_id)
3130        .await?;
3131
3132    session
3133        .peer
3134        .forward_send(session.connection_id, leader_id, request)?;
3135
3136    if let Some(project_id) = project_id {
3137        let room = session
3138            .db()
3139            .await
3140            .unfollow(room_id, project_id, leader_id, follower_id)
3141            .await?;
3142        room_updated(&room, &session.peer);
3143    }
3144
3145    Ok(())
3146}
3147
3148/// Notify everyone following you of your current location.
3149async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3150    let room_id = RoomId::from_proto(request.room_id);
3151    let database = session.db.lock().await;
3152
3153    let connection_ids = if let Some(project_id) = request.project_id {
3154        let project_id = ProjectId::from_proto(project_id);
3155        database
3156            .project_connection_ids(project_id, session.connection_id, true)
3157            .await?
3158    } else {
3159        database
3160            .room_connection_ids(room_id, session.connection_id)
3161            .await?
3162    };
3163
3164    // For now, don't send view update messages back to that view's current leader.
3165    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3166        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3167        _ => None,
3168    });
3169
3170    for connection_id in connection_ids.iter().cloned() {
3171        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3172            session
3173                .peer
3174                .forward_send(session.connection_id, connection_id, request.clone())?;
3175        }
3176    }
3177    Ok(())
3178}
3179
3180/// Get public data about users.
3181async fn get_users(
3182    request: proto::GetUsers,
3183    response: Response<proto::GetUsers>,
3184    session: Session,
3185) -> Result<()> {
3186    let user_ids = request
3187        .user_ids
3188        .into_iter()
3189        .map(UserId::from_proto)
3190        .collect();
3191    let users = session
3192        .db()
3193        .await
3194        .get_users_by_ids(user_ids)
3195        .await?
3196        .into_iter()
3197        .map(|user| proto::User {
3198            id: user.id.to_proto(),
3199            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3200            github_login: user.github_login,
3201        })
3202        .collect();
3203    response.send(proto::UsersResponse { users })?;
3204    Ok(())
3205}
3206
3207/// Search for users (to invite) buy Github login
3208async fn fuzzy_search_users(
3209    request: proto::FuzzySearchUsers,
3210    response: Response<proto::FuzzySearchUsers>,
3211    session: UserSession,
3212) -> Result<()> {
3213    let query = request.query;
3214    let users = match query.len() {
3215        0 => vec![],
3216        1 | 2 => session
3217            .db()
3218            .await
3219            .get_user_by_github_login(&query)
3220            .await?
3221            .into_iter()
3222            .collect(),
3223        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3224    };
3225    let users = users
3226        .into_iter()
3227        .filter(|user| user.id != session.user_id())
3228        .map(|user| proto::User {
3229            id: user.id.to_proto(),
3230            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3231            github_login: user.github_login,
3232        })
3233        .collect();
3234    response.send(proto::UsersResponse { users })?;
3235    Ok(())
3236}
3237
3238/// Send a contact request to another user.
3239async fn request_contact(
3240    request: proto::RequestContact,
3241    response: Response<proto::RequestContact>,
3242    session: UserSession,
3243) -> Result<()> {
3244    let requester_id = session.user_id();
3245    let responder_id = UserId::from_proto(request.responder_id);
3246    if requester_id == responder_id {
3247        return Err(anyhow!("cannot add yourself as a contact"))?;
3248    }
3249
3250    let notifications = session
3251        .db()
3252        .await
3253        .send_contact_request(requester_id, responder_id)
3254        .await?;
3255
3256    // Update outgoing contact requests of requester
3257    let mut update = proto::UpdateContacts::default();
3258    update.outgoing_requests.push(responder_id.to_proto());
3259    for connection_id in session
3260        .connection_pool()
3261        .await
3262        .user_connection_ids(requester_id)
3263    {
3264        session.peer.send(connection_id, update.clone())?;
3265    }
3266
3267    // Update incoming contact requests of responder
3268    let mut update = proto::UpdateContacts::default();
3269    update
3270        .incoming_requests
3271        .push(proto::IncomingContactRequest {
3272            requester_id: requester_id.to_proto(),
3273        });
3274    let connection_pool = session.connection_pool().await;
3275    for connection_id in connection_pool.user_connection_ids(responder_id) {
3276        session.peer.send(connection_id, update.clone())?;
3277    }
3278
3279    send_notifications(&connection_pool, &session.peer, notifications);
3280
3281    response.send(proto::Ack {})?;
3282    Ok(())
3283}
3284
3285/// Accept or decline a contact request
3286async fn respond_to_contact_request(
3287    request: proto::RespondToContactRequest,
3288    response: Response<proto::RespondToContactRequest>,
3289    session: UserSession,
3290) -> Result<()> {
3291    let responder_id = session.user_id();
3292    let requester_id = UserId::from_proto(request.requester_id);
3293    let db = session.db().await;
3294    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3295        db.dismiss_contact_notification(responder_id, requester_id)
3296            .await?;
3297    } else {
3298        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3299
3300        let notifications = db
3301            .respond_to_contact_request(responder_id, requester_id, accept)
3302            .await?;
3303        let requester_busy = db.is_user_busy(requester_id).await?;
3304        let responder_busy = db.is_user_busy(responder_id).await?;
3305
3306        let pool = session.connection_pool().await;
3307        // Update responder with new contact
3308        let mut update = proto::UpdateContacts::default();
3309        if accept {
3310            update
3311                .contacts
3312                .push(contact_for_user(requester_id, requester_busy, &pool));
3313        }
3314        update
3315            .remove_incoming_requests
3316            .push(requester_id.to_proto());
3317        for connection_id in pool.user_connection_ids(responder_id) {
3318            session.peer.send(connection_id, update.clone())?;
3319        }
3320
3321        // Update requester with new contact
3322        let mut update = proto::UpdateContacts::default();
3323        if accept {
3324            update
3325                .contacts
3326                .push(contact_for_user(responder_id, responder_busy, &pool));
3327        }
3328        update
3329            .remove_outgoing_requests
3330            .push(responder_id.to_proto());
3331
3332        for connection_id in pool.user_connection_ids(requester_id) {
3333            session.peer.send(connection_id, update.clone())?;
3334        }
3335
3336        send_notifications(&pool, &session.peer, notifications);
3337    }
3338
3339    response.send(proto::Ack {})?;
3340    Ok(())
3341}
3342
3343/// Remove a contact.
3344async fn remove_contact(
3345    request: proto::RemoveContact,
3346    response: Response<proto::RemoveContact>,
3347    session: UserSession,
3348) -> Result<()> {
3349    let requester_id = session.user_id();
3350    let responder_id = UserId::from_proto(request.user_id);
3351    let db = session.db().await;
3352    let (contact_accepted, deleted_notification_id) =
3353        db.remove_contact(requester_id, responder_id).await?;
3354
3355    let pool = session.connection_pool().await;
3356    // Update outgoing contact requests of requester
3357    let mut update = proto::UpdateContacts::default();
3358    if contact_accepted {
3359        update.remove_contacts.push(responder_id.to_proto());
3360    } else {
3361        update
3362            .remove_outgoing_requests
3363            .push(responder_id.to_proto());
3364    }
3365    for connection_id in pool.user_connection_ids(requester_id) {
3366        session.peer.send(connection_id, update.clone())?;
3367    }
3368
3369    // Update incoming contact requests of responder
3370    let mut update = proto::UpdateContacts::default();
3371    if contact_accepted {
3372        update.remove_contacts.push(requester_id.to_proto());
3373    } else {
3374        update
3375            .remove_incoming_requests
3376            .push(requester_id.to_proto());
3377    }
3378    for connection_id in pool.user_connection_ids(responder_id) {
3379        session.peer.send(connection_id, update.clone())?;
3380        if let Some(notification_id) = deleted_notification_id {
3381            session.peer.send(
3382                connection_id,
3383                proto::DeleteNotification {
3384                    notification_id: notification_id.to_proto(),
3385                },
3386            )?;
3387        }
3388    }
3389
3390    response.send(proto::Ack {})?;
3391    Ok(())
3392}
3393
3394fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3395    version.0.minor() < 139
3396}
3397
3398async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3399    subscribe_user_to_channels(
3400        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3401        &session,
3402    )
3403    .await?;
3404    Ok(())
3405}
3406
3407async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3408    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3409    let mut pool = session.connection_pool().await;
3410    for membership in &channels_for_user.channel_memberships {
3411        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3412    }
3413    session.peer.send(
3414        session.connection_id,
3415        build_update_user_channels(&channels_for_user),
3416    )?;
3417    session.peer.send(
3418        session.connection_id,
3419        build_channels_update(channels_for_user),
3420    )?;
3421    Ok(())
3422}
3423
3424/// Creates a new channel.
3425async fn create_channel(
3426    request: proto::CreateChannel,
3427    response: Response<proto::CreateChannel>,
3428    session: UserSession,
3429) -> Result<()> {
3430    let db = session.db().await;
3431
3432    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3433    let (channel, membership) = db
3434        .create_channel(&request.name, parent_id, session.user_id())
3435        .await?;
3436
3437    let root_id = channel.root_id();
3438    let channel = Channel::from_model(channel);
3439
3440    response.send(proto::CreateChannelResponse {
3441        channel: Some(channel.to_proto()),
3442        parent_id: request.parent_id,
3443    })?;
3444
3445    let mut connection_pool = session.connection_pool().await;
3446    if let Some(membership) = membership {
3447        connection_pool.subscribe_to_channel(
3448            membership.user_id,
3449            membership.channel_id,
3450            membership.role,
3451        );
3452        let update = proto::UpdateUserChannels {
3453            channel_memberships: vec![proto::ChannelMembership {
3454                channel_id: membership.channel_id.to_proto(),
3455                role: membership.role.into(),
3456            }],
3457            ..Default::default()
3458        };
3459        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3460            session.peer.send(connection_id, update.clone())?;
3461        }
3462    }
3463
3464    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3465        if !role.can_see_channel(channel.visibility) {
3466            continue;
3467        }
3468
3469        let update = proto::UpdateChannels {
3470            channels: vec![channel.to_proto()],
3471            ..Default::default()
3472        };
3473        session.peer.send(connection_id, update.clone())?;
3474    }
3475
3476    Ok(())
3477}
3478
3479/// Delete a channel
3480async fn delete_channel(
3481    request: proto::DeleteChannel,
3482    response: Response<proto::DeleteChannel>,
3483    session: UserSession,
3484) -> Result<()> {
3485    let db = session.db().await;
3486
3487    let channel_id = request.channel_id;
3488    let (root_channel, removed_channels) = db
3489        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3490        .await?;
3491    response.send(proto::Ack {})?;
3492
3493    // Notify members of removed channels
3494    let mut update = proto::UpdateChannels::default();
3495    update
3496        .delete_channels
3497        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3498
3499    let connection_pool = session.connection_pool().await;
3500    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3501        session.peer.send(connection_id, update.clone())?;
3502    }
3503
3504    Ok(())
3505}
3506
3507/// Invite someone to join a channel.
3508async fn invite_channel_member(
3509    request: proto::InviteChannelMember,
3510    response: Response<proto::InviteChannelMember>,
3511    session: UserSession,
3512) -> Result<()> {
3513    let db = session.db().await;
3514    let channel_id = ChannelId::from_proto(request.channel_id);
3515    let invitee_id = UserId::from_proto(request.user_id);
3516    let InviteMemberResult {
3517        channel,
3518        notifications,
3519    } = db
3520        .invite_channel_member(
3521            channel_id,
3522            invitee_id,
3523            session.user_id(),
3524            request.role().into(),
3525        )
3526        .await?;
3527
3528    let update = proto::UpdateChannels {
3529        channel_invitations: vec![channel.to_proto()],
3530        ..Default::default()
3531    };
3532
3533    let connection_pool = session.connection_pool().await;
3534    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3535        session.peer.send(connection_id, update.clone())?;
3536    }
3537
3538    send_notifications(&connection_pool, &session.peer, notifications);
3539
3540    response.send(proto::Ack {})?;
3541    Ok(())
3542}
3543
3544/// remove someone from a channel
3545async fn remove_channel_member(
3546    request: proto::RemoveChannelMember,
3547    response: Response<proto::RemoveChannelMember>,
3548    session: UserSession,
3549) -> Result<()> {
3550    let db = session.db().await;
3551    let channel_id = ChannelId::from_proto(request.channel_id);
3552    let member_id = UserId::from_proto(request.user_id);
3553
3554    let RemoveChannelMemberResult {
3555        membership_update,
3556        notification_id,
3557    } = db
3558        .remove_channel_member(channel_id, member_id, session.user_id())
3559        .await?;
3560
3561    let mut connection_pool = session.connection_pool().await;
3562    notify_membership_updated(
3563        &mut connection_pool,
3564        membership_update,
3565        member_id,
3566        &session.peer,
3567    );
3568    for connection_id in connection_pool.user_connection_ids(member_id) {
3569        if let Some(notification_id) = notification_id {
3570            session
3571                .peer
3572                .send(
3573                    connection_id,
3574                    proto::DeleteNotification {
3575                        notification_id: notification_id.to_proto(),
3576                    },
3577                )
3578                .trace_err();
3579        }
3580    }
3581
3582    response.send(proto::Ack {})?;
3583    Ok(())
3584}
3585
3586/// Toggle the channel between public and private.
3587/// Care is taken to maintain the invariant that public channels only descend from public channels,
3588/// (though members-only channels can appear at any point in the hierarchy).
3589async fn set_channel_visibility(
3590    request: proto::SetChannelVisibility,
3591    response: Response<proto::SetChannelVisibility>,
3592    session: UserSession,
3593) -> Result<()> {
3594    let db = session.db().await;
3595    let channel_id = ChannelId::from_proto(request.channel_id);
3596    let visibility = request.visibility().into();
3597
3598    let channel_model = db
3599        .set_channel_visibility(channel_id, visibility, session.user_id())
3600        .await?;
3601    let root_id = channel_model.root_id();
3602    let channel = Channel::from_model(channel_model);
3603
3604    let mut connection_pool = session.connection_pool().await;
3605    for (user_id, role) in connection_pool
3606        .channel_user_ids(root_id)
3607        .collect::<Vec<_>>()
3608        .into_iter()
3609    {
3610        let update = if role.can_see_channel(channel.visibility) {
3611            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3612            proto::UpdateChannels {
3613                channels: vec![channel.to_proto()],
3614                ..Default::default()
3615            }
3616        } else {
3617            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3618            proto::UpdateChannels {
3619                delete_channels: vec![channel.id.to_proto()],
3620                ..Default::default()
3621            }
3622        };
3623
3624        for connection_id in connection_pool.user_connection_ids(user_id) {
3625            session.peer.send(connection_id, update.clone())?;
3626        }
3627    }
3628
3629    response.send(proto::Ack {})?;
3630    Ok(())
3631}
3632
3633/// Alter the role for a user in the channel.
3634async fn set_channel_member_role(
3635    request: proto::SetChannelMemberRole,
3636    response: Response<proto::SetChannelMemberRole>,
3637    session: UserSession,
3638) -> Result<()> {
3639    let db = session.db().await;
3640    let channel_id = ChannelId::from_proto(request.channel_id);
3641    let member_id = UserId::from_proto(request.user_id);
3642    let result = db
3643        .set_channel_member_role(
3644            channel_id,
3645            session.user_id(),
3646            member_id,
3647            request.role().into(),
3648        )
3649        .await?;
3650
3651    match result {
3652        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3653            let mut connection_pool = session.connection_pool().await;
3654            notify_membership_updated(
3655                &mut connection_pool,
3656                membership_update,
3657                member_id,
3658                &session.peer,
3659            )
3660        }
3661        db::SetMemberRoleResult::InviteUpdated(channel) => {
3662            let update = proto::UpdateChannels {
3663                channel_invitations: vec![channel.to_proto()],
3664                ..Default::default()
3665            };
3666
3667            for connection_id in session
3668                .connection_pool()
3669                .await
3670                .user_connection_ids(member_id)
3671            {
3672                session.peer.send(connection_id, update.clone())?;
3673            }
3674        }
3675    }
3676
3677    response.send(proto::Ack {})?;
3678    Ok(())
3679}
3680
3681/// Change the name of a channel
3682async fn rename_channel(
3683    request: proto::RenameChannel,
3684    response: Response<proto::RenameChannel>,
3685    session: UserSession,
3686) -> Result<()> {
3687    let db = session.db().await;
3688    let channel_id = ChannelId::from_proto(request.channel_id);
3689    let channel_model = db
3690        .rename_channel(channel_id, session.user_id(), &request.name)
3691        .await?;
3692    let root_id = channel_model.root_id();
3693    let channel = Channel::from_model(channel_model);
3694
3695    response.send(proto::RenameChannelResponse {
3696        channel: Some(channel.to_proto()),
3697    })?;
3698
3699    let connection_pool = session.connection_pool().await;
3700    let update = proto::UpdateChannels {
3701        channels: vec![channel.to_proto()],
3702        ..Default::default()
3703    };
3704    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3705        if role.can_see_channel(channel.visibility) {
3706            session.peer.send(connection_id, update.clone())?;
3707        }
3708    }
3709
3710    Ok(())
3711}
3712
3713/// Move a channel to a new parent.
3714async fn move_channel(
3715    request: proto::MoveChannel,
3716    response: Response<proto::MoveChannel>,
3717    session: UserSession,
3718) -> Result<()> {
3719    let channel_id = ChannelId::from_proto(request.channel_id);
3720    let to = ChannelId::from_proto(request.to);
3721
3722    let (root_id, channels) = session
3723        .db()
3724        .await
3725        .move_channel(channel_id, to, session.user_id())
3726        .await?;
3727
3728    let connection_pool = session.connection_pool().await;
3729    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3730        let channels = channels
3731            .iter()
3732            .filter_map(|channel| {
3733                if role.can_see_channel(channel.visibility) {
3734                    Some(channel.to_proto())
3735                } else {
3736                    None
3737                }
3738            })
3739            .collect::<Vec<_>>();
3740        if channels.is_empty() {
3741            continue;
3742        }
3743
3744        let update = proto::UpdateChannels {
3745            channels,
3746            ..Default::default()
3747        };
3748
3749        session.peer.send(connection_id, update.clone())?;
3750    }
3751
3752    response.send(Ack {})?;
3753    Ok(())
3754}
3755
3756/// Get the list of channel members
3757async fn get_channel_members(
3758    request: proto::GetChannelMembers,
3759    response: Response<proto::GetChannelMembers>,
3760    session: UserSession,
3761) -> Result<()> {
3762    let db = session.db().await;
3763    let channel_id = ChannelId::from_proto(request.channel_id);
3764    let limit = if request.limit == 0 {
3765        u16::MAX as u64
3766    } else {
3767        request.limit
3768    };
3769    let (members, users) = db
3770        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3771        .await?;
3772    response.send(proto::GetChannelMembersResponse { members, users })?;
3773    Ok(())
3774}
3775
3776/// Accept or decline a channel invitation.
3777async fn respond_to_channel_invite(
3778    request: proto::RespondToChannelInvite,
3779    response: Response<proto::RespondToChannelInvite>,
3780    session: UserSession,
3781) -> Result<()> {
3782    let db = session.db().await;
3783    let channel_id = ChannelId::from_proto(request.channel_id);
3784    let RespondToChannelInvite {
3785        membership_update,
3786        notifications,
3787    } = db
3788        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3789        .await?;
3790
3791    let mut connection_pool = session.connection_pool().await;
3792    if let Some(membership_update) = membership_update {
3793        notify_membership_updated(
3794            &mut connection_pool,
3795            membership_update,
3796            session.user_id(),
3797            &session.peer,
3798        );
3799    } else {
3800        let update = proto::UpdateChannels {
3801            remove_channel_invitations: vec![channel_id.to_proto()],
3802            ..Default::default()
3803        };
3804
3805        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3806            session.peer.send(connection_id, update.clone())?;
3807        }
3808    };
3809
3810    send_notifications(&connection_pool, &session.peer, notifications);
3811
3812    response.send(proto::Ack {})?;
3813
3814    Ok(())
3815}
3816
3817/// Join the channels' room
3818async fn join_channel(
3819    request: proto::JoinChannel,
3820    response: Response<proto::JoinChannel>,
3821    session: UserSession,
3822) -> Result<()> {
3823    let channel_id = ChannelId::from_proto(request.channel_id);
3824    join_channel_internal(channel_id, Box::new(response), session).await
3825}
3826
3827trait JoinChannelInternalResponse {
3828    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3829}
3830impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3831    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3832        Response::<proto::JoinChannel>::send(self, result)
3833    }
3834}
3835impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3836    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3837        Response::<proto::JoinRoom>::send(self, result)
3838    }
3839}
3840
3841async fn join_channel_internal(
3842    channel_id: ChannelId,
3843    response: Box<impl JoinChannelInternalResponse>,
3844    session: UserSession,
3845) -> Result<()> {
3846    let joined_room = {
3847        let mut db = session.db().await;
3848        // If zed quits without leaving the room, and the user re-opens zed before the
3849        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3850        // room they were in.
3851        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3852            tracing::info!(
3853                stale_connection_id = %connection,
3854                "cleaning up stale connection",
3855            );
3856            drop(db);
3857            leave_room_for_session(&session, connection).await?;
3858            db = session.db().await;
3859        }
3860
3861        let (joined_room, membership_updated, role) = db
3862            .join_channel(channel_id, session.user_id(), session.connection_id)
3863            .await?;
3864
3865        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3866            let (can_publish, token) = if role == ChannelRole::Guest {
3867                (
3868                    false,
3869                    live_kit
3870                        .guest_token(
3871                            &joined_room.room.live_kit_room,
3872                            &session.user_id().to_string(),
3873                        )
3874                        .trace_err()?,
3875                )
3876            } else {
3877                (
3878                    true,
3879                    live_kit
3880                        .room_token(
3881                            &joined_room.room.live_kit_room,
3882                            &session.user_id().to_string(),
3883                        )
3884                        .trace_err()?,
3885                )
3886            };
3887
3888            Some(LiveKitConnectionInfo {
3889                server_url: live_kit.url().into(),
3890                token,
3891                can_publish,
3892            })
3893        });
3894
3895        response.send(proto::JoinRoomResponse {
3896            room: Some(joined_room.room.clone()),
3897            channel_id: joined_room
3898                .channel
3899                .as_ref()
3900                .map(|channel| channel.id.to_proto()),
3901            live_kit_connection_info,
3902        })?;
3903
3904        let mut connection_pool = session.connection_pool().await;
3905        if let Some(membership_updated) = membership_updated {
3906            notify_membership_updated(
3907                &mut connection_pool,
3908                membership_updated,
3909                session.user_id(),
3910                &session.peer,
3911            );
3912        }
3913
3914        room_updated(&joined_room.room, &session.peer);
3915
3916        joined_room
3917    };
3918
3919    channel_updated(
3920        &joined_room
3921            .channel
3922            .ok_or_else(|| anyhow!("channel not returned"))?,
3923        &joined_room.room,
3924        &session.peer,
3925        &*session.connection_pool().await,
3926    );
3927
3928    update_user_contacts(session.user_id(), &session).await?;
3929    Ok(())
3930}
3931
3932/// Start editing the channel notes
3933async fn join_channel_buffer(
3934    request: proto::JoinChannelBuffer,
3935    response: Response<proto::JoinChannelBuffer>,
3936    session: UserSession,
3937) -> Result<()> {
3938    let db = session.db().await;
3939    let channel_id = ChannelId::from_proto(request.channel_id);
3940
3941    let open_response = db
3942        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3943        .await?;
3944
3945    let collaborators = open_response.collaborators.clone();
3946    response.send(open_response)?;
3947
3948    let update = UpdateChannelBufferCollaborators {
3949        channel_id: channel_id.to_proto(),
3950        collaborators: collaborators.clone(),
3951    };
3952    channel_buffer_updated(
3953        session.connection_id,
3954        collaborators
3955            .iter()
3956            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3957        &update,
3958        &session.peer,
3959    );
3960
3961    Ok(())
3962}
3963
3964/// Edit the channel notes
3965async fn update_channel_buffer(
3966    request: proto::UpdateChannelBuffer,
3967    session: UserSession,
3968) -> Result<()> {
3969    let db = session.db().await;
3970    let channel_id = ChannelId::from_proto(request.channel_id);
3971
3972    let (collaborators, epoch, version) = db
3973        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3974        .await?;
3975
3976    channel_buffer_updated(
3977        session.connection_id,
3978        collaborators.clone(),
3979        &proto::UpdateChannelBuffer {
3980            channel_id: channel_id.to_proto(),
3981            operations: request.operations,
3982        },
3983        &session.peer,
3984    );
3985
3986    let pool = &*session.connection_pool().await;
3987
3988    let non_collaborators =
3989        pool.channel_connection_ids(channel_id)
3990            .filter_map(|(connection_id, _)| {
3991                if collaborators.contains(&connection_id) {
3992                    None
3993                } else {
3994                    Some(connection_id)
3995                }
3996            });
3997
3998    broadcast(None, non_collaborators, |peer_id| {
3999        session.peer.send(
4000            peer_id,
4001            proto::UpdateChannels {
4002                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4003                    channel_id: channel_id.to_proto(),
4004                    epoch: epoch as u64,
4005                    version: version.clone(),
4006                }],
4007                ..Default::default()
4008            },
4009        )
4010    });
4011
4012    Ok(())
4013}
4014
4015/// Rejoin the channel notes after a connection blip
4016async fn rejoin_channel_buffers(
4017    request: proto::RejoinChannelBuffers,
4018    response: Response<proto::RejoinChannelBuffers>,
4019    session: UserSession,
4020) -> Result<()> {
4021    let db = session.db().await;
4022    let buffers = db
4023        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4024        .await?;
4025
4026    for rejoined_buffer in &buffers {
4027        let collaborators_to_notify = rejoined_buffer
4028            .buffer
4029            .collaborators
4030            .iter()
4031            .filter_map(|c| Some(c.peer_id?.into()));
4032        channel_buffer_updated(
4033            session.connection_id,
4034            collaborators_to_notify,
4035            &proto::UpdateChannelBufferCollaborators {
4036                channel_id: rejoined_buffer.buffer.channel_id,
4037                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4038            },
4039            &session.peer,
4040        );
4041    }
4042
4043    response.send(proto::RejoinChannelBuffersResponse {
4044        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4045    })?;
4046
4047    Ok(())
4048}
4049
4050/// Stop editing the channel notes
4051async fn leave_channel_buffer(
4052    request: proto::LeaveChannelBuffer,
4053    response: Response<proto::LeaveChannelBuffer>,
4054    session: UserSession,
4055) -> Result<()> {
4056    let db = session.db().await;
4057    let channel_id = ChannelId::from_proto(request.channel_id);
4058
4059    let left_buffer = db
4060        .leave_channel_buffer(channel_id, session.connection_id)
4061        .await?;
4062
4063    response.send(Ack {})?;
4064
4065    channel_buffer_updated(
4066        session.connection_id,
4067        left_buffer.connections,
4068        &proto::UpdateChannelBufferCollaborators {
4069            channel_id: channel_id.to_proto(),
4070            collaborators: left_buffer.collaborators,
4071        },
4072        &session.peer,
4073    );
4074
4075    Ok(())
4076}
4077
4078fn channel_buffer_updated<T: EnvelopedMessage>(
4079    sender_id: ConnectionId,
4080    collaborators: impl IntoIterator<Item = ConnectionId>,
4081    message: &T,
4082    peer: &Peer,
4083) {
4084    broadcast(Some(sender_id), collaborators, |peer_id| {
4085        peer.send(peer_id, message.clone())
4086    });
4087}
4088
4089fn send_notifications(
4090    connection_pool: &ConnectionPool,
4091    peer: &Peer,
4092    notifications: db::NotificationBatch,
4093) {
4094    for (user_id, notification) in notifications {
4095        for connection_id in connection_pool.user_connection_ids(user_id) {
4096            if let Err(error) = peer.send(
4097                connection_id,
4098                proto::AddNotification {
4099                    notification: Some(notification.clone()),
4100                },
4101            ) {
4102                tracing::error!(
4103                    "failed to send notification to {:?} {}",
4104                    connection_id,
4105                    error
4106                );
4107            }
4108        }
4109    }
4110}
4111
4112/// Send a message to the channel
4113async fn send_channel_message(
4114    request: proto::SendChannelMessage,
4115    response: Response<proto::SendChannelMessage>,
4116    session: UserSession,
4117) -> Result<()> {
4118    // Validate the message body.
4119    let body = request.body.trim().to_string();
4120    if body.len() > MAX_MESSAGE_LEN {
4121        return Err(anyhow!("message is too long"))?;
4122    }
4123    if body.is_empty() {
4124        return Err(anyhow!("message can't be blank"))?;
4125    }
4126
4127    // TODO: adjust mentions if body is trimmed
4128
4129    let timestamp = OffsetDateTime::now_utc();
4130    let nonce = request
4131        .nonce
4132        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4133
4134    let channel_id = ChannelId::from_proto(request.channel_id);
4135    let CreatedChannelMessage {
4136        message_id,
4137        participant_connection_ids,
4138        notifications,
4139    } = session
4140        .db()
4141        .await
4142        .create_channel_message(
4143            channel_id,
4144            session.user_id(),
4145            &body,
4146            &request.mentions,
4147            timestamp,
4148            nonce.clone().into(),
4149            match request.reply_to_message_id {
4150                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4151                None => None,
4152            },
4153        )
4154        .await?;
4155
4156    let message = proto::ChannelMessage {
4157        sender_id: session.user_id().to_proto(),
4158        id: message_id.to_proto(),
4159        body,
4160        mentions: request.mentions,
4161        timestamp: timestamp.unix_timestamp() as u64,
4162        nonce: Some(nonce),
4163        reply_to_message_id: request.reply_to_message_id,
4164        edited_at: None,
4165    };
4166    broadcast(
4167        Some(session.connection_id),
4168        participant_connection_ids.clone(),
4169        |connection| {
4170            session.peer.send(
4171                connection,
4172                proto::ChannelMessageSent {
4173                    channel_id: channel_id.to_proto(),
4174                    message: Some(message.clone()),
4175                },
4176            )
4177        },
4178    );
4179    response.send(proto::SendChannelMessageResponse {
4180        message: Some(message),
4181    })?;
4182
4183    let pool = &*session.connection_pool().await;
4184    let non_participants =
4185        pool.channel_connection_ids(channel_id)
4186            .filter_map(|(connection_id, _)| {
4187                if participant_connection_ids.contains(&connection_id) {
4188                    None
4189                } else {
4190                    Some(connection_id)
4191                }
4192            });
4193    broadcast(None, non_participants, |peer_id| {
4194        session.peer.send(
4195            peer_id,
4196            proto::UpdateChannels {
4197                latest_channel_message_ids: vec![proto::ChannelMessageId {
4198                    channel_id: channel_id.to_proto(),
4199                    message_id: message_id.to_proto(),
4200                }],
4201                ..Default::default()
4202            },
4203        )
4204    });
4205    send_notifications(pool, &session.peer, notifications);
4206
4207    Ok(())
4208}
4209
4210/// Delete a channel message
4211async fn remove_channel_message(
4212    request: proto::RemoveChannelMessage,
4213    response: Response<proto::RemoveChannelMessage>,
4214    session: UserSession,
4215) -> Result<()> {
4216    let channel_id = ChannelId::from_proto(request.channel_id);
4217    let message_id = MessageId::from_proto(request.message_id);
4218    let (connection_ids, existing_notification_ids) = session
4219        .db()
4220        .await
4221        .remove_channel_message(channel_id, message_id, session.user_id())
4222        .await?;
4223
4224    broadcast(
4225        Some(session.connection_id),
4226        connection_ids,
4227        move |connection| {
4228            session.peer.send(connection, request.clone())?;
4229
4230            for notification_id in &existing_notification_ids {
4231                session.peer.send(
4232                    connection,
4233                    proto::DeleteNotification {
4234                        notification_id: (*notification_id).to_proto(),
4235                    },
4236                )?;
4237            }
4238
4239            Ok(())
4240        },
4241    );
4242    response.send(proto::Ack {})?;
4243    Ok(())
4244}
4245
4246async fn update_channel_message(
4247    request: proto::UpdateChannelMessage,
4248    response: Response<proto::UpdateChannelMessage>,
4249    session: UserSession,
4250) -> Result<()> {
4251    let channel_id = ChannelId::from_proto(request.channel_id);
4252    let message_id = MessageId::from_proto(request.message_id);
4253    let updated_at = OffsetDateTime::now_utc();
4254    let UpdatedChannelMessage {
4255        message_id,
4256        participant_connection_ids,
4257        notifications,
4258        reply_to_message_id,
4259        timestamp,
4260        deleted_mention_notification_ids,
4261        updated_mention_notifications,
4262    } = session
4263        .db()
4264        .await
4265        .update_channel_message(
4266            channel_id,
4267            message_id,
4268            session.user_id(),
4269            request.body.as_str(),
4270            &request.mentions,
4271            updated_at,
4272        )
4273        .await?;
4274
4275    let nonce = request
4276        .nonce
4277        .clone()
4278        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4279
4280    let message = proto::ChannelMessage {
4281        sender_id: session.user_id().to_proto(),
4282        id: message_id.to_proto(),
4283        body: request.body.clone(),
4284        mentions: request.mentions.clone(),
4285        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4286        nonce: Some(nonce),
4287        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4288        edited_at: Some(updated_at.unix_timestamp() as u64),
4289    };
4290
4291    response.send(proto::Ack {})?;
4292
4293    let pool = &*session.connection_pool().await;
4294    broadcast(
4295        Some(session.connection_id),
4296        participant_connection_ids,
4297        |connection| {
4298            session.peer.send(
4299                connection,
4300                proto::ChannelMessageUpdate {
4301                    channel_id: channel_id.to_proto(),
4302                    message: Some(message.clone()),
4303                },
4304            )?;
4305
4306            for notification_id in &deleted_mention_notification_ids {
4307                session.peer.send(
4308                    connection,
4309                    proto::DeleteNotification {
4310                        notification_id: (*notification_id).to_proto(),
4311                    },
4312                )?;
4313            }
4314
4315            for notification in &updated_mention_notifications {
4316                session.peer.send(
4317                    connection,
4318                    proto::UpdateNotification {
4319                        notification: Some(notification.clone()),
4320                    },
4321                )?;
4322            }
4323
4324            Ok(())
4325        },
4326    );
4327
4328    send_notifications(pool, &session.peer, notifications);
4329
4330    Ok(())
4331}
4332
4333/// Mark a channel message as read
4334async fn acknowledge_channel_message(
4335    request: proto::AckChannelMessage,
4336    session: UserSession,
4337) -> Result<()> {
4338    let channel_id = ChannelId::from_proto(request.channel_id);
4339    let message_id = MessageId::from_proto(request.message_id);
4340    let notifications = session
4341        .db()
4342        .await
4343        .observe_channel_message(channel_id, session.user_id(), message_id)
4344        .await?;
4345    send_notifications(
4346        &*session.connection_pool().await,
4347        &session.peer,
4348        notifications,
4349    );
4350    Ok(())
4351}
4352
4353/// Mark a buffer version as synced
4354async fn acknowledge_buffer_version(
4355    request: proto::AckBufferOperation,
4356    session: UserSession,
4357) -> Result<()> {
4358    let buffer_id = BufferId::from_proto(request.buffer_id);
4359    session
4360        .db()
4361        .await
4362        .observe_buffer_version(
4363            buffer_id,
4364            session.user_id(),
4365            request.epoch as i32,
4366            &request.version,
4367        )
4368        .await?;
4369    Ok(())
4370}
4371
4372struct CompleteWithLanguageModelRateLimit;
4373
4374impl RateLimit for CompleteWithLanguageModelRateLimit {
4375    fn capacity() -> usize {
4376        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4377            .ok()
4378            .and_then(|v| v.parse().ok())
4379            .unwrap_or(120) // Picked arbitrarily
4380    }
4381
4382    fn refill_duration() -> chrono::Duration {
4383        chrono::Duration::hours(1)
4384    }
4385
4386    fn db_name() -> &'static str {
4387        "complete-with-language-model"
4388    }
4389}
4390
4391async fn complete_with_language_model(
4392    request: proto::CompleteWithLanguageModel,
4393    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4394    session: Session,
4395    open_ai_api_key: Option<Arc<str>>,
4396    google_ai_api_key: Option<Arc<str>>,
4397    anthropic_api_key: Option<Arc<str>>,
4398) -> Result<()> {
4399    let Some(session) = session.for_user() else {
4400        return Err(anyhow!("user not found"))?;
4401    };
4402    authorize_access_to_language_models(&session).await?;
4403    session
4404        .rate_limiter
4405        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4406        .await?;
4407
4408    if request.model.starts_with("gpt") {
4409        let api_key =
4410            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4411        complete_with_open_ai(request, response, session, api_key).await?;
4412    } else if request.model.starts_with("gemini") {
4413        let api_key = google_ai_api_key
4414            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4415        complete_with_google_ai(request, response, session, api_key).await?;
4416    } else if request.model.starts_with("claude") {
4417        let api_key = anthropic_api_key
4418            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4419        complete_with_anthropic(request, response, session, api_key).await?;
4420    }
4421
4422    Ok(())
4423}
4424
4425async fn complete_with_open_ai(
4426    request: proto::CompleteWithLanguageModel,
4427    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4428    session: UserSession,
4429    api_key: Arc<str>,
4430) -> Result<()> {
4431    let mut completion_stream = open_ai::stream_completion(
4432        session.http_client.as_ref(),
4433        OPEN_AI_API_URL,
4434        &api_key,
4435        crate::ai::language_model_request_to_open_ai(request)?,
4436        None,
4437    )
4438    .await
4439    .context("open_ai::stream_completion request failed within collab")?;
4440
4441    while let Some(event) = completion_stream.next().await {
4442        let event = event?;
4443        response.send(proto::LanguageModelResponse {
4444            choices: event
4445                .choices
4446                .into_iter()
4447                .map(|choice| proto::LanguageModelChoiceDelta {
4448                    index: choice.index,
4449                    delta: Some(proto::LanguageModelResponseMessage {
4450                        role: choice.delta.role.map(|role| match role {
4451                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4452                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4453                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4454                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4455                        } as i32),
4456                        content: choice.delta.content,
4457                        tool_calls: choice
4458                            .delta
4459                            .tool_calls
4460                            .into_iter()
4461                            .map(|delta| proto::ToolCallDelta {
4462                                index: delta.index as u32,
4463                                id: delta.id,
4464                                variant: match delta.function {
4465                                    Some(function) => {
4466                                        let name = function.name;
4467                                        let arguments = function.arguments;
4468
4469                                        Some(proto::tool_call_delta::Variant::Function(
4470                                            proto::tool_call_delta::FunctionCallDelta {
4471                                                name,
4472                                                arguments,
4473                                            },
4474                                        ))
4475                                    }
4476                                    None => None,
4477                                },
4478                            })
4479                            .collect(),
4480                    }),
4481                    finish_reason: choice.finish_reason,
4482                })
4483                .collect(),
4484        })?;
4485    }
4486
4487    Ok(())
4488}
4489
4490async fn complete_with_google_ai(
4491    request: proto::CompleteWithLanguageModel,
4492    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4493    session: UserSession,
4494    api_key: Arc<str>,
4495) -> Result<()> {
4496    let mut stream = google_ai::stream_generate_content(
4497        session.http_client.clone(),
4498        google_ai::API_URL,
4499        api_key.as_ref(),
4500        crate::ai::language_model_request_to_google_ai(request)?,
4501    )
4502    .await
4503    .context("google_ai::stream_generate_content request failed")?;
4504
4505    while let Some(event) = stream.next().await {
4506        let event = event?;
4507        response.send(proto::LanguageModelResponse {
4508            choices: event
4509                .candidates
4510                .unwrap_or_default()
4511                .into_iter()
4512                .map(|candidate| proto::LanguageModelChoiceDelta {
4513                    index: candidate.index as u32,
4514                    delta: Some(proto::LanguageModelResponseMessage {
4515                        role: Some(match candidate.content.role {
4516                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4517                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4518                        } as i32),
4519                        content: Some(
4520                            candidate
4521                                .content
4522                                .parts
4523                                .into_iter()
4524                                .filter_map(|part| match part {
4525                                    google_ai::Part::TextPart(part) => Some(part.text),
4526                                    google_ai::Part::InlineDataPart(_) => None,
4527                                })
4528                                .collect(),
4529                        ),
4530                        // Tool calls are not supported for Google
4531                        tool_calls: Vec::new(),
4532                    }),
4533                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4534                })
4535                .collect(),
4536        })?;
4537    }
4538
4539    Ok(())
4540}
4541
4542async fn complete_with_anthropic(
4543    request: proto::CompleteWithLanguageModel,
4544    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4545    session: UserSession,
4546    api_key: Arc<str>,
4547) -> Result<()> {
4548    let model = anthropic::Model::from_id(&request.model)?;
4549
4550    let mut system_message = String::new();
4551    let messages = request
4552        .messages
4553        .into_iter()
4554        .filter_map(|message| {
4555            match message.role() {
4556                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4557                    role: anthropic::Role::User,
4558                    content: message.content,
4559                }),
4560                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4561                    role: anthropic::Role::Assistant,
4562                    content: message.content,
4563                }),
4564                // Anthropic's API breaks system instructions out as a separate field rather
4565                // than having a system message role.
4566                LanguageModelRole::LanguageModelSystem => {
4567                    if !system_message.is_empty() {
4568                        system_message.push_str("\n\n");
4569                    }
4570                    system_message.push_str(&message.content);
4571
4572                    None
4573                }
4574                // We don't yet support tool calls for Anthropic
4575                LanguageModelRole::LanguageModelTool => None,
4576            }
4577        })
4578        .collect();
4579
4580    let mut stream = anthropic::stream_completion(
4581        session.http_client.as_ref(),
4582        anthropic::ANTHROPIC_API_URL,
4583        &api_key,
4584        anthropic::Request {
4585            model,
4586            messages,
4587            stream: true,
4588            system: system_message,
4589            max_tokens: 4092,
4590        },
4591        None,
4592    )
4593    .await?;
4594
4595    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4596
4597    while let Some(event) = stream.next().await {
4598        let event = event?;
4599
4600        match event {
4601            anthropic::ResponseEvent::MessageStart { message } => {
4602                if let Some(role) = message.role {
4603                    if role == "assistant" {
4604                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4605                    } else if role == "user" {
4606                        current_role = proto::LanguageModelRole::LanguageModelUser;
4607                    }
4608                }
4609            }
4610            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4611                match content_block {
4612                    anthropic::ContentBlock::Text { text } => {
4613                        if !text.is_empty() {
4614                            response.send(proto::LanguageModelResponse {
4615                                choices: vec![proto::LanguageModelChoiceDelta {
4616                                    index: 0,
4617                                    delta: Some(proto::LanguageModelResponseMessage {
4618                                        role: Some(current_role as i32),
4619                                        content: Some(text),
4620                                        tool_calls: Vec::new(),
4621                                    }),
4622                                    finish_reason: None,
4623                                }],
4624                            })?;
4625                        }
4626                    }
4627                }
4628            }
4629            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4630                anthropic::TextDelta::TextDelta { text } => {
4631                    response.send(proto::LanguageModelResponse {
4632                        choices: vec![proto::LanguageModelChoiceDelta {
4633                            index: 0,
4634                            delta: Some(proto::LanguageModelResponseMessage {
4635                                role: Some(current_role as i32),
4636                                content: Some(text),
4637                                tool_calls: Vec::new(),
4638                            }),
4639                            finish_reason: None,
4640                        }],
4641                    })?;
4642                }
4643            },
4644            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4645                if let Some(stop_reason) = delta.stop_reason {
4646                    response.send(proto::LanguageModelResponse {
4647                        choices: vec![proto::LanguageModelChoiceDelta {
4648                            index: 0,
4649                            delta: None,
4650                            finish_reason: Some(stop_reason),
4651                        }],
4652                    })?;
4653                }
4654            }
4655            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4656            anthropic::ResponseEvent::MessageStop {} => {}
4657            anthropic::ResponseEvent::Ping {} => {}
4658        }
4659    }
4660
4661    Ok(())
4662}
4663
4664struct CountTokensWithLanguageModelRateLimit;
4665
4666impl RateLimit for CountTokensWithLanguageModelRateLimit {
4667    fn capacity() -> usize {
4668        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4669            .ok()
4670            .and_then(|v| v.parse().ok())
4671            .unwrap_or(600) // Picked arbitrarily
4672    }
4673
4674    fn refill_duration() -> chrono::Duration {
4675        chrono::Duration::hours(1)
4676    }
4677
4678    fn db_name() -> &'static str {
4679        "count-tokens-with-language-model"
4680    }
4681}
4682
4683async fn count_tokens_with_language_model(
4684    request: proto::CountTokensWithLanguageModel,
4685    response: Response<proto::CountTokensWithLanguageModel>,
4686    session: UserSession,
4687    google_ai_api_key: Option<Arc<str>>,
4688) -> Result<()> {
4689    authorize_access_to_language_models(&session).await?;
4690
4691    if !request.model.starts_with("gemini") {
4692        return Err(anyhow!(
4693            "counting tokens for model: {:?} is not supported",
4694            request.model
4695        ))?;
4696    }
4697
4698    session
4699        .rate_limiter
4700        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4701        .await?;
4702
4703    let api_key = google_ai_api_key
4704        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4705    let tokens_response = google_ai::count_tokens(
4706        session.http_client.as_ref(),
4707        google_ai::API_URL,
4708        &api_key,
4709        crate::ai::count_tokens_request_to_google_ai(request)?,
4710    )
4711    .await?;
4712    response.send(proto::CountTokensResponse {
4713        token_count: tokens_response.total_tokens as u32,
4714    })?;
4715    Ok(())
4716}
4717
4718struct ComputeEmbeddingsRateLimit;
4719
4720impl RateLimit for ComputeEmbeddingsRateLimit {
4721    fn capacity() -> usize {
4722        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4723            .ok()
4724            .and_then(|v| v.parse().ok())
4725            .unwrap_or(5000) // Picked arbitrarily
4726    }
4727
4728    fn refill_duration() -> chrono::Duration {
4729        chrono::Duration::hours(1)
4730    }
4731
4732    fn db_name() -> &'static str {
4733        "compute-embeddings"
4734    }
4735}
4736
4737async fn compute_embeddings(
4738    request: proto::ComputeEmbeddings,
4739    response: Response<proto::ComputeEmbeddings>,
4740    session: UserSession,
4741    api_key: Option<Arc<str>>,
4742) -> Result<()> {
4743    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4744    authorize_access_to_language_models(&session).await?;
4745
4746    session
4747        .rate_limiter
4748        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4749        .await?;
4750
4751    let embeddings = match request.model.as_str() {
4752        "openai/text-embedding-3-small" => {
4753            open_ai::embed(
4754                session.http_client.as_ref(),
4755                OPEN_AI_API_URL,
4756                &api_key,
4757                OpenAiEmbeddingModel::TextEmbedding3Small,
4758                request.texts.iter().map(|text| text.as_str()),
4759            )
4760            .await?
4761        }
4762        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4763    };
4764
4765    let embeddings = request
4766        .texts
4767        .iter()
4768        .map(|text| {
4769            let mut hasher = sha2::Sha256::new();
4770            hasher.update(text.as_bytes());
4771            let result = hasher.finalize();
4772            result.to_vec()
4773        })
4774        .zip(
4775            embeddings
4776                .data
4777                .into_iter()
4778                .map(|embedding| embedding.embedding),
4779        )
4780        .collect::<HashMap<_, _>>();
4781
4782    let db = session.db().await;
4783    db.save_embeddings(&request.model, &embeddings)
4784        .await
4785        .context("failed to save embeddings")
4786        .trace_err();
4787
4788    response.send(proto::ComputeEmbeddingsResponse {
4789        embeddings: embeddings
4790            .into_iter()
4791            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4792            .collect(),
4793    })?;
4794    Ok(())
4795}
4796
4797async fn get_cached_embeddings(
4798    request: proto::GetCachedEmbeddings,
4799    response: Response<proto::GetCachedEmbeddings>,
4800    session: UserSession,
4801) -> Result<()> {
4802    authorize_access_to_language_models(&session).await?;
4803
4804    let db = session.db().await;
4805    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4806
4807    response.send(proto::GetCachedEmbeddingsResponse {
4808        embeddings: embeddings
4809            .into_iter()
4810            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4811            .collect(),
4812    })?;
4813    Ok(())
4814}
4815
4816async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4817    let db = session.db().await;
4818    let flags = db.get_user_flags(session.user_id()).await?;
4819    if flags.iter().any(|flag| flag == "language-models") {
4820        Ok(())
4821    } else {
4822        Err(anyhow!("permission denied"))?
4823    }
4824}
4825
4826/// Get a Supermaven API key for the user
4827async fn get_supermaven_api_key(
4828    _request: proto::GetSupermavenApiKey,
4829    response: Response<proto::GetSupermavenApiKey>,
4830    session: UserSession,
4831) -> Result<()> {
4832    let user_id: String = session.user_id().to_string();
4833    if !session.is_staff() {
4834        return Err(anyhow!("supermaven not enabled for this account"))?;
4835    }
4836
4837    let email = session
4838        .email()
4839        .ok_or_else(|| anyhow!("user must have an email"))?;
4840
4841    let supermaven_admin_api = session
4842        .supermaven_client
4843        .as_ref()
4844        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4845
4846    let result = supermaven_admin_api
4847        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4848        .await?;
4849
4850    response.send(proto::GetSupermavenApiKeyResponse {
4851        api_key: result.api_key,
4852    })?;
4853
4854    Ok(())
4855}
4856
4857/// Start receiving chat updates for a channel
4858async fn join_channel_chat(
4859    request: proto::JoinChannelChat,
4860    response: Response<proto::JoinChannelChat>,
4861    session: UserSession,
4862) -> Result<()> {
4863    let channel_id = ChannelId::from_proto(request.channel_id);
4864
4865    let db = session.db().await;
4866    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4867        .await?;
4868    let messages = db
4869        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4870        .await?;
4871    response.send(proto::JoinChannelChatResponse {
4872        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4873        messages,
4874    })?;
4875    Ok(())
4876}
4877
4878/// Stop receiving chat updates for a channel
4879async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4880    let channel_id = ChannelId::from_proto(request.channel_id);
4881    session
4882        .db()
4883        .await
4884        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4885        .await?;
4886    Ok(())
4887}
4888
4889/// Retrieve the chat history for a channel
4890async fn get_channel_messages(
4891    request: proto::GetChannelMessages,
4892    response: Response<proto::GetChannelMessages>,
4893    session: UserSession,
4894) -> Result<()> {
4895    let channel_id = ChannelId::from_proto(request.channel_id);
4896    let messages = session
4897        .db()
4898        .await
4899        .get_channel_messages(
4900            channel_id,
4901            session.user_id(),
4902            MESSAGE_COUNT_PER_PAGE,
4903            Some(MessageId::from_proto(request.before_message_id)),
4904        )
4905        .await?;
4906    response.send(proto::GetChannelMessagesResponse {
4907        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4908        messages,
4909    })?;
4910    Ok(())
4911}
4912
4913/// Retrieve specific chat messages
4914async fn get_channel_messages_by_id(
4915    request: proto::GetChannelMessagesById,
4916    response: Response<proto::GetChannelMessagesById>,
4917    session: UserSession,
4918) -> Result<()> {
4919    let message_ids = request
4920        .message_ids
4921        .iter()
4922        .map(|id| MessageId::from_proto(*id))
4923        .collect::<Vec<_>>();
4924    let messages = session
4925        .db()
4926        .await
4927        .get_channel_messages_by_id(session.user_id(), &message_ids)
4928        .await?;
4929    response.send(proto::GetChannelMessagesResponse {
4930        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4931        messages,
4932    })?;
4933    Ok(())
4934}
4935
4936/// Retrieve the current users notifications
4937async fn get_notifications(
4938    request: proto::GetNotifications,
4939    response: Response<proto::GetNotifications>,
4940    session: UserSession,
4941) -> Result<()> {
4942    let notifications = session
4943        .db()
4944        .await
4945        .get_notifications(
4946            session.user_id(),
4947            NOTIFICATION_COUNT_PER_PAGE,
4948            request
4949                .before_id
4950                .map(|id| db::NotificationId::from_proto(id)),
4951        )
4952        .await?;
4953    response.send(proto::GetNotificationsResponse {
4954        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4955        notifications,
4956    })?;
4957    Ok(())
4958}
4959
4960/// Mark notifications as read
4961async fn mark_notification_as_read(
4962    request: proto::MarkNotificationRead,
4963    response: Response<proto::MarkNotificationRead>,
4964    session: UserSession,
4965) -> Result<()> {
4966    let database = &session.db().await;
4967    let notifications = database
4968        .mark_notification_as_read_by_id(
4969            session.user_id(),
4970            NotificationId::from_proto(request.notification_id),
4971        )
4972        .await?;
4973    send_notifications(
4974        &*session.connection_pool().await,
4975        &session.peer,
4976        notifications,
4977    );
4978    response.send(proto::Ack {})?;
4979    Ok(())
4980}
4981
4982/// Get the current users information
4983async fn get_private_user_info(
4984    _request: proto::GetPrivateUserInfo,
4985    response: Response<proto::GetPrivateUserInfo>,
4986    session: UserSession,
4987) -> Result<()> {
4988    let db = session.db().await;
4989
4990    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4991    let user = db
4992        .get_user_by_id(session.user_id())
4993        .await?
4994        .ok_or_else(|| anyhow!("user not found"))?;
4995    let flags = db.get_user_flags(session.user_id()).await?;
4996
4997    response.send(proto::GetPrivateUserInfoResponse {
4998        metrics_id,
4999        staff: user.admin,
5000        flags,
5001    })?;
5002    Ok(())
5003}
5004
5005fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5006    match message {
5007        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5008        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5009        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5010        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5011        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5012            code: frame.code.into(),
5013            reason: frame.reason,
5014        })),
5015    }
5016}
5017
5018fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5019    match message {
5020        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5021        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5022        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5023        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5024        AxumMessage::Close(frame) => {
5025            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5026                code: frame.code.into(),
5027                reason: frame.reason,
5028            }))
5029        }
5030    }
5031}
5032
5033fn notify_membership_updated(
5034    connection_pool: &mut ConnectionPool,
5035    result: MembershipUpdated,
5036    user_id: UserId,
5037    peer: &Peer,
5038) {
5039    for membership in &result.new_channels.channel_memberships {
5040        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5041    }
5042    for channel_id in &result.removed_channels {
5043        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5044    }
5045
5046    let user_channels_update = proto::UpdateUserChannels {
5047        channel_memberships: result
5048            .new_channels
5049            .channel_memberships
5050            .iter()
5051            .map(|cm| proto::ChannelMembership {
5052                channel_id: cm.channel_id.to_proto(),
5053                role: cm.role.into(),
5054            })
5055            .collect(),
5056        ..Default::default()
5057    };
5058
5059    let mut update = build_channels_update(result.new_channels);
5060    update.delete_channels = result
5061        .removed_channels
5062        .into_iter()
5063        .map(|id| id.to_proto())
5064        .collect();
5065    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5066
5067    for connection_id in connection_pool.user_connection_ids(user_id) {
5068        peer.send(connection_id, user_channels_update.clone())
5069            .trace_err();
5070        peer.send(connection_id, update.clone()).trace_err();
5071    }
5072}
5073
5074fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5075    proto::UpdateUserChannels {
5076        channel_memberships: channels
5077            .channel_memberships
5078            .iter()
5079            .map(|m| proto::ChannelMembership {
5080                channel_id: m.channel_id.to_proto(),
5081                role: m.role.into(),
5082            })
5083            .collect(),
5084        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5085        observed_channel_message_id: channels.observed_channel_messages.clone(),
5086    }
5087}
5088
5089fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5090    let mut update = proto::UpdateChannels::default();
5091
5092    for channel in channels.channels {
5093        update.channels.push(channel.to_proto());
5094    }
5095
5096    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5097    update.latest_channel_message_ids = channels.latest_channel_messages;
5098
5099    for (channel_id, participants) in channels.channel_participants {
5100        update
5101            .channel_participants
5102            .push(proto::ChannelParticipants {
5103                channel_id: channel_id.to_proto(),
5104                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5105            });
5106    }
5107
5108    for channel in channels.invited_channels {
5109        update.channel_invitations.push(channel.to_proto());
5110    }
5111
5112    update.hosted_projects = channels.hosted_projects;
5113    update
5114}
5115
5116fn build_initial_contacts_update(
5117    contacts: Vec<db::Contact>,
5118    pool: &ConnectionPool,
5119) -> proto::UpdateContacts {
5120    let mut update = proto::UpdateContacts::default();
5121
5122    for contact in contacts {
5123        match contact {
5124            db::Contact::Accepted { user_id, busy } => {
5125                update.contacts.push(contact_for_user(user_id, busy, &pool));
5126            }
5127            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5128            db::Contact::Incoming { user_id } => {
5129                update
5130                    .incoming_requests
5131                    .push(proto::IncomingContactRequest {
5132                        requester_id: user_id.to_proto(),
5133                    })
5134            }
5135        }
5136    }
5137
5138    update
5139}
5140
5141fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5142    proto::Contact {
5143        user_id: user_id.to_proto(),
5144        online: pool.is_user_online(user_id),
5145        busy,
5146    }
5147}
5148
5149fn room_updated(room: &proto::Room, peer: &Peer) {
5150    broadcast(
5151        None,
5152        room.participants
5153            .iter()
5154            .filter_map(|participant| Some(participant.peer_id?.into())),
5155        |peer_id| {
5156            peer.send(
5157                peer_id,
5158                proto::RoomUpdated {
5159                    room: Some(room.clone()),
5160                },
5161            )
5162        },
5163    );
5164}
5165
5166fn channel_updated(
5167    channel: &db::channel::Model,
5168    room: &proto::Room,
5169    peer: &Peer,
5170    pool: &ConnectionPool,
5171) {
5172    let participants = room
5173        .participants
5174        .iter()
5175        .map(|p| p.user_id)
5176        .collect::<Vec<_>>();
5177
5178    broadcast(
5179        None,
5180        pool.channel_connection_ids(channel.root_id())
5181            .filter_map(|(channel_id, role)| {
5182                role.can_see_channel(channel.visibility).then(|| channel_id)
5183            }),
5184        |peer_id| {
5185            peer.send(
5186                peer_id,
5187                proto::UpdateChannels {
5188                    channel_participants: vec![proto::ChannelParticipants {
5189                        channel_id: channel.id.to_proto(),
5190                        participant_user_ids: participants.clone(),
5191                    }],
5192                    ..Default::default()
5193                },
5194            )
5195        },
5196    );
5197}
5198
5199async fn send_dev_server_projects_update(
5200    user_id: UserId,
5201    mut status: proto::DevServerProjectsUpdate,
5202    session: &Session,
5203) {
5204    let pool = session.connection_pool().await;
5205    for dev_server in &mut status.dev_servers {
5206        dev_server.status =
5207            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5208    }
5209    let connections = pool.user_connection_ids(user_id);
5210    for connection_id in connections {
5211        session.peer.send(connection_id, status.clone()).trace_err();
5212    }
5213}
5214
5215async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5216    let db = session.db().await;
5217
5218    let contacts = db.get_contacts(user_id).await?;
5219    let busy = db.is_user_busy(user_id).await?;
5220
5221    let pool = session.connection_pool().await;
5222    let updated_contact = contact_for_user(user_id, busy, &pool);
5223    for contact in contacts {
5224        if let db::Contact::Accepted {
5225            user_id: contact_user_id,
5226            ..
5227        } = contact
5228        {
5229            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5230                session
5231                    .peer
5232                    .send(
5233                        contact_conn_id,
5234                        proto::UpdateContacts {
5235                            contacts: vec![updated_contact.clone()],
5236                            remove_contacts: Default::default(),
5237                            incoming_requests: Default::default(),
5238                            remove_incoming_requests: Default::default(),
5239                            outgoing_requests: Default::default(),
5240                            remove_outgoing_requests: Default::default(),
5241                        },
5242                    )
5243                    .trace_err();
5244            }
5245        }
5246    }
5247    Ok(())
5248}
5249
5250async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5251    log::info!("lost dev server connection, unsharing projects");
5252    let project_ids = session
5253        .db()
5254        .await
5255        .get_stale_dev_server_projects(session.connection_id)
5256        .await?;
5257
5258    for project_id in project_ids {
5259        // not unshare re-checks the connection ids match, so we get away with no transaction
5260        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5261    }
5262
5263    let user_id = session.dev_server().user_id;
5264    let update = session
5265        .db()
5266        .await
5267        .dev_server_projects_update(user_id)
5268        .await?;
5269
5270    send_dev_server_projects_update(user_id, update, session).await;
5271
5272    Ok(())
5273}
5274
5275async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5276    let mut contacts_to_update = HashSet::default();
5277
5278    let room_id;
5279    let canceled_calls_to_user_ids;
5280    let live_kit_room;
5281    let delete_live_kit_room;
5282    let room;
5283    let channel;
5284
5285    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5286        contacts_to_update.insert(session.user_id());
5287
5288        for project in left_room.left_projects.values() {
5289            project_left(project, session);
5290        }
5291
5292        room_id = RoomId::from_proto(left_room.room.id);
5293        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5294        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5295        delete_live_kit_room = left_room.deleted;
5296        room = mem::take(&mut left_room.room);
5297        channel = mem::take(&mut left_room.channel);
5298
5299        room_updated(&room, &session.peer);
5300    } else {
5301        return Ok(());
5302    }
5303
5304    if let Some(channel) = channel {
5305        channel_updated(
5306            &channel,
5307            &room,
5308            &session.peer,
5309            &*session.connection_pool().await,
5310        );
5311    }
5312
5313    {
5314        let pool = session.connection_pool().await;
5315        for canceled_user_id in canceled_calls_to_user_ids {
5316            for connection_id in pool.user_connection_ids(canceled_user_id) {
5317                session
5318                    .peer
5319                    .send(
5320                        connection_id,
5321                        proto::CallCanceled {
5322                            room_id: room_id.to_proto(),
5323                        },
5324                    )
5325                    .trace_err();
5326            }
5327            contacts_to_update.insert(canceled_user_id);
5328        }
5329    }
5330
5331    for contact_user_id in contacts_to_update {
5332        update_user_contacts(contact_user_id, &session).await?;
5333    }
5334
5335    if let Some(live_kit) = session.live_kit_client.as_ref() {
5336        live_kit
5337            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5338            .await
5339            .trace_err();
5340
5341        if delete_live_kit_room {
5342            live_kit.delete_room(live_kit_room).await.trace_err();
5343        }
5344    }
5345
5346    Ok(())
5347}
5348
5349async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5350    let left_channel_buffers = session
5351        .db()
5352        .await
5353        .leave_channel_buffers(session.connection_id)
5354        .await?;
5355
5356    for left_buffer in left_channel_buffers {
5357        channel_buffer_updated(
5358            session.connection_id,
5359            left_buffer.connections,
5360            &proto::UpdateChannelBufferCollaborators {
5361                channel_id: left_buffer.channel_id.to_proto(),
5362                collaborators: left_buffer.collaborators,
5363            },
5364            &session.peer,
5365        );
5366    }
5367
5368    Ok(())
5369}
5370
5371fn project_left(project: &db::LeftProject, session: &UserSession) {
5372    for connection_id in &project.connection_ids {
5373        if project.should_unshare {
5374            session
5375                .peer
5376                .send(
5377                    *connection_id,
5378                    proto::UnshareProject {
5379                        project_id: project.id.to_proto(),
5380                    },
5381                )
5382                .trace_err();
5383        } else {
5384            session
5385                .peer
5386                .send(
5387                    *connection_id,
5388                    proto::RemoveProjectCollaborator {
5389                        project_id: project.id.to_proto(),
5390                        peer_id: Some(session.connection_id.into()),
5391                    },
5392                )
5393                .trace_err();
5394        }
5395    }
5396}
5397
5398pub trait ResultExt {
5399    type Ok;
5400
5401    fn trace_err(self) -> Option<Self::Ok>;
5402}
5403
5404impl<T, E> ResultExt for Result<T, E>
5405where
5406    E: std::fmt::Debug,
5407{
5408    type Ok = T;
5409
5410    #[track_caller]
5411    fn trace_err(self) -> Option<T> {
5412        match self {
5413            Ok(value) => Some(value),
5414            Err(error) => {
5415                tracing::error!("{:?}", error);
5416                None
5417            }
5418        }
5419    }
5420}