rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  50        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(delete_dev_server_project))
 435            .add_request_handler(user_handler(create_dev_server))
 436            .add_request_handler(user_handler(regenerate_dev_server_token))
 437            .add_request_handler(user_handler(rename_dev_server))
 438            .add_request_handler(user_handler(delete_dev_server))
 439            .add_request_handler(dev_server_handler(share_dev_server_project))
 440            .add_request_handler(dev_server_handler(shutdown_dev_server))
 441            .add_request_handler(dev_server_handler(reconnect_dev_server))
 442            .add_message_handler(user_message_handler(leave_project))
 443            .add_request_handler(update_project)
 444            .add_request_handler(update_worktree)
 445            .add_message_handler(start_language_server)
 446            .add_message_handler(update_language_server)
 447            .add_message_handler(update_diagnostic_summary)
 448            .add_message_handler(update_worktree_settings)
 449            .add_request_handler(user_handler(
 450                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 451            ))
 452            .add_request_handler(user_handler(
 453                forward_project_request_for_owner::<proto::TaskTemplates>,
 454            ))
 455            .add_request_handler(user_handler(
 456                forward_read_only_project_request::<proto::GetHover>,
 457            ))
 458            .add_request_handler(user_handler(
 459                forward_read_only_project_request::<proto::GetDefinition>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_read_only_project_request::<proto::GetTypeDefinition>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetReferences>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::SearchProject>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetProjectSymbols>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_read_only_project_request::<proto::OpenBufferById>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::InlayHints>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::OpenBufferByPath>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_mutating_project_request::<proto::GetCompletions>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::GetCodeActions>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::ApplyCodeAction>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::PrepareRename>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::PerformRename>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::ReloadBuffers>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::FormatBuffers>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::CreateProjectEntry>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::RenameProjectEntry>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::CopyProjectEntry>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::OnTypeFormatting>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::BlameBuffer>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::MultiLspQuery>,
 547            ))
 548            .add_message_handler(create_buffer_for_peer)
 549            .add_request_handler(update_buffer)
 550            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 551            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 552            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 553            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 554            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 555            .add_request_handler(get_users)
 556            .add_request_handler(user_handler(fuzzy_search_users))
 557            .add_request_handler(user_handler(request_contact))
 558            .add_request_handler(user_handler(remove_contact))
 559            .add_request_handler(user_handler(respond_to_contact_request))
 560            .add_request_handler(user_handler(create_channel))
 561            .add_request_handler(user_handler(delete_channel))
 562            .add_request_handler(user_handler(invite_channel_member))
 563            .add_request_handler(user_handler(remove_channel_member))
 564            .add_request_handler(user_handler(set_channel_member_role))
 565            .add_request_handler(user_handler(set_channel_visibility))
 566            .add_request_handler(user_handler(rename_channel))
 567            .add_request_handler(user_handler(join_channel_buffer))
 568            .add_request_handler(user_handler(leave_channel_buffer))
 569            .add_message_handler(user_message_handler(update_channel_buffer))
 570            .add_request_handler(user_handler(rejoin_channel_buffers))
 571            .add_request_handler(user_handler(get_channel_members))
 572            .add_request_handler(user_handler(respond_to_channel_invite))
 573            .add_request_handler(user_handler(join_channel))
 574            .add_request_handler(user_handler(join_channel_chat))
 575            .add_message_handler(user_message_handler(leave_channel_chat))
 576            .add_request_handler(user_handler(send_channel_message))
 577            .add_request_handler(user_handler(remove_channel_message))
 578            .add_request_handler(user_handler(update_channel_message))
 579            .add_request_handler(user_handler(get_channel_messages))
 580            .add_request_handler(user_handler(get_channel_messages_by_id))
 581            .add_request_handler(user_handler(get_notifications))
 582            .add_request_handler(user_handler(mark_notification_as_read))
 583            .add_request_handler(user_handler(move_channel))
 584            .add_request_handler(user_handler(follow))
 585            .add_message_handler(user_message_handler(unfollow))
 586            .add_message_handler(user_message_handler(update_followers))
 587            .add_request_handler(user_handler(get_private_user_info))
 588            .add_message_handler(user_message_handler(acknowledge_channel_message))
 589            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 590            .add_request_handler(user_handler(get_supermaven_api_key))
 591            .add_streaming_request_handler({
 592                let app_state = app_state.clone();
 593                move |request, response, session| {
 594                    complete_with_language_model(
 595                        request,
 596                        response,
 597                        session,
 598                        app_state.config.openai_api_key.clone(),
 599                        app_state.config.google_ai_api_key.clone(),
 600                        app_state.config.anthropic_api_key.clone(),
 601                    )
 602                }
 603            })
 604            .add_request_handler({
 605                let app_state = app_state.clone();
 606                user_handler(move |request, response, session| {
 607                    count_tokens_with_language_model(
 608                        request,
 609                        response,
 610                        session,
 611                        app_state.config.google_ai_api_key.clone(),
 612                    )
 613                })
 614            })
 615            .add_request_handler({
 616                user_handler(move |request, response, session| {
 617                    get_cached_embeddings(request, response, session)
 618                })
 619            })
 620            .add_request_handler({
 621                let app_state = app_state.clone();
 622                user_handler(move |request, response, session| {
 623                    compute_embeddings(
 624                        request,
 625                        response,
 626                        session,
 627                        app_state.config.openai_api_key.clone(),
 628                    )
 629                })
 630            });
 631
 632        Arc::new(server)
 633    }
 634
 635    pub async fn start(&self) -> Result<()> {
 636        let server_id = *self.id.lock();
 637        let app_state = self.app_state.clone();
 638        let peer = self.peer.clone();
 639        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 640        let pool = self.connection_pool.clone();
 641        let live_kit_client = self.app_state.live_kit_client.clone();
 642
 643        let span = info_span!("start server");
 644        self.app_state.executor.spawn_detached(
 645            async move {
 646                tracing::info!("waiting for cleanup timeout");
 647                timeout.await;
 648                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 649                if let Some((room_ids, channel_ids)) = app_state
 650                    .db
 651                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 652                    .await
 653                    .trace_err()
 654                {
 655                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 656                    tracing::info!(
 657                        stale_channel_buffer_count = channel_ids.len(),
 658                        "retrieved stale channel buffers"
 659                    );
 660
 661                    for channel_id in channel_ids {
 662                        if let Some(refreshed_channel_buffer) = app_state
 663                            .db
 664                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 665                            .await
 666                            .trace_err()
 667                        {
 668                            for connection_id in refreshed_channel_buffer.connection_ids {
 669                                peer.send(
 670                                    connection_id,
 671                                    proto::UpdateChannelBufferCollaborators {
 672                                        channel_id: channel_id.to_proto(),
 673                                        collaborators: refreshed_channel_buffer
 674                                            .collaborators
 675                                            .clone(),
 676                                    },
 677                                )
 678                                .trace_err();
 679                            }
 680                        }
 681                    }
 682
 683                    for room_id in room_ids {
 684                        let mut contacts_to_update = HashSet::default();
 685                        let mut canceled_calls_to_user_ids = Vec::new();
 686                        let mut live_kit_room = String::new();
 687                        let mut delete_live_kit_room = false;
 688
 689                        if let Some(mut refreshed_room) = app_state
 690                            .db
 691                            .clear_stale_room_participants(room_id, server_id)
 692                            .await
 693                            .trace_err()
 694                        {
 695                            tracing::info!(
 696                                room_id = room_id.0,
 697                                new_participant_count = refreshed_room.room.participants.len(),
 698                                "refreshed room"
 699                            );
 700                            room_updated(&refreshed_room.room, &peer);
 701                            if let Some(channel) = refreshed_room.channel.as_ref() {
 702                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 703                            }
 704                            contacts_to_update
 705                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 706                            contacts_to_update
 707                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 708                            canceled_calls_to_user_ids =
 709                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 710                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 711                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 712                        }
 713
 714                        {
 715                            let pool = pool.lock();
 716                            for canceled_user_id in canceled_calls_to_user_ids {
 717                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 718                                    peer.send(
 719                                        connection_id,
 720                                        proto::CallCanceled {
 721                                            room_id: room_id.to_proto(),
 722                                        },
 723                                    )
 724                                    .trace_err();
 725                                }
 726                            }
 727                        }
 728
 729                        for user_id in contacts_to_update {
 730                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 731                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 732                            if let Some((busy, contacts)) = busy.zip(contacts) {
 733                                let pool = pool.lock();
 734                                let updated_contact = contact_for_user(user_id, busy, &pool);
 735                                for contact in contacts {
 736                                    if let db::Contact::Accepted {
 737                                        user_id: contact_user_id,
 738                                        ..
 739                                    } = contact
 740                                    {
 741                                        for contact_conn_id in
 742                                            pool.user_connection_ids(contact_user_id)
 743                                        {
 744                                            peer.send(
 745                                                contact_conn_id,
 746                                                proto::UpdateContacts {
 747                                                    contacts: vec![updated_contact.clone()],
 748                                                    remove_contacts: Default::default(),
 749                                                    incoming_requests: Default::default(),
 750                                                    remove_incoming_requests: Default::default(),
 751                                                    outgoing_requests: Default::default(),
 752                                                    remove_outgoing_requests: Default::default(),
 753                                                },
 754                                            )
 755                                            .trace_err();
 756                                        }
 757                                    }
 758                                }
 759                            }
 760                        }
 761
 762                        if let Some(live_kit) = live_kit_client.as_ref() {
 763                            if delete_live_kit_room {
 764                                live_kit.delete_room(live_kit_room).await.trace_err();
 765                            }
 766                        }
 767                    }
 768                }
 769
 770                app_state
 771                    .db
 772                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 773                    .await
 774                    .trace_err();
 775            }
 776            .instrument(span),
 777        );
 778        Ok(())
 779    }
 780
 781    pub fn teardown(&self) {
 782        self.peer.teardown();
 783        self.connection_pool.lock().reset();
 784        let _ = self.teardown.send(true);
 785    }
 786
 787    #[cfg(test)]
 788    pub fn reset(&self, id: ServerId) {
 789        self.teardown();
 790        *self.id.lock() = id;
 791        self.peer.reset(id.0 as u32);
 792        let _ = self.teardown.send(false);
 793    }
 794
 795    #[cfg(test)]
 796    pub fn id(&self) -> ServerId {
 797        *self.id.lock()
 798    }
 799
 800    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 801    where
 802        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 803        Fut: 'static + Send + Future<Output = Result<()>>,
 804        M: EnvelopedMessage,
 805    {
 806        let prev_handler = self.handlers.insert(
 807            TypeId::of::<M>(),
 808            Box::new(move |envelope, session| {
 809                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 810                let received_at = envelope.received_at;
 811                tracing::info!("message received");
 812                let start_time = Instant::now();
 813                let future = (handler)(*envelope, session);
 814                async move {
 815                    let result = future.await;
 816                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 817                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 818                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 819                    let payload_type = M::NAME;
 820
 821                    match result {
 822                        Err(error) => {
 823                            tracing::error!(
 824                                ?error,
 825                                total_duration_ms,
 826                                processing_duration_ms,
 827                                queue_duration_ms,
 828                                payload_type,
 829                                "error handling message"
 830                            )
 831                        }
 832                        Ok(()) => tracing::info!(
 833                            total_duration_ms,
 834                            processing_duration_ms,
 835                            queue_duration_ms,
 836                            "finished handling message"
 837                        ),
 838                    }
 839                }
 840                .boxed()
 841            }),
 842        );
 843        if prev_handler.is_some() {
 844            panic!("registered a handler for the same message twice");
 845        }
 846        self
 847    }
 848
 849    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 850    where
 851        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 852        Fut: 'static + Send + Future<Output = Result<()>>,
 853        M: EnvelopedMessage,
 854    {
 855        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 856        self
 857    }
 858
 859    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 860    where
 861        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 862        Fut: Send + Future<Output = Result<()>>,
 863        M: RequestMessage,
 864    {
 865        let handler = Arc::new(handler);
 866        self.add_handler(move |envelope, session| {
 867            let receipt = envelope.receipt();
 868            let handler = handler.clone();
 869            async move {
 870                let peer = session.peer.clone();
 871                let responded = Arc::new(AtomicBool::default());
 872                let response = Response {
 873                    peer: peer.clone(),
 874                    responded: responded.clone(),
 875                    receipt,
 876                };
 877                match (handler)(envelope.payload, response, session).await {
 878                    Ok(()) => {
 879                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 880                            Ok(())
 881                        } else {
 882                            Err(anyhow!("handler did not send a response"))?
 883                        }
 884                    }
 885                    Err(error) => {
 886                        let proto_err = match &error {
 887                            Error::Internal(err) => err.to_proto(),
 888                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 889                        };
 890                        peer.respond_with_error(receipt, proto_err)?;
 891                        Err(error)
 892                    }
 893                }
 894            }
 895        })
 896    }
 897
 898    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 899    where
 900        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 901        Fut: Send + Future<Output = Result<()>>,
 902        M: RequestMessage,
 903    {
 904        let handler = Arc::new(handler);
 905        self.add_handler(move |envelope, session| {
 906            let receipt = envelope.receipt();
 907            let handler = handler.clone();
 908            async move {
 909                let peer = session.peer.clone();
 910                let response = StreamingResponse {
 911                    peer: peer.clone(),
 912                    receipt,
 913                };
 914                match (handler)(envelope.payload, response, session).await {
 915                    Ok(()) => {
 916                        peer.end_stream(receipt)?;
 917                        Ok(())
 918                    }
 919                    Err(error) => {
 920                        let proto_err = match &error {
 921                            Error::Internal(err) => err.to_proto(),
 922                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 923                        };
 924                        peer.respond_with_error(receipt, proto_err)?;
 925                        Err(error)
 926                    }
 927                }
 928            }
 929        })
 930    }
 931
 932    #[allow(clippy::too_many_arguments)]
 933    pub fn handle_connection(
 934        self: &Arc<Self>,
 935        connection: Connection,
 936        address: String,
 937        principal: Principal,
 938        zed_version: ZedVersion,
 939        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 940        executor: Executor,
 941    ) -> impl Future<Output = ()> {
 942        let this = self.clone();
 943        let span = info_span!("handle connection", %address,
 944            connection_id=field::Empty,
 945            user_id=field::Empty,
 946            login=field::Empty,
 947            impersonator=field::Empty,
 948            dev_server_id=field::Empty
 949        );
 950        principal.update_span(&span);
 951
 952        let mut teardown = self.teardown.subscribe();
 953        async move {
 954            if *teardown.borrow() {
 955                tracing::error!("server is tearing down");
 956                return
 957            }
 958            let (connection_id, handle_io, mut incoming_rx) = this
 959                .peer
 960                .add_connection(connection, {
 961                    let executor = executor.clone();
 962                    move |duration| executor.sleep(duration)
 963                });
 964            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 965            tracing::info!("connection opened");
 966
 967            let http_client = match IsahcHttpClient::new() {
 968                Ok(http_client) => Arc::new(http_client),
 969                Err(error) => {
 970                    tracing::error!(?error, "failed to create HTTP client");
 971                    return;
 972                }
 973            };
 974
 975            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 976                Some(Arc::new(SupermavenAdminApi::new(
 977                    supermaven_admin_api_key.to_string(),
 978                    http_client.clone(),
 979                )))
 980            } else {
 981                None
 982            };
 983
 984            let session = Session {
 985                principal: principal.clone(),
 986                connection_id,
 987                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 988                peer: this.peer.clone(),
 989                connection_pool: this.connection_pool.clone(),
 990                live_kit_client: this.app_state.live_kit_client.clone(),
 991                http_client,
 992                rate_limiter: this.app_state.rate_limiter.clone(),
 993                _executor: executor.clone(),
 994                supermaven_client,
 995            };
 996
 997            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 998                tracing::error!(?error, "failed to send initial client update");
 999                return;
1000            }
1001
1002            let handle_io = handle_io.fuse();
1003            futures::pin_mut!(handle_io);
1004
1005            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1006            // This prevents deadlocks when e.g., client A performs a request to client B and
1007            // client B performs a request to client A. If both clients stop processing further
1008            // messages until their respective request completes, they won't have a chance to
1009            // respond to the other client's request and cause a deadlock.
1010            //
1011            // This arrangement ensures we will attempt to process earlier messages first, but fall
1012            // back to processing messages arrived later in the spirit of making progress.
1013            let mut foreground_message_handlers = FuturesUnordered::new();
1014            let concurrent_handlers = Arc::new(Semaphore::new(256));
1015            loop {
1016                let next_message = async {
1017                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1018                    let message = incoming_rx.next().await;
1019                    (permit, message)
1020                }.fuse();
1021                futures::pin_mut!(next_message);
1022                futures::select_biased! {
1023                    _ = teardown.changed().fuse() => return,
1024                    result = handle_io => {
1025                        if let Err(error) = result {
1026                            tracing::error!(?error, "error handling I/O");
1027                        }
1028                        break;
1029                    }
1030                    _ = foreground_message_handlers.next() => {}
1031                    next_message = next_message => {
1032                        let (permit, message) = next_message;
1033                        if let Some(message) = message {
1034                            let type_name = message.payload_type_name();
1035                            // note: we copy all the fields from the parent span so we can query them in the logs.
1036                            // (https://github.com/tokio-rs/tracing/issues/2670).
1037                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1038                                user_id=field::Empty,
1039                                login=field::Empty,
1040                                impersonator=field::Empty,
1041                                dev_server_id=field::Empty
1042                            );
1043                            principal.update_span(&span);
1044                            let span_enter = span.enter();
1045                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1046                                let is_background = message.is_background();
1047                                let handle_message = (handler)(message, session.clone());
1048                                drop(span_enter);
1049
1050                                let handle_message = async move {
1051                                    handle_message.await;
1052                                    drop(permit);
1053                                }.instrument(span);
1054                                if is_background {
1055                                    executor.spawn_detached(handle_message);
1056                                } else {
1057                                    foreground_message_handlers.push(handle_message);
1058                                }
1059                            } else {
1060                                tracing::error!("no message handler");
1061                            }
1062                        } else {
1063                            tracing::info!("connection closed");
1064                            break;
1065                        }
1066                    }
1067                }
1068            }
1069
1070            drop(foreground_message_handlers);
1071            tracing::info!("signing out");
1072            if let Err(error) = connection_lost(session, teardown, executor).await {
1073                tracing::error!(?error, "error signing out");
1074            }
1075
1076        }.instrument(span)
1077    }
1078
1079    async fn send_initial_client_update(
1080        &self,
1081        connection_id: ConnectionId,
1082        principal: &Principal,
1083        zed_version: ZedVersion,
1084        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1085        session: &Session,
1086    ) -> Result<()> {
1087        self.peer.send(
1088            connection_id,
1089            proto::Hello {
1090                peer_id: Some(connection_id.into()),
1091            },
1092        )?;
1093        tracing::info!("sent hello message");
1094        if let Some(send_connection_id) = send_connection_id.take() {
1095            let _ = send_connection_id.send(connection_id);
1096        }
1097
1098        match principal {
1099            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1100                if !user.connected_once {
1101                    self.peer.send(connection_id, proto::ShowContacts {})?;
1102                    self.app_state
1103                        .db
1104                        .set_user_connected_once(user.id, true)
1105                        .await?;
1106                }
1107
1108                let (contacts, channels_for_user, channel_invites, dev_server_projects) =
1109                    future::try_join4(
1110                        self.app_state.db.get_contacts(user.id),
1111                        self.app_state.db.get_channels_for_user(user.id),
1112                        self.app_state.db.get_channel_invites_for_user(user.id),
1113                        self.app_state.db.dev_server_projects_update(user.id),
1114                    )
1115                    .await?;
1116
1117                {
1118                    let mut pool = self.connection_pool.lock();
1119                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1120                    for membership in &channels_for_user.channel_memberships {
1121                        pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1122                    }
1123                    self.peer.send(
1124                        connection_id,
1125                        build_initial_contacts_update(contacts, &pool),
1126                    )?;
1127                    self.peer.send(
1128                        connection_id,
1129                        build_update_user_channels(&channels_for_user),
1130                    )?;
1131                    self.peer.send(
1132                        connection_id,
1133                        build_channels_update(channels_for_user, channel_invites),
1134                    )?;
1135                }
1136                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1137
1138                if let Some(incoming_call) =
1139                    self.app_state.db.incoming_call_for_user(user.id).await?
1140                {
1141                    self.peer.send(connection_id, incoming_call)?;
1142                }
1143
1144                update_user_contacts(user.id, &session).await?;
1145            }
1146            Principal::DevServer(dev_server) => {
1147                {
1148                    let mut pool = self.connection_pool.lock();
1149                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1150                    {
1151                        self.peer.send(
1152                            stale_connection_id,
1153                            proto::ShutdownDevServer {
1154                                reason: Some(
1155                                    "another dev server connected with the same token".to_string(),
1156                                ),
1157                            },
1158                        )?;
1159                        pool.remove_connection(stale_connection_id)?;
1160                    };
1161                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1162                }
1163
1164                let projects = self
1165                    .app_state
1166                    .db
1167                    .get_projects_for_dev_server(dev_server.id)
1168                    .await?;
1169                self.peer
1170                    .send(connection_id, proto::DevServerInstructions { projects })?;
1171
1172                let status = self
1173                    .app_state
1174                    .db
1175                    .dev_server_projects_update(dev_server.user_id)
1176                    .await?;
1177                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1178            }
1179        }
1180
1181        Ok(())
1182    }
1183
1184    pub async fn invite_code_redeemed(
1185        self: &Arc<Self>,
1186        inviter_id: UserId,
1187        invitee_id: UserId,
1188    ) -> Result<()> {
1189        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1190            if let Some(code) = &user.invite_code {
1191                let pool = self.connection_pool.lock();
1192                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1193                for connection_id in pool.user_connection_ids(inviter_id) {
1194                    self.peer.send(
1195                        connection_id,
1196                        proto::UpdateContacts {
1197                            contacts: vec![invitee_contact.clone()],
1198                            ..Default::default()
1199                        },
1200                    )?;
1201                    self.peer.send(
1202                        connection_id,
1203                        proto::UpdateInviteInfo {
1204                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1205                            count: user.invite_count as u32,
1206                        },
1207                    )?;
1208                }
1209            }
1210        }
1211        Ok(())
1212    }
1213
1214    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1215        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1216            if let Some(invite_code) = &user.invite_code {
1217                let pool = self.connection_pool.lock();
1218                for connection_id in pool.user_connection_ids(user_id) {
1219                    self.peer.send(
1220                        connection_id,
1221                        proto::UpdateInviteInfo {
1222                            url: format!(
1223                                "{}{}",
1224                                self.app_state.config.invite_link_prefix, invite_code
1225                            ),
1226                            count: user.invite_count as u32,
1227                        },
1228                    )?;
1229                }
1230            }
1231        }
1232        Ok(())
1233    }
1234
1235    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1236        ServerSnapshot {
1237            connection_pool: ConnectionPoolGuard {
1238                guard: self.connection_pool.lock(),
1239                _not_send: PhantomData,
1240            },
1241            peer: &self.peer,
1242        }
1243    }
1244}
1245
1246impl<'a> Deref for ConnectionPoolGuard<'a> {
1247    type Target = ConnectionPool;
1248
1249    fn deref(&self) -> &Self::Target {
1250        &self.guard
1251    }
1252}
1253
1254impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1255    fn deref_mut(&mut self) -> &mut Self::Target {
1256        &mut self.guard
1257    }
1258}
1259
1260impl<'a> Drop for ConnectionPoolGuard<'a> {
1261    fn drop(&mut self) {
1262        #[cfg(test)]
1263        self.check_invariants();
1264    }
1265}
1266
1267fn broadcast<F>(
1268    sender_id: Option<ConnectionId>,
1269    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1270    mut f: F,
1271) where
1272    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1273{
1274    for receiver_id in receiver_ids {
1275        if Some(receiver_id) != sender_id {
1276            if let Err(error) = f(receiver_id) {
1277                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1278            }
1279        }
1280    }
1281}
1282
1283pub struct ProtocolVersion(u32);
1284
1285impl Header for ProtocolVersion {
1286    fn name() -> &'static HeaderName {
1287        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1288        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1289    }
1290
1291    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1292    where
1293        Self: Sized,
1294        I: Iterator<Item = &'i axum::http::HeaderValue>,
1295    {
1296        let version = values
1297            .next()
1298            .ok_or_else(axum::headers::Error::invalid)?
1299            .to_str()
1300            .map_err(|_| axum::headers::Error::invalid())?
1301            .parse()
1302            .map_err(|_| axum::headers::Error::invalid())?;
1303        Ok(Self(version))
1304    }
1305
1306    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1307        values.extend([self.0.to_string().parse().unwrap()]);
1308    }
1309}
1310
1311pub struct AppVersionHeader(SemanticVersion);
1312impl Header for AppVersionHeader {
1313    fn name() -> &'static HeaderName {
1314        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1315        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1316    }
1317
1318    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1319    where
1320        Self: Sized,
1321        I: Iterator<Item = &'i axum::http::HeaderValue>,
1322    {
1323        let version = values
1324            .next()
1325            .ok_or_else(axum::headers::Error::invalid)?
1326            .to_str()
1327            .map_err(|_| axum::headers::Error::invalid())?
1328            .parse()
1329            .map_err(|_| axum::headers::Error::invalid())?;
1330        Ok(Self(version))
1331    }
1332
1333    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1334        values.extend([self.0.to_string().parse().unwrap()]);
1335    }
1336}
1337
1338pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1339    Router::new()
1340        .route("/rpc", get(handle_websocket_request))
1341        .layer(
1342            ServiceBuilder::new()
1343                .layer(Extension(server.app_state.clone()))
1344                .layer(middleware::from_fn(auth::validate_header)),
1345        )
1346        .route("/metrics", get(handle_metrics))
1347        .layer(Extension(server))
1348}
1349
1350pub async fn handle_websocket_request(
1351    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1352    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1353    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1354    Extension(server): Extension<Arc<Server>>,
1355    Extension(principal): Extension<Principal>,
1356    ws: WebSocketUpgrade,
1357) -> axum::response::Response {
1358    if protocol_version != rpc::PROTOCOL_VERSION {
1359        return (
1360            StatusCode::UPGRADE_REQUIRED,
1361            "client must be upgraded".to_string(),
1362        )
1363            .into_response();
1364    }
1365
1366    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1367        return (
1368            StatusCode::UPGRADE_REQUIRED,
1369            "no version header found".to_string(),
1370        )
1371            .into_response();
1372    };
1373
1374    if !version.can_collaborate() {
1375        return (
1376            StatusCode::UPGRADE_REQUIRED,
1377            "client must be upgraded".to_string(),
1378        )
1379            .into_response();
1380    }
1381
1382    let socket_address = socket_address.to_string();
1383    ws.on_upgrade(move |socket| {
1384        let socket = socket
1385            .map_ok(to_tungstenite_message)
1386            .err_into()
1387            .with(|message| async move { Ok(to_axum_message(message)) });
1388        let connection = Connection::new(Box::pin(socket));
1389        async move {
1390            server
1391                .handle_connection(
1392                    connection,
1393                    socket_address,
1394                    principal,
1395                    version,
1396                    None,
1397                    Executor::Production,
1398                )
1399                .await;
1400        }
1401    })
1402}
1403
1404pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1405    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1406    let connections_metric = CONNECTIONS_METRIC
1407        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1408
1409    let connections = server
1410        .connection_pool
1411        .lock()
1412        .connections()
1413        .filter(|connection| !connection.admin)
1414        .count();
1415    connections_metric.set(connections as _);
1416
1417    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1418    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1419        register_int_gauge!(
1420            "shared_projects",
1421            "number of open projects with one or more guests"
1422        )
1423        .unwrap()
1424    });
1425
1426    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1427    shared_projects_metric.set(shared_projects as _);
1428
1429    let encoder = prometheus::TextEncoder::new();
1430    let metric_families = prometheus::gather();
1431    let encoded_metrics = encoder
1432        .encode_to_string(&metric_families)
1433        .map_err(|err| anyhow!("{}", err))?;
1434    Ok(encoded_metrics)
1435}
1436
1437#[instrument(err, skip(executor))]
1438async fn connection_lost(
1439    session: Session,
1440    mut teardown: watch::Receiver<bool>,
1441    executor: Executor,
1442) -> Result<()> {
1443    session.peer.disconnect(session.connection_id);
1444    session
1445        .connection_pool()
1446        .await
1447        .remove_connection(session.connection_id)?;
1448
1449    session
1450        .db()
1451        .await
1452        .connection_lost(session.connection_id)
1453        .await
1454        .trace_err();
1455
1456    futures::select_biased! {
1457        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1458            match &session.principal {
1459                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1460                    let session = session.for_user().unwrap();
1461
1462                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1463                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1464                    leave_channel_buffers_for_session(&session)
1465                        .await
1466                        .trace_err();
1467
1468                    if !session
1469                        .connection_pool()
1470                        .await
1471                        .is_user_online(session.user_id())
1472                    {
1473                        let db = session.db().await;
1474                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1475                            room_updated(&room, &session.peer);
1476                        }
1477                    }
1478
1479                    update_user_contacts(session.user_id(), &session).await?;
1480                },
1481            Principal::DevServer(_) => {
1482                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1483            },
1484        }
1485        },
1486        _ = teardown.changed().fuse() => {}
1487    }
1488
1489    Ok(())
1490}
1491
1492/// Acknowledges a ping from a client, used to keep the connection alive.
1493async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1494    response.send(proto::Ack {})?;
1495    Ok(())
1496}
1497
1498/// Creates a new room for calling (outside of channels)
1499async fn create_room(
1500    _request: proto::CreateRoom,
1501    response: Response<proto::CreateRoom>,
1502    session: UserSession,
1503) -> Result<()> {
1504    let live_kit_room = nanoid::nanoid!(30);
1505
1506    let live_kit_connection_info = util::maybe!(async {
1507        let live_kit = session.live_kit_client.as_ref();
1508        let live_kit = live_kit?;
1509        let user_id = session.user_id().to_string();
1510
1511        let token = live_kit
1512            .room_token(&live_kit_room, &user_id.to_string())
1513            .trace_err()?;
1514
1515        Some(proto::LiveKitConnectionInfo {
1516            server_url: live_kit.url().into(),
1517            token,
1518            can_publish: true,
1519        })
1520    })
1521    .await;
1522
1523    let room = session
1524        .db()
1525        .await
1526        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1527        .await?;
1528
1529    response.send(proto::CreateRoomResponse {
1530        room: Some(room.clone()),
1531        live_kit_connection_info,
1532    })?;
1533
1534    update_user_contacts(session.user_id(), &session).await?;
1535    Ok(())
1536}
1537
1538/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1539async fn join_room(
1540    request: proto::JoinRoom,
1541    response: Response<proto::JoinRoom>,
1542    session: UserSession,
1543) -> Result<()> {
1544    let room_id = RoomId::from_proto(request.id);
1545
1546    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1547
1548    if let Some(channel_id) = channel_id {
1549        return join_channel_internal(channel_id, Box::new(response), session).await;
1550    }
1551
1552    let joined_room = {
1553        let room = session
1554            .db()
1555            .await
1556            .join_room(room_id, session.user_id(), session.connection_id)
1557            .await?;
1558        room_updated(&room.room, &session.peer);
1559        room.into_inner()
1560    };
1561
1562    for connection_id in session
1563        .connection_pool()
1564        .await
1565        .user_connection_ids(session.user_id())
1566    {
1567        session
1568            .peer
1569            .send(
1570                connection_id,
1571                proto::CallCanceled {
1572                    room_id: room_id.to_proto(),
1573                },
1574            )
1575            .trace_err();
1576    }
1577
1578    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1579        if let Some(token) = live_kit
1580            .room_token(
1581                &joined_room.room.live_kit_room,
1582                &session.user_id().to_string(),
1583            )
1584            .trace_err()
1585        {
1586            Some(proto::LiveKitConnectionInfo {
1587                server_url: live_kit.url().into(),
1588                token,
1589                can_publish: true,
1590            })
1591        } else {
1592            None
1593        }
1594    } else {
1595        None
1596    };
1597
1598    response.send(proto::JoinRoomResponse {
1599        room: Some(joined_room.room),
1600        channel_id: None,
1601        live_kit_connection_info,
1602    })?;
1603
1604    update_user_contacts(session.user_id(), &session).await?;
1605    Ok(())
1606}
1607
1608/// Rejoin room is used to reconnect to a room after connection errors.
1609async fn rejoin_room(
1610    request: proto::RejoinRoom,
1611    response: Response<proto::RejoinRoom>,
1612    session: UserSession,
1613) -> Result<()> {
1614    let room;
1615    let channel;
1616    {
1617        let mut rejoined_room = session
1618            .db()
1619            .await
1620            .rejoin_room(request, session.user_id(), session.connection_id)
1621            .await?;
1622
1623        response.send(proto::RejoinRoomResponse {
1624            room: Some(rejoined_room.room.clone()),
1625            reshared_projects: rejoined_room
1626                .reshared_projects
1627                .iter()
1628                .map(|project| proto::ResharedProject {
1629                    id: project.id.to_proto(),
1630                    collaborators: project
1631                        .collaborators
1632                        .iter()
1633                        .map(|collaborator| collaborator.to_proto())
1634                        .collect(),
1635                })
1636                .collect(),
1637            rejoined_projects: rejoined_room
1638                .rejoined_projects
1639                .iter()
1640                .map(|rejoined_project| rejoined_project.to_proto())
1641                .collect(),
1642        })?;
1643        room_updated(&rejoined_room.room, &session.peer);
1644
1645        for project in &rejoined_room.reshared_projects {
1646            for collaborator in &project.collaborators {
1647                session
1648                    .peer
1649                    .send(
1650                        collaborator.connection_id,
1651                        proto::UpdateProjectCollaborator {
1652                            project_id: project.id.to_proto(),
1653                            old_peer_id: Some(project.old_connection_id.into()),
1654                            new_peer_id: Some(session.connection_id.into()),
1655                        },
1656                    )
1657                    .trace_err();
1658            }
1659
1660            broadcast(
1661                Some(session.connection_id),
1662                project
1663                    .collaborators
1664                    .iter()
1665                    .map(|collaborator| collaborator.connection_id),
1666                |connection_id| {
1667                    session.peer.forward_send(
1668                        session.connection_id,
1669                        connection_id,
1670                        proto::UpdateProject {
1671                            project_id: project.id.to_proto(),
1672                            worktrees: project.worktrees.clone(),
1673                        },
1674                    )
1675                },
1676            );
1677        }
1678
1679        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1680
1681        let rejoined_room = rejoined_room.into_inner();
1682
1683        room = rejoined_room.room;
1684        channel = rejoined_room.channel;
1685    }
1686
1687    if let Some(channel) = channel {
1688        channel_updated(
1689            &channel,
1690            &room,
1691            &session.peer,
1692            &*session.connection_pool().await,
1693        );
1694    }
1695
1696    update_user_contacts(session.user_id(), &session).await?;
1697    Ok(())
1698}
1699
1700fn notify_rejoined_projects(
1701    rejoined_projects: &mut Vec<RejoinedProject>,
1702    session: &UserSession,
1703) -> Result<()> {
1704    for project in rejoined_projects.iter() {
1705        for collaborator in &project.collaborators {
1706            session
1707                .peer
1708                .send(
1709                    collaborator.connection_id,
1710                    proto::UpdateProjectCollaborator {
1711                        project_id: project.id.to_proto(),
1712                        old_peer_id: Some(project.old_connection_id.into()),
1713                        new_peer_id: Some(session.connection_id.into()),
1714                    },
1715                )
1716                .trace_err();
1717        }
1718    }
1719
1720    for project in rejoined_projects {
1721        for worktree in mem::take(&mut project.worktrees) {
1722            #[cfg(any(test, feature = "test-support"))]
1723            const MAX_CHUNK_SIZE: usize = 2;
1724            #[cfg(not(any(test, feature = "test-support")))]
1725            const MAX_CHUNK_SIZE: usize = 256;
1726
1727            // Stream this worktree's entries.
1728            let message = proto::UpdateWorktree {
1729                project_id: project.id.to_proto(),
1730                worktree_id: worktree.id,
1731                abs_path: worktree.abs_path.clone(),
1732                root_name: worktree.root_name,
1733                updated_entries: worktree.updated_entries,
1734                removed_entries: worktree.removed_entries,
1735                scan_id: worktree.scan_id,
1736                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1737                updated_repositories: worktree.updated_repositories,
1738                removed_repositories: worktree.removed_repositories,
1739            };
1740            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1741                session.peer.send(session.connection_id, update.clone())?;
1742            }
1743
1744            // Stream this worktree's diagnostics.
1745            for summary in worktree.diagnostic_summaries {
1746                session.peer.send(
1747                    session.connection_id,
1748                    proto::UpdateDiagnosticSummary {
1749                        project_id: project.id.to_proto(),
1750                        worktree_id: worktree.id,
1751                        summary: Some(summary),
1752                    },
1753                )?;
1754            }
1755
1756            for settings_file in worktree.settings_files {
1757                session.peer.send(
1758                    session.connection_id,
1759                    proto::UpdateWorktreeSettings {
1760                        project_id: project.id.to_proto(),
1761                        worktree_id: worktree.id,
1762                        path: settings_file.path,
1763                        content: Some(settings_file.content),
1764                    },
1765                )?;
1766            }
1767        }
1768
1769        for language_server in &project.language_servers {
1770            session.peer.send(
1771                session.connection_id,
1772                proto::UpdateLanguageServer {
1773                    project_id: project.id.to_proto(),
1774                    language_server_id: language_server.id,
1775                    variant: Some(
1776                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1777                            proto::LspDiskBasedDiagnosticsUpdated {},
1778                        ),
1779                    ),
1780                },
1781            )?;
1782        }
1783    }
1784    Ok(())
1785}
1786
1787/// leave room disconnects from the room.
1788async fn leave_room(
1789    _: proto::LeaveRoom,
1790    response: Response<proto::LeaveRoom>,
1791    session: UserSession,
1792) -> Result<()> {
1793    leave_room_for_session(&session, session.connection_id).await?;
1794    response.send(proto::Ack {})?;
1795    Ok(())
1796}
1797
1798/// Updates the permissions of someone else in the room.
1799async fn set_room_participant_role(
1800    request: proto::SetRoomParticipantRole,
1801    response: Response<proto::SetRoomParticipantRole>,
1802    session: UserSession,
1803) -> Result<()> {
1804    let user_id = UserId::from_proto(request.user_id);
1805    let role = ChannelRole::from(request.role());
1806
1807    let (live_kit_room, can_publish) = {
1808        let room = session
1809            .db()
1810            .await
1811            .set_room_participant_role(
1812                session.user_id(),
1813                RoomId::from_proto(request.room_id),
1814                user_id,
1815                role,
1816            )
1817            .await?;
1818
1819        let live_kit_room = room.live_kit_room.clone();
1820        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1821        room_updated(&room, &session.peer);
1822        (live_kit_room, can_publish)
1823    };
1824
1825    if let Some(live_kit) = session.live_kit_client.as_ref() {
1826        live_kit
1827            .update_participant(
1828                live_kit_room.clone(),
1829                request.user_id.to_string(),
1830                live_kit_server::proto::ParticipantPermission {
1831                    can_subscribe: true,
1832                    can_publish,
1833                    can_publish_data: can_publish,
1834                    hidden: false,
1835                    recorder: false,
1836                },
1837            )
1838            .await
1839            .trace_err();
1840    }
1841
1842    response.send(proto::Ack {})?;
1843    Ok(())
1844}
1845
1846/// Call someone else into the current room
1847async fn call(
1848    request: proto::Call,
1849    response: Response<proto::Call>,
1850    session: UserSession,
1851) -> Result<()> {
1852    let room_id = RoomId::from_proto(request.room_id);
1853    let calling_user_id = session.user_id();
1854    let calling_connection_id = session.connection_id;
1855    let called_user_id = UserId::from_proto(request.called_user_id);
1856    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1857    if !session
1858        .db()
1859        .await
1860        .has_contact(calling_user_id, called_user_id)
1861        .await?
1862    {
1863        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1864    }
1865
1866    let incoming_call = {
1867        let (room, incoming_call) = &mut *session
1868            .db()
1869            .await
1870            .call(
1871                room_id,
1872                calling_user_id,
1873                calling_connection_id,
1874                called_user_id,
1875                initial_project_id,
1876            )
1877            .await?;
1878        room_updated(&room, &session.peer);
1879        mem::take(incoming_call)
1880    };
1881    update_user_contacts(called_user_id, &session).await?;
1882
1883    let mut calls = session
1884        .connection_pool()
1885        .await
1886        .user_connection_ids(called_user_id)
1887        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1888        .collect::<FuturesUnordered<_>>();
1889
1890    while let Some(call_response) = calls.next().await {
1891        match call_response.as_ref() {
1892            Ok(_) => {
1893                response.send(proto::Ack {})?;
1894                return Ok(());
1895            }
1896            Err(_) => {
1897                call_response.trace_err();
1898            }
1899        }
1900    }
1901
1902    {
1903        let room = session
1904            .db()
1905            .await
1906            .call_failed(room_id, called_user_id)
1907            .await?;
1908        room_updated(&room, &session.peer);
1909    }
1910    update_user_contacts(called_user_id, &session).await?;
1911
1912    Err(anyhow!("failed to ring user"))?
1913}
1914
1915/// Cancel an outgoing call.
1916async fn cancel_call(
1917    request: proto::CancelCall,
1918    response: Response<proto::CancelCall>,
1919    session: UserSession,
1920) -> Result<()> {
1921    let called_user_id = UserId::from_proto(request.called_user_id);
1922    let room_id = RoomId::from_proto(request.room_id);
1923    {
1924        let room = session
1925            .db()
1926            .await
1927            .cancel_call(room_id, session.connection_id, called_user_id)
1928            .await?;
1929        room_updated(&room, &session.peer);
1930    }
1931
1932    for connection_id in session
1933        .connection_pool()
1934        .await
1935        .user_connection_ids(called_user_id)
1936    {
1937        session
1938            .peer
1939            .send(
1940                connection_id,
1941                proto::CallCanceled {
1942                    room_id: room_id.to_proto(),
1943                },
1944            )
1945            .trace_err();
1946    }
1947    response.send(proto::Ack {})?;
1948
1949    update_user_contacts(called_user_id, &session).await?;
1950    Ok(())
1951}
1952
1953/// Decline an incoming call.
1954async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1955    let room_id = RoomId::from_proto(message.room_id);
1956    {
1957        let room = session
1958            .db()
1959            .await
1960            .decline_call(Some(room_id), session.user_id())
1961            .await?
1962            .ok_or_else(|| anyhow!("failed to decline call"))?;
1963        room_updated(&room, &session.peer);
1964    }
1965
1966    for connection_id in session
1967        .connection_pool()
1968        .await
1969        .user_connection_ids(session.user_id())
1970    {
1971        session
1972            .peer
1973            .send(
1974                connection_id,
1975                proto::CallCanceled {
1976                    room_id: room_id.to_proto(),
1977                },
1978            )
1979            .trace_err();
1980    }
1981    update_user_contacts(session.user_id(), &session).await?;
1982    Ok(())
1983}
1984
1985/// Updates other participants in the room with your current location.
1986async fn update_participant_location(
1987    request: proto::UpdateParticipantLocation,
1988    response: Response<proto::UpdateParticipantLocation>,
1989    session: UserSession,
1990) -> Result<()> {
1991    let room_id = RoomId::from_proto(request.room_id);
1992    let location = request
1993        .location
1994        .ok_or_else(|| anyhow!("invalid location"))?;
1995
1996    let db = session.db().await;
1997    let room = db
1998        .update_room_participant_location(room_id, session.connection_id, location)
1999        .await?;
2000
2001    room_updated(&room, &session.peer);
2002    response.send(proto::Ack {})?;
2003    Ok(())
2004}
2005
2006/// Share a project into the room.
2007async fn share_project(
2008    request: proto::ShareProject,
2009    response: Response<proto::ShareProject>,
2010    session: UserSession,
2011) -> Result<()> {
2012    let (project_id, room) = &*session
2013        .db()
2014        .await
2015        .share_project(
2016            RoomId::from_proto(request.room_id),
2017            session.connection_id,
2018            &request.worktrees,
2019            request
2020                .dev_server_project_id
2021                .map(|id| DevServerProjectId::from_proto(id)),
2022        )
2023        .await?;
2024    response.send(proto::ShareProjectResponse {
2025        project_id: project_id.to_proto(),
2026    })?;
2027    room_updated(&room, &session.peer);
2028
2029    Ok(())
2030}
2031
2032/// Unshare a project from the room.
2033async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2034    let project_id = ProjectId::from_proto(message.project_id);
2035    unshare_project_internal(
2036        project_id,
2037        session.connection_id,
2038        session.user_id(),
2039        &session,
2040    )
2041    .await
2042}
2043
2044async fn unshare_project_internal(
2045    project_id: ProjectId,
2046    connection_id: ConnectionId,
2047    user_id: Option<UserId>,
2048    session: &Session,
2049) -> Result<()> {
2050    let delete = {
2051        let room_guard = session
2052            .db()
2053            .await
2054            .unshare_project(project_id, connection_id, user_id)
2055            .await?;
2056
2057        let (delete, room, guest_connection_ids) = &*room_guard;
2058
2059        let message = proto::UnshareProject {
2060            project_id: project_id.to_proto(),
2061        };
2062
2063        broadcast(
2064            Some(connection_id),
2065            guest_connection_ids.iter().copied(),
2066            |conn_id| session.peer.send(conn_id, message.clone()),
2067        );
2068        if let Some(room) = room {
2069            room_updated(room, &session.peer);
2070        }
2071
2072        *delete
2073    };
2074
2075    if delete {
2076        let db = session.db().await;
2077        db.delete_project(project_id).await?;
2078    }
2079
2080    Ok(())
2081}
2082
2083/// DevServer makes a project available online
2084async fn share_dev_server_project(
2085    request: proto::ShareDevServerProject,
2086    response: Response<proto::ShareDevServerProject>,
2087    session: DevServerSession,
2088) -> Result<()> {
2089    let (dev_server_project, user_id, status) = session
2090        .db()
2091        .await
2092        .share_dev_server_project(
2093            DevServerProjectId::from_proto(request.dev_server_project_id),
2094            session.dev_server_id(),
2095            session.connection_id,
2096            &request.worktrees,
2097        )
2098        .await?;
2099    let Some(project_id) = dev_server_project.project_id else {
2100        return Err(anyhow!("failed to share remote project"))?;
2101    };
2102
2103    send_dev_server_projects_update(user_id, status, &session).await;
2104
2105    response.send(proto::ShareProjectResponse { project_id })?;
2106
2107    Ok(())
2108}
2109
2110/// Join someone elses shared project.
2111async fn join_project(
2112    request: proto::JoinProject,
2113    response: Response<proto::JoinProject>,
2114    session: UserSession,
2115) -> Result<()> {
2116    let project_id = ProjectId::from_proto(request.project_id);
2117
2118    tracing::info!(%project_id, "join project");
2119
2120    let db = session.db().await;
2121    let (project, replica_id) = &mut *db
2122        .join_project(project_id, session.connection_id, session.user_id())
2123        .await?;
2124    drop(db);
2125    tracing::info!(%project_id, "join remote project");
2126    join_project_internal(response, session, project, replica_id)
2127}
2128
2129trait JoinProjectInternalResponse {
2130    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2131}
2132impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2133    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2134        Response::<proto::JoinProject>::send(self, result)
2135    }
2136}
2137impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2138    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2139        Response::<proto::JoinHostedProject>::send(self, result)
2140    }
2141}
2142
2143fn join_project_internal(
2144    response: impl JoinProjectInternalResponse,
2145    session: UserSession,
2146    project: &mut Project,
2147    replica_id: &ReplicaId,
2148) -> Result<()> {
2149    let collaborators = project
2150        .collaborators
2151        .iter()
2152        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2153        .map(|collaborator| collaborator.to_proto())
2154        .collect::<Vec<_>>();
2155    let project_id = project.id;
2156    let guest_user_id = session.user_id();
2157
2158    let worktrees = project
2159        .worktrees
2160        .iter()
2161        .map(|(id, worktree)| proto::WorktreeMetadata {
2162            id: *id,
2163            root_name: worktree.root_name.clone(),
2164            visible: worktree.visible,
2165            abs_path: worktree.abs_path.clone(),
2166        })
2167        .collect::<Vec<_>>();
2168
2169    let add_project_collaborator = proto::AddProjectCollaborator {
2170        project_id: project_id.to_proto(),
2171        collaborator: Some(proto::Collaborator {
2172            peer_id: Some(session.connection_id.into()),
2173            replica_id: replica_id.0 as u32,
2174            user_id: guest_user_id.to_proto(),
2175        }),
2176    };
2177
2178    for collaborator in &collaborators {
2179        session
2180            .peer
2181            .send(
2182                collaborator.peer_id.unwrap().into(),
2183                add_project_collaborator.clone(),
2184            )
2185            .trace_err();
2186    }
2187
2188    // First, we send the metadata associated with each worktree.
2189    response.send(proto::JoinProjectResponse {
2190        project_id: project.id.0 as u64,
2191        worktrees: worktrees.clone(),
2192        replica_id: replica_id.0 as u32,
2193        collaborators: collaborators.clone(),
2194        language_servers: project.language_servers.clone(),
2195        role: project.role.into(),
2196        dev_server_project_id: project
2197            .dev_server_project_id
2198            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2199    })?;
2200
2201    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2202        #[cfg(any(test, feature = "test-support"))]
2203        const MAX_CHUNK_SIZE: usize = 2;
2204        #[cfg(not(any(test, feature = "test-support")))]
2205        const MAX_CHUNK_SIZE: usize = 256;
2206
2207        // Stream this worktree's entries.
2208        let message = proto::UpdateWorktree {
2209            project_id: project_id.to_proto(),
2210            worktree_id,
2211            abs_path: worktree.abs_path.clone(),
2212            root_name: worktree.root_name,
2213            updated_entries: worktree.entries,
2214            removed_entries: Default::default(),
2215            scan_id: worktree.scan_id,
2216            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2217            updated_repositories: worktree.repository_entries.into_values().collect(),
2218            removed_repositories: Default::default(),
2219        };
2220        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2221            session.peer.send(session.connection_id, update.clone())?;
2222        }
2223
2224        // Stream this worktree's diagnostics.
2225        for summary in worktree.diagnostic_summaries {
2226            session.peer.send(
2227                session.connection_id,
2228                proto::UpdateDiagnosticSummary {
2229                    project_id: project_id.to_proto(),
2230                    worktree_id: worktree.id,
2231                    summary: Some(summary),
2232                },
2233            )?;
2234        }
2235
2236        for settings_file in worktree.settings_files {
2237            session.peer.send(
2238                session.connection_id,
2239                proto::UpdateWorktreeSettings {
2240                    project_id: project_id.to_proto(),
2241                    worktree_id: worktree.id,
2242                    path: settings_file.path,
2243                    content: Some(settings_file.content),
2244                },
2245            )?;
2246        }
2247    }
2248
2249    for language_server in &project.language_servers {
2250        session.peer.send(
2251            session.connection_id,
2252            proto::UpdateLanguageServer {
2253                project_id: project_id.to_proto(),
2254                language_server_id: language_server.id,
2255                variant: Some(
2256                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2257                        proto::LspDiskBasedDiagnosticsUpdated {},
2258                    ),
2259                ),
2260            },
2261        )?;
2262    }
2263
2264    Ok(())
2265}
2266
2267/// Leave someone elses shared project.
2268async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2269    let sender_id = session.connection_id;
2270    let project_id = ProjectId::from_proto(request.project_id);
2271    let db = session.db().await;
2272    if db.is_hosted_project(project_id).await? {
2273        let project = db.leave_hosted_project(project_id, sender_id).await?;
2274        project_left(&project, &session);
2275        return Ok(());
2276    }
2277
2278    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2279    tracing::info!(
2280        %project_id,
2281        "leave project"
2282    );
2283
2284    project_left(&project, &session);
2285    if let Some(room) = room {
2286        room_updated(&room, &session.peer);
2287    }
2288
2289    Ok(())
2290}
2291
2292async fn join_hosted_project(
2293    request: proto::JoinHostedProject,
2294    response: Response<proto::JoinHostedProject>,
2295    session: UserSession,
2296) -> Result<()> {
2297    let (mut project, replica_id) = session
2298        .db()
2299        .await
2300        .join_hosted_project(
2301            ProjectId(request.project_id as i32),
2302            session.user_id(),
2303            session.connection_id,
2304        )
2305        .await?;
2306
2307    join_project_internal(response, session, &mut project, &replica_id)
2308}
2309
2310async fn create_dev_server_project(
2311    request: proto::CreateDevServerProject,
2312    response: Response<proto::CreateDevServerProject>,
2313    session: UserSession,
2314) -> Result<()> {
2315    let dev_server_id = DevServerId(request.dev_server_id as i32);
2316    let dev_server_connection_id = session
2317        .connection_pool()
2318        .await
2319        .dev_server_connection_id(dev_server_id);
2320    let Some(dev_server_connection_id) = dev_server_connection_id else {
2321        Err(ErrorCode::DevServerOffline
2322            .message("Cannot create a remote project when the dev server is offline".to_string())
2323            .anyhow())?
2324    };
2325
2326    let path = request.path.clone();
2327    //Check that the path exists on the dev server
2328    session
2329        .peer
2330        .forward_request(
2331            session.connection_id,
2332            dev_server_connection_id,
2333            proto::ValidateDevServerProjectRequest { path: path.clone() },
2334        )
2335        .await?;
2336
2337    let (dev_server_project, update) = session
2338        .db()
2339        .await
2340        .create_dev_server_project(
2341            DevServerId(request.dev_server_id as i32),
2342            &request.path,
2343            session.user_id(),
2344        )
2345        .await?;
2346
2347    let projects = session
2348        .db()
2349        .await
2350        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2351        .await?;
2352
2353    session.peer.send(
2354        dev_server_connection_id,
2355        proto::DevServerInstructions { projects },
2356    )?;
2357
2358    send_dev_server_projects_update(session.user_id(), update, &session).await;
2359
2360    response.send(proto::CreateDevServerProjectResponse {
2361        dev_server_project: Some(dev_server_project.to_proto(None)),
2362    })?;
2363    Ok(())
2364}
2365
2366async fn create_dev_server(
2367    request: proto::CreateDevServer,
2368    response: Response<proto::CreateDevServer>,
2369    session: UserSession,
2370) -> Result<()> {
2371    let access_token = auth::random_token();
2372    let hashed_access_token = auth::hash_access_token(&access_token);
2373
2374    if request.name.is_empty() {
2375        return Err(proto::ErrorCode::Forbidden
2376            .message("Dev server name cannot be empty".to_string())
2377            .anyhow())?;
2378    }
2379
2380    let (dev_server, status) = session
2381        .db()
2382        .await
2383        .create_dev_server(
2384            &request.name,
2385            request.ssh_connection_string.as_deref(),
2386            &hashed_access_token,
2387            session.user_id(),
2388        )
2389        .await?;
2390
2391    send_dev_server_projects_update(session.user_id(), status, &session).await;
2392
2393    response.send(proto::CreateDevServerResponse {
2394        dev_server_id: dev_server.id.0 as u64,
2395        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2396        name: request.name,
2397    })?;
2398    Ok(())
2399}
2400
2401async fn regenerate_dev_server_token(
2402    request: proto::RegenerateDevServerToken,
2403    response: Response<proto::RegenerateDevServerToken>,
2404    session: UserSession,
2405) -> Result<()> {
2406    let dev_server_id = DevServerId(request.dev_server_id as i32);
2407    let access_token = auth::random_token();
2408    let hashed_access_token = auth::hash_access_token(&access_token);
2409
2410    let connection_id = session
2411        .connection_pool()
2412        .await
2413        .dev_server_connection_id(dev_server_id);
2414    if let Some(connection_id) = connection_id {
2415        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2416        session.peer.send(
2417            connection_id,
2418            proto::ShutdownDevServer {
2419                reason: Some("dev server token was regenerated".to_string()),
2420            },
2421        )?;
2422        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2423    }
2424
2425    let status = session
2426        .db()
2427        .await
2428        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2429        .await?;
2430
2431    send_dev_server_projects_update(session.user_id(), status, &session).await;
2432
2433    response.send(proto::RegenerateDevServerTokenResponse {
2434        dev_server_id: dev_server_id.to_proto(),
2435        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2436    })?;
2437    Ok(())
2438}
2439
2440async fn rename_dev_server(
2441    request: proto::RenameDevServer,
2442    response: Response<proto::RenameDevServer>,
2443    session: UserSession,
2444) -> Result<()> {
2445    if request.name.trim().is_empty() {
2446        return Err(proto::ErrorCode::Forbidden
2447            .message("Dev server name cannot be empty".to_string())
2448            .anyhow())?;
2449    }
2450
2451    let dev_server_id = DevServerId(request.dev_server_id as i32);
2452    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2453    if dev_server.user_id != session.user_id() {
2454        return Err(anyhow!(ErrorCode::Forbidden))?;
2455    }
2456
2457    let status = session
2458        .db()
2459        .await
2460        .rename_dev_server(
2461            dev_server_id,
2462            &request.name,
2463            request.ssh_connection_string.as_deref(),
2464            session.user_id(),
2465        )
2466        .await?;
2467
2468    send_dev_server_projects_update(session.user_id(), status, &session).await;
2469
2470    response.send(proto::Ack {})?;
2471    Ok(())
2472}
2473
2474async fn delete_dev_server(
2475    request: proto::DeleteDevServer,
2476    response: Response<proto::DeleteDevServer>,
2477    session: UserSession,
2478) -> Result<()> {
2479    let dev_server_id = DevServerId(request.dev_server_id as i32);
2480    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2481    if dev_server.user_id != session.user_id() {
2482        return Err(anyhow!(ErrorCode::Forbidden))?;
2483    }
2484
2485    let connection_id = session
2486        .connection_pool()
2487        .await
2488        .dev_server_connection_id(dev_server_id);
2489    if let Some(connection_id) = connection_id {
2490        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2491        session.peer.send(
2492            connection_id,
2493            proto::ShutdownDevServer {
2494                reason: Some("dev server was deleted".to_string()),
2495            },
2496        )?;
2497        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2498    }
2499
2500    let status = session
2501        .db()
2502        .await
2503        .delete_dev_server(dev_server_id, session.user_id())
2504        .await?;
2505
2506    send_dev_server_projects_update(session.user_id(), status, &session).await;
2507
2508    response.send(proto::Ack {})?;
2509    Ok(())
2510}
2511
2512async fn delete_dev_server_project(
2513    request: proto::DeleteDevServerProject,
2514    response: Response<proto::DeleteDevServerProject>,
2515    session: UserSession,
2516) -> Result<()> {
2517    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2518    let dev_server_project = session
2519        .db()
2520        .await
2521        .get_dev_server_project(dev_server_project_id)
2522        .await?;
2523
2524    let dev_server = session
2525        .db()
2526        .await
2527        .get_dev_server(dev_server_project.dev_server_id)
2528        .await?;
2529    if dev_server.user_id != session.user_id() {
2530        return Err(anyhow!(ErrorCode::Forbidden))?;
2531    }
2532
2533    let dev_server_connection_id = session
2534        .connection_pool()
2535        .await
2536        .dev_server_connection_id(dev_server.id);
2537
2538    if let Some(dev_server_connection_id) = dev_server_connection_id {
2539        let project = session
2540            .db()
2541            .await
2542            .find_dev_server_project(dev_server_project_id)
2543            .await;
2544        if let Ok(project) = project {
2545            unshare_project_internal(
2546                project.id,
2547                dev_server_connection_id,
2548                Some(session.user_id()),
2549                &session,
2550            )
2551            .await?;
2552        }
2553    }
2554
2555    let (projects, status) = session
2556        .db()
2557        .await
2558        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2559        .await?;
2560
2561    if let Some(dev_server_connection_id) = dev_server_connection_id {
2562        session.peer.send(
2563            dev_server_connection_id,
2564            proto::DevServerInstructions { projects },
2565        )?;
2566    }
2567
2568    send_dev_server_projects_update(session.user_id(), status, &session).await;
2569
2570    response.send(proto::Ack {})?;
2571    Ok(())
2572}
2573
2574async fn rejoin_dev_server_projects(
2575    request: proto::RejoinRemoteProjects,
2576    response: Response<proto::RejoinRemoteProjects>,
2577    session: UserSession,
2578) -> Result<()> {
2579    let mut rejoined_projects = {
2580        let db = session.db().await;
2581        db.rejoin_dev_server_projects(
2582            &request.rejoined_projects,
2583            session.user_id(),
2584            session.0.connection_id,
2585        )
2586        .await?
2587    };
2588    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2589
2590    response.send(proto::RejoinRemoteProjectsResponse {
2591        rejoined_projects: rejoined_projects
2592            .into_iter()
2593            .map(|project| project.to_proto())
2594            .collect(),
2595    })
2596}
2597
2598async fn reconnect_dev_server(
2599    request: proto::ReconnectDevServer,
2600    response: Response<proto::ReconnectDevServer>,
2601    session: DevServerSession,
2602) -> Result<()> {
2603    let reshared_projects = {
2604        let db = session.db().await;
2605        db.reshare_dev_server_projects(
2606            &request.reshared_projects,
2607            session.dev_server_id(),
2608            session.0.connection_id,
2609        )
2610        .await?
2611    };
2612
2613    for project in &reshared_projects {
2614        for collaborator in &project.collaborators {
2615            session
2616                .peer
2617                .send(
2618                    collaborator.connection_id,
2619                    proto::UpdateProjectCollaborator {
2620                        project_id: project.id.to_proto(),
2621                        old_peer_id: Some(project.old_connection_id.into()),
2622                        new_peer_id: Some(session.connection_id.into()),
2623                    },
2624                )
2625                .trace_err();
2626        }
2627
2628        broadcast(
2629            Some(session.connection_id),
2630            project
2631                .collaborators
2632                .iter()
2633                .map(|collaborator| collaborator.connection_id),
2634            |connection_id| {
2635                session.peer.forward_send(
2636                    session.connection_id,
2637                    connection_id,
2638                    proto::UpdateProject {
2639                        project_id: project.id.to_proto(),
2640                        worktrees: project.worktrees.clone(),
2641                    },
2642                )
2643            },
2644        );
2645    }
2646
2647    response.send(proto::ReconnectDevServerResponse {
2648        reshared_projects: reshared_projects
2649            .iter()
2650            .map(|project| proto::ResharedProject {
2651                id: project.id.to_proto(),
2652                collaborators: project
2653                    .collaborators
2654                    .iter()
2655                    .map(|collaborator| collaborator.to_proto())
2656                    .collect(),
2657            })
2658            .collect(),
2659    })?;
2660
2661    Ok(())
2662}
2663
2664async fn shutdown_dev_server(
2665    _: proto::ShutdownDevServer,
2666    response: Response<proto::ShutdownDevServer>,
2667    session: DevServerSession,
2668) -> Result<()> {
2669    response.send(proto::Ack {})?;
2670    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2671    remove_dev_server_connection(session.dev_server_id(), &session).await
2672}
2673
2674async fn shutdown_dev_server_internal(
2675    dev_server_id: DevServerId,
2676    connection_id: ConnectionId,
2677    session: &Session,
2678) -> Result<()> {
2679    let (dev_server_projects, dev_server) = {
2680        let db = session.db().await;
2681        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2682        let dev_server = db.get_dev_server(dev_server_id).await?;
2683        (dev_server_projects, dev_server)
2684    };
2685
2686    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2687        unshare_project_internal(
2688            ProjectId::from_proto(project_id),
2689            connection_id,
2690            None,
2691            session,
2692        )
2693        .await?;
2694    }
2695
2696    session
2697        .connection_pool()
2698        .await
2699        .set_dev_server_offline(dev_server_id);
2700
2701    let status = session
2702        .db()
2703        .await
2704        .dev_server_projects_update(dev_server.user_id)
2705        .await?;
2706    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2707
2708    Ok(())
2709}
2710
2711async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2712    let dev_server_connection = session
2713        .connection_pool()
2714        .await
2715        .dev_server_connection_id(dev_server_id);
2716
2717    if let Some(dev_server_connection) = dev_server_connection {
2718        session
2719            .connection_pool()
2720            .await
2721            .remove_connection(dev_server_connection)?;
2722    }
2723    Ok(())
2724}
2725
2726/// Updates other participants with changes to the project
2727async fn update_project(
2728    request: proto::UpdateProject,
2729    response: Response<proto::UpdateProject>,
2730    session: Session,
2731) -> Result<()> {
2732    let project_id = ProjectId::from_proto(request.project_id);
2733    let (room, guest_connection_ids) = &*session
2734        .db()
2735        .await
2736        .update_project(project_id, session.connection_id, &request.worktrees)
2737        .await?;
2738    broadcast(
2739        Some(session.connection_id),
2740        guest_connection_ids.iter().copied(),
2741        |connection_id| {
2742            session
2743                .peer
2744                .forward_send(session.connection_id, connection_id, request.clone())
2745        },
2746    );
2747    if let Some(room) = room {
2748        room_updated(&room, &session.peer);
2749    }
2750    response.send(proto::Ack {})?;
2751
2752    Ok(())
2753}
2754
2755/// Updates other participants with changes to the worktree
2756async fn update_worktree(
2757    request: proto::UpdateWorktree,
2758    response: Response<proto::UpdateWorktree>,
2759    session: Session,
2760) -> Result<()> {
2761    let guest_connection_ids = session
2762        .db()
2763        .await
2764        .update_worktree(&request, session.connection_id)
2765        .await?;
2766
2767    broadcast(
2768        Some(session.connection_id),
2769        guest_connection_ids.iter().copied(),
2770        |connection_id| {
2771            session
2772                .peer
2773                .forward_send(session.connection_id, connection_id, request.clone())
2774        },
2775    );
2776    response.send(proto::Ack {})?;
2777    Ok(())
2778}
2779
2780/// Updates other participants with changes to the diagnostics
2781async fn update_diagnostic_summary(
2782    message: proto::UpdateDiagnosticSummary,
2783    session: Session,
2784) -> Result<()> {
2785    let guest_connection_ids = session
2786        .db()
2787        .await
2788        .update_diagnostic_summary(&message, session.connection_id)
2789        .await?;
2790
2791    broadcast(
2792        Some(session.connection_id),
2793        guest_connection_ids.iter().copied(),
2794        |connection_id| {
2795            session
2796                .peer
2797                .forward_send(session.connection_id, connection_id, message.clone())
2798        },
2799    );
2800
2801    Ok(())
2802}
2803
2804/// Updates other participants with changes to the worktree settings
2805async fn update_worktree_settings(
2806    message: proto::UpdateWorktreeSettings,
2807    session: Session,
2808) -> Result<()> {
2809    let guest_connection_ids = session
2810        .db()
2811        .await
2812        .update_worktree_settings(&message, session.connection_id)
2813        .await?;
2814
2815    broadcast(
2816        Some(session.connection_id),
2817        guest_connection_ids.iter().copied(),
2818        |connection_id| {
2819            session
2820                .peer
2821                .forward_send(session.connection_id, connection_id, message.clone())
2822        },
2823    );
2824
2825    Ok(())
2826}
2827
2828/// Notify other participants that a  language server has started.
2829async fn start_language_server(
2830    request: proto::StartLanguageServer,
2831    session: Session,
2832) -> Result<()> {
2833    let guest_connection_ids = session
2834        .db()
2835        .await
2836        .start_language_server(&request, session.connection_id)
2837        .await?;
2838
2839    broadcast(
2840        Some(session.connection_id),
2841        guest_connection_ids.iter().copied(),
2842        |connection_id| {
2843            session
2844                .peer
2845                .forward_send(session.connection_id, connection_id, request.clone())
2846        },
2847    );
2848    Ok(())
2849}
2850
2851/// Notify other participants that a language server has changed.
2852async fn update_language_server(
2853    request: proto::UpdateLanguageServer,
2854    session: Session,
2855) -> Result<()> {
2856    let project_id = ProjectId::from_proto(request.project_id);
2857    let project_connection_ids = session
2858        .db()
2859        .await
2860        .project_connection_ids(project_id, session.connection_id, true)
2861        .await?;
2862    broadcast(
2863        Some(session.connection_id),
2864        project_connection_ids.iter().copied(),
2865        |connection_id| {
2866            session
2867                .peer
2868                .forward_send(session.connection_id, connection_id, request.clone())
2869        },
2870    );
2871    Ok(())
2872}
2873
2874/// forward a project request to the host. These requests should be read only
2875/// as guests are allowed to send them.
2876async fn forward_read_only_project_request<T>(
2877    request: T,
2878    response: Response<T>,
2879    session: UserSession,
2880) -> Result<()>
2881where
2882    T: EntityMessage + RequestMessage,
2883{
2884    let project_id = ProjectId::from_proto(request.remote_entity_id());
2885    let host_connection_id = session
2886        .db()
2887        .await
2888        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2889        .await?;
2890    let payload = session
2891        .peer
2892        .forward_request(session.connection_id, host_connection_id, request)
2893        .await?;
2894    response.send(payload)?;
2895    Ok(())
2896}
2897
2898/// forward a project request to the dev server. Only allowed
2899/// if it's your dev server.
2900async fn forward_project_request_for_owner<T>(
2901    request: T,
2902    response: Response<T>,
2903    session: UserSession,
2904) -> Result<()>
2905where
2906    T: EntityMessage + RequestMessage,
2907{
2908    let project_id = ProjectId::from_proto(request.remote_entity_id());
2909
2910    let host_connection_id = session
2911        .db()
2912        .await
2913        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2914        .await?;
2915    let payload = session
2916        .peer
2917        .forward_request(session.connection_id, host_connection_id, request)
2918        .await?;
2919    response.send(payload)?;
2920    Ok(())
2921}
2922
2923/// forward a project request to the host. These requests are disallowed
2924/// for guests.
2925async fn forward_mutating_project_request<T>(
2926    request: T,
2927    response: Response<T>,
2928    session: UserSession,
2929) -> Result<()>
2930where
2931    T: EntityMessage + RequestMessage,
2932{
2933    let project_id = ProjectId::from_proto(request.remote_entity_id());
2934
2935    let host_connection_id = session
2936        .db()
2937        .await
2938        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2939        .await?;
2940    let payload = session
2941        .peer
2942        .forward_request(session.connection_id, host_connection_id, request)
2943        .await?;
2944    response.send(payload)?;
2945    Ok(())
2946}
2947
2948/// forward a project request to the host. These requests are disallowed
2949/// for guests.
2950async fn forward_versioned_mutating_project_request<T>(
2951    request: T,
2952    response: Response<T>,
2953    session: UserSession,
2954) -> Result<()>
2955where
2956    T: EntityMessage + RequestMessage + VersionedMessage,
2957{
2958    let project_id = ProjectId::from_proto(request.remote_entity_id());
2959
2960    let host_connection_id = session
2961        .db()
2962        .await
2963        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2964        .await?;
2965    if let Some(host_version) = session
2966        .connection_pool()
2967        .await
2968        .connection(host_connection_id)
2969        .map(|c| c.zed_version)
2970    {
2971        if let Some(min_required_version) = request.required_host_version() {
2972            if min_required_version > host_version {
2973                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2974                    .with_tag("required", &min_required_version.to_string())))?;
2975            }
2976        }
2977    }
2978
2979    let payload = session
2980        .peer
2981        .forward_request(session.connection_id, host_connection_id, request)
2982        .await?;
2983    response.send(payload)?;
2984    Ok(())
2985}
2986
2987/// Notify other participants that a new buffer has been created
2988async fn create_buffer_for_peer(
2989    request: proto::CreateBufferForPeer,
2990    session: Session,
2991) -> Result<()> {
2992    session
2993        .db()
2994        .await
2995        .check_user_is_project_host(
2996            ProjectId::from_proto(request.project_id),
2997            session.connection_id,
2998        )
2999        .await?;
3000    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3001    session
3002        .peer
3003        .forward_send(session.connection_id, peer_id.into(), request)?;
3004    Ok(())
3005}
3006
3007/// Notify other participants that a buffer has been updated. This is
3008/// allowed for guests as long as the update is limited to selections.
3009async fn update_buffer(
3010    request: proto::UpdateBuffer,
3011    response: Response<proto::UpdateBuffer>,
3012    session: Session,
3013) -> Result<()> {
3014    let project_id = ProjectId::from_proto(request.project_id);
3015    let mut capability = Capability::ReadOnly;
3016
3017    for op in request.operations.iter() {
3018        match op.variant {
3019            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3020            Some(_) => capability = Capability::ReadWrite,
3021        }
3022    }
3023
3024    let host = {
3025        let guard = session
3026            .db()
3027            .await
3028            .connections_for_buffer_update(
3029                project_id,
3030                session.principal_id(),
3031                session.connection_id,
3032                capability,
3033            )
3034            .await?;
3035
3036        let (host, guests) = &*guard;
3037
3038        broadcast(
3039            Some(session.connection_id),
3040            guests.clone(),
3041            |connection_id| {
3042                session
3043                    .peer
3044                    .forward_send(session.connection_id, connection_id, request.clone())
3045            },
3046        );
3047
3048        *host
3049    };
3050
3051    if host != session.connection_id {
3052        session
3053            .peer
3054            .forward_request(session.connection_id, host, request.clone())
3055            .await?;
3056    }
3057
3058    response.send(proto::Ack {})?;
3059    Ok(())
3060}
3061
3062/// Notify other participants that a project has been updated.
3063async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3064    request: T,
3065    session: Session,
3066) -> Result<()> {
3067    let project_id = ProjectId::from_proto(request.remote_entity_id());
3068    let project_connection_ids = session
3069        .db()
3070        .await
3071        .project_connection_ids(project_id, session.connection_id, false)
3072        .await?;
3073
3074    broadcast(
3075        Some(session.connection_id),
3076        project_connection_ids.iter().copied(),
3077        |connection_id| {
3078            session
3079                .peer
3080                .forward_send(session.connection_id, connection_id, request.clone())
3081        },
3082    );
3083    Ok(())
3084}
3085
3086/// Start following another user in a call.
3087async fn follow(
3088    request: proto::Follow,
3089    response: Response<proto::Follow>,
3090    session: UserSession,
3091) -> Result<()> {
3092    let room_id = RoomId::from_proto(request.room_id);
3093    let project_id = request.project_id.map(ProjectId::from_proto);
3094    let leader_id = request
3095        .leader_id
3096        .ok_or_else(|| anyhow!("invalid leader id"))?
3097        .into();
3098    let follower_id = session.connection_id;
3099
3100    session
3101        .db()
3102        .await
3103        .check_room_participants(room_id, leader_id, session.connection_id)
3104        .await?;
3105
3106    let response_payload = session
3107        .peer
3108        .forward_request(session.connection_id, leader_id, request)
3109        .await?;
3110    response.send(response_payload)?;
3111
3112    if let Some(project_id) = project_id {
3113        let room = session
3114            .db()
3115            .await
3116            .follow(room_id, project_id, leader_id, follower_id)
3117            .await?;
3118        room_updated(&room, &session.peer);
3119    }
3120
3121    Ok(())
3122}
3123
3124/// Stop following another user in a call.
3125async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3126    let room_id = RoomId::from_proto(request.room_id);
3127    let project_id = request.project_id.map(ProjectId::from_proto);
3128    let leader_id = request
3129        .leader_id
3130        .ok_or_else(|| anyhow!("invalid leader id"))?
3131        .into();
3132    let follower_id = session.connection_id;
3133
3134    session
3135        .db()
3136        .await
3137        .check_room_participants(room_id, leader_id, session.connection_id)
3138        .await?;
3139
3140    session
3141        .peer
3142        .forward_send(session.connection_id, leader_id, request)?;
3143
3144    if let Some(project_id) = project_id {
3145        let room = session
3146            .db()
3147            .await
3148            .unfollow(room_id, project_id, leader_id, follower_id)
3149            .await?;
3150        room_updated(&room, &session.peer);
3151    }
3152
3153    Ok(())
3154}
3155
3156/// Notify everyone following you of your current location.
3157async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3158    let room_id = RoomId::from_proto(request.room_id);
3159    let database = session.db.lock().await;
3160
3161    let connection_ids = if let Some(project_id) = request.project_id {
3162        let project_id = ProjectId::from_proto(project_id);
3163        database
3164            .project_connection_ids(project_id, session.connection_id, true)
3165            .await?
3166    } else {
3167        database
3168            .room_connection_ids(room_id, session.connection_id)
3169            .await?
3170    };
3171
3172    // For now, don't send view update messages back to that view's current leader.
3173    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3174        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3175        _ => None,
3176    });
3177
3178    for connection_id in connection_ids.iter().cloned() {
3179        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3180            session
3181                .peer
3182                .forward_send(session.connection_id, connection_id, request.clone())?;
3183        }
3184    }
3185    Ok(())
3186}
3187
3188/// Get public data about users.
3189async fn get_users(
3190    request: proto::GetUsers,
3191    response: Response<proto::GetUsers>,
3192    session: Session,
3193) -> Result<()> {
3194    let user_ids = request
3195        .user_ids
3196        .into_iter()
3197        .map(UserId::from_proto)
3198        .collect();
3199    let users = session
3200        .db()
3201        .await
3202        .get_users_by_ids(user_ids)
3203        .await?
3204        .into_iter()
3205        .map(|user| proto::User {
3206            id: user.id.to_proto(),
3207            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3208            github_login: user.github_login,
3209        })
3210        .collect();
3211    response.send(proto::UsersResponse { users })?;
3212    Ok(())
3213}
3214
3215/// Search for users (to invite) buy Github login
3216async fn fuzzy_search_users(
3217    request: proto::FuzzySearchUsers,
3218    response: Response<proto::FuzzySearchUsers>,
3219    session: UserSession,
3220) -> Result<()> {
3221    let query = request.query;
3222    let users = match query.len() {
3223        0 => vec![],
3224        1 | 2 => session
3225            .db()
3226            .await
3227            .get_user_by_github_login(&query)
3228            .await?
3229            .into_iter()
3230            .collect(),
3231        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3232    };
3233    let users = users
3234        .into_iter()
3235        .filter(|user| user.id != session.user_id())
3236        .map(|user| proto::User {
3237            id: user.id.to_proto(),
3238            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3239            github_login: user.github_login,
3240        })
3241        .collect();
3242    response.send(proto::UsersResponse { users })?;
3243    Ok(())
3244}
3245
3246/// Send a contact request to another user.
3247async fn request_contact(
3248    request: proto::RequestContact,
3249    response: Response<proto::RequestContact>,
3250    session: UserSession,
3251) -> Result<()> {
3252    let requester_id = session.user_id();
3253    let responder_id = UserId::from_proto(request.responder_id);
3254    if requester_id == responder_id {
3255        return Err(anyhow!("cannot add yourself as a contact"))?;
3256    }
3257
3258    let notifications = session
3259        .db()
3260        .await
3261        .send_contact_request(requester_id, responder_id)
3262        .await?;
3263
3264    // Update outgoing contact requests of requester
3265    let mut update = proto::UpdateContacts::default();
3266    update.outgoing_requests.push(responder_id.to_proto());
3267    for connection_id in session
3268        .connection_pool()
3269        .await
3270        .user_connection_ids(requester_id)
3271    {
3272        session.peer.send(connection_id, update.clone())?;
3273    }
3274
3275    // Update incoming contact requests of responder
3276    let mut update = proto::UpdateContacts::default();
3277    update
3278        .incoming_requests
3279        .push(proto::IncomingContactRequest {
3280            requester_id: requester_id.to_proto(),
3281        });
3282    let connection_pool = session.connection_pool().await;
3283    for connection_id in connection_pool.user_connection_ids(responder_id) {
3284        session.peer.send(connection_id, update.clone())?;
3285    }
3286
3287    send_notifications(&connection_pool, &session.peer, notifications);
3288
3289    response.send(proto::Ack {})?;
3290    Ok(())
3291}
3292
3293/// Accept or decline a contact request
3294async fn respond_to_contact_request(
3295    request: proto::RespondToContactRequest,
3296    response: Response<proto::RespondToContactRequest>,
3297    session: UserSession,
3298) -> Result<()> {
3299    let responder_id = session.user_id();
3300    let requester_id = UserId::from_proto(request.requester_id);
3301    let db = session.db().await;
3302    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3303        db.dismiss_contact_notification(responder_id, requester_id)
3304            .await?;
3305    } else {
3306        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3307
3308        let notifications = db
3309            .respond_to_contact_request(responder_id, requester_id, accept)
3310            .await?;
3311        let requester_busy = db.is_user_busy(requester_id).await?;
3312        let responder_busy = db.is_user_busy(responder_id).await?;
3313
3314        let pool = session.connection_pool().await;
3315        // Update responder with new contact
3316        let mut update = proto::UpdateContacts::default();
3317        if accept {
3318            update
3319                .contacts
3320                .push(contact_for_user(requester_id, requester_busy, &pool));
3321        }
3322        update
3323            .remove_incoming_requests
3324            .push(requester_id.to_proto());
3325        for connection_id in pool.user_connection_ids(responder_id) {
3326            session.peer.send(connection_id, update.clone())?;
3327        }
3328
3329        // Update requester with new contact
3330        let mut update = proto::UpdateContacts::default();
3331        if accept {
3332            update
3333                .contacts
3334                .push(contact_for_user(responder_id, responder_busy, &pool));
3335        }
3336        update
3337            .remove_outgoing_requests
3338            .push(responder_id.to_proto());
3339
3340        for connection_id in pool.user_connection_ids(requester_id) {
3341            session.peer.send(connection_id, update.clone())?;
3342        }
3343
3344        send_notifications(&pool, &session.peer, notifications);
3345    }
3346
3347    response.send(proto::Ack {})?;
3348    Ok(())
3349}
3350
3351/// Remove a contact.
3352async fn remove_contact(
3353    request: proto::RemoveContact,
3354    response: Response<proto::RemoveContact>,
3355    session: UserSession,
3356) -> Result<()> {
3357    let requester_id = session.user_id();
3358    let responder_id = UserId::from_proto(request.user_id);
3359    let db = session.db().await;
3360    let (contact_accepted, deleted_notification_id) =
3361        db.remove_contact(requester_id, responder_id).await?;
3362
3363    let pool = session.connection_pool().await;
3364    // Update outgoing contact requests of requester
3365    let mut update = proto::UpdateContacts::default();
3366    if contact_accepted {
3367        update.remove_contacts.push(responder_id.to_proto());
3368    } else {
3369        update
3370            .remove_outgoing_requests
3371            .push(responder_id.to_proto());
3372    }
3373    for connection_id in pool.user_connection_ids(requester_id) {
3374        session.peer.send(connection_id, update.clone())?;
3375    }
3376
3377    // Update incoming contact requests of responder
3378    let mut update = proto::UpdateContacts::default();
3379    if contact_accepted {
3380        update.remove_contacts.push(requester_id.to_proto());
3381    } else {
3382        update
3383            .remove_incoming_requests
3384            .push(requester_id.to_proto());
3385    }
3386    for connection_id in pool.user_connection_ids(responder_id) {
3387        session.peer.send(connection_id, update.clone())?;
3388        if let Some(notification_id) = deleted_notification_id {
3389            session.peer.send(
3390                connection_id,
3391                proto::DeleteNotification {
3392                    notification_id: notification_id.to_proto(),
3393                },
3394            )?;
3395        }
3396    }
3397
3398    response.send(proto::Ack {})?;
3399    Ok(())
3400}
3401
3402/// Creates a new channel.
3403async fn create_channel(
3404    request: proto::CreateChannel,
3405    response: Response<proto::CreateChannel>,
3406    session: UserSession,
3407) -> Result<()> {
3408    let db = session.db().await;
3409
3410    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3411    let (channel, membership) = db
3412        .create_channel(&request.name, parent_id, session.user_id())
3413        .await?;
3414
3415    let root_id = channel.root_id();
3416    let channel = Channel::from_model(channel);
3417
3418    response.send(proto::CreateChannelResponse {
3419        channel: Some(channel.to_proto()),
3420        parent_id: request.parent_id,
3421    })?;
3422
3423    let mut connection_pool = session.connection_pool().await;
3424    if let Some(membership) = membership {
3425        connection_pool.subscribe_to_channel(
3426            membership.user_id,
3427            membership.channel_id,
3428            membership.role,
3429        );
3430        let update = proto::UpdateUserChannels {
3431            channel_memberships: vec![proto::ChannelMembership {
3432                channel_id: membership.channel_id.to_proto(),
3433                role: membership.role.into(),
3434            }],
3435            ..Default::default()
3436        };
3437        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3438            session.peer.send(connection_id, update.clone())?;
3439        }
3440    }
3441
3442    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3443        if !role.can_see_channel(channel.visibility) {
3444            continue;
3445        }
3446
3447        let update = proto::UpdateChannels {
3448            channels: vec![channel.to_proto()],
3449            ..Default::default()
3450        };
3451        session.peer.send(connection_id, update.clone())?;
3452    }
3453
3454    Ok(())
3455}
3456
3457/// Delete a channel
3458async fn delete_channel(
3459    request: proto::DeleteChannel,
3460    response: Response<proto::DeleteChannel>,
3461    session: UserSession,
3462) -> Result<()> {
3463    let db = session.db().await;
3464
3465    let channel_id = request.channel_id;
3466    let (root_channel, removed_channels) = db
3467        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3468        .await?;
3469    response.send(proto::Ack {})?;
3470
3471    // Notify members of removed channels
3472    let mut update = proto::UpdateChannels::default();
3473    update
3474        .delete_channels
3475        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3476
3477    let connection_pool = session.connection_pool().await;
3478    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3479        session.peer.send(connection_id, update.clone())?;
3480    }
3481
3482    Ok(())
3483}
3484
3485/// Invite someone to join a channel.
3486async fn invite_channel_member(
3487    request: proto::InviteChannelMember,
3488    response: Response<proto::InviteChannelMember>,
3489    session: UserSession,
3490) -> Result<()> {
3491    let db = session.db().await;
3492    let channel_id = ChannelId::from_proto(request.channel_id);
3493    let invitee_id = UserId::from_proto(request.user_id);
3494    let InviteMemberResult {
3495        channel,
3496        notifications,
3497    } = db
3498        .invite_channel_member(
3499            channel_id,
3500            invitee_id,
3501            session.user_id(),
3502            request.role().into(),
3503        )
3504        .await?;
3505
3506    let update = proto::UpdateChannels {
3507        channel_invitations: vec![channel.to_proto()],
3508        ..Default::default()
3509    };
3510
3511    let connection_pool = session.connection_pool().await;
3512    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3513        session.peer.send(connection_id, update.clone())?;
3514    }
3515
3516    send_notifications(&connection_pool, &session.peer, notifications);
3517
3518    response.send(proto::Ack {})?;
3519    Ok(())
3520}
3521
3522/// remove someone from a channel
3523async fn remove_channel_member(
3524    request: proto::RemoveChannelMember,
3525    response: Response<proto::RemoveChannelMember>,
3526    session: UserSession,
3527) -> Result<()> {
3528    let db = session.db().await;
3529    let channel_id = ChannelId::from_proto(request.channel_id);
3530    let member_id = UserId::from_proto(request.user_id);
3531
3532    let RemoveChannelMemberResult {
3533        membership_update,
3534        notification_id,
3535    } = db
3536        .remove_channel_member(channel_id, member_id, session.user_id())
3537        .await?;
3538
3539    let mut connection_pool = session.connection_pool().await;
3540    notify_membership_updated(
3541        &mut connection_pool,
3542        membership_update,
3543        member_id,
3544        &session.peer,
3545    );
3546    for connection_id in connection_pool.user_connection_ids(member_id) {
3547        if let Some(notification_id) = notification_id {
3548            session
3549                .peer
3550                .send(
3551                    connection_id,
3552                    proto::DeleteNotification {
3553                        notification_id: notification_id.to_proto(),
3554                    },
3555                )
3556                .trace_err();
3557        }
3558    }
3559
3560    response.send(proto::Ack {})?;
3561    Ok(())
3562}
3563
3564/// Toggle the channel between public and private.
3565/// Care is taken to maintain the invariant that public channels only descend from public channels,
3566/// (though members-only channels can appear at any point in the hierarchy).
3567async fn set_channel_visibility(
3568    request: proto::SetChannelVisibility,
3569    response: Response<proto::SetChannelVisibility>,
3570    session: UserSession,
3571) -> Result<()> {
3572    let db = session.db().await;
3573    let channel_id = ChannelId::from_proto(request.channel_id);
3574    let visibility = request.visibility().into();
3575
3576    let channel_model = db
3577        .set_channel_visibility(channel_id, visibility, session.user_id())
3578        .await?;
3579    let root_id = channel_model.root_id();
3580    let channel = Channel::from_model(channel_model);
3581
3582    let mut connection_pool = session.connection_pool().await;
3583    for (user_id, role) in connection_pool
3584        .channel_user_ids(root_id)
3585        .collect::<Vec<_>>()
3586        .into_iter()
3587    {
3588        let update = if role.can_see_channel(channel.visibility) {
3589            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3590            proto::UpdateChannels {
3591                channels: vec![channel.to_proto()],
3592                ..Default::default()
3593            }
3594        } else {
3595            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3596            proto::UpdateChannels {
3597                delete_channels: vec![channel.id.to_proto()],
3598                ..Default::default()
3599            }
3600        };
3601
3602        for connection_id in connection_pool.user_connection_ids(user_id) {
3603            session.peer.send(connection_id, update.clone())?;
3604        }
3605    }
3606
3607    response.send(proto::Ack {})?;
3608    Ok(())
3609}
3610
3611/// Alter the role for a user in the channel.
3612async fn set_channel_member_role(
3613    request: proto::SetChannelMemberRole,
3614    response: Response<proto::SetChannelMemberRole>,
3615    session: UserSession,
3616) -> Result<()> {
3617    let db = session.db().await;
3618    let channel_id = ChannelId::from_proto(request.channel_id);
3619    let member_id = UserId::from_proto(request.user_id);
3620    let result = db
3621        .set_channel_member_role(
3622            channel_id,
3623            session.user_id(),
3624            member_id,
3625            request.role().into(),
3626        )
3627        .await?;
3628
3629    match result {
3630        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3631            let mut connection_pool = session.connection_pool().await;
3632            notify_membership_updated(
3633                &mut connection_pool,
3634                membership_update,
3635                member_id,
3636                &session.peer,
3637            )
3638        }
3639        db::SetMemberRoleResult::InviteUpdated(channel) => {
3640            let update = proto::UpdateChannels {
3641                channel_invitations: vec![channel.to_proto()],
3642                ..Default::default()
3643            };
3644
3645            for connection_id in session
3646                .connection_pool()
3647                .await
3648                .user_connection_ids(member_id)
3649            {
3650                session.peer.send(connection_id, update.clone())?;
3651            }
3652        }
3653    }
3654
3655    response.send(proto::Ack {})?;
3656    Ok(())
3657}
3658
3659/// Change the name of a channel
3660async fn rename_channel(
3661    request: proto::RenameChannel,
3662    response: Response<proto::RenameChannel>,
3663    session: UserSession,
3664) -> Result<()> {
3665    let db = session.db().await;
3666    let channel_id = ChannelId::from_proto(request.channel_id);
3667    let channel_model = db
3668        .rename_channel(channel_id, session.user_id(), &request.name)
3669        .await?;
3670    let root_id = channel_model.root_id();
3671    let channel = Channel::from_model(channel_model);
3672
3673    response.send(proto::RenameChannelResponse {
3674        channel: Some(channel.to_proto()),
3675    })?;
3676
3677    let connection_pool = session.connection_pool().await;
3678    let update = proto::UpdateChannels {
3679        channels: vec![channel.to_proto()],
3680        ..Default::default()
3681    };
3682    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3683        if role.can_see_channel(channel.visibility) {
3684            session.peer.send(connection_id, update.clone())?;
3685        }
3686    }
3687
3688    Ok(())
3689}
3690
3691/// Move a channel to a new parent.
3692async fn move_channel(
3693    request: proto::MoveChannel,
3694    response: Response<proto::MoveChannel>,
3695    session: UserSession,
3696) -> Result<()> {
3697    let channel_id = ChannelId::from_proto(request.channel_id);
3698    let to = ChannelId::from_proto(request.to);
3699
3700    let (root_id, channels) = session
3701        .db()
3702        .await
3703        .move_channel(channel_id, to, session.user_id())
3704        .await?;
3705
3706    let connection_pool = session.connection_pool().await;
3707    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3708        let channels = channels
3709            .iter()
3710            .filter_map(|channel| {
3711                if role.can_see_channel(channel.visibility) {
3712                    Some(channel.to_proto())
3713                } else {
3714                    None
3715                }
3716            })
3717            .collect::<Vec<_>>();
3718        if channels.is_empty() {
3719            continue;
3720        }
3721
3722        let update = proto::UpdateChannels {
3723            channels,
3724            ..Default::default()
3725        };
3726
3727        session.peer.send(connection_id, update.clone())?;
3728    }
3729
3730    response.send(Ack {})?;
3731    Ok(())
3732}
3733
3734/// Get the list of channel members
3735async fn get_channel_members(
3736    request: proto::GetChannelMembers,
3737    response: Response<proto::GetChannelMembers>,
3738    session: UserSession,
3739) -> Result<()> {
3740    let db = session.db().await;
3741    let channel_id = ChannelId::from_proto(request.channel_id);
3742    let limit = if request.limit == 0 {
3743        u16::MAX as u64
3744    } else {
3745        request.limit
3746    };
3747    let (members, users) = db
3748        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3749        .await?;
3750    response.send(proto::GetChannelMembersResponse { members, users })?;
3751    Ok(())
3752}
3753
3754/// Accept or decline a channel invitation.
3755async fn respond_to_channel_invite(
3756    request: proto::RespondToChannelInvite,
3757    response: Response<proto::RespondToChannelInvite>,
3758    session: UserSession,
3759) -> Result<()> {
3760    let db = session.db().await;
3761    let channel_id = ChannelId::from_proto(request.channel_id);
3762    let RespondToChannelInvite {
3763        membership_update,
3764        notifications,
3765    } = db
3766        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3767        .await?;
3768
3769    let mut connection_pool = session.connection_pool().await;
3770    if let Some(membership_update) = membership_update {
3771        notify_membership_updated(
3772            &mut connection_pool,
3773            membership_update,
3774            session.user_id(),
3775            &session.peer,
3776        );
3777    } else {
3778        let update = proto::UpdateChannels {
3779            remove_channel_invitations: vec![channel_id.to_proto()],
3780            ..Default::default()
3781        };
3782
3783        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3784            session.peer.send(connection_id, update.clone())?;
3785        }
3786    };
3787
3788    send_notifications(&connection_pool, &session.peer, notifications);
3789
3790    response.send(proto::Ack {})?;
3791
3792    Ok(())
3793}
3794
3795/// Join the channels' room
3796async fn join_channel(
3797    request: proto::JoinChannel,
3798    response: Response<proto::JoinChannel>,
3799    session: UserSession,
3800) -> Result<()> {
3801    let channel_id = ChannelId::from_proto(request.channel_id);
3802    join_channel_internal(channel_id, Box::new(response), session).await
3803}
3804
3805trait JoinChannelInternalResponse {
3806    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3807}
3808impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3809    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3810        Response::<proto::JoinChannel>::send(self, result)
3811    }
3812}
3813impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3814    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3815        Response::<proto::JoinRoom>::send(self, result)
3816    }
3817}
3818
3819async fn join_channel_internal(
3820    channel_id: ChannelId,
3821    response: Box<impl JoinChannelInternalResponse>,
3822    session: UserSession,
3823) -> Result<()> {
3824    let joined_room = {
3825        let mut db = session.db().await;
3826        // If zed quits without leaving the room, and the user re-opens zed before the
3827        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3828        // room they were in.
3829        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3830            tracing::info!(
3831                stale_connection_id = %connection,
3832                "cleaning up stale connection",
3833            );
3834            drop(db);
3835            leave_room_for_session(&session, connection).await?;
3836            db = session.db().await;
3837        }
3838
3839        let (joined_room, membership_updated, role) = db
3840            .join_channel(channel_id, session.user_id(), session.connection_id)
3841            .await?;
3842
3843        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3844            let (can_publish, token) = if role == ChannelRole::Guest {
3845                (
3846                    false,
3847                    live_kit
3848                        .guest_token(
3849                            &joined_room.room.live_kit_room,
3850                            &session.user_id().to_string(),
3851                        )
3852                        .trace_err()?,
3853                )
3854            } else {
3855                (
3856                    true,
3857                    live_kit
3858                        .room_token(
3859                            &joined_room.room.live_kit_room,
3860                            &session.user_id().to_string(),
3861                        )
3862                        .trace_err()?,
3863                )
3864            };
3865
3866            Some(LiveKitConnectionInfo {
3867                server_url: live_kit.url().into(),
3868                token,
3869                can_publish,
3870            })
3871        });
3872
3873        response.send(proto::JoinRoomResponse {
3874            room: Some(joined_room.room.clone()),
3875            channel_id: joined_room
3876                .channel
3877                .as_ref()
3878                .map(|channel| channel.id.to_proto()),
3879            live_kit_connection_info,
3880        })?;
3881
3882        let mut connection_pool = session.connection_pool().await;
3883        if let Some(membership_updated) = membership_updated {
3884            notify_membership_updated(
3885                &mut connection_pool,
3886                membership_updated,
3887                session.user_id(),
3888                &session.peer,
3889            );
3890        }
3891
3892        room_updated(&joined_room.room, &session.peer);
3893
3894        joined_room
3895    };
3896
3897    channel_updated(
3898        &joined_room
3899            .channel
3900            .ok_or_else(|| anyhow!("channel not returned"))?,
3901        &joined_room.room,
3902        &session.peer,
3903        &*session.connection_pool().await,
3904    );
3905
3906    update_user_contacts(session.user_id(), &session).await?;
3907    Ok(())
3908}
3909
3910/// Start editing the channel notes
3911async fn join_channel_buffer(
3912    request: proto::JoinChannelBuffer,
3913    response: Response<proto::JoinChannelBuffer>,
3914    session: UserSession,
3915) -> Result<()> {
3916    let db = session.db().await;
3917    let channel_id = ChannelId::from_proto(request.channel_id);
3918
3919    let open_response = db
3920        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3921        .await?;
3922
3923    let collaborators = open_response.collaborators.clone();
3924    response.send(open_response)?;
3925
3926    let update = UpdateChannelBufferCollaborators {
3927        channel_id: channel_id.to_proto(),
3928        collaborators: collaborators.clone(),
3929    };
3930    channel_buffer_updated(
3931        session.connection_id,
3932        collaborators
3933            .iter()
3934            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3935        &update,
3936        &session.peer,
3937    );
3938
3939    Ok(())
3940}
3941
3942/// Edit the channel notes
3943async fn update_channel_buffer(
3944    request: proto::UpdateChannelBuffer,
3945    session: UserSession,
3946) -> Result<()> {
3947    let db = session.db().await;
3948    let channel_id = ChannelId::from_proto(request.channel_id);
3949
3950    let (collaborators, epoch, version) = db
3951        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3952        .await?;
3953
3954    channel_buffer_updated(
3955        session.connection_id,
3956        collaborators.clone(),
3957        &proto::UpdateChannelBuffer {
3958            channel_id: channel_id.to_proto(),
3959            operations: request.operations,
3960        },
3961        &session.peer,
3962    );
3963
3964    let pool = &*session.connection_pool().await;
3965
3966    let non_collaborators =
3967        pool.channel_connection_ids(channel_id)
3968            .filter_map(|(connection_id, _)| {
3969                if collaborators.contains(&connection_id) {
3970                    None
3971                } else {
3972                    Some(connection_id)
3973                }
3974            });
3975
3976    broadcast(None, non_collaborators, |peer_id| {
3977        session.peer.send(
3978            peer_id,
3979            proto::UpdateChannels {
3980                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3981                    channel_id: channel_id.to_proto(),
3982                    epoch: epoch as u64,
3983                    version: version.clone(),
3984                }],
3985                ..Default::default()
3986            },
3987        )
3988    });
3989
3990    Ok(())
3991}
3992
3993/// Rejoin the channel notes after a connection blip
3994async fn rejoin_channel_buffers(
3995    request: proto::RejoinChannelBuffers,
3996    response: Response<proto::RejoinChannelBuffers>,
3997    session: UserSession,
3998) -> Result<()> {
3999    let db = session.db().await;
4000    let buffers = db
4001        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4002        .await?;
4003
4004    for rejoined_buffer in &buffers {
4005        let collaborators_to_notify = rejoined_buffer
4006            .buffer
4007            .collaborators
4008            .iter()
4009            .filter_map(|c| Some(c.peer_id?.into()));
4010        channel_buffer_updated(
4011            session.connection_id,
4012            collaborators_to_notify,
4013            &proto::UpdateChannelBufferCollaborators {
4014                channel_id: rejoined_buffer.buffer.channel_id,
4015                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4016            },
4017            &session.peer,
4018        );
4019    }
4020
4021    response.send(proto::RejoinChannelBuffersResponse {
4022        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4023    })?;
4024
4025    Ok(())
4026}
4027
4028/// Stop editing the channel notes
4029async fn leave_channel_buffer(
4030    request: proto::LeaveChannelBuffer,
4031    response: Response<proto::LeaveChannelBuffer>,
4032    session: UserSession,
4033) -> Result<()> {
4034    let db = session.db().await;
4035    let channel_id = ChannelId::from_proto(request.channel_id);
4036
4037    let left_buffer = db
4038        .leave_channel_buffer(channel_id, session.connection_id)
4039        .await?;
4040
4041    response.send(Ack {})?;
4042
4043    channel_buffer_updated(
4044        session.connection_id,
4045        left_buffer.connections,
4046        &proto::UpdateChannelBufferCollaborators {
4047            channel_id: channel_id.to_proto(),
4048            collaborators: left_buffer.collaborators,
4049        },
4050        &session.peer,
4051    );
4052
4053    Ok(())
4054}
4055
4056fn channel_buffer_updated<T: EnvelopedMessage>(
4057    sender_id: ConnectionId,
4058    collaborators: impl IntoIterator<Item = ConnectionId>,
4059    message: &T,
4060    peer: &Peer,
4061) {
4062    broadcast(Some(sender_id), collaborators, |peer_id| {
4063        peer.send(peer_id, message.clone())
4064    });
4065}
4066
4067fn send_notifications(
4068    connection_pool: &ConnectionPool,
4069    peer: &Peer,
4070    notifications: db::NotificationBatch,
4071) {
4072    for (user_id, notification) in notifications {
4073        for connection_id in connection_pool.user_connection_ids(user_id) {
4074            if let Err(error) = peer.send(
4075                connection_id,
4076                proto::AddNotification {
4077                    notification: Some(notification.clone()),
4078                },
4079            ) {
4080                tracing::error!(
4081                    "failed to send notification to {:?} {}",
4082                    connection_id,
4083                    error
4084                );
4085            }
4086        }
4087    }
4088}
4089
4090/// Send a message to the channel
4091async fn send_channel_message(
4092    request: proto::SendChannelMessage,
4093    response: Response<proto::SendChannelMessage>,
4094    session: UserSession,
4095) -> Result<()> {
4096    // Validate the message body.
4097    let body = request.body.trim().to_string();
4098    if body.len() > MAX_MESSAGE_LEN {
4099        return Err(anyhow!("message is too long"))?;
4100    }
4101    if body.is_empty() {
4102        return Err(anyhow!("message can't be blank"))?;
4103    }
4104
4105    // TODO: adjust mentions if body is trimmed
4106
4107    let timestamp = OffsetDateTime::now_utc();
4108    let nonce = request
4109        .nonce
4110        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4111
4112    let channel_id = ChannelId::from_proto(request.channel_id);
4113    let CreatedChannelMessage {
4114        message_id,
4115        participant_connection_ids,
4116        notifications,
4117    } = session
4118        .db()
4119        .await
4120        .create_channel_message(
4121            channel_id,
4122            session.user_id(),
4123            &body,
4124            &request.mentions,
4125            timestamp,
4126            nonce.clone().into(),
4127            match request.reply_to_message_id {
4128                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4129                None => None,
4130            },
4131        )
4132        .await?;
4133
4134    let message = proto::ChannelMessage {
4135        sender_id: session.user_id().to_proto(),
4136        id: message_id.to_proto(),
4137        body,
4138        mentions: request.mentions,
4139        timestamp: timestamp.unix_timestamp() as u64,
4140        nonce: Some(nonce),
4141        reply_to_message_id: request.reply_to_message_id,
4142        edited_at: None,
4143    };
4144    broadcast(
4145        Some(session.connection_id),
4146        participant_connection_ids.clone(),
4147        |connection| {
4148            session.peer.send(
4149                connection,
4150                proto::ChannelMessageSent {
4151                    channel_id: channel_id.to_proto(),
4152                    message: Some(message.clone()),
4153                },
4154            )
4155        },
4156    );
4157    response.send(proto::SendChannelMessageResponse {
4158        message: Some(message),
4159    })?;
4160
4161    let pool = &*session.connection_pool().await;
4162    let non_participants =
4163        pool.channel_connection_ids(channel_id)
4164            .filter_map(|(connection_id, _)| {
4165                if participant_connection_ids.contains(&connection_id) {
4166                    None
4167                } else {
4168                    Some(connection_id)
4169                }
4170            });
4171    broadcast(None, non_participants, |peer_id| {
4172        session.peer.send(
4173            peer_id,
4174            proto::UpdateChannels {
4175                latest_channel_message_ids: vec![proto::ChannelMessageId {
4176                    channel_id: channel_id.to_proto(),
4177                    message_id: message_id.to_proto(),
4178                }],
4179                ..Default::default()
4180            },
4181        )
4182    });
4183    send_notifications(pool, &session.peer, notifications);
4184
4185    Ok(())
4186}
4187
4188/// Delete a channel message
4189async fn remove_channel_message(
4190    request: proto::RemoveChannelMessage,
4191    response: Response<proto::RemoveChannelMessage>,
4192    session: UserSession,
4193) -> Result<()> {
4194    let channel_id = ChannelId::from_proto(request.channel_id);
4195    let message_id = MessageId::from_proto(request.message_id);
4196    let (connection_ids, existing_notification_ids) = session
4197        .db()
4198        .await
4199        .remove_channel_message(channel_id, message_id, session.user_id())
4200        .await?;
4201
4202    broadcast(
4203        Some(session.connection_id),
4204        connection_ids,
4205        move |connection| {
4206            session.peer.send(connection, request.clone())?;
4207
4208            for notification_id in &existing_notification_ids {
4209                session.peer.send(
4210                    connection,
4211                    proto::DeleteNotification {
4212                        notification_id: (*notification_id).to_proto(),
4213                    },
4214                )?;
4215            }
4216
4217            Ok(())
4218        },
4219    );
4220    response.send(proto::Ack {})?;
4221    Ok(())
4222}
4223
4224async fn update_channel_message(
4225    request: proto::UpdateChannelMessage,
4226    response: Response<proto::UpdateChannelMessage>,
4227    session: UserSession,
4228) -> Result<()> {
4229    let channel_id = ChannelId::from_proto(request.channel_id);
4230    let message_id = MessageId::from_proto(request.message_id);
4231    let updated_at = OffsetDateTime::now_utc();
4232    let UpdatedChannelMessage {
4233        message_id,
4234        participant_connection_ids,
4235        notifications,
4236        reply_to_message_id,
4237        timestamp,
4238        deleted_mention_notification_ids,
4239        updated_mention_notifications,
4240    } = session
4241        .db()
4242        .await
4243        .update_channel_message(
4244            channel_id,
4245            message_id,
4246            session.user_id(),
4247            request.body.as_str(),
4248            &request.mentions,
4249            updated_at,
4250        )
4251        .await?;
4252
4253    let nonce = request
4254        .nonce
4255        .clone()
4256        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4257
4258    let message = proto::ChannelMessage {
4259        sender_id: session.user_id().to_proto(),
4260        id: message_id.to_proto(),
4261        body: request.body.clone(),
4262        mentions: request.mentions.clone(),
4263        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4264        nonce: Some(nonce),
4265        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4266        edited_at: Some(updated_at.unix_timestamp() as u64),
4267    };
4268
4269    response.send(proto::Ack {})?;
4270
4271    let pool = &*session.connection_pool().await;
4272    broadcast(
4273        Some(session.connection_id),
4274        participant_connection_ids,
4275        |connection| {
4276            session.peer.send(
4277                connection,
4278                proto::ChannelMessageUpdate {
4279                    channel_id: channel_id.to_proto(),
4280                    message: Some(message.clone()),
4281                },
4282            )?;
4283
4284            for notification_id in &deleted_mention_notification_ids {
4285                session.peer.send(
4286                    connection,
4287                    proto::DeleteNotification {
4288                        notification_id: (*notification_id).to_proto(),
4289                    },
4290                )?;
4291            }
4292
4293            for notification in &updated_mention_notifications {
4294                session.peer.send(
4295                    connection,
4296                    proto::UpdateNotification {
4297                        notification: Some(notification.clone()),
4298                    },
4299                )?;
4300            }
4301
4302            Ok(())
4303        },
4304    );
4305
4306    send_notifications(pool, &session.peer, notifications);
4307
4308    Ok(())
4309}
4310
4311/// Mark a channel message as read
4312async fn acknowledge_channel_message(
4313    request: proto::AckChannelMessage,
4314    session: UserSession,
4315) -> Result<()> {
4316    let channel_id = ChannelId::from_proto(request.channel_id);
4317    let message_id = MessageId::from_proto(request.message_id);
4318    let notifications = session
4319        .db()
4320        .await
4321        .observe_channel_message(channel_id, session.user_id(), message_id)
4322        .await?;
4323    send_notifications(
4324        &*session.connection_pool().await,
4325        &session.peer,
4326        notifications,
4327    );
4328    Ok(())
4329}
4330
4331/// Mark a buffer version as synced
4332async fn acknowledge_buffer_version(
4333    request: proto::AckBufferOperation,
4334    session: UserSession,
4335) -> Result<()> {
4336    let buffer_id = BufferId::from_proto(request.buffer_id);
4337    session
4338        .db()
4339        .await
4340        .observe_buffer_version(
4341            buffer_id,
4342            session.user_id(),
4343            request.epoch as i32,
4344            &request.version,
4345        )
4346        .await?;
4347    Ok(())
4348}
4349
4350struct CompleteWithLanguageModelRateLimit;
4351
4352impl RateLimit for CompleteWithLanguageModelRateLimit {
4353    fn capacity() -> usize {
4354        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4355            .ok()
4356            .and_then(|v| v.parse().ok())
4357            .unwrap_or(120) // Picked arbitrarily
4358    }
4359
4360    fn refill_duration() -> chrono::Duration {
4361        chrono::Duration::hours(1)
4362    }
4363
4364    fn db_name() -> &'static str {
4365        "complete-with-language-model"
4366    }
4367}
4368
4369async fn complete_with_language_model(
4370    request: proto::CompleteWithLanguageModel,
4371    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4372    session: Session,
4373    open_ai_api_key: Option<Arc<str>>,
4374    google_ai_api_key: Option<Arc<str>>,
4375    anthropic_api_key: Option<Arc<str>>,
4376) -> Result<()> {
4377    let Some(session) = session.for_user() else {
4378        return Err(anyhow!("user not found"))?;
4379    };
4380    authorize_access_to_language_models(&session).await?;
4381    session
4382        .rate_limiter
4383        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4384        .await?;
4385
4386    if request.model.starts_with("gpt") {
4387        let api_key =
4388            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4389        complete_with_open_ai(request, response, session, api_key).await?;
4390    } else if request.model.starts_with("gemini") {
4391        let api_key = google_ai_api_key
4392            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4393        complete_with_google_ai(request, response, session, api_key).await?;
4394    } else if request.model.starts_with("claude") {
4395        let api_key = anthropic_api_key
4396            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4397        complete_with_anthropic(request, response, session, api_key).await?;
4398    }
4399
4400    Ok(())
4401}
4402
4403async fn complete_with_open_ai(
4404    request: proto::CompleteWithLanguageModel,
4405    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4406    session: UserSession,
4407    api_key: Arc<str>,
4408) -> Result<()> {
4409    let mut completion_stream = open_ai::stream_completion(
4410        session.http_client.as_ref(),
4411        OPEN_AI_API_URL,
4412        &api_key,
4413        crate::ai::language_model_request_to_open_ai(request)?,
4414        None,
4415    )
4416    .await
4417    .context("open_ai::stream_completion request failed within collab")?;
4418
4419    while let Some(event) = completion_stream.next().await {
4420        let event = event?;
4421        response.send(proto::LanguageModelResponse {
4422            choices: event
4423                .choices
4424                .into_iter()
4425                .map(|choice| proto::LanguageModelChoiceDelta {
4426                    index: choice.index,
4427                    delta: Some(proto::LanguageModelResponseMessage {
4428                        role: choice.delta.role.map(|role| match role {
4429                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4430                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4431                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4432                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4433                        } as i32),
4434                        content: choice.delta.content,
4435                        tool_calls: choice
4436                            .delta
4437                            .tool_calls
4438                            .into_iter()
4439                            .map(|delta| proto::ToolCallDelta {
4440                                index: delta.index as u32,
4441                                id: delta.id,
4442                                variant: match delta.function {
4443                                    Some(function) => {
4444                                        let name = function.name;
4445                                        let arguments = function.arguments;
4446
4447                                        Some(proto::tool_call_delta::Variant::Function(
4448                                            proto::tool_call_delta::FunctionCallDelta {
4449                                                name,
4450                                                arguments,
4451                                            },
4452                                        ))
4453                                    }
4454                                    None => None,
4455                                },
4456                            })
4457                            .collect(),
4458                    }),
4459                    finish_reason: choice.finish_reason,
4460                })
4461                .collect(),
4462        })?;
4463    }
4464
4465    Ok(())
4466}
4467
4468async fn complete_with_google_ai(
4469    request: proto::CompleteWithLanguageModel,
4470    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4471    session: UserSession,
4472    api_key: Arc<str>,
4473) -> Result<()> {
4474    let mut stream = google_ai::stream_generate_content(
4475        session.http_client.clone(),
4476        google_ai::API_URL,
4477        api_key.as_ref(),
4478        crate::ai::language_model_request_to_google_ai(request)?,
4479    )
4480    .await
4481    .context("google_ai::stream_generate_content request failed")?;
4482
4483    while let Some(event) = stream.next().await {
4484        let event = event?;
4485        response.send(proto::LanguageModelResponse {
4486            choices: event
4487                .candidates
4488                .unwrap_or_default()
4489                .into_iter()
4490                .map(|candidate| proto::LanguageModelChoiceDelta {
4491                    index: candidate.index as u32,
4492                    delta: Some(proto::LanguageModelResponseMessage {
4493                        role: Some(match candidate.content.role {
4494                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4495                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4496                        } as i32),
4497                        content: Some(
4498                            candidate
4499                                .content
4500                                .parts
4501                                .into_iter()
4502                                .filter_map(|part| match part {
4503                                    google_ai::Part::TextPart(part) => Some(part.text),
4504                                    google_ai::Part::InlineDataPart(_) => None,
4505                                })
4506                                .collect(),
4507                        ),
4508                        // Tool calls are not supported for Google
4509                        tool_calls: Vec::new(),
4510                    }),
4511                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4512                })
4513                .collect(),
4514        })?;
4515    }
4516
4517    Ok(())
4518}
4519
4520async fn complete_with_anthropic(
4521    request: proto::CompleteWithLanguageModel,
4522    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4523    session: UserSession,
4524    api_key: Arc<str>,
4525) -> Result<()> {
4526    let model = anthropic::Model::from_id(&request.model)?;
4527
4528    let mut system_message = String::new();
4529    let messages = request
4530        .messages
4531        .into_iter()
4532        .filter_map(|message| {
4533            match message.role() {
4534                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4535                    role: anthropic::Role::User,
4536                    content: message.content,
4537                }),
4538                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4539                    role: anthropic::Role::Assistant,
4540                    content: message.content,
4541                }),
4542                // Anthropic's API breaks system instructions out as a separate field rather
4543                // than having a system message role.
4544                LanguageModelRole::LanguageModelSystem => {
4545                    if !system_message.is_empty() {
4546                        system_message.push_str("\n\n");
4547                    }
4548                    system_message.push_str(&message.content);
4549
4550                    None
4551                }
4552                // We don't yet support tool calls for Anthropic
4553                LanguageModelRole::LanguageModelTool => None,
4554            }
4555        })
4556        .collect();
4557
4558    let mut stream = anthropic::stream_completion(
4559        session.http_client.as_ref(),
4560        anthropic::ANTHROPIC_API_URL,
4561        &api_key,
4562        anthropic::Request {
4563            model,
4564            messages,
4565            stream: true,
4566            system: system_message,
4567            max_tokens: 4092,
4568        },
4569        None,
4570    )
4571    .await?;
4572
4573    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4574
4575    while let Some(event) = stream.next().await {
4576        let event = event?;
4577
4578        match event {
4579            anthropic::ResponseEvent::MessageStart { message } => {
4580                if let Some(role) = message.role {
4581                    if role == "assistant" {
4582                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4583                    } else if role == "user" {
4584                        current_role = proto::LanguageModelRole::LanguageModelUser;
4585                    }
4586                }
4587            }
4588            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4589                match content_block {
4590                    anthropic::ContentBlock::Text { text } => {
4591                        if !text.is_empty() {
4592                            response.send(proto::LanguageModelResponse {
4593                                choices: vec![proto::LanguageModelChoiceDelta {
4594                                    index: 0,
4595                                    delta: Some(proto::LanguageModelResponseMessage {
4596                                        role: Some(current_role as i32),
4597                                        content: Some(text),
4598                                        tool_calls: Vec::new(),
4599                                    }),
4600                                    finish_reason: None,
4601                                }],
4602                            })?;
4603                        }
4604                    }
4605                }
4606            }
4607            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4608                anthropic::TextDelta::TextDelta { text } => {
4609                    response.send(proto::LanguageModelResponse {
4610                        choices: vec![proto::LanguageModelChoiceDelta {
4611                            index: 0,
4612                            delta: Some(proto::LanguageModelResponseMessage {
4613                                role: Some(current_role as i32),
4614                                content: Some(text),
4615                                tool_calls: Vec::new(),
4616                            }),
4617                            finish_reason: None,
4618                        }],
4619                    })?;
4620                }
4621            },
4622            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4623                if let Some(stop_reason) = delta.stop_reason {
4624                    response.send(proto::LanguageModelResponse {
4625                        choices: vec![proto::LanguageModelChoiceDelta {
4626                            index: 0,
4627                            delta: None,
4628                            finish_reason: Some(stop_reason),
4629                        }],
4630                    })?;
4631                }
4632            }
4633            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4634            anthropic::ResponseEvent::MessageStop {} => {}
4635            anthropic::ResponseEvent::Ping {} => {}
4636        }
4637    }
4638
4639    Ok(())
4640}
4641
4642struct CountTokensWithLanguageModelRateLimit;
4643
4644impl RateLimit for CountTokensWithLanguageModelRateLimit {
4645    fn capacity() -> usize {
4646        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4647            .ok()
4648            .and_then(|v| v.parse().ok())
4649            .unwrap_or(600) // Picked arbitrarily
4650    }
4651
4652    fn refill_duration() -> chrono::Duration {
4653        chrono::Duration::hours(1)
4654    }
4655
4656    fn db_name() -> &'static str {
4657        "count-tokens-with-language-model"
4658    }
4659}
4660
4661async fn count_tokens_with_language_model(
4662    request: proto::CountTokensWithLanguageModel,
4663    response: Response<proto::CountTokensWithLanguageModel>,
4664    session: UserSession,
4665    google_ai_api_key: Option<Arc<str>>,
4666) -> Result<()> {
4667    authorize_access_to_language_models(&session).await?;
4668
4669    if !request.model.starts_with("gemini") {
4670        return Err(anyhow!(
4671            "counting tokens for model: {:?} is not supported",
4672            request.model
4673        ))?;
4674    }
4675
4676    session
4677        .rate_limiter
4678        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4679        .await?;
4680
4681    let api_key = google_ai_api_key
4682        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4683    let tokens_response = google_ai::count_tokens(
4684        session.http_client.as_ref(),
4685        google_ai::API_URL,
4686        &api_key,
4687        crate::ai::count_tokens_request_to_google_ai(request)?,
4688    )
4689    .await?;
4690    response.send(proto::CountTokensResponse {
4691        token_count: tokens_response.total_tokens as u32,
4692    })?;
4693    Ok(())
4694}
4695
4696struct ComputeEmbeddingsRateLimit;
4697
4698impl RateLimit for ComputeEmbeddingsRateLimit {
4699    fn capacity() -> usize {
4700        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4701            .ok()
4702            .and_then(|v| v.parse().ok())
4703            .unwrap_or(5000) // Picked arbitrarily
4704    }
4705
4706    fn refill_duration() -> chrono::Duration {
4707        chrono::Duration::hours(1)
4708    }
4709
4710    fn db_name() -> &'static str {
4711        "compute-embeddings"
4712    }
4713}
4714
4715async fn compute_embeddings(
4716    request: proto::ComputeEmbeddings,
4717    response: Response<proto::ComputeEmbeddings>,
4718    session: UserSession,
4719    api_key: Option<Arc<str>>,
4720) -> Result<()> {
4721    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4722    authorize_access_to_language_models(&session).await?;
4723
4724    session
4725        .rate_limiter
4726        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4727        .await?;
4728
4729    let embeddings = match request.model.as_str() {
4730        "openai/text-embedding-3-small" => {
4731            open_ai::embed(
4732                session.http_client.as_ref(),
4733                OPEN_AI_API_URL,
4734                &api_key,
4735                OpenAiEmbeddingModel::TextEmbedding3Small,
4736                request.texts.iter().map(|text| text.as_str()),
4737            )
4738            .await?
4739        }
4740        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4741    };
4742
4743    let embeddings = request
4744        .texts
4745        .iter()
4746        .map(|text| {
4747            let mut hasher = sha2::Sha256::new();
4748            hasher.update(text.as_bytes());
4749            let result = hasher.finalize();
4750            result.to_vec()
4751        })
4752        .zip(
4753            embeddings
4754                .data
4755                .into_iter()
4756                .map(|embedding| embedding.embedding),
4757        )
4758        .collect::<HashMap<_, _>>();
4759
4760    let db = session.db().await;
4761    db.save_embeddings(&request.model, &embeddings)
4762        .await
4763        .context("failed to save embeddings")
4764        .trace_err();
4765
4766    response.send(proto::ComputeEmbeddingsResponse {
4767        embeddings: embeddings
4768            .into_iter()
4769            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4770            .collect(),
4771    })?;
4772    Ok(())
4773}
4774
4775async fn get_cached_embeddings(
4776    request: proto::GetCachedEmbeddings,
4777    response: Response<proto::GetCachedEmbeddings>,
4778    session: UserSession,
4779) -> Result<()> {
4780    authorize_access_to_language_models(&session).await?;
4781
4782    let db = session.db().await;
4783    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4784
4785    response.send(proto::GetCachedEmbeddingsResponse {
4786        embeddings: embeddings
4787            .into_iter()
4788            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4789            .collect(),
4790    })?;
4791    Ok(())
4792}
4793
4794async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4795    let db = session.db().await;
4796    let flags = db.get_user_flags(session.user_id()).await?;
4797    if flags.iter().any(|flag| flag == "language-models") {
4798        Ok(())
4799    } else {
4800        Err(anyhow!("permission denied"))?
4801    }
4802}
4803
4804/// Get a Supermaven API key for the user
4805async fn get_supermaven_api_key(
4806    _request: proto::GetSupermavenApiKey,
4807    response: Response<proto::GetSupermavenApiKey>,
4808    session: UserSession,
4809) -> Result<()> {
4810    let user_id: String = session.user_id().to_string();
4811    if !session.is_staff() {
4812        return Err(anyhow!("supermaven not enabled for this account"))?;
4813    }
4814
4815    let email = session
4816        .email()
4817        .ok_or_else(|| anyhow!("user must have an email"))?;
4818
4819    let supermaven_admin_api = session
4820        .supermaven_client
4821        .as_ref()
4822        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4823
4824    let result = supermaven_admin_api
4825        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4826        .await?;
4827
4828    response.send(proto::GetSupermavenApiKeyResponse {
4829        api_key: result.api_key,
4830    })?;
4831
4832    Ok(())
4833}
4834
4835/// Start receiving chat updates for a channel
4836async fn join_channel_chat(
4837    request: proto::JoinChannelChat,
4838    response: Response<proto::JoinChannelChat>,
4839    session: UserSession,
4840) -> Result<()> {
4841    let channel_id = ChannelId::from_proto(request.channel_id);
4842
4843    let db = session.db().await;
4844    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4845        .await?;
4846    let messages = db
4847        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4848        .await?;
4849    response.send(proto::JoinChannelChatResponse {
4850        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4851        messages,
4852    })?;
4853    Ok(())
4854}
4855
4856/// Stop receiving chat updates for a channel
4857async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4858    let channel_id = ChannelId::from_proto(request.channel_id);
4859    session
4860        .db()
4861        .await
4862        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4863        .await?;
4864    Ok(())
4865}
4866
4867/// Retrieve the chat history for a channel
4868async fn get_channel_messages(
4869    request: proto::GetChannelMessages,
4870    response: Response<proto::GetChannelMessages>,
4871    session: UserSession,
4872) -> Result<()> {
4873    let channel_id = ChannelId::from_proto(request.channel_id);
4874    let messages = session
4875        .db()
4876        .await
4877        .get_channel_messages(
4878            channel_id,
4879            session.user_id(),
4880            MESSAGE_COUNT_PER_PAGE,
4881            Some(MessageId::from_proto(request.before_message_id)),
4882        )
4883        .await?;
4884    response.send(proto::GetChannelMessagesResponse {
4885        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4886        messages,
4887    })?;
4888    Ok(())
4889}
4890
4891/// Retrieve specific chat messages
4892async fn get_channel_messages_by_id(
4893    request: proto::GetChannelMessagesById,
4894    response: Response<proto::GetChannelMessagesById>,
4895    session: UserSession,
4896) -> Result<()> {
4897    let message_ids = request
4898        .message_ids
4899        .iter()
4900        .map(|id| MessageId::from_proto(*id))
4901        .collect::<Vec<_>>();
4902    let messages = session
4903        .db()
4904        .await
4905        .get_channel_messages_by_id(session.user_id(), &message_ids)
4906        .await?;
4907    response.send(proto::GetChannelMessagesResponse {
4908        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4909        messages,
4910    })?;
4911    Ok(())
4912}
4913
4914/// Retrieve the current users notifications
4915async fn get_notifications(
4916    request: proto::GetNotifications,
4917    response: Response<proto::GetNotifications>,
4918    session: UserSession,
4919) -> Result<()> {
4920    let notifications = session
4921        .db()
4922        .await
4923        .get_notifications(
4924            session.user_id(),
4925            NOTIFICATION_COUNT_PER_PAGE,
4926            request
4927                .before_id
4928                .map(|id| db::NotificationId::from_proto(id)),
4929        )
4930        .await?;
4931    response.send(proto::GetNotificationsResponse {
4932        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4933        notifications,
4934    })?;
4935    Ok(())
4936}
4937
4938/// Mark notifications as read
4939async fn mark_notification_as_read(
4940    request: proto::MarkNotificationRead,
4941    response: Response<proto::MarkNotificationRead>,
4942    session: UserSession,
4943) -> Result<()> {
4944    let database = &session.db().await;
4945    let notifications = database
4946        .mark_notification_as_read_by_id(
4947            session.user_id(),
4948            NotificationId::from_proto(request.notification_id),
4949        )
4950        .await?;
4951    send_notifications(
4952        &*session.connection_pool().await,
4953        &session.peer,
4954        notifications,
4955    );
4956    response.send(proto::Ack {})?;
4957    Ok(())
4958}
4959
4960/// Get the current users information
4961async fn get_private_user_info(
4962    _request: proto::GetPrivateUserInfo,
4963    response: Response<proto::GetPrivateUserInfo>,
4964    session: UserSession,
4965) -> Result<()> {
4966    let db = session.db().await;
4967
4968    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4969    let user = db
4970        .get_user_by_id(session.user_id())
4971        .await?
4972        .ok_or_else(|| anyhow!("user not found"))?;
4973    let flags = db.get_user_flags(session.user_id()).await?;
4974
4975    response.send(proto::GetPrivateUserInfoResponse {
4976        metrics_id,
4977        staff: user.admin,
4978        flags,
4979    })?;
4980    Ok(())
4981}
4982
4983fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4984    match message {
4985        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4986        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4987        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4988        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4989        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4990            code: frame.code.into(),
4991            reason: frame.reason,
4992        })),
4993    }
4994}
4995
4996fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4997    match message {
4998        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4999        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5000        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5001        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5002        AxumMessage::Close(frame) => {
5003            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5004                code: frame.code.into(),
5005                reason: frame.reason,
5006            }))
5007        }
5008    }
5009}
5010
5011fn notify_membership_updated(
5012    connection_pool: &mut ConnectionPool,
5013    result: MembershipUpdated,
5014    user_id: UserId,
5015    peer: &Peer,
5016) {
5017    for membership in &result.new_channels.channel_memberships {
5018        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5019    }
5020    for channel_id in &result.removed_channels {
5021        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5022    }
5023
5024    let user_channels_update = proto::UpdateUserChannels {
5025        channel_memberships: result
5026            .new_channels
5027            .channel_memberships
5028            .iter()
5029            .map(|cm| proto::ChannelMembership {
5030                channel_id: cm.channel_id.to_proto(),
5031                role: cm.role.into(),
5032            })
5033            .collect(),
5034        ..Default::default()
5035    };
5036
5037    let mut update = build_channels_update(result.new_channels, vec![]);
5038    update.delete_channels = result
5039        .removed_channels
5040        .into_iter()
5041        .map(|id| id.to_proto())
5042        .collect();
5043    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5044
5045    for connection_id in connection_pool.user_connection_ids(user_id) {
5046        peer.send(connection_id, user_channels_update.clone())
5047            .trace_err();
5048        peer.send(connection_id, update.clone()).trace_err();
5049    }
5050}
5051
5052fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5053    proto::UpdateUserChannels {
5054        channel_memberships: channels
5055            .channel_memberships
5056            .iter()
5057            .map(|m| proto::ChannelMembership {
5058                channel_id: m.channel_id.to_proto(),
5059                role: m.role.into(),
5060            })
5061            .collect(),
5062        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5063        observed_channel_message_id: channels.observed_channel_messages.clone(),
5064    }
5065}
5066
5067fn build_channels_update(
5068    channels: ChannelsForUser,
5069    channel_invites: Vec<db::Channel>,
5070) -> proto::UpdateChannels {
5071    let mut update = proto::UpdateChannels::default();
5072
5073    for channel in channels.channels {
5074        update.channels.push(channel.to_proto());
5075    }
5076
5077    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5078    update.latest_channel_message_ids = channels.latest_channel_messages;
5079
5080    for (channel_id, participants) in channels.channel_participants {
5081        update
5082            .channel_participants
5083            .push(proto::ChannelParticipants {
5084                channel_id: channel_id.to_proto(),
5085                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5086            });
5087    }
5088
5089    for channel in channel_invites {
5090        update.channel_invitations.push(channel.to_proto());
5091    }
5092
5093    update.hosted_projects = channels.hosted_projects;
5094    update
5095}
5096
5097fn build_initial_contacts_update(
5098    contacts: Vec<db::Contact>,
5099    pool: &ConnectionPool,
5100) -> proto::UpdateContacts {
5101    let mut update = proto::UpdateContacts::default();
5102
5103    for contact in contacts {
5104        match contact {
5105            db::Contact::Accepted { user_id, busy } => {
5106                update.contacts.push(contact_for_user(user_id, busy, &pool));
5107            }
5108            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5109            db::Contact::Incoming { user_id } => {
5110                update
5111                    .incoming_requests
5112                    .push(proto::IncomingContactRequest {
5113                        requester_id: user_id.to_proto(),
5114                    })
5115            }
5116        }
5117    }
5118
5119    update
5120}
5121
5122fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5123    proto::Contact {
5124        user_id: user_id.to_proto(),
5125        online: pool.is_user_online(user_id),
5126        busy,
5127    }
5128}
5129
5130fn room_updated(room: &proto::Room, peer: &Peer) {
5131    broadcast(
5132        None,
5133        room.participants
5134            .iter()
5135            .filter_map(|participant| Some(participant.peer_id?.into())),
5136        |peer_id| {
5137            peer.send(
5138                peer_id,
5139                proto::RoomUpdated {
5140                    room: Some(room.clone()),
5141                },
5142            )
5143        },
5144    );
5145}
5146
5147fn channel_updated(
5148    channel: &db::channel::Model,
5149    room: &proto::Room,
5150    peer: &Peer,
5151    pool: &ConnectionPool,
5152) {
5153    let participants = room
5154        .participants
5155        .iter()
5156        .map(|p| p.user_id)
5157        .collect::<Vec<_>>();
5158
5159    broadcast(
5160        None,
5161        pool.channel_connection_ids(channel.root_id())
5162            .filter_map(|(channel_id, role)| {
5163                role.can_see_channel(channel.visibility).then(|| channel_id)
5164            }),
5165        |peer_id| {
5166            peer.send(
5167                peer_id,
5168                proto::UpdateChannels {
5169                    channel_participants: vec![proto::ChannelParticipants {
5170                        channel_id: channel.id.to_proto(),
5171                        participant_user_ids: participants.clone(),
5172                    }],
5173                    ..Default::default()
5174                },
5175            )
5176        },
5177    );
5178}
5179
5180async fn send_dev_server_projects_update(
5181    user_id: UserId,
5182    mut status: proto::DevServerProjectsUpdate,
5183    session: &Session,
5184) {
5185    let pool = session.connection_pool().await;
5186    for dev_server in &mut status.dev_servers {
5187        dev_server.status =
5188            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5189    }
5190    let connections = pool.user_connection_ids(user_id);
5191    for connection_id in connections {
5192        session.peer.send(connection_id, status.clone()).trace_err();
5193    }
5194}
5195
5196async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5197    let db = session.db().await;
5198
5199    let contacts = db.get_contacts(user_id).await?;
5200    let busy = db.is_user_busy(user_id).await?;
5201
5202    let pool = session.connection_pool().await;
5203    let updated_contact = contact_for_user(user_id, busy, &pool);
5204    for contact in contacts {
5205        if let db::Contact::Accepted {
5206            user_id: contact_user_id,
5207            ..
5208        } = contact
5209        {
5210            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5211                session
5212                    .peer
5213                    .send(
5214                        contact_conn_id,
5215                        proto::UpdateContacts {
5216                            contacts: vec![updated_contact.clone()],
5217                            remove_contacts: Default::default(),
5218                            incoming_requests: Default::default(),
5219                            remove_incoming_requests: Default::default(),
5220                            outgoing_requests: Default::default(),
5221                            remove_outgoing_requests: Default::default(),
5222                        },
5223                    )
5224                    .trace_err();
5225            }
5226        }
5227    }
5228    Ok(())
5229}
5230
5231async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5232    log::info!("lost dev server connection, unsharing projects");
5233    let project_ids = session
5234        .db()
5235        .await
5236        .get_stale_dev_server_projects(session.connection_id)
5237        .await?;
5238
5239    for project_id in project_ids {
5240        // not unshare re-checks the connection ids match, so we get away with no transaction
5241        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5242    }
5243
5244    let user_id = session.dev_server().user_id;
5245    let update = session
5246        .db()
5247        .await
5248        .dev_server_projects_update(user_id)
5249        .await?;
5250
5251    send_dev_server_projects_update(user_id, update, session).await;
5252
5253    Ok(())
5254}
5255
5256async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5257    let mut contacts_to_update = HashSet::default();
5258
5259    let room_id;
5260    let canceled_calls_to_user_ids;
5261    let live_kit_room;
5262    let delete_live_kit_room;
5263    let room;
5264    let channel;
5265
5266    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5267        contacts_to_update.insert(session.user_id());
5268
5269        for project in left_room.left_projects.values() {
5270            project_left(project, session);
5271        }
5272
5273        room_id = RoomId::from_proto(left_room.room.id);
5274        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5275        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5276        delete_live_kit_room = left_room.deleted;
5277        room = mem::take(&mut left_room.room);
5278        channel = mem::take(&mut left_room.channel);
5279
5280        room_updated(&room, &session.peer);
5281    } else {
5282        return Ok(());
5283    }
5284
5285    if let Some(channel) = channel {
5286        channel_updated(
5287            &channel,
5288            &room,
5289            &session.peer,
5290            &*session.connection_pool().await,
5291        );
5292    }
5293
5294    {
5295        let pool = session.connection_pool().await;
5296        for canceled_user_id in canceled_calls_to_user_ids {
5297            for connection_id in pool.user_connection_ids(canceled_user_id) {
5298                session
5299                    .peer
5300                    .send(
5301                        connection_id,
5302                        proto::CallCanceled {
5303                            room_id: room_id.to_proto(),
5304                        },
5305                    )
5306                    .trace_err();
5307            }
5308            contacts_to_update.insert(canceled_user_id);
5309        }
5310    }
5311
5312    for contact_user_id in contacts_to_update {
5313        update_user_contacts(contact_user_id, &session).await?;
5314    }
5315
5316    if let Some(live_kit) = session.live_kit_client.as_ref() {
5317        live_kit
5318            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5319            .await
5320            .trace_err();
5321
5322        if delete_live_kit_room {
5323            live_kit.delete_room(live_kit_room).await.trace_err();
5324        }
5325    }
5326
5327    Ok(())
5328}
5329
5330async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5331    let left_channel_buffers = session
5332        .db()
5333        .await
5334        .leave_channel_buffers(session.connection_id)
5335        .await?;
5336
5337    for left_buffer in left_channel_buffers {
5338        channel_buffer_updated(
5339            session.connection_id,
5340            left_buffer.connections,
5341            &proto::UpdateChannelBufferCollaborators {
5342                channel_id: left_buffer.channel_id.to_proto(),
5343                collaborators: left_buffer.collaborators,
5344            },
5345            &session.peer,
5346        );
5347    }
5348
5349    Ok(())
5350}
5351
5352fn project_left(project: &db::LeftProject, session: &UserSession) {
5353    for connection_id in &project.connection_ids {
5354        if project.should_unshare {
5355            session
5356                .peer
5357                .send(
5358                    *connection_id,
5359                    proto::UnshareProject {
5360                        project_id: project.id.to_proto(),
5361                    },
5362                )
5363                .trace_err();
5364        } else {
5365            session
5366                .peer
5367                .send(
5368                    *connection_id,
5369                    proto::RemoveProjectCollaborator {
5370                        project_id: project.id.to_proto(),
5371                        peer_id: Some(session.connection_id.into()),
5372                    },
5373                )
5374                .trace_err();
5375        }
5376    }
5377}
5378
5379pub trait ResultExt {
5380    type Ok;
5381
5382    fn trace_err(self) -> Option<Self::Ok>;
5383}
5384
5385impl<T, E> ResultExt for Result<T, E>
5386where
5387    E: std::fmt::Debug,
5388{
5389    type Ok = T;
5390
5391    #[track_caller]
5392    fn trace_err(self) -> Option<T> {
5393        match self {
5394            Ok(value) => Some(value),
5395            Err(error) => {
5396                tracing::error!("{:?}", error);
5397                None
5398            }
5399        }
5400    }
5401}