rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  50        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(delete_dev_server_project))
 435            .add_request_handler(user_handler(create_dev_server))
 436            .add_request_handler(user_handler(regenerate_dev_server_token))
 437            .add_request_handler(user_handler(rename_dev_server))
 438            .add_request_handler(user_handler(delete_dev_server))
 439            .add_request_handler(dev_server_handler(share_dev_server_project))
 440            .add_request_handler(dev_server_handler(shutdown_dev_server))
 441            .add_request_handler(dev_server_handler(reconnect_dev_server))
 442            .add_message_handler(user_message_handler(leave_project))
 443            .add_request_handler(update_project)
 444            .add_request_handler(update_worktree)
 445            .add_message_handler(start_language_server)
 446            .add_message_handler(update_language_server)
 447            .add_message_handler(update_diagnostic_summary)
 448            .add_message_handler(update_worktree_settings)
 449            .add_request_handler(user_handler(
 450                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 451            ))
 452            .add_request_handler(user_handler(
 453                forward_project_request_for_owner::<proto::TaskTemplates>,
 454            ))
 455            .add_request_handler(user_handler(
 456                forward_read_only_project_request::<proto::GetHover>,
 457            ))
 458            .add_request_handler(user_handler(
 459                forward_read_only_project_request::<proto::GetDefinition>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_read_only_project_request::<proto::GetTypeDefinition>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetReferences>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::SearchProject>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetProjectSymbols>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_read_only_project_request::<proto::OpenBufferById>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::InlayHints>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::OpenBufferByPath>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_mutating_project_request::<proto::GetCompletions>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::GetCodeActions>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::ApplyCodeAction>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::PrepareRename>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::PerformRename>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::ReloadBuffers>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::FormatBuffers>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::CreateProjectEntry>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::RenameProjectEntry>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::CopyProjectEntry>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::OnTypeFormatting>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::BlameBuffer>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::MultiLspQuery>,
 547            ))
 548            .add_request_handler(user_handler(
 549                forward_mutating_project_request::<proto::RestartLanguageServers>,
 550            ))
 551            .add_message_handler(create_buffer_for_peer)
 552            .add_request_handler(update_buffer)
 553            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 554            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 555            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 556            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 557            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 558            .add_request_handler(get_users)
 559            .add_request_handler(user_handler(fuzzy_search_users))
 560            .add_request_handler(user_handler(request_contact))
 561            .add_request_handler(user_handler(remove_contact))
 562            .add_request_handler(user_handler(respond_to_contact_request))
 563            .add_message_handler(subscribe_to_channels)
 564            .add_request_handler(user_handler(create_channel))
 565            .add_request_handler(user_handler(delete_channel))
 566            .add_request_handler(user_handler(invite_channel_member))
 567            .add_request_handler(user_handler(remove_channel_member))
 568            .add_request_handler(user_handler(set_channel_member_role))
 569            .add_request_handler(user_handler(set_channel_visibility))
 570            .add_request_handler(user_handler(rename_channel))
 571            .add_request_handler(user_handler(join_channel_buffer))
 572            .add_request_handler(user_handler(leave_channel_buffer))
 573            .add_message_handler(user_message_handler(update_channel_buffer))
 574            .add_request_handler(user_handler(rejoin_channel_buffers))
 575            .add_request_handler(user_handler(get_channel_members))
 576            .add_request_handler(user_handler(respond_to_channel_invite))
 577            .add_request_handler(user_handler(join_channel))
 578            .add_request_handler(user_handler(join_channel_chat))
 579            .add_message_handler(user_message_handler(leave_channel_chat))
 580            .add_request_handler(user_handler(send_channel_message))
 581            .add_request_handler(user_handler(remove_channel_message))
 582            .add_request_handler(user_handler(update_channel_message))
 583            .add_request_handler(user_handler(get_channel_messages))
 584            .add_request_handler(user_handler(get_channel_messages_by_id))
 585            .add_request_handler(user_handler(get_notifications))
 586            .add_request_handler(user_handler(mark_notification_as_read))
 587            .add_request_handler(user_handler(move_channel))
 588            .add_request_handler(user_handler(follow))
 589            .add_message_handler(user_message_handler(unfollow))
 590            .add_message_handler(user_message_handler(update_followers))
 591            .add_request_handler(user_handler(get_private_user_info))
 592            .add_message_handler(user_message_handler(acknowledge_channel_message))
 593            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 594            .add_request_handler(user_handler(get_supermaven_api_key))
 595            .add_streaming_request_handler({
 596                let app_state = app_state.clone();
 597                move |request, response, session| {
 598                    complete_with_language_model(
 599                        request,
 600                        response,
 601                        session,
 602                        app_state.config.openai_api_key.clone(),
 603                        app_state.config.google_ai_api_key.clone(),
 604                        app_state.config.anthropic_api_key.clone(),
 605                    )
 606                }
 607            })
 608            .add_request_handler({
 609                let app_state = app_state.clone();
 610                user_handler(move |request, response, session| {
 611                    count_tokens_with_language_model(
 612                        request,
 613                        response,
 614                        session,
 615                        app_state.config.google_ai_api_key.clone(),
 616                    )
 617                })
 618            })
 619            .add_request_handler({
 620                user_handler(move |request, response, session| {
 621                    get_cached_embeddings(request, response, session)
 622                })
 623            })
 624            .add_request_handler({
 625                let app_state = app_state.clone();
 626                user_handler(move |request, response, session| {
 627                    compute_embeddings(
 628                        request,
 629                        response,
 630                        session,
 631                        app_state.config.openai_api_key.clone(),
 632                    )
 633                })
 634            });
 635
 636        Arc::new(server)
 637    }
 638
 639    pub async fn start(&self) -> Result<()> {
 640        let server_id = *self.id.lock();
 641        let app_state = self.app_state.clone();
 642        let peer = self.peer.clone();
 643        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 644        let pool = self.connection_pool.clone();
 645        let live_kit_client = self.app_state.live_kit_client.clone();
 646
 647        let span = info_span!("start server");
 648        self.app_state.executor.spawn_detached(
 649            async move {
 650                tracing::info!("waiting for cleanup timeout");
 651                timeout.await;
 652                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 653                if let Some((room_ids, channel_ids)) = app_state
 654                    .db
 655                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 656                    .await
 657                    .trace_err()
 658                {
 659                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 660                    tracing::info!(
 661                        stale_channel_buffer_count = channel_ids.len(),
 662                        "retrieved stale channel buffers"
 663                    );
 664
 665                    for channel_id in channel_ids {
 666                        if let Some(refreshed_channel_buffer) = app_state
 667                            .db
 668                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 669                            .await
 670                            .trace_err()
 671                        {
 672                            for connection_id in refreshed_channel_buffer.connection_ids {
 673                                peer.send(
 674                                    connection_id,
 675                                    proto::UpdateChannelBufferCollaborators {
 676                                        channel_id: channel_id.to_proto(),
 677                                        collaborators: refreshed_channel_buffer
 678                                            .collaborators
 679                                            .clone(),
 680                                    },
 681                                )
 682                                .trace_err();
 683                            }
 684                        }
 685                    }
 686
 687                    for room_id in room_ids {
 688                        let mut contacts_to_update = HashSet::default();
 689                        let mut canceled_calls_to_user_ids = Vec::new();
 690                        let mut live_kit_room = String::new();
 691                        let mut delete_live_kit_room = false;
 692
 693                        if let Some(mut refreshed_room) = app_state
 694                            .db
 695                            .clear_stale_room_participants(room_id, server_id)
 696                            .await
 697                            .trace_err()
 698                        {
 699                            tracing::info!(
 700                                room_id = room_id.0,
 701                                new_participant_count = refreshed_room.room.participants.len(),
 702                                "refreshed room"
 703                            );
 704                            room_updated(&refreshed_room.room, &peer);
 705                            if let Some(channel) = refreshed_room.channel.as_ref() {
 706                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 707                            }
 708                            contacts_to_update
 709                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 710                            contacts_to_update
 711                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 712                            canceled_calls_to_user_ids =
 713                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 714                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 715                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 716                        }
 717
 718                        {
 719                            let pool = pool.lock();
 720                            for canceled_user_id in canceled_calls_to_user_ids {
 721                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 722                                    peer.send(
 723                                        connection_id,
 724                                        proto::CallCanceled {
 725                                            room_id: room_id.to_proto(),
 726                                        },
 727                                    )
 728                                    .trace_err();
 729                                }
 730                            }
 731                        }
 732
 733                        for user_id in contacts_to_update {
 734                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 735                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 736                            if let Some((busy, contacts)) = busy.zip(contacts) {
 737                                let pool = pool.lock();
 738                                let updated_contact = contact_for_user(user_id, busy, &pool);
 739                                for contact in contacts {
 740                                    if let db::Contact::Accepted {
 741                                        user_id: contact_user_id,
 742                                        ..
 743                                    } = contact
 744                                    {
 745                                        for contact_conn_id in
 746                                            pool.user_connection_ids(contact_user_id)
 747                                        {
 748                                            peer.send(
 749                                                contact_conn_id,
 750                                                proto::UpdateContacts {
 751                                                    contacts: vec![updated_contact.clone()],
 752                                                    remove_contacts: Default::default(),
 753                                                    incoming_requests: Default::default(),
 754                                                    remove_incoming_requests: Default::default(),
 755                                                    outgoing_requests: Default::default(),
 756                                                    remove_outgoing_requests: Default::default(),
 757                                                },
 758                                            )
 759                                            .trace_err();
 760                                        }
 761                                    }
 762                                }
 763                            }
 764                        }
 765
 766                        if let Some(live_kit) = live_kit_client.as_ref() {
 767                            if delete_live_kit_room {
 768                                live_kit.delete_room(live_kit_room).await.trace_err();
 769                            }
 770                        }
 771                    }
 772                }
 773
 774                app_state
 775                    .db
 776                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 777                    .await
 778                    .trace_err();
 779            }
 780            .instrument(span),
 781        );
 782        Ok(())
 783    }
 784
 785    pub fn teardown(&self) {
 786        self.peer.teardown();
 787        self.connection_pool.lock().reset();
 788        let _ = self.teardown.send(true);
 789    }
 790
 791    #[cfg(test)]
 792    pub fn reset(&self, id: ServerId) {
 793        self.teardown();
 794        *self.id.lock() = id;
 795        self.peer.reset(id.0 as u32);
 796        let _ = self.teardown.send(false);
 797    }
 798
 799    #[cfg(test)]
 800    pub fn id(&self) -> ServerId {
 801        *self.id.lock()
 802    }
 803
 804    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 805    where
 806        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 807        Fut: 'static + Send + Future<Output = Result<()>>,
 808        M: EnvelopedMessage,
 809    {
 810        let prev_handler = self.handlers.insert(
 811            TypeId::of::<M>(),
 812            Box::new(move |envelope, session| {
 813                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 814                let received_at = envelope.received_at;
 815                tracing::info!("message received");
 816                let start_time = Instant::now();
 817                let future = (handler)(*envelope, session);
 818                async move {
 819                    let result = future.await;
 820                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 821                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 822                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 823                    let payload_type = M::NAME;
 824
 825                    match result {
 826                        Err(error) => {
 827                            tracing::error!(
 828                                ?error,
 829                                total_duration_ms,
 830                                processing_duration_ms,
 831                                queue_duration_ms,
 832                                payload_type,
 833                                "error handling message"
 834                            )
 835                        }
 836                        Ok(()) => tracing::info!(
 837                            total_duration_ms,
 838                            processing_duration_ms,
 839                            queue_duration_ms,
 840                            "finished handling message"
 841                        ),
 842                    }
 843                }
 844                .boxed()
 845            }),
 846        );
 847        if prev_handler.is_some() {
 848            panic!("registered a handler for the same message twice");
 849        }
 850        self
 851    }
 852
 853    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 854    where
 855        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 856        Fut: 'static + Send + Future<Output = Result<()>>,
 857        M: EnvelopedMessage,
 858    {
 859        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 860        self
 861    }
 862
 863    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 864    where
 865        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 866        Fut: Send + Future<Output = Result<()>>,
 867        M: RequestMessage,
 868    {
 869        let handler = Arc::new(handler);
 870        self.add_handler(move |envelope, session| {
 871            let receipt = envelope.receipt();
 872            let handler = handler.clone();
 873            async move {
 874                let peer = session.peer.clone();
 875                let responded = Arc::new(AtomicBool::default());
 876                let response = Response {
 877                    peer: peer.clone(),
 878                    responded: responded.clone(),
 879                    receipt,
 880                };
 881                match (handler)(envelope.payload, response, session).await {
 882                    Ok(()) => {
 883                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 884                            Ok(())
 885                        } else {
 886                            Err(anyhow!("handler did not send a response"))?
 887                        }
 888                    }
 889                    Err(error) => {
 890                        let proto_err = match &error {
 891                            Error::Internal(err) => err.to_proto(),
 892                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 893                        };
 894                        peer.respond_with_error(receipt, proto_err)?;
 895                        Err(error)
 896                    }
 897                }
 898            }
 899        })
 900    }
 901
 902    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 903    where
 904        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 905        Fut: Send + Future<Output = Result<()>>,
 906        M: RequestMessage,
 907    {
 908        let handler = Arc::new(handler);
 909        self.add_handler(move |envelope, session| {
 910            let receipt = envelope.receipt();
 911            let handler = handler.clone();
 912            async move {
 913                let peer = session.peer.clone();
 914                let response = StreamingResponse {
 915                    peer: peer.clone(),
 916                    receipt,
 917                };
 918                match (handler)(envelope.payload, response, session).await {
 919                    Ok(()) => {
 920                        peer.end_stream(receipt)?;
 921                        Ok(())
 922                    }
 923                    Err(error) => {
 924                        let proto_err = match &error {
 925                            Error::Internal(err) => err.to_proto(),
 926                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 927                        };
 928                        peer.respond_with_error(receipt, proto_err)?;
 929                        Err(error)
 930                    }
 931                }
 932            }
 933        })
 934    }
 935
 936    #[allow(clippy::too_many_arguments)]
 937    pub fn handle_connection(
 938        self: &Arc<Self>,
 939        connection: Connection,
 940        address: String,
 941        principal: Principal,
 942        zed_version: ZedVersion,
 943        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 944        executor: Executor,
 945    ) -> impl Future<Output = ()> {
 946        let this = self.clone();
 947        let span = info_span!("handle connection", %address,
 948            connection_id=field::Empty,
 949            user_id=field::Empty,
 950            login=field::Empty,
 951            impersonator=field::Empty,
 952            dev_server_id=field::Empty
 953        );
 954        principal.update_span(&span);
 955
 956        let mut teardown = self.teardown.subscribe();
 957        async move {
 958            if *teardown.borrow() {
 959                tracing::error!("server is tearing down");
 960                return
 961            }
 962            let (connection_id, handle_io, mut incoming_rx) = this
 963                .peer
 964                .add_connection(connection, {
 965                    let executor = executor.clone();
 966                    move |duration| executor.sleep(duration)
 967                });
 968            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 969            tracing::info!("connection opened");
 970
 971            let http_client = match IsahcHttpClient::new() {
 972                Ok(http_client) => Arc::new(http_client),
 973                Err(error) => {
 974                    tracing::error!(?error, "failed to create HTTP client");
 975                    return;
 976                }
 977            };
 978
 979            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 980                Some(Arc::new(SupermavenAdminApi::new(
 981                    supermaven_admin_api_key.to_string(),
 982                    http_client.clone(),
 983                )))
 984            } else {
 985                None
 986            };
 987
 988            let session = Session {
 989                principal: principal.clone(),
 990                connection_id,
 991                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 992                peer: this.peer.clone(),
 993                connection_pool: this.connection_pool.clone(),
 994                live_kit_client: this.app_state.live_kit_client.clone(),
 995                http_client,
 996                rate_limiter: this.app_state.rate_limiter.clone(),
 997                _executor: executor.clone(),
 998                supermaven_client,
 999            };
1000
1001            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1002                tracing::error!(?error, "failed to send initial client update");
1003                return;
1004            }
1005
1006            let handle_io = handle_io.fuse();
1007            futures::pin_mut!(handle_io);
1008
1009            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1010            // This prevents deadlocks when e.g., client A performs a request to client B and
1011            // client B performs a request to client A. If both clients stop processing further
1012            // messages until their respective request completes, they won't have a chance to
1013            // respond to the other client's request and cause a deadlock.
1014            //
1015            // This arrangement ensures we will attempt to process earlier messages first, but fall
1016            // back to processing messages arrived later in the spirit of making progress.
1017            let mut foreground_message_handlers = FuturesUnordered::new();
1018            let concurrent_handlers = Arc::new(Semaphore::new(256));
1019            loop {
1020                let next_message = async {
1021                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1022                    let message = incoming_rx.next().await;
1023                    (permit, message)
1024                }.fuse();
1025                futures::pin_mut!(next_message);
1026                futures::select_biased! {
1027                    _ = teardown.changed().fuse() => return,
1028                    result = handle_io => {
1029                        if let Err(error) = result {
1030                            tracing::error!(?error, "error handling I/O");
1031                        }
1032                        break;
1033                    }
1034                    _ = foreground_message_handlers.next() => {}
1035                    next_message = next_message => {
1036                        let (permit, message) = next_message;
1037                        if let Some(message) = message {
1038                            let type_name = message.payload_type_name();
1039                            // note: we copy all the fields from the parent span so we can query them in the logs.
1040                            // (https://github.com/tokio-rs/tracing/issues/2670).
1041                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1042                                user_id=field::Empty,
1043                                login=field::Empty,
1044                                impersonator=field::Empty,
1045                                dev_server_id=field::Empty
1046                            );
1047                            principal.update_span(&span);
1048                            let span_enter = span.enter();
1049                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1050                                let is_background = message.is_background();
1051                                let handle_message = (handler)(message, session.clone());
1052                                drop(span_enter);
1053
1054                                let handle_message = async move {
1055                                    handle_message.await;
1056                                    drop(permit);
1057                                }.instrument(span);
1058                                if is_background {
1059                                    executor.spawn_detached(handle_message);
1060                                } else {
1061                                    foreground_message_handlers.push(handle_message);
1062                                }
1063                            } else {
1064                                tracing::error!("no message handler");
1065                            }
1066                        } else {
1067                            tracing::info!("connection closed");
1068                            break;
1069                        }
1070                    }
1071                }
1072            }
1073
1074            drop(foreground_message_handlers);
1075            tracing::info!("signing out");
1076            if let Err(error) = connection_lost(session, teardown, executor).await {
1077                tracing::error!(?error, "error signing out");
1078            }
1079
1080        }.instrument(span)
1081    }
1082
1083    async fn send_initial_client_update(
1084        &self,
1085        connection_id: ConnectionId,
1086        principal: &Principal,
1087        zed_version: ZedVersion,
1088        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1089        session: &Session,
1090    ) -> Result<()> {
1091        self.peer.send(
1092            connection_id,
1093            proto::Hello {
1094                peer_id: Some(connection_id.into()),
1095            },
1096        )?;
1097        tracing::info!("sent hello message");
1098        if let Some(send_connection_id) = send_connection_id.take() {
1099            let _ = send_connection_id.send(connection_id);
1100        }
1101
1102        match principal {
1103            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1104                if !user.connected_once {
1105                    self.peer.send(connection_id, proto::ShowContacts {})?;
1106                    self.app_state
1107                        .db
1108                        .set_user_connected_once(user.id, true)
1109                        .await?;
1110                }
1111
1112                let (contacts, dev_server_projects) = future::try_join(
1113                    self.app_state.db.get_contacts(user.id),
1114                    self.app_state.db.dev_server_projects_update(user.id),
1115                )
1116                .await?;
1117
1118                {
1119                    let mut pool = self.connection_pool.lock();
1120                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1121                    self.peer.send(
1122                        connection_id,
1123                        build_initial_contacts_update(contacts, &pool),
1124                    )?;
1125                }
1126
1127                if should_auto_subscribe_to_channels(zed_version) {
1128                    subscribe_user_to_channels(user.id, session).await?;
1129                }
1130
1131                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1132
1133                if let Some(incoming_call) =
1134                    self.app_state.db.incoming_call_for_user(user.id).await?
1135                {
1136                    self.peer.send(connection_id, incoming_call)?;
1137                }
1138
1139                update_user_contacts(user.id, &session).await?;
1140            }
1141            Principal::DevServer(dev_server) => {
1142                {
1143                    let mut pool = self.connection_pool.lock();
1144                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1145                    {
1146                        self.peer.send(
1147                            stale_connection_id,
1148                            proto::ShutdownDevServer {
1149                                reason: Some(
1150                                    "another dev server connected with the same token".to_string(),
1151                                ),
1152                            },
1153                        )?;
1154                        pool.remove_connection(stale_connection_id)?;
1155                    };
1156                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1157                }
1158
1159                let projects = self
1160                    .app_state
1161                    .db
1162                    .get_projects_for_dev_server(dev_server.id)
1163                    .await?;
1164                self.peer
1165                    .send(connection_id, proto::DevServerInstructions { projects })?;
1166
1167                let status = self
1168                    .app_state
1169                    .db
1170                    .dev_server_projects_update(dev_server.user_id)
1171                    .await?;
1172                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1173            }
1174        }
1175
1176        Ok(())
1177    }
1178
1179    pub async fn invite_code_redeemed(
1180        self: &Arc<Self>,
1181        inviter_id: UserId,
1182        invitee_id: UserId,
1183    ) -> Result<()> {
1184        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1185            if let Some(code) = &user.invite_code {
1186                let pool = self.connection_pool.lock();
1187                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1188                for connection_id in pool.user_connection_ids(inviter_id) {
1189                    self.peer.send(
1190                        connection_id,
1191                        proto::UpdateContacts {
1192                            contacts: vec![invitee_contact.clone()],
1193                            ..Default::default()
1194                        },
1195                    )?;
1196                    self.peer.send(
1197                        connection_id,
1198                        proto::UpdateInviteInfo {
1199                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1200                            count: user.invite_count as u32,
1201                        },
1202                    )?;
1203                }
1204            }
1205        }
1206        Ok(())
1207    }
1208
1209    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1210        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1211            if let Some(invite_code) = &user.invite_code {
1212                let pool = self.connection_pool.lock();
1213                for connection_id in pool.user_connection_ids(user_id) {
1214                    self.peer.send(
1215                        connection_id,
1216                        proto::UpdateInviteInfo {
1217                            url: format!(
1218                                "{}{}",
1219                                self.app_state.config.invite_link_prefix, invite_code
1220                            ),
1221                            count: user.invite_count as u32,
1222                        },
1223                    )?;
1224                }
1225            }
1226        }
1227        Ok(())
1228    }
1229
1230    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1231        ServerSnapshot {
1232            connection_pool: ConnectionPoolGuard {
1233                guard: self.connection_pool.lock(),
1234                _not_send: PhantomData,
1235            },
1236            peer: &self.peer,
1237        }
1238    }
1239}
1240
1241impl<'a> Deref for ConnectionPoolGuard<'a> {
1242    type Target = ConnectionPool;
1243
1244    fn deref(&self) -> &Self::Target {
1245        &self.guard
1246    }
1247}
1248
1249impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1250    fn deref_mut(&mut self) -> &mut Self::Target {
1251        &mut self.guard
1252    }
1253}
1254
1255impl<'a> Drop for ConnectionPoolGuard<'a> {
1256    fn drop(&mut self) {
1257        #[cfg(test)]
1258        self.check_invariants();
1259    }
1260}
1261
1262fn broadcast<F>(
1263    sender_id: Option<ConnectionId>,
1264    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1265    mut f: F,
1266) where
1267    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1268{
1269    for receiver_id in receiver_ids {
1270        if Some(receiver_id) != sender_id {
1271            if let Err(error) = f(receiver_id) {
1272                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1273            }
1274        }
1275    }
1276}
1277
1278pub struct ProtocolVersion(u32);
1279
1280impl Header for ProtocolVersion {
1281    fn name() -> &'static HeaderName {
1282        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1283        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1284    }
1285
1286    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1287    where
1288        Self: Sized,
1289        I: Iterator<Item = &'i axum::http::HeaderValue>,
1290    {
1291        let version = values
1292            .next()
1293            .ok_or_else(axum::headers::Error::invalid)?
1294            .to_str()
1295            .map_err(|_| axum::headers::Error::invalid())?
1296            .parse()
1297            .map_err(|_| axum::headers::Error::invalid())?;
1298        Ok(Self(version))
1299    }
1300
1301    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1302        values.extend([self.0.to_string().parse().unwrap()]);
1303    }
1304}
1305
1306pub struct AppVersionHeader(SemanticVersion);
1307impl Header for AppVersionHeader {
1308    fn name() -> &'static HeaderName {
1309        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1310        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1311    }
1312
1313    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1314    where
1315        Self: Sized,
1316        I: Iterator<Item = &'i axum::http::HeaderValue>,
1317    {
1318        let version = values
1319            .next()
1320            .ok_or_else(axum::headers::Error::invalid)?
1321            .to_str()
1322            .map_err(|_| axum::headers::Error::invalid())?
1323            .parse()
1324            .map_err(|_| axum::headers::Error::invalid())?;
1325        Ok(Self(version))
1326    }
1327
1328    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1329        values.extend([self.0.to_string().parse().unwrap()]);
1330    }
1331}
1332
1333pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1334    Router::new()
1335        .route("/rpc", get(handle_websocket_request))
1336        .layer(
1337            ServiceBuilder::new()
1338                .layer(Extension(server.app_state.clone()))
1339                .layer(middleware::from_fn(auth::validate_header)),
1340        )
1341        .route("/metrics", get(handle_metrics))
1342        .layer(Extension(server))
1343}
1344
1345pub async fn handle_websocket_request(
1346    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1347    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1348    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1349    Extension(server): Extension<Arc<Server>>,
1350    Extension(principal): Extension<Principal>,
1351    ws: WebSocketUpgrade,
1352) -> axum::response::Response {
1353    if protocol_version != rpc::PROTOCOL_VERSION {
1354        return (
1355            StatusCode::UPGRADE_REQUIRED,
1356            "client must be upgraded".to_string(),
1357        )
1358            .into_response();
1359    }
1360
1361    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1362        return (
1363            StatusCode::UPGRADE_REQUIRED,
1364            "no version header found".to_string(),
1365        )
1366            .into_response();
1367    };
1368
1369    if !version.can_collaborate() {
1370        return (
1371            StatusCode::UPGRADE_REQUIRED,
1372            "client must be upgraded".to_string(),
1373        )
1374            .into_response();
1375    }
1376
1377    let socket_address = socket_address.to_string();
1378    ws.on_upgrade(move |socket| {
1379        let socket = socket
1380            .map_ok(to_tungstenite_message)
1381            .err_into()
1382            .with(|message| async move { Ok(to_axum_message(message)) });
1383        let connection = Connection::new(Box::pin(socket));
1384        async move {
1385            server
1386                .handle_connection(
1387                    connection,
1388                    socket_address,
1389                    principal,
1390                    version,
1391                    None,
1392                    Executor::Production,
1393                )
1394                .await;
1395        }
1396    })
1397}
1398
1399pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1400    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1401    let connections_metric = CONNECTIONS_METRIC
1402        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1403
1404    let connections = server
1405        .connection_pool
1406        .lock()
1407        .connections()
1408        .filter(|connection| !connection.admin)
1409        .count();
1410    connections_metric.set(connections as _);
1411
1412    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1413    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1414        register_int_gauge!(
1415            "shared_projects",
1416            "number of open projects with one or more guests"
1417        )
1418        .unwrap()
1419    });
1420
1421    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1422    shared_projects_metric.set(shared_projects as _);
1423
1424    let encoder = prometheus::TextEncoder::new();
1425    let metric_families = prometheus::gather();
1426    let encoded_metrics = encoder
1427        .encode_to_string(&metric_families)
1428        .map_err(|err| anyhow!("{}", err))?;
1429    Ok(encoded_metrics)
1430}
1431
1432#[instrument(err, skip(executor))]
1433async fn connection_lost(
1434    session: Session,
1435    mut teardown: watch::Receiver<bool>,
1436    executor: Executor,
1437) -> Result<()> {
1438    session.peer.disconnect(session.connection_id);
1439    session
1440        .connection_pool()
1441        .await
1442        .remove_connection(session.connection_id)?;
1443
1444    session
1445        .db()
1446        .await
1447        .connection_lost(session.connection_id)
1448        .await
1449        .trace_err();
1450
1451    futures::select_biased! {
1452        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1453            match &session.principal {
1454                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1455                    let session = session.for_user().unwrap();
1456
1457                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1458                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1459                    leave_channel_buffers_for_session(&session)
1460                        .await
1461                        .trace_err();
1462
1463                    if !session
1464                        .connection_pool()
1465                        .await
1466                        .is_user_online(session.user_id())
1467                    {
1468                        let db = session.db().await;
1469                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1470                            room_updated(&room, &session.peer);
1471                        }
1472                    }
1473
1474                    update_user_contacts(session.user_id(), &session).await?;
1475                },
1476            Principal::DevServer(_) => {
1477                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1478            },
1479        }
1480        },
1481        _ = teardown.changed().fuse() => {}
1482    }
1483
1484    Ok(())
1485}
1486
1487/// Acknowledges a ping from a client, used to keep the connection alive.
1488async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1489    response.send(proto::Ack {})?;
1490    Ok(())
1491}
1492
1493/// Creates a new room for calling (outside of channels)
1494async fn create_room(
1495    _request: proto::CreateRoom,
1496    response: Response<proto::CreateRoom>,
1497    session: UserSession,
1498) -> Result<()> {
1499    let live_kit_room = nanoid::nanoid!(30);
1500
1501    let live_kit_connection_info = util::maybe!(async {
1502        let live_kit = session.live_kit_client.as_ref();
1503        let live_kit = live_kit?;
1504        let user_id = session.user_id().to_string();
1505
1506        let token = live_kit
1507            .room_token(&live_kit_room, &user_id.to_string())
1508            .trace_err()?;
1509
1510        Some(proto::LiveKitConnectionInfo {
1511            server_url: live_kit.url().into(),
1512            token,
1513            can_publish: true,
1514        })
1515    })
1516    .await;
1517
1518    let room = session
1519        .db()
1520        .await
1521        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1522        .await?;
1523
1524    response.send(proto::CreateRoomResponse {
1525        room: Some(room.clone()),
1526        live_kit_connection_info,
1527    })?;
1528
1529    update_user_contacts(session.user_id(), &session).await?;
1530    Ok(())
1531}
1532
1533/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1534async fn join_room(
1535    request: proto::JoinRoom,
1536    response: Response<proto::JoinRoom>,
1537    session: UserSession,
1538) -> Result<()> {
1539    let room_id = RoomId::from_proto(request.id);
1540
1541    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1542
1543    if let Some(channel_id) = channel_id {
1544        return join_channel_internal(channel_id, Box::new(response), session).await;
1545    }
1546
1547    let joined_room = {
1548        let room = session
1549            .db()
1550            .await
1551            .join_room(room_id, session.user_id(), session.connection_id)
1552            .await?;
1553        room_updated(&room.room, &session.peer);
1554        room.into_inner()
1555    };
1556
1557    for connection_id in session
1558        .connection_pool()
1559        .await
1560        .user_connection_ids(session.user_id())
1561    {
1562        session
1563            .peer
1564            .send(
1565                connection_id,
1566                proto::CallCanceled {
1567                    room_id: room_id.to_proto(),
1568                },
1569            )
1570            .trace_err();
1571    }
1572
1573    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1574        if let Some(token) = live_kit
1575            .room_token(
1576                &joined_room.room.live_kit_room,
1577                &session.user_id().to_string(),
1578            )
1579            .trace_err()
1580        {
1581            Some(proto::LiveKitConnectionInfo {
1582                server_url: live_kit.url().into(),
1583                token,
1584                can_publish: true,
1585            })
1586        } else {
1587            None
1588        }
1589    } else {
1590        None
1591    };
1592
1593    response.send(proto::JoinRoomResponse {
1594        room: Some(joined_room.room),
1595        channel_id: None,
1596        live_kit_connection_info,
1597    })?;
1598
1599    update_user_contacts(session.user_id(), &session).await?;
1600    Ok(())
1601}
1602
1603/// Rejoin room is used to reconnect to a room after connection errors.
1604async fn rejoin_room(
1605    request: proto::RejoinRoom,
1606    response: Response<proto::RejoinRoom>,
1607    session: UserSession,
1608) -> Result<()> {
1609    let room;
1610    let channel;
1611    {
1612        let mut rejoined_room = session
1613            .db()
1614            .await
1615            .rejoin_room(request, session.user_id(), session.connection_id)
1616            .await?;
1617
1618        response.send(proto::RejoinRoomResponse {
1619            room: Some(rejoined_room.room.clone()),
1620            reshared_projects: rejoined_room
1621                .reshared_projects
1622                .iter()
1623                .map(|project| proto::ResharedProject {
1624                    id: project.id.to_proto(),
1625                    collaborators: project
1626                        .collaborators
1627                        .iter()
1628                        .map(|collaborator| collaborator.to_proto())
1629                        .collect(),
1630                })
1631                .collect(),
1632            rejoined_projects: rejoined_room
1633                .rejoined_projects
1634                .iter()
1635                .map(|rejoined_project| rejoined_project.to_proto())
1636                .collect(),
1637        })?;
1638        room_updated(&rejoined_room.room, &session.peer);
1639
1640        for project in &rejoined_room.reshared_projects {
1641            for collaborator in &project.collaborators {
1642                session
1643                    .peer
1644                    .send(
1645                        collaborator.connection_id,
1646                        proto::UpdateProjectCollaborator {
1647                            project_id: project.id.to_proto(),
1648                            old_peer_id: Some(project.old_connection_id.into()),
1649                            new_peer_id: Some(session.connection_id.into()),
1650                        },
1651                    )
1652                    .trace_err();
1653            }
1654
1655            broadcast(
1656                Some(session.connection_id),
1657                project
1658                    .collaborators
1659                    .iter()
1660                    .map(|collaborator| collaborator.connection_id),
1661                |connection_id| {
1662                    session.peer.forward_send(
1663                        session.connection_id,
1664                        connection_id,
1665                        proto::UpdateProject {
1666                            project_id: project.id.to_proto(),
1667                            worktrees: project.worktrees.clone(),
1668                        },
1669                    )
1670                },
1671            );
1672        }
1673
1674        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1675
1676        let rejoined_room = rejoined_room.into_inner();
1677
1678        room = rejoined_room.room;
1679        channel = rejoined_room.channel;
1680    }
1681
1682    if let Some(channel) = channel {
1683        channel_updated(
1684            &channel,
1685            &room,
1686            &session.peer,
1687            &*session.connection_pool().await,
1688        );
1689    }
1690
1691    update_user_contacts(session.user_id(), &session).await?;
1692    Ok(())
1693}
1694
1695fn notify_rejoined_projects(
1696    rejoined_projects: &mut Vec<RejoinedProject>,
1697    session: &UserSession,
1698) -> Result<()> {
1699    for project in rejoined_projects.iter() {
1700        for collaborator in &project.collaborators {
1701            session
1702                .peer
1703                .send(
1704                    collaborator.connection_id,
1705                    proto::UpdateProjectCollaborator {
1706                        project_id: project.id.to_proto(),
1707                        old_peer_id: Some(project.old_connection_id.into()),
1708                        new_peer_id: Some(session.connection_id.into()),
1709                    },
1710                )
1711                .trace_err();
1712        }
1713    }
1714
1715    for project in rejoined_projects {
1716        for worktree in mem::take(&mut project.worktrees) {
1717            #[cfg(any(test, feature = "test-support"))]
1718            const MAX_CHUNK_SIZE: usize = 2;
1719            #[cfg(not(any(test, feature = "test-support")))]
1720            const MAX_CHUNK_SIZE: usize = 256;
1721
1722            // Stream this worktree's entries.
1723            let message = proto::UpdateWorktree {
1724                project_id: project.id.to_proto(),
1725                worktree_id: worktree.id,
1726                abs_path: worktree.abs_path.clone(),
1727                root_name: worktree.root_name,
1728                updated_entries: worktree.updated_entries,
1729                removed_entries: worktree.removed_entries,
1730                scan_id: worktree.scan_id,
1731                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1732                updated_repositories: worktree.updated_repositories,
1733                removed_repositories: worktree.removed_repositories,
1734            };
1735            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1736                session.peer.send(session.connection_id, update.clone())?;
1737            }
1738
1739            // Stream this worktree's diagnostics.
1740            for summary in worktree.diagnostic_summaries {
1741                session.peer.send(
1742                    session.connection_id,
1743                    proto::UpdateDiagnosticSummary {
1744                        project_id: project.id.to_proto(),
1745                        worktree_id: worktree.id,
1746                        summary: Some(summary),
1747                    },
1748                )?;
1749            }
1750
1751            for settings_file in worktree.settings_files {
1752                session.peer.send(
1753                    session.connection_id,
1754                    proto::UpdateWorktreeSettings {
1755                        project_id: project.id.to_proto(),
1756                        worktree_id: worktree.id,
1757                        path: settings_file.path,
1758                        content: Some(settings_file.content),
1759                    },
1760                )?;
1761            }
1762        }
1763
1764        for language_server in &project.language_servers {
1765            session.peer.send(
1766                session.connection_id,
1767                proto::UpdateLanguageServer {
1768                    project_id: project.id.to_proto(),
1769                    language_server_id: language_server.id,
1770                    variant: Some(
1771                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1772                            proto::LspDiskBasedDiagnosticsUpdated {},
1773                        ),
1774                    ),
1775                },
1776            )?;
1777        }
1778    }
1779    Ok(())
1780}
1781
1782/// leave room disconnects from the room.
1783async fn leave_room(
1784    _: proto::LeaveRoom,
1785    response: Response<proto::LeaveRoom>,
1786    session: UserSession,
1787) -> Result<()> {
1788    leave_room_for_session(&session, session.connection_id).await?;
1789    response.send(proto::Ack {})?;
1790    Ok(())
1791}
1792
1793/// Updates the permissions of someone else in the room.
1794async fn set_room_participant_role(
1795    request: proto::SetRoomParticipantRole,
1796    response: Response<proto::SetRoomParticipantRole>,
1797    session: UserSession,
1798) -> Result<()> {
1799    let user_id = UserId::from_proto(request.user_id);
1800    let role = ChannelRole::from(request.role());
1801
1802    let (live_kit_room, can_publish) = {
1803        let room = session
1804            .db()
1805            .await
1806            .set_room_participant_role(
1807                session.user_id(),
1808                RoomId::from_proto(request.room_id),
1809                user_id,
1810                role,
1811            )
1812            .await?;
1813
1814        let live_kit_room = room.live_kit_room.clone();
1815        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1816        room_updated(&room, &session.peer);
1817        (live_kit_room, can_publish)
1818    };
1819
1820    if let Some(live_kit) = session.live_kit_client.as_ref() {
1821        live_kit
1822            .update_participant(
1823                live_kit_room.clone(),
1824                request.user_id.to_string(),
1825                live_kit_server::proto::ParticipantPermission {
1826                    can_subscribe: true,
1827                    can_publish,
1828                    can_publish_data: can_publish,
1829                    hidden: false,
1830                    recorder: false,
1831                },
1832            )
1833            .await
1834            .trace_err();
1835    }
1836
1837    response.send(proto::Ack {})?;
1838    Ok(())
1839}
1840
1841/// Call someone else into the current room
1842async fn call(
1843    request: proto::Call,
1844    response: Response<proto::Call>,
1845    session: UserSession,
1846) -> Result<()> {
1847    let room_id = RoomId::from_proto(request.room_id);
1848    let calling_user_id = session.user_id();
1849    let calling_connection_id = session.connection_id;
1850    let called_user_id = UserId::from_proto(request.called_user_id);
1851    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1852    if !session
1853        .db()
1854        .await
1855        .has_contact(calling_user_id, called_user_id)
1856        .await?
1857    {
1858        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1859    }
1860
1861    let incoming_call = {
1862        let (room, incoming_call) = &mut *session
1863            .db()
1864            .await
1865            .call(
1866                room_id,
1867                calling_user_id,
1868                calling_connection_id,
1869                called_user_id,
1870                initial_project_id,
1871            )
1872            .await?;
1873        room_updated(&room, &session.peer);
1874        mem::take(incoming_call)
1875    };
1876    update_user_contacts(called_user_id, &session).await?;
1877
1878    let mut calls = session
1879        .connection_pool()
1880        .await
1881        .user_connection_ids(called_user_id)
1882        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1883        .collect::<FuturesUnordered<_>>();
1884
1885    while let Some(call_response) = calls.next().await {
1886        match call_response.as_ref() {
1887            Ok(_) => {
1888                response.send(proto::Ack {})?;
1889                return Ok(());
1890            }
1891            Err(_) => {
1892                call_response.trace_err();
1893            }
1894        }
1895    }
1896
1897    {
1898        let room = session
1899            .db()
1900            .await
1901            .call_failed(room_id, called_user_id)
1902            .await?;
1903        room_updated(&room, &session.peer);
1904    }
1905    update_user_contacts(called_user_id, &session).await?;
1906
1907    Err(anyhow!("failed to ring user"))?
1908}
1909
1910/// Cancel an outgoing call.
1911async fn cancel_call(
1912    request: proto::CancelCall,
1913    response: Response<proto::CancelCall>,
1914    session: UserSession,
1915) -> Result<()> {
1916    let called_user_id = UserId::from_proto(request.called_user_id);
1917    let room_id = RoomId::from_proto(request.room_id);
1918    {
1919        let room = session
1920            .db()
1921            .await
1922            .cancel_call(room_id, session.connection_id, called_user_id)
1923            .await?;
1924        room_updated(&room, &session.peer);
1925    }
1926
1927    for connection_id in session
1928        .connection_pool()
1929        .await
1930        .user_connection_ids(called_user_id)
1931    {
1932        session
1933            .peer
1934            .send(
1935                connection_id,
1936                proto::CallCanceled {
1937                    room_id: room_id.to_proto(),
1938                },
1939            )
1940            .trace_err();
1941    }
1942    response.send(proto::Ack {})?;
1943
1944    update_user_contacts(called_user_id, &session).await?;
1945    Ok(())
1946}
1947
1948/// Decline an incoming call.
1949async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1950    let room_id = RoomId::from_proto(message.room_id);
1951    {
1952        let room = session
1953            .db()
1954            .await
1955            .decline_call(Some(room_id), session.user_id())
1956            .await?
1957            .ok_or_else(|| anyhow!("failed to decline call"))?;
1958        room_updated(&room, &session.peer);
1959    }
1960
1961    for connection_id in session
1962        .connection_pool()
1963        .await
1964        .user_connection_ids(session.user_id())
1965    {
1966        session
1967            .peer
1968            .send(
1969                connection_id,
1970                proto::CallCanceled {
1971                    room_id: room_id.to_proto(),
1972                },
1973            )
1974            .trace_err();
1975    }
1976    update_user_contacts(session.user_id(), &session).await?;
1977    Ok(())
1978}
1979
1980/// Updates other participants in the room with your current location.
1981async fn update_participant_location(
1982    request: proto::UpdateParticipantLocation,
1983    response: Response<proto::UpdateParticipantLocation>,
1984    session: UserSession,
1985) -> Result<()> {
1986    let room_id = RoomId::from_proto(request.room_id);
1987    let location = request
1988        .location
1989        .ok_or_else(|| anyhow!("invalid location"))?;
1990
1991    let db = session.db().await;
1992    let room = db
1993        .update_room_participant_location(room_id, session.connection_id, location)
1994        .await?;
1995
1996    room_updated(&room, &session.peer);
1997    response.send(proto::Ack {})?;
1998    Ok(())
1999}
2000
2001/// Share a project into the room.
2002async fn share_project(
2003    request: proto::ShareProject,
2004    response: Response<proto::ShareProject>,
2005    session: UserSession,
2006) -> Result<()> {
2007    let (project_id, room) = &*session
2008        .db()
2009        .await
2010        .share_project(
2011            RoomId::from_proto(request.room_id),
2012            session.connection_id,
2013            &request.worktrees,
2014            request
2015                .dev_server_project_id
2016                .map(|id| DevServerProjectId::from_proto(id)),
2017        )
2018        .await?;
2019    response.send(proto::ShareProjectResponse {
2020        project_id: project_id.to_proto(),
2021    })?;
2022    room_updated(&room, &session.peer);
2023
2024    Ok(())
2025}
2026
2027/// Unshare a project from the room.
2028async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2029    let project_id = ProjectId::from_proto(message.project_id);
2030    unshare_project_internal(
2031        project_id,
2032        session.connection_id,
2033        session.user_id(),
2034        &session,
2035    )
2036    .await
2037}
2038
2039async fn unshare_project_internal(
2040    project_id: ProjectId,
2041    connection_id: ConnectionId,
2042    user_id: Option<UserId>,
2043    session: &Session,
2044) -> Result<()> {
2045    let delete = {
2046        let room_guard = session
2047            .db()
2048            .await
2049            .unshare_project(project_id, connection_id, user_id)
2050            .await?;
2051
2052        let (delete, room, guest_connection_ids) = &*room_guard;
2053
2054        let message = proto::UnshareProject {
2055            project_id: project_id.to_proto(),
2056        };
2057
2058        broadcast(
2059            Some(connection_id),
2060            guest_connection_ids.iter().copied(),
2061            |conn_id| session.peer.send(conn_id, message.clone()),
2062        );
2063        if let Some(room) = room {
2064            room_updated(room, &session.peer);
2065        }
2066
2067        *delete
2068    };
2069
2070    if delete {
2071        let db = session.db().await;
2072        db.delete_project(project_id).await?;
2073    }
2074
2075    Ok(())
2076}
2077
2078/// DevServer makes a project available online
2079async fn share_dev_server_project(
2080    request: proto::ShareDevServerProject,
2081    response: Response<proto::ShareDevServerProject>,
2082    session: DevServerSession,
2083) -> Result<()> {
2084    let (dev_server_project, user_id, status) = session
2085        .db()
2086        .await
2087        .share_dev_server_project(
2088            DevServerProjectId::from_proto(request.dev_server_project_id),
2089            session.dev_server_id(),
2090            session.connection_id,
2091            &request.worktrees,
2092        )
2093        .await?;
2094    let Some(project_id) = dev_server_project.project_id else {
2095        return Err(anyhow!("failed to share remote project"))?;
2096    };
2097
2098    send_dev_server_projects_update(user_id, status, &session).await;
2099
2100    response.send(proto::ShareProjectResponse { project_id })?;
2101
2102    Ok(())
2103}
2104
2105/// Join someone elses shared project.
2106async fn join_project(
2107    request: proto::JoinProject,
2108    response: Response<proto::JoinProject>,
2109    session: UserSession,
2110) -> Result<()> {
2111    let project_id = ProjectId::from_proto(request.project_id);
2112
2113    tracing::info!(%project_id, "join project");
2114
2115    let db = session.db().await;
2116    let (project, replica_id) = &mut *db
2117        .join_project(project_id, session.connection_id, session.user_id())
2118        .await?;
2119    drop(db);
2120    tracing::info!(%project_id, "join remote project");
2121    join_project_internal(response, session, project, replica_id)
2122}
2123
2124trait JoinProjectInternalResponse {
2125    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2126}
2127impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2128    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2129        Response::<proto::JoinProject>::send(self, result)
2130    }
2131}
2132impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2133    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2134        Response::<proto::JoinHostedProject>::send(self, result)
2135    }
2136}
2137
2138fn join_project_internal(
2139    response: impl JoinProjectInternalResponse,
2140    session: UserSession,
2141    project: &mut Project,
2142    replica_id: &ReplicaId,
2143) -> Result<()> {
2144    let collaborators = project
2145        .collaborators
2146        .iter()
2147        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2148        .map(|collaborator| collaborator.to_proto())
2149        .collect::<Vec<_>>();
2150    let project_id = project.id;
2151    let guest_user_id = session.user_id();
2152
2153    let worktrees = project
2154        .worktrees
2155        .iter()
2156        .map(|(id, worktree)| proto::WorktreeMetadata {
2157            id: *id,
2158            root_name: worktree.root_name.clone(),
2159            visible: worktree.visible,
2160            abs_path: worktree.abs_path.clone(),
2161        })
2162        .collect::<Vec<_>>();
2163
2164    let add_project_collaborator = proto::AddProjectCollaborator {
2165        project_id: project_id.to_proto(),
2166        collaborator: Some(proto::Collaborator {
2167            peer_id: Some(session.connection_id.into()),
2168            replica_id: replica_id.0 as u32,
2169            user_id: guest_user_id.to_proto(),
2170        }),
2171    };
2172
2173    for collaborator in &collaborators {
2174        session
2175            .peer
2176            .send(
2177                collaborator.peer_id.unwrap().into(),
2178                add_project_collaborator.clone(),
2179            )
2180            .trace_err();
2181    }
2182
2183    // First, we send the metadata associated with each worktree.
2184    response.send(proto::JoinProjectResponse {
2185        project_id: project.id.0 as u64,
2186        worktrees: worktrees.clone(),
2187        replica_id: replica_id.0 as u32,
2188        collaborators: collaborators.clone(),
2189        language_servers: project.language_servers.clone(),
2190        role: project.role.into(),
2191        dev_server_project_id: project
2192            .dev_server_project_id
2193            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2194    })?;
2195
2196    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2197        #[cfg(any(test, feature = "test-support"))]
2198        const MAX_CHUNK_SIZE: usize = 2;
2199        #[cfg(not(any(test, feature = "test-support")))]
2200        const MAX_CHUNK_SIZE: usize = 256;
2201
2202        // Stream this worktree's entries.
2203        let message = proto::UpdateWorktree {
2204            project_id: project_id.to_proto(),
2205            worktree_id,
2206            abs_path: worktree.abs_path.clone(),
2207            root_name: worktree.root_name,
2208            updated_entries: worktree.entries,
2209            removed_entries: Default::default(),
2210            scan_id: worktree.scan_id,
2211            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2212            updated_repositories: worktree.repository_entries.into_values().collect(),
2213            removed_repositories: Default::default(),
2214        };
2215        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2216            session.peer.send(session.connection_id, update.clone())?;
2217        }
2218
2219        // Stream this worktree's diagnostics.
2220        for summary in worktree.diagnostic_summaries {
2221            session.peer.send(
2222                session.connection_id,
2223                proto::UpdateDiagnosticSummary {
2224                    project_id: project_id.to_proto(),
2225                    worktree_id: worktree.id,
2226                    summary: Some(summary),
2227                },
2228            )?;
2229        }
2230
2231        for settings_file in worktree.settings_files {
2232            session.peer.send(
2233                session.connection_id,
2234                proto::UpdateWorktreeSettings {
2235                    project_id: project_id.to_proto(),
2236                    worktree_id: worktree.id,
2237                    path: settings_file.path,
2238                    content: Some(settings_file.content),
2239                },
2240            )?;
2241        }
2242    }
2243
2244    for language_server in &project.language_servers {
2245        session.peer.send(
2246            session.connection_id,
2247            proto::UpdateLanguageServer {
2248                project_id: project_id.to_proto(),
2249                language_server_id: language_server.id,
2250                variant: Some(
2251                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2252                        proto::LspDiskBasedDiagnosticsUpdated {},
2253                    ),
2254                ),
2255            },
2256        )?;
2257    }
2258
2259    Ok(())
2260}
2261
2262/// Leave someone elses shared project.
2263async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2264    let sender_id = session.connection_id;
2265    let project_id = ProjectId::from_proto(request.project_id);
2266    let db = session.db().await;
2267    if db.is_hosted_project(project_id).await? {
2268        let project = db.leave_hosted_project(project_id, sender_id).await?;
2269        project_left(&project, &session);
2270        return Ok(());
2271    }
2272
2273    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2274    tracing::info!(
2275        %project_id,
2276        "leave project"
2277    );
2278
2279    project_left(&project, &session);
2280    if let Some(room) = room {
2281        room_updated(&room, &session.peer);
2282    }
2283
2284    Ok(())
2285}
2286
2287async fn join_hosted_project(
2288    request: proto::JoinHostedProject,
2289    response: Response<proto::JoinHostedProject>,
2290    session: UserSession,
2291) -> Result<()> {
2292    let (mut project, replica_id) = session
2293        .db()
2294        .await
2295        .join_hosted_project(
2296            ProjectId(request.project_id as i32),
2297            session.user_id(),
2298            session.connection_id,
2299        )
2300        .await?;
2301
2302    join_project_internal(response, session, &mut project, &replica_id)
2303}
2304
2305async fn create_dev_server_project(
2306    request: proto::CreateDevServerProject,
2307    response: Response<proto::CreateDevServerProject>,
2308    session: UserSession,
2309) -> Result<()> {
2310    let dev_server_id = DevServerId(request.dev_server_id as i32);
2311    let dev_server_connection_id = session
2312        .connection_pool()
2313        .await
2314        .dev_server_connection_id(dev_server_id);
2315    let Some(dev_server_connection_id) = dev_server_connection_id else {
2316        Err(ErrorCode::DevServerOffline
2317            .message("Cannot create a remote project when the dev server is offline".to_string())
2318            .anyhow())?
2319    };
2320
2321    let path = request.path.clone();
2322    //Check that the path exists on the dev server
2323    session
2324        .peer
2325        .forward_request(
2326            session.connection_id,
2327            dev_server_connection_id,
2328            proto::ValidateDevServerProjectRequest { path: path.clone() },
2329        )
2330        .await?;
2331
2332    let (dev_server_project, update) = session
2333        .db()
2334        .await
2335        .create_dev_server_project(
2336            DevServerId(request.dev_server_id as i32),
2337            &request.path,
2338            session.user_id(),
2339        )
2340        .await?;
2341
2342    let projects = session
2343        .db()
2344        .await
2345        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2346        .await?;
2347
2348    session.peer.send(
2349        dev_server_connection_id,
2350        proto::DevServerInstructions { projects },
2351    )?;
2352
2353    send_dev_server_projects_update(session.user_id(), update, &session).await;
2354
2355    response.send(proto::CreateDevServerProjectResponse {
2356        dev_server_project: Some(dev_server_project.to_proto(None)),
2357    })?;
2358    Ok(())
2359}
2360
2361async fn create_dev_server(
2362    request: proto::CreateDevServer,
2363    response: Response<proto::CreateDevServer>,
2364    session: UserSession,
2365) -> Result<()> {
2366    let access_token = auth::random_token();
2367    let hashed_access_token = auth::hash_access_token(&access_token);
2368
2369    if request.name.is_empty() {
2370        return Err(proto::ErrorCode::Forbidden
2371            .message("Dev server name cannot be empty".to_string())
2372            .anyhow())?;
2373    }
2374
2375    let (dev_server, status) = session
2376        .db()
2377        .await
2378        .create_dev_server(
2379            &request.name,
2380            request.ssh_connection_string.as_deref(),
2381            &hashed_access_token,
2382            session.user_id(),
2383        )
2384        .await?;
2385
2386    send_dev_server_projects_update(session.user_id(), status, &session).await;
2387
2388    response.send(proto::CreateDevServerResponse {
2389        dev_server_id: dev_server.id.0 as u64,
2390        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2391        name: request.name,
2392    })?;
2393    Ok(())
2394}
2395
2396async fn regenerate_dev_server_token(
2397    request: proto::RegenerateDevServerToken,
2398    response: Response<proto::RegenerateDevServerToken>,
2399    session: UserSession,
2400) -> Result<()> {
2401    let dev_server_id = DevServerId(request.dev_server_id as i32);
2402    let access_token = auth::random_token();
2403    let hashed_access_token = auth::hash_access_token(&access_token);
2404
2405    let connection_id = session
2406        .connection_pool()
2407        .await
2408        .dev_server_connection_id(dev_server_id);
2409    if let Some(connection_id) = connection_id {
2410        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2411        session.peer.send(
2412            connection_id,
2413            proto::ShutdownDevServer {
2414                reason: Some("dev server token was regenerated".to_string()),
2415            },
2416        )?;
2417        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2418    }
2419
2420    let status = session
2421        .db()
2422        .await
2423        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2424        .await?;
2425
2426    send_dev_server_projects_update(session.user_id(), status, &session).await;
2427
2428    response.send(proto::RegenerateDevServerTokenResponse {
2429        dev_server_id: dev_server_id.to_proto(),
2430        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2431    })?;
2432    Ok(())
2433}
2434
2435async fn rename_dev_server(
2436    request: proto::RenameDevServer,
2437    response: Response<proto::RenameDevServer>,
2438    session: UserSession,
2439) -> Result<()> {
2440    if request.name.trim().is_empty() {
2441        return Err(proto::ErrorCode::Forbidden
2442            .message("Dev server name cannot be empty".to_string())
2443            .anyhow())?;
2444    }
2445
2446    let dev_server_id = DevServerId(request.dev_server_id as i32);
2447    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2448    if dev_server.user_id != session.user_id() {
2449        return Err(anyhow!(ErrorCode::Forbidden))?;
2450    }
2451
2452    let status = session
2453        .db()
2454        .await
2455        .rename_dev_server(
2456            dev_server_id,
2457            &request.name,
2458            request.ssh_connection_string.as_deref(),
2459            session.user_id(),
2460        )
2461        .await?;
2462
2463    send_dev_server_projects_update(session.user_id(), status, &session).await;
2464
2465    response.send(proto::Ack {})?;
2466    Ok(())
2467}
2468
2469async fn delete_dev_server(
2470    request: proto::DeleteDevServer,
2471    response: Response<proto::DeleteDevServer>,
2472    session: UserSession,
2473) -> Result<()> {
2474    let dev_server_id = DevServerId(request.dev_server_id as i32);
2475    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2476    if dev_server.user_id != session.user_id() {
2477        return Err(anyhow!(ErrorCode::Forbidden))?;
2478    }
2479
2480    let connection_id = session
2481        .connection_pool()
2482        .await
2483        .dev_server_connection_id(dev_server_id);
2484    if let Some(connection_id) = connection_id {
2485        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2486        session.peer.send(
2487            connection_id,
2488            proto::ShutdownDevServer {
2489                reason: Some("dev server was deleted".to_string()),
2490            },
2491        )?;
2492        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2493    }
2494
2495    let status = session
2496        .db()
2497        .await
2498        .delete_dev_server(dev_server_id, session.user_id())
2499        .await?;
2500
2501    send_dev_server_projects_update(session.user_id(), status, &session).await;
2502
2503    response.send(proto::Ack {})?;
2504    Ok(())
2505}
2506
2507async fn delete_dev_server_project(
2508    request: proto::DeleteDevServerProject,
2509    response: Response<proto::DeleteDevServerProject>,
2510    session: UserSession,
2511) -> Result<()> {
2512    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2513    let dev_server_project = session
2514        .db()
2515        .await
2516        .get_dev_server_project(dev_server_project_id)
2517        .await?;
2518
2519    let dev_server = session
2520        .db()
2521        .await
2522        .get_dev_server(dev_server_project.dev_server_id)
2523        .await?;
2524    if dev_server.user_id != session.user_id() {
2525        return Err(anyhow!(ErrorCode::Forbidden))?;
2526    }
2527
2528    let dev_server_connection_id = session
2529        .connection_pool()
2530        .await
2531        .dev_server_connection_id(dev_server.id);
2532
2533    if let Some(dev_server_connection_id) = dev_server_connection_id {
2534        let project = session
2535            .db()
2536            .await
2537            .find_dev_server_project(dev_server_project_id)
2538            .await;
2539        if let Ok(project) = project {
2540            unshare_project_internal(
2541                project.id,
2542                dev_server_connection_id,
2543                Some(session.user_id()),
2544                &session,
2545            )
2546            .await?;
2547        }
2548    }
2549
2550    let (projects, status) = session
2551        .db()
2552        .await
2553        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2554        .await?;
2555
2556    if let Some(dev_server_connection_id) = dev_server_connection_id {
2557        session.peer.send(
2558            dev_server_connection_id,
2559            proto::DevServerInstructions { projects },
2560        )?;
2561    }
2562
2563    send_dev_server_projects_update(session.user_id(), status, &session).await;
2564
2565    response.send(proto::Ack {})?;
2566    Ok(())
2567}
2568
2569async fn rejoin_dev_server_projects(
2570    request: proto::RejoinRemoteProjects,
2571    response: Response<proto::RejoinRemoteProjects>,
2572    session: UserSession,
2573) -> Result<()> {
2574    let mut rejoined_projects = {
2575        let db = session.db().await;
2576        db.rejoin_dev_server_projects(
2577            &request.rejoined_projects,
2578            session.user_id(),
2579            session.0.connection_id,
2580        )
2581        .await?
2582    };
2583    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2584
2585    response.send(proto::RejoinRemoteProjectsResponse {
2586        rejoined_projects: rejoined_projects
2587            .into_iter()
2588            .map(|project| project.to_proto())
2589            .collect(),
2590    })
2591}
2592
2593async fn reconnect_dev_server(
2594    request: proto::ReconnectDevServer,
2595    response: Response<proto::ReconnectDevServer>,
2596    session: DevServerSession,
2597) -> Result<()> {
2598    let reshared_projects = {
2599        let db = session.db().await;
2600        db.reshare_dev_server_projects(
2601            &request.reshared_projects,
2602            session.dev_server_id(),
2603            session.0.connection_id,
2604        )
2605        .await?
2606    };
2607
2608    for project in &reshared_projects {
2609        for collaborator in &project.collaborators {
2610            session
2611                .peer
2612                .send(
2613                    collaborator.connection_id,
2614                    proto::UpdateProjectCollaborator {
2615                        project_id: project.id.to_proto(),
2616                        old_peer_id: Some(project.old_connection_id.into()),
2617                        new_peer_id: Some(session.connection_id.into()),
2618                    },
2619                )
2620                .trace_err();
2621        }
2622
2623        broadcast(
2624            Some(session.connection_id),
2625            project
2626                .collaborators
2627                .iter()
2628                .map(|collaborator| collaborator.connection_id),
2629            |connection_id| {
2630                session.peer.forward_send(
2631                    session.connection_id,
2632                    connection_id,
2633                    proto::UpdateProject {
2634                        project_id: project.id.to_proto(),
2635                        worktrees: project.worktrees.clone(),
2636                    },
2637                )
2638            },
2639        );
2640    }
2641
2642    response.send(proto::ReconnectDevServerResponse {
2643        reshared_projects: reshared_projects
2644            .iter()
2645            .map(|project| proto::ResharedProject {
2646                id: project.id.to_proto(),
2647                collaborators: project
2648                    .collaborators
2649                    .iter()
2650                    .map(|collaborator| collaborator.to_proto())
2651                    .collect(),
2652            })
2653            .collect(),
2654    })?;
2655
2656    Ok(())
2657}
2658
2659async fn shutdown_dev_server(
2660    _: proto::ShutdownDevServer,
2661    response: Response<proto::ShutdownDevServer>,
2662    session: DevServerSession,
2663) -> Result<()> {
2664    response.send(proto::Ack {})?;
2665    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2666    remove_dev_server_connection(session.dev_server_id(), &session).await
2667}
2668
2669async fn shutdown_dev_server_internal(
2670    dev_server_id: DevServerId,
2671    connection_id: ConnectionId,
2672    session: &Session,
2673) -> Result<()> {
2674    let (dev_server_projects, dev_server) = {
2675        let db = session.db().await;
2676        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2677        let dev_server = db.get_dev_server(dev_server_id).await?;
2678        (dev_server_projects, dev_server)
2679    };
2680
2681    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2682        unshare_project_internal(
2683            ProjectId::from_proto(project_id),
2684            connection_id,
2685            None,
2686            session,
2687        )
2688        .await?;
2689    }
2690
2691    session
2692        .connection_pool()
2693        .await
2694        .set_dev_server_offline(dev_server_id);
2695
2696    let status = session
2697        .db()
2698        .await
2699        .dev_server_projects_update(dev_server.user_id)
2700        .await?;
2701    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2702
2703    Ok(())
2704}
2705
2706async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2707    let dev_server_connection = session
2708        .connection_pool()
2709        .await
2710        .dev_server_connection_id(dev_server_id);
2711
2712    if let Some(dev_server_connection) = dev_server_connection {
2713        session
2714            .connection_pool()
2715            .await
2716            .remove_connection(dev_server_connection)?;
2717    }
2718    Ok(())
2719}
2720
2721/// Updates other participants with changes to the project
2722async fn update_project(
2723    request: proto::UpdateProject,
2724    response: Response<proto::UpdateProject>,
2725    session: Session,
2726) -> Result<()> {
2727    let project_id = ProjectId::from_proto(request.project_id);
2728    let (room, guest_connection_ids) = &*session
2729        .db()
2730        .await
2731        .update_project(project_id, session.connection_id, &request.worktrees)
2732        .await?;
2733    broadcast(
2734        Some(session.connection_id),
2735        guest_connection_ids.iter().copied(),
2736        |connection_id| {
2737            session
2738                .peer
2739                .forward_send(session.connection_id, connection_id, request.clone())
2740        },
2741    );
2742    if let Some(room) = room {
2743        room_updated(&room, &session.peer);
2744    }
2745    response.send(proto::Ack {})?;
2746
2747    Ok(())
2748}
2749
2750/// Updates other participants with changes to the worktree
2751async fn update_worktree(
2752    request: proto::UpdateWorktree,
2753    response: Response<proto::UpdateWorktree>,
2754    session: Session,
2755) -> Result<()> {
2756    let guest_connection_ids = session
2757        .db()
2758        .await
2759        .update_worktree(&request, session.connection_id)
2760        .await?;
2761
2762    broadcast(
2763        Some(session.connection_id),
2764        guest_connection_ids.iter().copied(),
2765        |connection_id| {
2766            session
2767                .peer
2768                .forward_send(session.connection_id, connection_id, request.clone())
2769        },
2770    );
2771    response.send(proto::Ack {})?;
2772    Ok(())
2773}
2774
2775/// Updates other participants with changes to the diagnostics
2776async fn update_diagnostic_summary(
2777    message: proto::UpdateDiagnosticSummary,
2778    session: Session,
2779) -> Result<()> {
2780    let guest_connection_ids = session
2781        .db()
2782        .await
2783        .update_diagnostic_summary(&message, session.connection_id)
2784        .await?;
2785
2786    broadcast(
2787        Some(session.connection_id),
2788        guest_connection_ids.iter().copied(),
2789        |connection_id| {
2790            session
2791                .peer
2792                .forward_send(session.connection_id, connection_id, message.clone())
2793        },
2794    );
2795
2796    Ok(())
2797}
2798
2799/// Updates other participants with changes to the worktree settings
2800async fn update_worktree_settings(
2801    message: proto::UpdateWorktreeSettings,
2802    session: Session,
2803) -> Result<()> {
2804    let guest_connection_ids = session
2805        .db()
2806        .await
2807        .update_worktree_settings(&message, session.connection_id)
2808        .await?;
2809
2810    broadcast(
2811        Some(session.connection_id),
2812        guest_connection_ids.iter().copied(),
2813        |connection_id| {
2814            session
2815                .peer
2816                .forward_send(session.connection_id, connection_id, message.clone())
2817        },
2818    );
2819
2820    Ok(())
2821}
2822
2823/// Notify other participants that a  language server has started.
2824async fn start_language_server(
2825    request: proto::StartLanguageServer,
2826    session: Session,
2827) -> Result<()> {
2828    let guest_connection_ids = session
2829        .db()
2830        .await
2831        .start_language_server(&request, session.connection_id)
2832        .await?;
2833
2834    broadcast(
2835        Some(session.connection_id),
2836        guest_connection_ids.iter().copied(),
2837        |connection_id| {
2838            session
2839                .peer
2840                .forward_send(session.connection_id, connection_id, request.clone())
2841        },
2842    );
2843    Ok(())
2844}
2845
2846/// Notify other participants that a language server has changed.
2847async fn update_language_server(
2848    request: proto::UpdateLanguageServer,
2849    session: Session,
2850) -> Result<()> {
2851    let project_id = ProjectId::from_proto(request.project_id);
2852    let project_connection_ids = session
2853        .db()
2854        .await
2855        .project_connection_ids(project_id, session.connection_id, true)
2856        .await?;
2857    broadcast(
2858        Some(session.connection_id),
2859        project_connection_ids.iter().copied(),
2860        |connection_id| {
2861            session
2862                .peer
2863                .forward_send(session.connection_id, connection_id, request.clone())
2864        },
2865    );
2866    Ok(())
2867}
2868
2869/// forward a project request to the host. These requests should be read only
2870/// as guests are allowed to send them.
2871async fn forward_read_only_project_request<T>(
2872    request: T,
2873    response: Response<T>,
2874    session: UserSession,
2875) -> Result<()>
2876where
2877    T: EntityMessage + RequestMessage,
2878{
2879    let project_id = ProjectId::from_proto(request.remote_entity_id());
2880    let host_connection_id = session
2881        .db()
2882        .await
2883        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2884        .await?;
2885    let payload = session
2886        .peer
2887        .forward_request(session.connection_id, host_connection_id, request)
2888        .await?;
2889    response.send(payload)?;
2890    Ok(())
2891}
2892
2893/// forward a project request to the dev server. Only allowed
2894/// if it's your dev server.
2895async fn forward_project_request_for_owner<T>(
2896    request: T,
2897    response: Response<T>,
2898    session: UserSession,
2899) -> Result<()>
2900where
2901    T: EntityMessage + RequestMessage,
2902{
2903    let project_id = ProjectId::from_proto(request.remote_entity_id());
2904
2905    let host_connection_id = session
2906        .db()
2907        .await
2908        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2909        .await?;
2910    let payload = session
2911        .peer
2912        .forward_request(session.connection_id, host_connection_id, request)
2913        .await?;
2914    response.send(payload)?;
2915    Ok(())
2916}
2917
2918/// forward a project request to the host. These requests are disallowed
2919/// for guests.
2920async fn forward_mutating_project_request<T>(
2921    request: T,
2922    response: Response<T>,
2923    session: UserSession,
2924) -> Result<()>
2925where
2926    T: EntityMessage + RequestMessage,
2927{
2928    let project_id = ProjectId::from_proto(request.remote_entity_id());
2929
2930    let host_connection_id = session
2931        .db()
2932        .await
2933        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2934        .await?;
2935    let payload = session
2936        .peer
2937        .forward_request(session.connection_id, host_connection_id, request)
2938        .await?;
2939    response.send(payload)?;
2940    Ok(())
2941}
2942
2943/// forward a project request to the host. These requests are disallowed
2944/// for guests.
2945async fn forward_versioned_mutating_project_request<T>(
2946    request: T,
2947    response: Response<T>,
2948    session: UserSession,
2949) -> Result<()>
2950where
2951    T: EntityMessage + RequestMessage + VersionedMessage,
2952{
2953    let project_id = ProjectId::from_proto(request.remote_entity_id());
2954
2955    let host_connection_id = session
2956        .db()
2957        .await
2958        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2959        .await?;
2960    if let Some(host_version) = session
2961        .connection_pool()
2962        .await
2963        .connection(host_connection_id)
2964        .map(|c| c.zed_version)
2965    {
2966        if let Some(min_required_version) = request.required_host_version() {
2967            if min_required_version > host_version {
2968                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2969                    .with_tag("required", &min_required_version.to_string())))?;
2970            }
2971        }
2972    }
2973
2974    let payload = session
2975        .peer
2976        .forward_request(session.connection_id, host_connection_id, request)
2977        .await?;
2978    response.send(payload)?;
2979    Ok(())
2980}
2981
2982/// Notify other participants that a new buffer has been created
2983async fn create_buffer_for_peer(
2984    request: proto::CreateBufferForPeer,
2985    session: Session,
2986) -> Result<()> {
2987    session
2988        .db()
2989        .await
2990        .check_user_is_project_host(
2991            ProjectId::from_proto(request.project_id),
2992            session.connection_id,
2993        )
2994        .await?;
2995    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2996    session
2997        .peer
2998        .forward_send(session.connection_id, peer_id.into(), request)?;
2999    Ok(())
3000}
3001
3002/// Notify other participants that a buffer has been updated. This is
3003/// allowed for guests as long as the update is limited to selections.
3004async fn update_buffer(
3005    request: proto::UpdateBuffer,
3006    response: Response<proto::UpdateBuffer>,
3007    session: Session,
3008) -> Result<()> {
3009    let project_id = ProjectId::from_proto(request.project_id);
3010    let mut capability = Capability::ReadOnly;
3011
3012    for op in request.operations.iter() {
3013        match op.variant {
3014            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3015            Some(_) => capability = Capability::ReadWrite,
3016        }
3017    }
3018
3019    let host = {
3020        let guard = session
3021            .db()
3022            .await
3023            .connections_for_buffer_update(
3024                project_id,
3025                session.principal_id(),
3026                session.connection_id,
3027                capability,
3028            )
3029            .await?;
3030
3031        let (host, guests) = &*guard;
3032
3033        broadcast(
3034            Some(session.connection_id),
3035            guests.clone(),
3036            |connection_id| {
3037                session
3038                    .peer
3039                    .forward_send(session.connection_id, connection_id, request.clone())
3040            },
3041        );
3042
3043        *host
3044    };
3045
3046    if host != session.connection_id {
3047        session
3048            .peer
3049            .forward_request(session.connection_id, host, request.clone())
3050            .await?;
3051    }
3052
3053    response.send(proto::Ack {})?;
3054    Ok(())
3055}
3056
3057/// Notify other participants that a project has been updated.
3058async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3059    request: T,
3060    session: Session,
3061) -> Result<()> {
3062    let project_id = ProjectId::from_proto(request.remote_entity_id());
3063    let project_connection_ids = session
3064        .db()
3065        .await
3066        .project_connection_ids(project_id, session.connection_id, false)
3067        .await?;
3068
3069    broadcast(
3070        Some(session.connection_id),
3071        project_connection_ids.iter().copied(),
3072        |connection_id| {
3073            session
3074                .peer
3075                .forward_send(session.connection_id, connection_id, request.clone())
3076        },
3077    );
3078    Ok(())
3079}
3080
3081/// Start following another user in a call.
3082async fn follow(
3083    request: proto::Follow,
3084    response: Response<proto::Follow>,
3085    session: UserSession,
3086) -> Result<()> {
3087    let room_id = RoomId::from_proto(request.room_id);
3088    let project_id = request.project_id.map(ProjectId::from_proto);
3089    let leader_id = request
3090        .leader_id
3091        .ok_or_else(|| anyhow!("invalid leader id"))?
3092        .into();
3093    let follower_id = session.connection_id;
3094
3095    session
3096        .db()
3097        .await
3098        .check_room_participants(room_id, leader_id, session.connection_id)
3099        .await?;
3100
3101    let response_payload = session
3102        .peer
3103        .forward_request(session.connection_id, leader_id, request)
3104        .await?;
3105    response.send(response_payload)?;
3106
3107    if let Some(project_id) = project_id {
3108        let room = session
3109            .db()
3110            .await
3111            .follow(room_id, project_id, leader_id, follower_id)
3112            .await?;
3113        room_updated(&room, &session.peer);
3114    }
3115
3116    Ok(())
3117}
3118
3119/// Stop following another user in a call.
3120async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3121    let room_id = RoomId::from_proto(request.room_id);
3122    let project_id = request.project_id.map(ProjectId::from_proto);
3123    let leader_id = request
3124        .leader_id
3125        .ok_or_else(|| anyhow!("invalid leader id"))?
3126        .into();
3127    let follower_id = session.connection_id;
3128
3129    session
3130        .db()
3131        .await
3132        .check_room_participants(room_id, leader_id, session.connection_id)
3133        .await?;
3134
3135    session
3136        .peer
3137        .forward_send(session.connection_id, leader_id, request)?;
3138
3139    if let Some(project_id) = project_id {
3140        let room = session
3141            .db()
3142            .await
3143            .unfollow(room_id, project_id, leader_id, follower_id)
3144            .await?;
3145        room_updated(&room, &session.peer);
3146    }
3147
3148    Ok(())
3149}
3150
3151/// Notify everyone following you of your current location.
3152async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3153    let room_id = RoomId::from_proto(request.room_id);
3154    let database = session.db.lock().await;
3155
3156    let connection_ids = if let Some(project_id) = request.project_id {
3157        let project_id = ProjectId::from_proto(project_id);
3158        database
3159            .project_connection_ids(project_id, session.connection_id, true)
3160            .await?
3161    } else {
3162        database
3163            .room_connection_ids(room_id, session.connection_id)
3164            .await?
3165    };
3166
3167    // For now, don't send view update messages back to that view's current leader.
3168    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3169        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3170        _ => None,
3171    });
3172
3173    for connection_id in connection_ids.iter().cloned() {
3174        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3175            session
3176                .peer
3177                .forward_send(session.connection_id, connection_id, request.clone())?;
3178        }
3179    }
3180    Ok(())
3181}
3182
3183/// Get public data about users.
3184async fn get_users(
3185    request: proto::GetUsers,
3186    response: Response<proto::GetUsers>,
3187    session: Session,
3188) -> Result<()> {
3189    let user_ids = request
3190        .user_ids
3191        .into_iter()
3192        .map(UserId::from_proto)
3193        .collect();
3194    let users = session
3195        .db()
3196        .await
3197        .get_users_by_ids(user_ids)
3198        .await?
3199        .into_iter()
3200        .map(|user| proto::User {
3201            id: user.id.to_proto(),
3202            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3203            github_login: user.github_login,
3204        })
3205        .collect();
3206    response.send(proto::UsersResponse { users })?;
3207    Ok(())
3208}
3209
3210/// Search for users (to invite) buy Github login
3211async fn fuzzy_search_users(
3212    request: proto::FuzzySearchUsers,
3213    response: Response<proto::FuzzySearchUsers>,
3214    session: UserSession,
3215) -> Result<()> {
3216    let query = request.query;
3217    let users = match query.len() {
3218        0 => vec![],
3219        1 | 2 => session
3220            .db()
3221            .await
3222            .get_user_by_github_login(&query)
3223            .await?
3224            .into_iter()
3225            .collect(),
3226        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3227    };
3228    let users = users
3229        .into_iter()
3230        .filter(|user| user.id != session.user_id())
3231        .map(|user| proto::User {
3232            id: user.id.to_proto(),
3233            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3234            github_login: user.github_login,
3235        })
3236        .collect();
3237    response.send(proto::UsersResponse { users })?;
3238    Ok(())
3239}
3240
3241/// Send a contact request to another user.
3242async fn request_contact(
3243    request: proto::RequestContact,
3244    response: Response<proto::RequestContact>,
3245    session: UserSession,
3246) -> Result<()> {
3247    let requester_id = session.user_id();
3248    let responder_id = UserId::from_proto(request.responder_id);
3249    if requester_id == responder_id {
3250        return Err(anyhow!("cannot add yourself as a contact"))?;
3251    }
3252
3253    let notifications = session
3254        .db()
3255        .await
3256        .send_contact_request(requester_id, responder_id)
3257        .await?;
3258
3259    // Update outgoing contact requests of requester
3260    let mut update = proto::UpdateContacts::default();
3261    update.outgoing_requests.push(responder_id.to_proto());
3262    for connection_id in session
3263        .connection_pool()
3264        .await
3265        .user_connection_ids(requester_id)
3266    {
3267        session.peer.send(connection_id, update.clone())?;
3268    }
3269
3270    // Update incoming contact requests of responder
3271    let mut update = proto::UpdateContacts::default();
3272    update
3273        .incoming_requests
3274        .push(proto::IncomingContactRequest {
3275            requester_id: requester_id.to_proto(),
3276        });
3277    let connection_pool = session.connection_pool().await;
3278    for connection_id in connection_pool.user_connection_ids(responder_id) {
3279        session.peer.send(connection_id, update.clone())?;
3280    }
3281
3282    send_notifications(&connection_pool, &session.peer, notifications);
3283
3284    response.send(proto::Ack {})?;
3285    Ok(())
3286}
3287
3288/// Accept or decline a contact request
3289async fn respond_to_contact_request(
3290    request: proto::RespondToContactRequest,
3291    response: Response<proto::RespondToContactRequest>,
3292    session: UserSession,
3293) -> Result<()> {
3294    let responder_id = session.user_id();
3295    let requester_id = UserId::from_proto(request.requester_id);
3296    let db = session.db().await;
3297    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3298        db.dismiss_contact_notification(responder_id, requester_id)
3299            .await?;
3300    } else {
3301        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3302
3303        let notifications = db
3304            .respond_to_contact_request(responder_id, requester_id, accept)
3305            .await?;
3306        let requester_busy = db.is_user_busy(requester_id).await?;
3307        let responder_busy = db.is_user_busy(responder_id).await?;
3308
3309        let pool = session.connection_pool().await;
3310        // Update responder with new contact
3311        let mut update = proto::UpdateContacts::default();
3312        if accept {
3313            update
3314                .contacts
3315                .push(contact_for_user(requester_id, requester_busy, &pool));
3316        }
3317        update
3318            .remove_incoming_requests
3319            .push(requester_id.to_proto());
3320        for connection_id in pool.user_connection_ids(responder_id) {
3321            session.peer.send(connection_id, update.clone())?;
3322        }
3323
3324        // Update requester with new contact
3325        let mut update = proto::UpdateContacts::default();
3326        if accept {
3327            update
3328                .contacts
3329                .push(contact_for_user(responder_id, responder_busy, &pool));
3330        }
3331        update
3332            .remove_outgoing_requests
3333            .push(responder_id.to_proto());
3334
3335        for connection_id in pool.user_connection_ids(requester_id) {
3336            session.peer.send(connection_id, update.clone())?;
3337        }
3338
3339        send_notifications(&pool, &session.peer, notifications);
3340    }
3341
3342    response.send(proto::Ack {})?;
3343    Ok(())
3344}
3345
3346/// Remove a contact.
3347async fn remove_contact(
3348    request: proto::RemoveContact,
3349    response: Response<proto::RemoveContact>,
3350    session: UserSession,
3351) -> Result<()> {
3352    let requester_id = session.user_id();
3353    let responder_id = UserId::from_proto(request.user_id);
3354    let db = session.db().await;
3355    let (contact_accepted, deleted_notification_id) =
3356        db.remove_contact(requester_id, responder_id).await?;
3357
3358    let pool = session.connection_pool().await;
3359    // Update outgoing contact requests of requester
3360    let mut update = proto::UpdateContacts::default();
3361    if contact_accepted {
3362        update.remove_contacts.push(responder_id.to_proto());
3363    } else {
3364        update
3365            .remove_outgoing_requests
3366            .push(responder_id.to_proto());
3367    }
3368    for connection_id in pool.user_connection_ids(requester_id) {
3369        session.peer.send(connection_id, update.clone())?;
3370    }
3371
3372    // Update incoming contact requests of responder
3373    let mut update = proto::UpdateContacts::default();
3374    if contact_accepted {
3375        update.remove_contacts.push(requester_id.to_proto());
3376    } else {
3377        update
3378            .remove_incoming_requests
3379            .push(requester_id.to_proto());
3380    }
3381    for connection_id in pool.user_connection_ids(responder_id) {
3382        session.peer.send(connection_id, update.clone())?;
3383        if let Some(notification_id) = deleted_notification_id {
3384            session.peer.send(
3385                connection_id,
3386                proto::DeleteNotification {
3387                    notification_id: notification_id.to_proto(),
3388                },
3389            )?;
3390        }
3391    }
3392
3393    response.send(proto::Ack {})?;
3394    Ok(())
3395}
3396
3397fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3398    version.0.minor() < 139
3399}
3400
3401async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3402    subscribe_user_to_channels(
3403        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3404        &session,
3405    )
3406    .await?;
3407    Ok(())
3408}
3409
3410async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3411    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3412    let mut pool = session.connection_pool().await;
3413    for membership in &channels_for_user.channel_memberships {
3414        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3415    }
3416    session.peer.send(
3417        session.connection_id,
3418        build_update_user_channels(&channels_for_user),
3419    )?;
3420    session.peer.send(
3421        session.connection_id,
3422        build_channels_update(channels_for_user),
3423    )?;
3424    Ok(())
3425}
3426
3427/// Creates a new channel.
3428async fn create_channel(
3429    request: proto::CreateChannel,
3430    response: Response<proto::CreateChannel>,
3431    session: UserSession,
3432) -> Result<()> {
3433    let db = session.db().await;
3434
3435    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3436    let (channel, membership) = db
3437        .create_channel(&request.name, parent_id, session.user_id())
3438        .await?;
3439
3440    let root_id = channel.root_id();
3441    let channel = Channel::from_model(channel);
3442
3443    response.send(proto::CreateChannelResponse {
3444        channel: Some(channel.to_proto()),
3445        parent_id: request.parent_id,
3446    })?;
3447
3448    let mut connection_pool = session.connection_pool().await;
3449    if let Some(membership) = membership {
3450        connection_pool.subscribe_to_channel(
3451            membership.user_id,
3452            membership.channel_id,
3453            membership.role,
3454        );
3455        let update = proto::UpdateUserChannels {
3456            channel_memberships: vec![proto::ChannelMembership {
3457                channel_id: membership.channel_id.to_proto(),
3458                role: membership.role.into(),
3459            }],
3460            ..Default::default()
3461        };
3462        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3463            session.peer.send(connection_id, update.clone())?;
3464        }
3465    }
3466
3467    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3468        if !role.can_see_channel(channel.visibility) {
3469            continue;
3470        }
3471
3472        let update = proto::UpdateChannels {
3473            channels: vec![channel.to_proto()],
3474            ..Default::default()
3475        };
3476        session.peer.send(connection_id, update.clone())?;
3477    }
3478
3479    Ok(())
3480}
3481
3482/// Delete a channel
3483async fn delete_channel(
3484    request: proto::DeleteChannel,
3485    response: Response<proto::DeleteChannel>,
3486    session: UserSession,
3487) -> Result<()> {
3488    let db = session.db().await;
3489
3490    let channel_id = request.channel_id;
3491    let (root_channel, removed_channels) = db
3492        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3493        .await?;
3494    response.send(proto::Ack {})?;
3495
3496    // Notify members of removed channels
3497    let mut update = proto::UpdateChannels::default();
3498    update
3499        .delete_channels
3500        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3501
3502    let connection_pool = session.connection_pool().await;
3503    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3504        session.peer.send(connection_id, update.clone())?;
3505    }
3506
3507    Ok(())
3508}
3509
3510/// Invite someone to join a channel.
3511async fn invite_channel_member(
3512    request: proto::InviteChannelMember,
3513    response: Response<proto::InviteChannelMember>,
3514    session: UserSession,
3515) -> Result<()> {
3516    let db = session.db().await;
3517    let channel_id = ChannelId::from_proto(request.channel_id);
3518    let invitee_id = UserId::from_proto(request.user_id);
3519    let InviteMemberResult {
3520        channel,
3521        notifications,
3522    } = db
3523        .invite_channel_member(
3524            channel_id,
3525            invitee_id,
3526            session.user_id(),
3527            request.role().into(),
3528        )
3529        .await?;
3530
3531    let update = proto::UpdateChannels {
3532        channel_invitations: vec![channel.to_proto()],
3533        ..Default::default()
3534    };
3535
3536    let connection_pool = session.connection_pool().await;
3537    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3538        session.peer.send(connection_id, update.clone())?;
3539    }
3540
3541    send_notifications(&connection_pool, &session.peer, notifications);
3542
3543    response.send(proto::Ack {})?;
3544    Ok(())
3545}
3546
3547/// remove someone from a channel
3548async fn remove_channel_member(
3549    request: proto::RemoveChannelMember,
3550    response: Response<proto::RemoveChannelMember>,
3551    session: UserSession,
3552) -> Result<()> {
3553    let db = session.db().await;
3554    let channel_id = ChannelId::from_proto(request.channel_id);
3555    let member_id = UserId::from_proto(request.user_id);
3556
3557    let RemoveChannelMemberResult {
3558        membership_update,
3559        notification_id,
3560    } = db
3561        .remove_channel_member(channel_id, member_id, session.user_id())
3562        .await?;
3563
3564    let mut connection_pool = session.connection_pool().await;
3565    notify_membership_updated(
3566        &mut connection_pool,
3567        membership_update,
3568        member_id,
3569        &session.peer,
3570    );
3571    for connection_id in connection_pool.user_connection_ids(member_id) {
3572        if let Some(notification_id) = notification_id {
3573            session
3574                .peer
3575                .send(
3576                    connection_id,
3577                    proto::DeleteNotification {
3578                        notification_id: notification_id.to_proto(),
3579                    },
3580                )
3581                .trace_err();
3582        }
3583    }
3584
3585    response.send(proto::Ack {})?;
3586    Ok(())
3587}
3588
3589/// Toggle the channel between public and private.
3590/// Care is taken to maintain the invariant that public channels only descend from public channels,
3591/// (though members-only channels can appear at any point in the hierarchy).
3592async fn set_channel_visibility(
3593    request: proto::SetChannelVisibility,
3594    response: Response<proto::SetChannelVisibility>,
3595    session: UserSession,
3596) -> Result<()> {
3597    let db = session.db().await;
3598    let channel_id = ChannelId::from_proto(request.channel_id);
3599    let visibility = request.visibility().into();
3600
3601    let channel_model = db
3602        .set_channel_visibility(channel_id, visibility, session.user_id())
3603        .await?;
3604    let root_id = channel_model.root_id();
3605    let channel = Channel::from_model(channel_model);
3606
3607    let mut connection_pool = session.connection_pool().await;
3608    for (user_id, role) in connection_pool
3609        .channel_user_ids(root_id)
3610        .collect::<Vec<_>>()
3611        .into_iter()
3612    {
3613        let update = if role.can_see_channel(channel.visibility) {
3614            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3615            proto::UpdateChannels {
3616                channels: vec![channel.to_proto()],
3617                ..Default::default()
3618            }
3619        } else {
3620            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3621            proto::UpdateChannels {
3622                delete_channels: vec![channel.id.to_proto()],
3623                ..Default::default()
3624            }
3625        };
3626
3627        for connection_id in connection_pool.user_connection_ids(user_id) {
3628            session.peer.send(connection_id, update.clone())?;
3629        }
3630    }
3631
3632    response.send(proto::Ack {})?;
3633    Ok(())
3634}
3635
3636/// Alter the role for a user in the channel.
3637async fn set_channel_member_role(
3638    request: proto::SetChannelMemberRole,
3639    response: Response<proto::SetChannelMemberRole>,
3640    session: UserSession,
3641) -> Result<()> {
3642    let db = session.db().await;
3643    let channel_id = ChannelId::from_proto(request.channel_id);
3644    let member_id = UserId::from_proto(request.user_id);
3645    let result = db
3646        .set_channel_member_role(
3647            channel_id,
3648            session.user_id(),
3649            member_id,
3650            request.role().into(),
3651        )
3652        .await?;
3653
3654    match result {
3655        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3656            let mut connection_pool = session.connection_pool().await;
3657            notify_membership_updated(
3658                &mut connection_pool,
3659                membership_update,
3660                member_id,
3661                &session.peer,
3662            )
3663        }
3664        db::SetMemberRoleResult::InviteUpdated(channel) => {
3665            let update = proto::UpdateChannels {
3666                channel_invitations: vec![channel.to_proto()],
3667                ..Default::default()
3668            };
3669
3670            for connection_id in session
3671                .connection_pool()
3672                .await
3673                .user_connection_ids(member_id)
3674            {
3675                session.peer.send(connection_id, update.clone())?;
3676            }
3677        }
3678    }
3679
3680    response.send(proto::Ack {})?;
3681    Ok(())
3682}
3683
3684/// Change the name of a channel
3685async fn rename_channel(
3686    request: proto::RenameChannel,
3687    response: Response<proto::RenameChannel>,
3688    session: UserSession,
3689) -> Result<()> {
3690    let db = session.db().await;
3691    let channel_id = ChannelId::from_proto(request.channel_id);
3692    let channel_model = db
3693        .rename_channel(channel_id, session.user_id(), &request.name)
3694        .await?;
3695    let root_id = channel_model.root_id();
3696    let channel = Channel::from_model(channel_model);
3697
3698    response.send(proto::RenameChannelResponse {
3699        channel: Some(channel.to_proto()),
3700    })?;
3701
3702    let connection_pool = session.connection_pool().await;
3703    let update = proto::UpdateChannels {
3704        channels: vec![channel.to_proto()],
3705        ..Default::default()
3706    };
3707    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3708        if role.can_see_channel(channel.visibility) {
3709            session.peer.send(connection_id, update.clone())?;
3710        }
3711    }
3712
3713    Ok(())
3714}
3715
3716/// Move a channel to a new parent.
3717async fn move_channel(
3718    request: proto::MoveChannel,
3719    response: Response<proto::MoveChannel>,
3720    session: UserSession,
3721) -> Result<()> {
3722    let channel_id = ChannelId::from_proto(request.channel_id);
3723    let to = ChannelId::from_proto(request.to);
3724
3725    let (root_id, channels) = session
3726        .db()
3727        .await
3728        .move_channel(channel_id, to, session.user_id())
3729        .await?;
3730
3731    let connection_pool = session.connection_pool().await;
3732    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3733        let channels = channels
3734            .iter()
3735            .filter_map(|channel| {
3736                if role.can_see_channel(channel.visibility) {
3737                    Some(channel.to_proto())
3738                } else {
3739                    None
3740                }
3741            })
3742            .collect::<Vec<_>>();
3743        if channels.is_empty() {
3744            continue;
3745        }
3746
3747        let update = proto::UpdateChannels {
3748            channels,
3749            ..Default::default()
3750        };
3751
3752        session.peer.send(connection_id, update.clone())?;
3753    }
3754
3755    response.send(Ack {})?;
3756    Ok(())
3757}
3758
3759/// Get the list of channel members
3760async fn get_channel_members(
3761    request: proto::GetChannelMembers,
3762    response: Response<proto::GetChannelMembers>,
3763    session: UserSession,
3764) -> Result<()> {
3765    let db = session.db().await;
3766    let channel_id = ChannelId::from_proto(request.channel_id);
3767    let limit = if request.limit == 0 {
3768        u16::MAX as u64
3769    } else {
3770        request.limit
3771    };
3772    let (members, users) = db
3773        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3774        .await?;
3775    response.send(proto::GetChannelMembersResponse { members, users })?;
3776    Ok(())
3777}
3778
3779/// Accept or decline a channel invitation.
3780async fn respond_to_channel_invite(
3781    request: proto::RespondToChannelInvite,
3782    response: Response<proto::RespondToChannelInvite>,
3783    session: UserSession,
3784) -> Result<()> {
3785    let db = session.db().await;
3786    let channel_id = ChannelId::from_proto(request.channel_id);
3787    let RespondToChannelInvite {
3788        membership_update,
3789        notifications,
3790    } = db
3791        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3792        .await?;
3793
3794    let mut connection_pool = session.connection_pool().await;
3795    if let Some(membership_update) = membership_update {
3796        notify_membership_updated(
3797            &mut connection_pool,
3798            membership_update,
3799            session.user_id(),
3800            &session.peer,
3801        );
3802    } else {
3803        let update = proto::UpdateChannels {
3804            remove_channel_invitations: vec![channel_id.to_proto()],
3805            ..Default::default()
3806        };
3807
3808        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3809            session.peer.send(connection_id, update.clone())?;
3810        }
3811    };
3812
3813    send_notifications(&connection_pool, &session.peer, notifications);
3814
3815    response.send(proto::Ack {})?;
3816
3817    Ok(())
3818}
3819
3820/// Join the channels' room
3821async fn join_channel(
3822    request: proto::JoinChannel,
3823    response: Response<proto::JoinChannel>,
3824    session: UserSession,
3825) -> Result<()> {
3826    let channel_id = ChannelId::from_proto(request.channel_id);
3827    join_channel_internal(channel_id, Box::new(response), session).await
3828}
3829
3830trait JoinChannelInternalResponse {
3831    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3832}
3833impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3834    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3835        Response::<proto::JoinChannel>::send(self, result)
3836    }
3837}
3838impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3839    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3840        Response::<proto::JoinRoom>::send(self, result)
3841    }
3842}
3843
3844async fn join_channel_internal(
3845    channel_id: ChannelId,
3846    response: Box<impl JoinChannelInternalResponse>,
3847    session: UserSession,
3848) -> Result<()> {
3849    let joined_room = {
3850        let mut db = session.db().await;
3851        // If zed quits without leaving the room, and the user re-opens zed before the
3852        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3853        // room they were in.
3854        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3855            tracing::info!(
3856                stale_connection_id = %connection,
3857                "cleaning up stale connection",
3858            );
3859            drop(db);
3860            leave_room_for_session(&session, connection).await?;
3861            db = session.db().await;
3862        }
3863
3864        let (joined_room, membership_updated, role) = db
3865            .join_channel(channel_id, session.user_id(), session.connection_id)
3866            .await?;
3867
3868        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3869            let (can_publish, token) = if role == ChannelRole::Guest {
3870                (
3871                    false,
3872                    live_kit
3873                        .guest_token(
3874                            &joined_room.room.live_kit_room,
3875                            &session.user_id().to_string(),
3876                        )
3877                        .trace_err()?,
3878                )
3879            } else {
3880                (
3881                    true,
3882                    live_kit
3883                        .room_token(
3884                            &joined_room.room.live_kit_room,
3885                            &session.user_id().to_string(),
3886                        )
3887                        .trace_err()?,
3888                )
3889            };
3890
3891            Some(LiveKitConnectionInfo {
3892                server_url: live_kit.url().into(),
3893                token,
3894                can_publish,
3895            })
3896        });
3897
3898        response.send(proto::JoinRoomResponse {
3899            room: Some(joined_room.room.clone()),
3900            channel_id: joined_room
3901                .channel
3902                .as_ref()
3903                .map(|channel| channel.id.to_proto()),
3904            live_kit_connection_info,
3905        })?;
3906
3907        let mut connection_pool = session.connection_pool().await;
3908        if let Some(membership_updated) = membership_updated {
3909            notify_membership_updated(
3910                &mut connection_pool,
3911                membership_updated,
3912                session.user_id(),
3913                &session.peer,
3914            );
3915        }
3916
3917        room_updated(&joined_room.room, &session.peer);
3918
3919        joined_room
3920    };
3921
3922    channel_updated(
3923        &joined_room
3924            .channel
3925            .ok_or_else(|| anyhow!("channel not returned"))?,
3926        &joined_room.room,
3927        &session.peer,
3928        &*session.connection_pool().await,
3929    );
3930
3931    update_user_contacts(session.user_id(), &session).await?;
3932    Ok(())
3933}
3934
3935/// Start editing the channel notes
3936async fn join_channel_buffer(
3937    request: proto::JoinChannelBuffer,
3938    response: Response<proto::JoinChannelBuffer>,
3939    session: UserSession,
3940) -> Result<()> {
3941    let db = session.db().await;
3942    let channel_id = ChannelId::from_proto(request.channel_id);
3943
3944    let open_response = db
3945        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3946        .await?;
3947
3948    let collaborators = open_response.collaborators.clone();
3949    response.send(open_response)?;
3950
3951    let update = UpdateChannelBufferCollaborators {
3952        channel_id: channel_id.to_proto(),
3953        collaborators: collaborators.clone(),
3954    };
3955    channel_buffer_updated(
3956        session.connection_id,
3957        collaborators
3958            .iter()
3959            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3960        &update,
3961        &session.peer,
3962    );
3963
3964    Ok(())
3965}
3966
3967/// Edit the channel notes
3968async fn update_channel_buffer(
3969    request: proto::UpdateChannelBuffer,
3970    session: UserSession,
3971) -> Result<()> {
3972    let db = session.db().await;
3973    let channel_id = ChannelId::from_proto(request.channel_id);
3974
3975    let (collaborators, epoch, version) = db
3976        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3977        .await?;
3978
3979    channel_buffer_updated(
3980        session.connection_id,
3981        collaborators.clone(),
3982        &proto::UpdateChannelBuffer {
3983            channel_id: channel_id.to_proto(),
3984            operations: request.operations,
3985        },
3986        &session.peer,
3987    );
3988
3989    let pool = &*session.connection_pool().await;
3990
3991    let non_collaborators =
3992        pool.channel_connection_ids(channel_id)
3993            .filter_map(|(connection_id, _)| {
3994                if collaborators.contains(&connection_id) {
3995                    None
3996                } else {
3997                    Some(connection_id)
3998                }
3999            });
4000
4001    broadcast(None, non_collaborators, |peer_id| {
4002        session.peer.send(
4003            peer_id,
4004            proto::UpdateChannels {
4005                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4006                    channel_id: channel_id.to_proto(),
4007                    epoch: epoch as u64,
4008                    version: version.clone(),
4009                }],
4010                ..Default::default()
4011            },
4012        )
4013    });
4014
4015    Ok(())
4016}
4017
4018/// Rejoin the channel notes after a connection blip
4019async fn rejoin_channel_buffers(
4020    request: proto::RejoinChannelBuffers,
4021    response: Response<proto::RejoinChannelBuffers>,
4022    session: UserSession,
4023) -> Result<()> {
4024    let db = session.db().await;
4025    let buffers = db
4026        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4027        .await?;
4028
4029    for rejoined_buffer in &buffers {
4030        let collaborators_to_notify = rejoined_buffer
4031            .buffer
4032            .collaborators
4033            .iter()
4034            .filter_map(|c| Some(c.peer_id?.into()));
4035        channel_buffer_updated(
4036            session.connection_id,
4037            collaborators_to_notify,
4038            &proto::UpdateChannelBufferCollaborators {
4039                channel_id: rejoined_buffer.buffer.channel_id,
4040                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4041            },
4042            &session.peer,
4043        );
4044    }
4045
4046    response.send(proto::RejoinChannelBuffersResponse {
4047        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4048    })?;
4049
4050    Ok(())
4051}
4052
4053/// Stop editing the channel notes
4054async fn leave_channel_buffer(
4055    request: proto::LeaveChannelBuffer,
4056    response: Response<proto::LeaveChannelBuffer>,
4057    session: UserSession,
4058) -> Result<()> {
4059    let db = session.db().await;
4060    let channel_id = ChannelId::from_proto(request.channel_id);
4061
4062    let left_buffer = db
4063        .leave_channel_buffer(channel_id, session.connection_id)
4064        .await?;
4065
4066    response.send(Ack {})?;
4067
4068    channel_buffer_updated(
4069        session.connection_id,
4070        left_buffer.connections,
4071        &proto::UpdateChannelBufferCollaborators {
4072            channel_id: channel_id.to_proto(),
4073            collaborators: left_buffer.collaborators,
4074        },
4075        &session.peer,
4076    );
4077
4078    Ok(())
4079}
4080
4081fn channel_buffer_updated<T: EnvelopedMessage>(
4082    sender_id: ConnectionId,
4083    collaborators: impl IntoIterator<Item = ConnectionId>,
4084    message: &T,
4085    peer: &Peer,
4086) {
4087    broadcast(Some(sender_id), collaborators, |peer_id| {
4088        peer.send(peer_id, message.clone())
4089    });
4090}
4091
4092fn send_notifications(
4093    connection_pool: &ConnectionPool,
4094    peer: &Peer,
4095    notifications: db::NotificationBatch,
4096) {
4097    for (user_id, notification) in notifications {
4098        for connection_id in connection_pool.user_connection_ids(user_id) {
4099            if let Err(error) = peer.send(
4100                connection_id,
4101                proto::AddNotification {
4102                    notification: Some(notification.clone()),
4103                },
4104            ) {
4105                tracing::error!(
4106                    "failed to send notification to {:?} {}",
4107                    connection_id,
4108                    error
4109                );
4110            }
4111        }
4112    }
4113}
4114
4115/// Send a message to the channel
4116async fn send_channel_message(
4117    request: proto::SendChannelMessage,
4118    response: Response<proto::SendChannelMessage>,
4119    session: UserSession,
4120) -> Result<()> {
4121    // Validate the message body.
4122    let body = request.body.trim().to_string();
4123    if body.len() > MAX_MESSAGE_LEN {
4124        return Err(anyhow!("message is too long"))?;
4125    }
4126    if body.is_empty() {
4127        return Err(anyhow!("message can't be blank"))?;
4128    }
4129
4130    // TODO: adjust mentions if body is trimmed
4131
4132    let timestamp = OffsetDateTime::now_utc();
4133    let nonce = request
4134        .nonce
4135        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4136
4137    let channel_id = ChannelId::from_proto(request.channel_id);
4138    let CreatedChannelMessage {
4139        message_id,
4140        participant_connection_ids,
4141        notifications,
4142    } = session
4143        .db()
4144        .await
4145        .create_channel_message(
4146            channel_id,
4147            session.user_id(),
4148            &body,
4149            &request.mentions,
4150            timestamp,
4151            nonce.clone().into(),
4152            match request.reply_to_message_id {
4153                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4154                None => None,
4155            },
4156        )
4157        .await?;
4158
4159    let message = proto::ChannelMessage {
4160        sender_id: session.user_id().to_proto(),
4161        id: message_id.to_proto(),
4162        body,
4163        mentions: request.mentions,
4164        timestamp: timestamp.unix_timestamp() as u64,
4165        nonce: Some(nonce),
4166        reply_to_message_id: request.reply_to_message_id,
4167        edited_at: None,
4168    };
4169    broadcast(
4170        Some(session.connection_id),
4171        participant_connection_ids.clone(),
4172        |connection| {
4173            session.peer.send(
4174                connection,
4175                proto::ChannelMessageSent {
4176                    channel_id: channel_id.to_proto(),
4177                    message: Some(message.clone()),
4178                },
4179            )
4180        },
4181    );
4182    response.send(proto::SendChannelMessageResponse {
4183        message: Some(message),
4184    })?;
4185
4186    let pool = &*session.connection_pool().await;
4187    let non_participants =
4188        pool.channel_connection_ids(channel_id)
4189            .filter_map(|(connection_id, _)| {
4190                if participant_connection_ids.contains(&connection_id) {
4191                    None
4192                } else {
4193                    Some(connection_id)
4194                }
4195            });
4196    broadcast(None, non_participants, |peer_id| {
4197        session.peer.send(
4198            peer_id,
4199            proto::UpdateChannels {
4200                latest_channel_message_ids: vec![proto::ChannelMessageId {
4201                    channel_id: channel_id.to_proto(),
4202                    message_id: message_id.to_proto(),
4203                }],
4204                ..Default::default()
4205            },
4206        )
4207    });
4208    send_notifications(pool, &session.peer, notifications);
4209
4210    Ok(())
4211}
4212
4213/// Delete a channel message
4214async fn remove_channel_message(
4215    request: proto::RemoveChannelMessage,
4216    response: Response<proto::RemoveChannelMessage>,
4217    session: UserSession,
4218) -> Result<()> {
4219    let channel_id = ChannelId::from_proto(request.channel_id);
4220    let message_id = MessageId::from_proto(request.message_id);
4221    let (connection_ids, existing_notification_ids) = session
4222        .db()
4223        .await
4224        .remove_channel_message(channel_id, message_id, session.user_id())
4225        .await?;
4226
4227    broadcast(
4228        Some(session.connection_id),
4229        connection_ids,
4230        move |connection| {
4231            session.peer.send(connection, request.clone())?;
4232
4233            for notification_id in &existing_notification_ids {
4234                session.peer.send(
4235                    connection,
4236                    proto::DeleteNotification {
4237                        notification_id: (*notification_id).to_proto(),
4238                    },
4239                )?;
4240            }
4241
4242            Ok(())
4243        },
4244    );
4245    response.send(proto::Ack {})?;
4246    Ok(())
4247}
4248
4249async fn update_channel_message(
4250    request: proto::UpdateChannelMessage,
4251    response: Response<proto::UpdateChannelMessage>,
4252    session: UserSession,
4253) -> Result<()> {
4254    let channel_id = ChannelId::from_proto(request.channel_id);
4255    let message_id = MessageId::from_proto(request.message_id);
4256    let updated_at = OffsetDateTime::now_utc();
4257    let UpdatedChannelMessage {
4258        message_id,
4259        participant_connection_ids,
4260        notifications,
4261        reply_to_message_id,
4262        timestamp,
4263        deleted_mention_notification_ids,
4264        updated_mention_notifications,
4265    } = session
4266        .db()
4267        .await
4268        .update_channel_message(
4269            channel_id,
4270            message_id,
4271            session.user_id(),
4272            request.body.as_str(),
4273            &request.mentions,
4274            updated_at,
4275        )
4276        .await?;
4277
4278    let nonce = request
4279        .nonce
4280        .clone()
4281        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4282
4283    let message = proto::ChannelMessage {
4284        sender_id: session.user_id().to_proto(),
4285        id: message_id.to_proto(),
4286        body: request.body.clone(),
4287        mentions: request.mentions.clone(),
4288        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4289        nonce: Some(nonce),
4290        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4291        edited_at: Some(updated_at.unix_timestamp() as u64),
4292    };
4293
4294    response.send(proto::Ack {})?;
4295
4296    let pool = &*session.connection_pool().await;
4297    broadcast(
4298        Some(session.connection_id),
4299        participant_connection_ids,
4300        |connection| {
4301            session.peer.send(
4302                connection,
4303                proto::ChannelMessageUpdate {
4304                    channel_id: channel_id.to_proto(),
4305                    message: Some(message.clone()),
4306                },
4307            )?;
4308
4309            for notification_id in &deleted_mention_notification_ids {
4310                session.peer.send(
4311                    connection,
4312                    proto::DeleteNotification {
4313                        notification_id: (*notification_id).to_proto(),
4314                    },
4315                )?;
4316            }
4317
4318            for notification in &updated_mention_notifications {
4319                session.peer.send(
4320                    connection,
4321                    proto::UpdateNotification {
4322                        notification: Some(notification.clone()),
4323                    },
4324                )?;
4325            }
4326
4327            Ok(())
4328        },
4329    );
4330
4331    send_notifications(pool, &session.peer, notifications);
4332
4333    Ok(())
4334}
4335
4336/// Mark a channel message as read
4337async fn acknowledge_channel_message(
4338    request: proto::AckChannelMessage,
4339    session: UserSession,
4340) -> Result<()> {
4341    let channel_id = ChannelId::from_proto(request.channel_id);
4342    let message_id = MessageId::from_proto(request.message_id);
4343    let notifications = session
4344        .db()
4345        .await
4346        .observe_channel_message(channel_id, session.user_id(), message_id)
4347        .await?;
4348    send_notifications(
4349        &*session.connection_pool().await,
4350        &session.peer,
4351        notifications,
4352    );
4353    Ok(())
4354}
4355
4356/// Mark a buffer version as synced
4357async fn acknowledge_buffer_version(
4358    request: proto::AckBufferOperation,
4359    session: UserSession,
4360) -> Result<()> {
4361    let buffer_id = BufferId::from_proto(request.buffer_id);
4362    session
4363        .db()
4364        .await
4365        .observe_buffer_version(
4366            buffer_id,
4367            session.user_id(),
4368            request.epoch as i32,
4369            &request.version,
4370        )
4371        .await?;
4372    Ok(())
4373}
4374
4375struct CompleteWithLanguageModelRateLimit;
4376
4377impl RateLimit for CompleteWithLanguageModelRateLimit {
4378    fn capacity() -> usize {
4379        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4380            .ok()
4381            .and_then(|v| v.parse().ok())
4382            .unwrap_or(120) // Picked arbitrarily
4383    }
4384
4385    fn refill_duration() -> chrono::Duration {
4386        chrono::Duration::hours(1)
4387    }
4388
4389    fn db_name() -> &'static str {
4390        "complete-with-language-model"
4391    }
4392}
4393
4394async fn complete_with_language_model(
4395    request: proto::CompleteWithLanguageModel,
4396    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4397    session: Session,
4398    open_ai_api_key: Option<Arc<str>>,
4399    google_ai_api_key: Option<Arc<str>>,
4400    anthropic_api_key: Option<Arc<str>>,
4401) -> Result<()> {
4402    let Some(session) = session.for_user() else {
4403        return Err(anyhow!("user not found"))?;
4404    };
4405    authorize_access_to_language_models(&session).await?;
4406    session
4407        .rate_limiter
4408        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4409        .await?;
4410
4411    if request.model.starts_with("gpt") {
4412        let api_key =
4413            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4414        complete_with_open_ai(request, response, session, api_key).await?;
4415    } else if request.model.starts_with("gemini") {
4416        let api_key = google_ai_api_key
4417            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4418        complete_with_google_ai(request, response, session, api_key).await?;
4419    } else if request.model.starts_with("claude") {
4420        let api_key = anthropic_api_key
4421            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4422        complete_with_anthropic(request, response, session, api_key).await?;
4423    }
4424
4425    Ok(())
4426}
4427
4428async fn complete_with_open_ai(
4429    request: proto::CompleteWithLanguageModel,
4430    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4431    session: UserSession,
4432    api_key: Arc<str>,
4433) -> Result<()> {
4434    let mut completion_stream = open_ai::stream_completion(
4435        session.http_client.as_ref(),
4436        OPEN_AI_API_URL,
4437        &api_key,
4438        crate::ai::language_model_request_to_open_ai(request)?,
4439        None,
4440    )
4441    .await
4442    .context("open_ai::stream_completion request failed within collab")?;
4443
4444    while let Some(event) = completion_stream.next().await {
4445        let event = event?;
4446        response.send(proto::LanguageModelResponse {
4447            choices: event
4448                .choices
4449                .into_iter()
4450                .map(|choice| proto::LanguageModelChoiceDelta {
4451                    index: choice.index,
4452                    delta: Some(proto::LanguageModelResponseMessage {
4453                        role: choice.delta.role.map(|role| match role {
4454                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4455                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4456                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4457                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4458                        } as i32),
4459                        content: choice.delta.content,
4460                        tool_calls: choice
4461                            .delta
4462                            .tool_calls
4463                            .into_iter()
4464                            .map(|delta| proto::ToolCallDelta {
4465                                index: delta.index as u32,
4466                                id: delta.id,
4467                                variant: match delta.function {
4468                                    Some(function) => {
4469                                        let name = function.name;
4470                                        let arguments = function.arguments;
4471
4472                                        Some(proto::tool_call_delta::Variant::Function(
4473                                            proto::tool_call_delta::FunctionCallDelta {
4474                                                name,
4475                                                arguments,
4476                                            },
4477                                        ))
4478                                    }
4479                                    None => None,
4480                                },
4481                            })
4482                            .collect(),
4483                    }),
4484                    finish_reason: choice.finish_reason,
4485                })
4486                .collect(),
4487        })?;
4488    }
4489
4490    Ok(())
4491}
4492
4493async fn complete_with_google_ai(
4494    request: proto::CompleteWithLanguageModel,
4495    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4496    session: UserSession,
4497    api_key: Arc<str>,
4498) -> Result<()> {
4499    let mut stream = google_ai::stream_generate_content(
4500        session.http_client.clone(),
4501        google_ai::API_URL,
4502        api_key.as_ref(),
4503        crate::ai::language_model_request_to_google_ai(request)?,
4504    )
4505    .await
4506    .context("google_ai::stream_generate_content request failed")?;
4507
4508    while let Some(event) = stream.next().await {
4509        let event = event?;
4510        response.send(proto::LanguageModelResponse {
4511            choices: event
4512                .candidates
4513                .unwrap_or_default()
4514                .into_iter()
4515                .map(|candidate| proto::LanguageModelChoiceDelta {
4516                    index: candidate.index as u32,
4517                    delta: Some(proto::LanguageModelResponseMessage {
4518                        role: Some(match candidate.content.role {
4519                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4520                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4521                        } as i32),
4522                        content: Some(
4523                            candidate
4524                                .content
4525                                .parts
4526                                .into_iter()
4527                                .filter_map(|part| match part {
4528                                    google_ai::Part::TextPart(part) => Some(part.text),
4529                                    google_ai::Part::InlineDataPart(_) => None,
4530                                })
4531                                .collect(),
4532                        ),
4533                        // Tool calls are not supported for Google
4534                        tool_calls: Vec::new(),
4535                    }),
4536                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4537                })
4538                .collect(),
4539        })?;
4540    }
4541
4542    Ok(())
4543}
4544
4545async fn complete_with_anthropic(
4546    request: proto::CompleteWithLanguageModel,
4547    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4548    session: UserSession,
4549    api_key: Arc<str>,
4550) -> Result<()> {
4551    let model = anthropic::Model::from_id(&request.model)?;
4552
4553    let mut system_message = String::new();
4554    let messages = request
4555        .messages
4556        .into_iter()
4557        .filter_map(|message| {
4558            match message.role() {
4559                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4560                    role: anthropic::Role::User,
4561                    content: message.content,
4562                }),
4563                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4564                    role: anthropic::Role::Assistant,
4565                    content: message.content,
4566                }),
4567                // Anthropic's API breaks system instructions out as a separate field rather
4568                // than having a system message role.
4569                LanguageModelRole::LanguageModelSystem => {
4570                    if !system_message.is_empty() {
4571                        system_message.push_str("\n\n");
4572                    }
4573                    system_message.push_str(&message.content);
4574
4575                    None
4576                }
4577                // We don't yet support tool calls for Anthropic
4578                LanguageModelRole::LanguageModelTool => None,
4579            }
4580        })
4581        .collect();
4582
4583    let mut stream = anthropic::stream_completion(
4584        session.http_client.as_ref(),
4585        anthropic::ANTHROPIC_API_URL,
4586        &api_key,
4587        anthropic::Request {
4588            model,
4589            messages,
4590            stream: true,
4591            system: system_message,
4592            max_tokens: 4092,
4593        },
4594        None,
4595    )
4596    .await?;
4597
4598    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4599
4600    while let Some(event) = stream.next().await {
4601        let event = event?;
4602
4603        match event {
4604            anthropic::ResponseEvent::MessageStart { message } => {
4605                if let Some(role) = message.role {
4606                    if role == "assistant" {
4607                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4608                    } else if role == "user" {
4609                        current_role = proto::LanguageModelRole::LanguageModelUser;
4610                    }
4611                }
4612            }
4613            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4614                match content_block {
4615                    anthropic::ContentBlock::Text { text } => {
4616                        if !text.is_empty() {
4617                            response.send(proto::LanguageModelResponse {
4618                                choices: vec![proto::LanguageModelChoiceDelta {
4619                                    index: 0,
4620                                    delta: Some(proto::LanguageModelResponseMessage {
4621                                        role: Some(current_role as i32),
4622                                        content: Some(text),
4623                                        tool_calls: Vec::new(),
4624                                    }),
4625                                    finish_reason: None,
4626                                }],
4627                            })?;
4628                        }
4629                    }
4630                }
4631            }
4632            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4633                anthropic::TextDelta::TextDelta { text } => {
4634                    response.send(proto::LanguageModelResponse {
4635                        choices: vec![proto::LanguageModelChoiceDelta {
4636                            index: 0,
4637                            delta: Some(proto::LanguageModelResponseMessage {
4638                                role: Some(current_role as i32),
4639                                content: Some(text),
4640                                tool_calls: Vec::new(),
4641                            }),
4642                            finish_reason: None,
4643                        }],
4644                    })?;
4645                }
4646            },
4647            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4648                if let Some(stop_reason) = delta.stop_reason {
4649                    response.send(proto::LanguageModelResponse {
4650                        choices: vec![proto::LanguageModelChoiceDelta {
4651                            index: 0,
4652                            delta: None,
4653                            finish_reason: Some(stop_reason),
4654                        }],
4655                    })?;
4656                }
4657            }
4658            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4659            anthropic::ResponseEvent::MessageStop {} => {}
4660            anthropic::ResponseEvent::Ping {} => {}
4661        }
4662    }
4663
4664    Ok(())
4665}
4666
4667struct CountTokensWithLanguageModelRateLimit;
4668
4669impl RateLimit for CountTokensWithLanguageModelRateLimit {
4670    fn capacity() -> usize {
4671        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4672            .ok()
4673            .and_then(|v| v.parse().ok())
4674            .unwrap_or(600) // Picked arbitrarily
4675    }
4676
4677    fn refill_duration() -> chrono::Duration {
4678        chrono::Duration::hours(1)
4679    }
4680
4681    fn db_name() -> &'static str {
4682        "count-tokens-with-language-model"
4683    }
4684}
4685
4686async fn count_tokens_with_language_model(
4687    request: proto::CountTokensWithLanguageModel,
4688    response: Response<proto::CountTokensWithLanguageModel>,
4689    session: UserSession,
4690    google_ai_api_key: Option<Arc<str>>,
4691) -> Result<()> {
4692    authorize_access_to_language_models(&session).await?;
4693
4694    if !request.model.starts_with("gemini") {
4695        return Err(anyhow!(
4696            "counting tokens for model: {:?} is not supported",
4697            request.model
4698        ))?;
4699    }
4700
4701    session
4702        .rate_limiter
4703        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4704        .await?;
4705
4706    let api_key = google_ai_api_key
4707        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4708    let tokens_response = google_ai::count_tokens(
4709        session.http_client.as_ref(),
4710        google_ai::API_URL,
4711        &api_key,
4712        crate::ai::count_tokens_request_to_google_ai(request)?,
4713    )
4714    .await?;
4715    response.send(proto::CountTokensResponse {
4716        token_count: tokens_response.total_tokens as u32,
4717    })?;
4718    Ok(())
4719}
4720
4721struct ComputeEmbeddingsRateLimit;
4722
4723impl RateLimit for ComputeEmbeddingsRateLimit {
4724    fn capacity() -> usize {
4725        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4726            .ok()
4727            .and_then(|v| v.parse().ok())
4728            .unwrap_or(5000) // Picked arbitrarily
4729    }
4730
4731    fn refill_duration() -> chrono::Duration {
4732        chrono::Duration::hours(1)
4733    }
4734
4735    fn db_name() -> &'static str {
4736        "compute-embeddings"
4737    }
4738}
4739
4740async fn compute_embeddings(
4741    request: proto::ComputeEmbeddings,
4742    response: Response<proto::ComputeEmbeddings>,
4743    session: UserSession,
4744    api_key: Option<Arc<str>>,
4745) -> Result<()> {
4746    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4747    authorize_access_to_language_models(&session).await?;
4748
4749    session
4750        .rate_limiter
4751        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4752        .await?;
4753
4754    let embeddings = match request.model.as_str() {
4755        "openai/text-embedding-3-small" => {
4756            open_ai::embed(
4757                session.http_client.as_ref(),
4758                OPEN_AI_API_URL,
4759                &api_key,
4760                OpenAiEmbeddingModel::TextEmbedding3Small,
4761                request.texts.iter().map(|text| text.as_str()),
4762            )
4763            .await?
4764        }
4765        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4766    };
4767
4768    let embeddings = request
4769        .texts
4770        .iter()
4771        .map(|text| {
4772            let mut hasher = sha2::Sha256::new();
4773            hasher.update(text.as_bytes());
4774            let result = hasher.finalize();
4775            result.to_vec()
4776        })
4777        .zip(
4778            embeddings
4779                .data
4780                .into_iter()
4781                .map(|embedding| embedding.embedding),
4782        )
4783        .collect::<HashMap<_, _>>();
4784
4785    let db = session.db().await;
4786    db.save_embeddings(&request.model, &embeddings)
4787        .await
4788        .context("failed to save embeddings")
4789        .trace_err();
4790
4791    response.send(proto::ComputeEmbeddingsResponse {
4792        embeddings: embeddings
4793            .into_iter()
4794            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4795            .collect(),
4796    })?;
4797    Ok(())
4798}
4799
4800async fn get_cached_embeddings(
4801    request: proto::GetCachedEmbeddings,
4802    response: Response<proto::GetCachedEmbeddings>,
4803    session: UserSession,
4804) -> Result<()> {
4805    authorize_access_to_language_models(&session).await?;
4806
4807    let db = session.db().await;
4808    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4809
4810    response.send(proto::GetCachedEmbeddingsResponse {
4811        embeddings: embeddings
4812            .into_iter()
4813            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4814            .collect(),
4815    })?;
4816    Ok(())
4817}
4818
4819async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4820    let db = session.db().await;
4821    let flags = db.get_user_flags(session.user_id()).await?;
4822    if flags.iter().any(|flag| flag == "language-models") {
4823        Ok(())
4824    } else {
4825        Err(anyhow!("permission denied"))?
4826    }
4827}
4828
4829/// Get a Supermaven API key for the user
4830async fn get_supermaven_api_key(
4831    _request: proto::GetSupermavenApiKey,
4832    response: Response<proto::GetSupermavenApiKey>,
4833    session: UserSession,
4834) -> Result<()> {
4835    let user_id: String = session.user_id().to_string();
4836    if !session.is_staff() {
4837        return Err(anyhow!("supermaven not enabled for this account"))?;
4838    }
4839
4840    let email = session
4841        .email()
4842        .ok_or_else(|| anyhow!("user must have an email"))?;
4843
4844    let supermaven_admin_api = session
4845        .supermaven_client
4846        .as_ref()
4847        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4848
4849    let result = supermaven_admin_api
4850        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4851        .await?;
4852
4853    response.send(proto::GetSupermavenApiKeyResponse {
4854        api_key: result.api_key,
4855    })?;
4856
4857    Ok(())
4858}
4859
4860/// Start receiving chat updates for a channel
4861async fn join_channel_chat(
4862    request: proto::JoinChannelChat,
4863    response: Response<proto::JoinChannelChat>,
4864    session: UserSession,
4865) -> Result<()> {
4866    let channel_id = ChannelId::from_proto(request.channel_id);
4867
4868    let db = session.db().await;
4869    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4870        .await?;
4871    let messages = db
4872        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4873        .await?;
4874    response.send(proto::JoinChannelChatResponse {
4875        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4876        messages,
4877    })?;
4878    Ok(())
4879}
4880
4881/// Stop receiving chat updates for a channel
4882async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4883    let channel_id = ChannelId::from_proto(request.channel_id);
4884    session
4885        .db()
4886        .await
4887        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4888        .await?;
4889    Ok(())
4890}
4891
4892/// Retrieve the chat history for a channel
4893async fn get_channel_messages(
4894    request: proto::GetChannelMessages,
4895    response: Response<proto::GetChannelMessages>,
4896    session: UserSession,
4897) -> Result<()> {
4898    let channel_id = ChannelId::from_proto(request.channel_id);
4899    let messages = session
4900        .db()
4901        .await
4902        .get_channel_messages(
4903            channel_id,
4904            session.user_id(),
4905            MESSAGE_COUNT_PER_PAGE,
4906            Some(MessageId::from_proto(request.before_message_id)),
4907        )
4908        .await?;
4909    response.send(proto::GetChannelMessagesResponse {
4910        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4911        messages,
4912    })?;
4913    Ok(())
4914}
4915
4916/// Retrieve specific chat messages
4917async fn get_channel_messages_by_id(
4918    request: proto::GetChannelMessagesById,
4919    response: Response<proto::GetChannelMessagesById>,
4920    session: UserSession,
4921) -> Result<()> {
4922    let message_ids = request
4923        .message_ids
4924        .iter()
4925        .map(|id| MessageId::from_proto(*id))
4926        .collect::<Vec<_>>();
4927    let messages = session
4928        .db()
4929        .await
4930        .get_channel_messages_by_id(session.user_id(), &message_ids)
4931        .await?;
4932    response.send(proto::GetChannelMessagesResponse {
4933        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4934        messages,
4935    })?;
4936    Ok(())
4937}
4938
4939/// Retrieve the current users notifications
4940async fn get_notifications(
4941    request: proto::GetNotifications,
4942    response: Response<proto::GetNotifications>,
4943    session: UserSession,
4944) -> Result<()> {
4945    let notifications = session
4946        .db()
4947        .await
4948        .get_notifications(
4949            session.user_id(),
4950            NOTIFICATION_COUNT_PER_PAGE,
4951            request
4952                .before_id
4953                .map(|id| db::NotificationId::from_proto(id)),
4954        )
4955        .await?;
4956    response.send(proto::GetNotificationsResponse {
4957        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4958        notifications,
4959    })?;
4960    Ok(())
4961}
4962
4963/// Mark notifications as read
4964async fn mark_notification_as_read(
4965    request: proto::MarkNotificationRead,
4966    response: Response<proto::MarkNotificationRead>,
4967    session: UserSession,
4968) -> Result<()> {
4969    let database = &session.db().await;
4970    let notifications = database
4971        .mark_notification_as_read_by_id(
4972            session.user_id(),
4973            NotificationId::from_proto(request.notification_id),
4974        )
4975        .await?;
4976    send_notifications(
4977        &*session.connection_pool().await,
4978        &session.peer,
4979        notifications,
4980    );
4981    response.send(proto::Ack {})?;
4982    Ok(())
4983}
4984
4985/// Get the current users information
4986async fn get_private_user_info(
4987    _request: proto::GetPrivateUserInfo,
4988    response: Response<proto::GetPrivateUserInfo>,
4989    session: UserSession,
4990) -> Result<()> {
4991    let db = session.db().await;
4992
4993    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4994    let user = db
4995        .get_user_by_id(session.user_id())
4996        .await?
4997        .ok_or_else(|| anyhow!("user not found"))?;
4998    let flags = db.get_user_flags(session.user_id()).await?;
4999
5000    response.send(proto::GetPrivateUserInfoResponse {
5001        metrics_id,
5002        staff: user.admin,
5003        flags,
5004    })?;
5005    Ok(())
5006}
5007
5008fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5009    match message {
5010        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5011        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5012        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5013        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5014        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5015            code: frame.code.into(),
5016            reason: frame.reason,
5017        })),
5018    }
5019}
5020
5021fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5022    match message {
5023        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5024        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5025        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5026        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5027        AxumMessage::Close(frame) => {
5028            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5029                code: frame.code.into(),
5030                reason: frame.reason,
5031            }))
5032        }
5033    }
5034}
5035
5036fn notify_membership_updated(
5037    connection_pool: &mut ConnectionPool,
5038    result: MembershipUpdated,
5039    user_id: UserId,
5040    peer: &Peer,
5041) {
5042    for membership in &result.new_channels.channel_memberships {
5043        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5044    }
5045    for channel_id in &result.removed_channels {
5046        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5047    }
5048
5049    let user_channels_update = proto::UpdateUserChannels {
5050        channel_memberships: result
5051            .new_channels
5052            .channel_memberships
5053            .iter()
5054            .map(|cm| proto::ChannelMembership {
5055                channel_id: cm.channel_id.to_proto(),
5056                role: cm.role.into(),
5057            })
5058            .collect(),
5059        ..Default::default()
5060    };
5061
5062    let mut update = build_channels_update(result.new_channels);
5063    update.delete_channels = result
5064        .removed_channels
5065        .into_iter()
5066        .map(|id| id.to_proto())
5067        .collect();
5068    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5069
5070    for connection_id in connection_pool.user_connection_ids(user_id) {
5071        peer.send(connection_id, user_channels_update.clone())
5072            .trace_err();
5073        peer.send(connection_id, update.clone()).trace_err();
5074    }
5075}
5076
5077fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5078    proto::UpdateUserChannels {
5079        channel_memberships: channels
5080            .channel_memberships
5081            .iter()
5082            .map(|m| proto::ChannelMembership {
5083                channel_id: m.channel_id.to_proto(),
5084                role: m.role.into(),
5085            })
5086            .collect(),
5087        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5088        observed_channel_message_id: channels.observed_channel_messages.clone(),
5089    }
5090}
5091
5092fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5093    let mut update = proto::UpdateChannels::default();
5094
5095    for channel in channels.channels {
5096        update.channels.push(channel.to_proto());
5097    }
5098
5099    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5100    update.latest_channel_message_ids = channels.latest_channel_messages;
5101
5102    for (channel_id, participants) in channels.channel_participants {
5103        update
5104            .channel_participants
5105            .push(proto::ChannelParticipants {
5106                channel_id: channel_id.to_proto(),
5107                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5108            });
5109    }
5110
5111    for channel in channels.invited_channels {
5112        update.channel_invitations.push(channel.to_proto());
5113    }
5114
5115    update.hosted_projects = channels.hosted_projects;
5116    update
5117}
5118
5119fn build_initial_contacts_update(
5120    contacts: Vec<db::Contact>,
5121    pool: &ConnectionPool,
5122) -> proto::UpdateContacts {
5123    let mut update = proto::UpdateContacts::default();
5124
5125    for contact in contacts {
5126        match contact {
5127            db::Contact::Accepted { user_id, busy } => {
5128                update.contacts.push(contact_for_user(user_id, busy, &pool));
5129            }
5130            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5131            db::Contact::Incoming { user_id } => {
5132                update
5133                    .incoming_requests
5134                    .push(proto::IncomingContactRequest {
5135                        requester_id: user_id.to_proto(),
5136                    })
5137            }
5138        }
5139    }
5140
5141    update
5142}
5143
5144fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5145    proto::Contact {
5146        user_id: user_id.to_proto(),
5147        online: pool.is_user_online(user_id),
5148        busy,
5149    }
5150}
5151
5152fn room_updated(room: &proto::Room, peer: &Peer) {
5153    broadcast(
5154        None,
5155        room.participants
5156            .iter()
5157            .filter_map(|participant| Some(participant.peer_id?.into())),
5158        |peer_id| {
5159            peer.send(
5160                peer_id,
5161                proto::RoomUpdated {
5162                    room: Some(room.clone()),
5163                },
5164            )
5165        },
5166    );
5167}
5168
5169fn channel_updated(
5170    channel: &db::channel::Model,
5171    room: &proto::Room,
5172    peer: &Peer,
5173    pool: &ConnectionPool,
5174) {
5175    let participants = room
5176        .participants
5177        .iter()
5178        .map(|p| p.user_id)
5179        .collect::<Vec<_>>();
5180
5181    broadcast(
5182        None,
5183        pool.channel_connection_ids(channel.root_id())
5184            .filter_map(|(channel_id, role)| {
5185                role.can_see_channel(channel.visibility).then(|| channel_id)
5186            }),
5187        |peer_id| {
5188            peer.send(
5189                peer_id,
5190                proto::UpdateChannels {
5191                    channel_participants: vec![proto::ChannelParticipants {
5192                        channel_id: channel.id.to_proto(),
5193                        participant_user_ids: participants.clone(),
5194                    }],
5195                    ..Default::default()
5196                },
5197            )
5198        },
5199    );
5200}
5201
5202async fn send_dev_server_projects_update(
5203    user_id: UserId,
5204    mut status: proto::DevServerProjectsUpdate,
5205    session: &Session,
5206) {
5207    let pool = session.connection_pool().await;
5208    for dev_server in &mut status.dev_servers {
5209        dev_server.status =
5210            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5211    }
5212    let connections = pool.user_connection_ids(user_id);
5213    for connection_id in connections {
5214        session.peer.send(connection_id, status.clone()).trace_err();
5215    }
5216}
5217
5218async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5219    let db = session.db().await;
5220
5221    let contacts = db.get_contacts(user_id).await?;
5222    let busy = db.is_user_busy(user_id).await?;
5223
5224    let pool = session.connection_pool().await;
5225    let updated_contact = contact_for_user(user_id, busy, &pool);
5226    for contact in contacts {
5227        if let db::Contact::Accepted {
5228            user_id: contact_user_id,
5229            ..
5230        } = contact
5231        {
5232            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5233                session
5234                    .peer
5235                    .send(
5236                        contact_conn_id,
5237                        proto::UpdateContacts {
5238                            contacts: vec![updated_contact.clone()],
5239                            remove_contacts: Default::default(),
5240                            incoming_requests: Default::default(),
5241                            remove_incoming_requests: Default::default(),
5242                            outgoing_requests: Default::default(),
5243                            remove_outgoing_requests: Default::default(),
5244                        },
5245                    )
5246                    .trace_err();
5247            }
5248        }
5249    }
5250    Ok(())
5251}
5252
5253async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5254    log::info!("lost dev server connection, unsharing projects");
5255    let project_ids = session
5256        .db()
5257        .await
5258        .get_stale_dev_server_projects(session.connection_id)
5259        .await?;
5260
5261    for project_id in project_ids {
5262        // not unshare re-checks the connection ids match, so we get away with no transaction
5263        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5264    }
5265
5266    let user_id = session.dev_server().user_id;
5267    let update = session
5268        .db()
5269        .await
5270        .dev_server_projects_update(user_id)
5271        .await?;
5272
5273    send_dev_server_projects_update(user_id, update, session).await;
5274
5275    Ok(())
5276}
5277
5278async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5279    let mut contacts_to_update = HashSet::default();
5280
5281    let room_id;
5282    let canceled_calls_to_user_ids;
5283    let live_kit_room;
5284    let delete_live_kit_room;
5285    let room;
5286    let channel;
5287
5288    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5289        contacts_to_update.insert(session.user_id());
5290
5291        for project in left_room.left_projects.values() {
5292            project_left(project, session);
5293        }
5294
5295        room_id = RoomId::from_proto(left_room.room.id);
5296        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5297        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5298        delete_live_kit_room = left_room.deleted;
5299        room = mem::take(&mut left_room.room);
5300        channel = mem::take(&mut left_room.channel);
5301
5302        room_updated(&room, &session.peer);
5303    } else {
5304        return Ok(());
5305    }
5306
5307    if let Some(channel) = channel {
5308        channel_updated(
5309            &channel,
5310            &room,
5311            &session.peer,
5312            &*session.connection_pool().await,
5313        );
5314    }
5315
5316    {
5317        let pool = session.connection_pool().await;
5318        for canceled_user_id in canceled_calls_to_user_ids {
5319            for connection_id in pool.user_connection_ids(canceled_user_id) {
5320                session
5321                    .peer
5322                    .send(
5323                        connection_id,
5324                        proto::CallCanceled {
5325                            room_id: room_id.to_proto(),
5326                        },
5327                    )
5328                    .trace_err();
5329            }
5330            contacts_to_update.insert(canceled_user_id);
5331        }
5332    }
5333
5334    for contact_user_id in contacts_to_update {
5335        update_user_contacts(contact_user_id, &session).await?;
5336    }
5337
5338    if let Some(live_kit) = session.live_kit_client.as_ref() {
5339        live_kit
5340            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5341            .await
5342            .trace_err();
5343
5344        if delete_live_kit_room {
5345            live_kit.delete_room(live_kit_room).await.trace_err();
5346        }
5347    }
5348
5349    Ok(())
5350}
5351
5352async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5353    let left_channel_buffers = session
5354        .db()
5355        .await
5356        .leave_channel_buffers(session.connection_id)
5357        .await?;
5358
5359    for left_buffer in left_channel_buffers {
5360        channel_buffer_updated(
5361            session.connection_id,
5362            left_buffer.connections,
5363            &proto::UpdateChannelBufferCollaborators {
5364                channel_id: left_buffer.channel_id.to_proto(),
5365                collaborators: left_buffer.collaborators,
5366            },
5367            &session.peer,
5368        );
5369    }
5370
5371    Ok(())
5372}
5373
5374fn project_left(project: &db::LeftProject, session: &UserSession) {
5375    for connection_id in &project.connection_ids {
5376        if project.should_unshare {
5377            session
5378                .peer
5379                .send(
5380                    *connection_id,
5381                    proto::UnshareProject {
5382                        project_id: project.id.to_proto(),
5383                    },
5384                )
5385                .trace_err();
5386        } else {
5387            session
5388                .peer
5389                .send(
5390                    *connection_id,
5391                    proto::RemoveProjectCollaborator {
5392                        project_id: project.id.to_proto(),
5393                        peer_id: Some(session.connection_id.into()),
5394                    },
5395                )
5396                .trace_err();
5397        }
5398    }
5399}
5400
5401pub trait ResultExt {
5402    type Ok;
5403
5404    fn trace_err(self) -> Option<Self::Ok>;
5405}
5406
5407impl<T, E> ResultExt for Result<T, E>
5408where
5409    E: std::fmt::Debug,
5410{
5411    type Ok = T;
5412
5413    #[track_caller]
5414    fn trace_err(self) -> Option<T> {
5415        match self {
5416            Ok(value) => Some(value),
5417            Err(error) => {
5418                tracing::error!("{:?}", error);
5419                None
5420            }
5421        }
5422    }
5423}