rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  50        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(delete_dev_server_project))
 435            .add_request_handler(user_handler(create_dev_server))
 436            .add_request_handler(user_handler(regenerate_dev_server_token))
 437            .add_request_handler(user_handler(rename_dev_server))
 438            .add_request_handler(user_handler(delete_dev_server))
 439            .add_request_handler(dev_server_handler(share_dev_server_project))
 440            .add_request_handler(dev_server_handler(shutdown_dev_server))
 441            .add_request_handler(dev_server_handler(reconnect_dev_server))
 442            .add_message_handler(user_message_handler(leave_project))
 443            .add_request_handler(update_project)
 444            .add_request_handler(update_worktree)
 445            .add_message_handler(start_language_server)
 446            .add_message_handler(update_language_server)
 447            .add_message_handler(update_diagnostic_summary)
 448            .add_message_handler(update_worktree_settings)
 449            .add_request_handler(user_handler(
 450                forward_read_only_project_request::<proto::GetHover>,
 451            ))
 452            .add_request_handler(user_handler(
 453                forward_read_only_project_request::<proto::GetDefinition>,
 454            ))
 455            .add_request_handler(user_handler(
 456                forward_read_only_project_request::<proto::GetTypeDefinition>,
 457            ))
 458            .add_request_handler(user_handler(
 459                forward_read_only_project_request::<proto::GetReferences>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_read_only_project_request::<proto::SearchProject>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::GetProjectSymbols>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::OpenBufferById>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_read_only_project_request::<proto::InlayHints>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_read_only_project_request::<proto::OpenBufferByPath>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_mutating_project_request::<proto::GetCompletions>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_mutating_project_request::<proto::GetCodeActions>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_mutating_project_request::<proto::ApplyCodeAction>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::PrepareRename>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::PerformRename>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::ReloadBuffers>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::FormatBuffers>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::CreateProjectEntry>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::RenameProjectEntry>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::CopyProjectEntry>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::OnTypeFormatting>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::BlameBuffer>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_mutating_project_request::<proto::MultiLspQuery>,
 541            ))
 542            .add_message_handler(create_buffer_for_peer)
 543            .add_request_handler(update_buffer)
 544            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 545            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 546            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 547            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 548            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 549            .add_request_handler(get_users)
 550            .add_request_handler(user_handler(fuzzy_search_users))
 551            .add_request_handler(user_handler(request_contact))
 552            .add_request_handler(user_handler(remove_contact))
 553            .add_request_handler(user_handler(respond_to_contact_request))
 554            .add_request_handler(user_handler(create_channel))
 555            .add_request_handler(user_handler(delete_channel))
 556            .add_request_handler(user_handler(invite_channel_member))
 557            .add_request_handler(user_handler(remove_channel_member))
 558            .add_request_handler(user_handler(set_channel_member_role))
 559            .add_request_handler(user_handler(set_channel_visibility))
 560            .add_request_handler(user_handler(rename_channel))
 561            .add_request_handler(user_handler(join_channel_buffer))
 562            .add_request_handler(user_handler(leave_channel_buffer))
 563            .add_message_handler(user_message_handler(update_channel_buffer))
 564            .add_request_handler(user_handler(rejoin_channel_buffers))
 565            .add_request_handler(user_handler(get_channel_members))
 566            .add_request_handler(user_handler(respond_to_channel_invite))
 567            .add_request_handler(user_handler(join_channel))
 568            .add_request_handler(user_handler(join_channel_chat))
 569            .add_message_handler(user_message_handler(leave_channel_chat))
 570            .add_request_handler(user_handler(send_channel_message))
 571            .add_request_handler(user_handler(remove_channel_message))
 572            .add_request_handler(user_handler(update_channel_message))
 573            .add_request_handler(user_handler(get_channel_messages))
 574            .add_request_handler(user_handler(get_channel_messages_by_id))
 575            .add_request_handler(user_handler(get_notifications))
 576            .add_request_handler(user_handler(mark_notification_as_read))
 577            .add_request_handler(user_handler(move_channel))
 578            .add_request_handler(user_handler(follow))
 579            .add_message_handler(user_message_handler(unfollow))
 580            .add_message_handler(user_message_handler(update_followers))
 581            .add_request_handler(user_handler(get_private_user_info))
 582            .add_message_handler(user_message_handler(acknowledge_channel_message))
 583            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 584            .add_request_handler(user_handler(get_supermaven_api_key))
 585            .add_streaming_request_handler({
 586                let app_state = app_state.clone();
 587                move |request, response, session| {
 588                    complete_with_language_model(
 589                        request,
 590                        response,
 591                        session,
 592                        app_state.config.openai_api_key.clone(),
 593                        app_state.config.google_ai_api_key.clone(),
 594                        app_state.config.anthropic_api_key.clone(),
 595                    )
 596                }
 597            })
 598            .add_request_handler({
 599                let app_state = app_state.clone();
 600                user_handler(move |request, response, session| {
 601                    count_tokens_with_language_model(
 602                        request,
 603                        response,
 604                        session,
 605                        app_state.config.google_ai_api_key.clone(),
 606                    )
 607                })
 608            })
 609            .add_request_handler({
 610                user_handler(move |request, response, session| {
 611                    get_cached_embeddings(request, response, session)
 612                })
 613            })
 614            .add_request_handler({
 615                let app_state = app_state.clone();
 616                user_handler(move |request, response, session| {
 617                    compute_embeddings(
 618                        request,
 619                        response,
 620                        session,
 621                        app_state.config.openai_api_key.clone(),
 622                    )
 623                })
 624            });
 625
 626        Arc::new(server)
 627    }
 628
 629    pub async fn start(&self) -> Result<()> {
 630        let server_id = *self.id.lock();
 631        let app_state = self.app_state.clone();
 632        let peer = self.peer.clone();
 633        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 634        let pool = self.connection_pool.clone();
 635        let live_kit_client = self.app_state.live_kit_client.clone();
 636
 637        let span = info_span!("start server");
 638        self.app_state.executor.spawn_detached(
 639            async move {
 640                tracing::info!("waiting for cleanup timeout");
 641                timeout.await;
 642                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 643                if let Some((room_ids, channel_ids)) = app_state
 644                    .db
 645                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 646                    .await
 647                    .trace_err()
 648                {
 649                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 650                    tracing::info!(
 651                        stale_channel_buffer_count = channel_ids.len(),
 652                        "retrieved stale channel buffers"
 653                    );
 654
 655                    for channel_id in channel_ids {
 656                        if let Some(refreshed_channel_buffer) = app_state
 657                            .db
 658                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 659                            .await
 660                            .trace_err()
 661                        {
 662                            for connection_id in refreshed_channel_buffer.connection_ids {
 663                                peer.send(
 664                                    connection_id,
 665                                    proto::UpdateChannelBufferCollaborators {
 666                                        channel_id: channel_id.to_proto(),
 667                                        collaborators: refreshed_channel_buffer
 668                                            .collaborators
 669                                            .clone(),
 670                                    },
 671                                )
 672                                .trace_err();
 673                            }
 674                        }
 675                    }
 676
 677                    for room_id in room_ids {
 678                        let mut contacts_to_update = HashSet::default();
 679                        let mut canceled_calls_to_user_ids = Vec::new();
 680                        let mut live_kit_room = String::new();
 681                        let mut delete_live_kit_room = false;
 682
 683                        if let Some(mut refreshed_room) = app_state
 684                            .db
 685                            .clear_stale_room_participants(room_id, server_id)
 686                            .await
 687                            .trace_err()
 688                        {
 689                            tracing::info!(
 690                                room_id = room_id.0,
 691                                new_participant_count = refreshed_room.room.participants.len(),
 692                                "refreshed room"
 693                            );
 694                            room_updated(&refreshed_room.room, &peer);
 695                            if let Some(channel) = refreshed_room.channel.as_ref() {
 696                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 697                            }
 698                            contacts_to_update
 699                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 700                            contacts_to_update
 701                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 702                            canceled_calls_to_user_ids =
 703                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 704                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 705                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 706                        }
 707
 708                        {
 709                            let pool = pool.lock();
 710                            for canceled_user_id in canceled_calls_to_user_ids {
 711                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 712                                    peer.send(
 713                                        connection_id,
 714                                        proto::CallCanceled {
 715                                            room_id: room_id.to_proto(),
 716                                        },
 717                                    )
 718                                    .trace_err();
 719                                }
 720                            }
 721                        }
 722
 723                        for user_id in contacts_to_update {
 724                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 725                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 726                            if let Some((busy, contacts)) = busy.zip(contacts) {
 727                                let pool = pool.lock();
 728                                let updated_contact = contact_for_user(user_id, busy, &pool);
 729                                for contact in contacts {
 730                                    if let db::Contact::Accepted {
 731                                        user_id: contact_user_id,
 732                                        ..
 733                                    } = contact
 734                                    {
 735                                        for contact_conn_id in
 736                                            pool.user_connection_ids(contact_user_id)
 737                                        {
 738                                            peer.send(
 739                                                contact_conn_id,
 740                                                proto::UpdateContacts {
 741                                                    contacts: vec![updated_contact.clone()],
 742                                                    remove_contacts: Default::default(),
 743                                                    incoming_requests: Default::default(),
 744                                                    remove_incoming_requests: Default::default(),
 745                                                    outgoing_requests: Default::default(),
 746                                                    remove_outgoing_requests: Default::default(),
 747                                                },
 748                                            )
 749                                            .trace_err();
 750                                        }
 751                                    }
 752                                }
 753                            }
 754                        }
 755
 756                        if let Some(live_kit) = live_kit_client.as_ref() {
 757                            if delete_live_kit_room {
 758                                live_kit.delete_room(live_kit_room).await.trace_err();
 759                            }
 760                        }
 761                    }
 762                }
 763
 764                app_state
 765                    .db
 766                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 767                    .await
 768                    .trace_err();
 769            }
 770            .instrument(span),
 771        );
 772        Ok(())
 773    }
 774
 775    pub fn teardown(&self) {
 776        self.peer.teardown();
 777        self.connection_pool.lock().reset();
 778        let _ = self.teardown.send(true);
 779    }
 780
 781    #[cfg(test)]
 782    pub fn reset(&self, id: ServerId) {
 783        self.teardown();
 784        *self.id.lock() = id;
 785        self.peer.reset(id.0 as u32);
 786        let _ = self.teardown.send(false);
 787    }
 788
 789    #[cfg(test)]
 790    pub fn id(&self) -> ServerId {
 791        *self.id.lock()
 792    }
 793
 794    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 795    where
 796        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 797        Fut: 'static + Send + Future<Output = Result<()>>,
 798        M: EnvelopedMessage,
 799    {
 800        let prev_handler = self.handlers.insert(
 801            TypeId::of::<M>(),
 802            Box::new(move |envelope, session| {
 803                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 804                let received_at = envelope.received_at;
 805                tracing::info!("message received");
 806                let start_time = Instant::now();
 807                let future = (handler)(*envelope, session);
 808                async move {
 809                    let result = future.await;
 810                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 811                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 812                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 813                    let payload_type = M::NAME;
 814
 815                    match result {
 816                        Err(error) => {
 817                            tracing::error!(
 818                                ?error,
 819                                total_duration_ms,
 820                                processing_duration_ms,
 821                                queue_duration_ms,
 822                                payload_type,
 823                                "error handling message"
 824                            )
 825                        }
 826                        Ok(()) => tracing::info!(
 827                            total_duration_ms,
 828                            processing_duration_ms,
 829                            queue_duration_ms,
 830                            "finished handling message"
 831                        ),
 832                    }
 833                }
 834                .boxed()
 835            }),
 836        );
 837        if prev_handler.is_some() {
 838            panic!("registered a handler for the same message twice");
 839        }
 840        self
 841    }
 842
 843    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 844    where
 845        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 846        Fut: 'static + Send + Future<Output = Result<()>>,
 847        M: EnvelopedMessage,
 848    {
 849        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 850        self
 851    }
 852
 853    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 854    where
 855        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 856        Fut: Send + Future<Output = Result<()>>,
 857        M: RequestMessage,
 858    {
 859        let handler = Arc::new(handler);
 860        self.add_handler(move |envelope, session| {
 861            let receipt = envelope.receipt();
 862            let handler = handler.clone();
 863            async move {
 864                let peer = session.peer.clone();
 865                let responded = Arc::new(AtomicBool::default());
 866                let response = Response {
 867                    peer: peer.clone(),
 868                    responded: responded.clone(),
 869                    receipt,
 870                };
 871                match (handler)(envelope.payload, response, session).await {
 872                    Ok(()) => {
 873                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 874                            Ok(())
 875                        } else {
 876                            Err(anyhow!("handler did not send a response"))?
 877                        }
 878                    }
 879                    Err(error) => {
 880                        let proto_err = match &error {
 881                            Error::Internal(err) => err.to_proto(),
 882                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 883                        };
 884                        peer.respond_with_error(receipt, proto_err)?;
 885                        Err(error)
 886                    }
 887                }
 888            }
 889        })
 890    }
 891
 892    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 893    where
 894        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 895        Fut: Send + Future<Output = Result<()>>,
 896        M: RequestMessage,
 897    {
 898        let handler = Arc::new(handler);
 899        self.add_handler(move |envelope, session| {
 900            let receipt = envelope.receipt();
 901            let handler = handler.clone();
 902            async move {
 903                let peer = session.peer.clone();
 904                let response = StreamingResponse {
 905                    peer: peer.clone(),
 906                    receipt,
 907                };
 908                match (handler)(envelope.payload, response, session).await {
 909                    Ok(()) => {
 910                        peer.end_stream(receipt)?;
 911                        Ok(())
 912                    }
 913                    Err(error) => {
 914                        let proto_err = match &error {
 915                            Error::Internal(err) => err.to_proto(),
 916                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 917                        };
 918                        peer.respond_with_error(receipt, proto_err)?;
 919                        Err(error)
 920                    }
 921                }
 922            }
 923        })
 924    }
 925
 926    #[allow(clippy::too_many_arguments)]
 927    pub fn handle_connection(
 928        self: &Arc<Self>,
 929        connection: Connection,
 930        address: String,
 931        principal: Principal,
 932        zed_version: ZedVersion,
 933        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 934        executor: Executor,
 935    ) -> impl Future<Output = ()> {
 936        let this = self.clone();
 937        let span = info_span!("handle connection", %address,
 938            connection_id=field::Empty,
 939            user_id=field::Empty,
 940            login=field::Empty,
 941            impersonator=field::Empty,
 942            dev_server_id=field::Empty
 943        );
 944        principal.update_span(&span);
 945
 946        let mut teardown = self.teardown.subscribe();
 947        async move {
 948            if *teardown.borrow() {
 949                tracing::error!("server is tearing down");
 950                return
 951            }
 952            let (connection_id, handle_io, mut incoming_rx) = this
 953                .peer
 954                .add_connection(connection, {
 955                    let executor = executor.clone();
 956                    move |duration| executor.sleep(duration)
 957                });
 958            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 959            tracing::info!("connection opened");
 960
 961            let http_client = match IsahcHttpClient::new() {
 962                Ok(http_client) => Arc::new(http_client),
 963                Err(error) => {
 964                    tracing::error!(?error, "failed to create HTTP client");
 965                    return;
 966                }
 967            };
 968
 969            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 970                Some(Arc::new(SupermavenAdminApi::new(
 971                    supermaven_admin_api_key.to_string(),
 972                    http_client.clone(),
 973                )))
 974            } else {
 975                None
 976            };
 977
 978            let session = Session {
 979                principal: principal.clone(),
 980                connection_id,
 981                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 982                peer: this.peer.clone(),
 983                connection_pool: this.connection_pool.clone(),
 984                live_kit_client: this.app_state.live_kit_client.clone(),
 985                http_client,
 986                rate_limiter: this.app_state.rate_limiter.clone(),
 987                _executor: executor.clone(),
 988                supermaven_client,
 989            };
 990
 991            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 992                tracing::error!(?error, "failed to send initial client update");
 993                return;
 994            }
 995
 996            let handle_io = handle_io.fuse();
 997            futures::pin_mut!(handle_io);
 998
 999            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1000            // This prevents deadlocks when e.g., client A performs a request to client B and
1001            // client B performs a request to client A. If both clients stop processing further
1002            // messages until their respective request completes, they won't have a chance to
1003            // respond to the other client's request and cause a deadlock.
1004            //
1005            // This arrangement ensures we will attempt to process earlier messages first, but fall
1006            // back to processing messages arrived later in the spirit of making progress.
1007            let mut foreground_message_handlers = FuturesUnordered::new();
1008            let concurrent_handlers = Arc::new(Semaphore::new(256));
1009            loop {
1010                let next_message = async {
1011                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1012                    let message = incoming_rx.next().await;
1013                    (permit, message)
1014                }.fuse();
1015                futures::pin_mut!(next_message);
1016                futures::select_biased! {
1017                    _ = teardown.changed().fuse() => return,
1018                    result = handle_io => {
1019                        if let Err(error) = result {
1020                            tracing::error!(?error, "error handling I/O");
1021                        }
1022                        break;
1023                    }
1024                    _ = foreground_message_handlers.next() => {}
1025                    next_message = next_message => {
1026                        let (permit, message) = next_message;
1027                        if let Some(message) = message {
1028                            let type_name = message.payload_type_name();
1029                            // note: we copy all the fields from the parent span so we can query them in the logs.
1030                            // (https://github.com/tokio-rs/tracing/issues/2670).
1031                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1032                                user_id=field::Empty,
1033                                login=field::Empty,
1034                                impersonator=field::Empty,
1035                                dev_server_id=field::Empty
1036                            );
1037                            principal.update_span(&span);
1038                            let span_enter = span.enter();
1039                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1040                                let is_background = message.is_background();
1041                                let handle_message = (handler)(message, session.clone());
1042                                drop(span_enter);
1043
1044                                let handle_message = async move {
1045                                    handle_message.await;
1046                                    drop(permit);
1047                                }.instrument(span);
1048                                if is_background {
1049                                    executor.spawn_detached(handle_message);
1050                                } else {
1051                                    foreground_message_handlers.push(handle_message);
1052                                }
1053                            } else {
1054                                tracing::error!("no message handler");
1055                            }
1056                        } else {
1057                            tracing::info!("connection closed");
1058                            break;
1059                        }
1060                    }
1061                }
1062            }
1063
1064            drop(foreground_message_handlers);
1065            tracing::info!("signing out");
1066            if let Err(error) = connection_lost(session, teardown, executor).await {
1067                tracing::error!(?error, "error signing out");
1068            }
1069
1070        }.instrument(span)
1071    }
1072
1073    async fn send_initial_client_update(
1074        &self,
1075        connection_id: ConnectionId,
1076        principal: &Principal,
1077        zed_version: ZedVersion,
1078        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1079        session: &Session,
1080    ) -> Result<()> {
1081        self.peer.send(
1082            connection_id,
1083            proto::Hello {
1084                peer_id: Some(connection_id.into()),
1085            },
1086        )?;
1087        tracing::info!("sent hello message");
1088        if let Some(send_connection_id) = send_connection_id.take() {
1089            let _ = send_connection_id.send(connection_id);
1090        }
1091
1092        match principal {
1093            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1094                if !user.connected_once {
1095                    self.peer.send(connection_id, proto::ShowContacts {})?;
1096                    self.app_state
1097                        .db
1098                        .set_user_connected_once(user.id, true)
1099                        .await?;
1100                }
1101
1102                let (contacts, channels_for_user, channel_invites, dev_server_projects) =
1103                    future::try_join4(
1104                        self.app_state.db.get_contacts(user.id),
1105                        self.app_state.db.get_channels_for_user(user.id),
1106                        self.app_state.db.get_channel_invites_for_user(user.id),
1107                        self.app_state.db.dev_server_projects_update(user.id),
1108                    )
1109                    .await?;
1110
1111                {
1112                    let mut pool = self.connection_pool.lock();
1113                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1114                    for membership in &channels_for_user.channel_memberships {
1115                        pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1116                    }
1117                    self.peer.send(
1118                        connection_id,
1119                        build_initial_contacts_update(contacts, &pool),
1120                    )?;
1121                    self.peer.send(
1122                        connection_id,
1123                        build_update_user_channels(&channels_for_user),
1124                    )?;
1125                    self.peer.send(
1126                        connection_id,
1127                        build_channels_update(channels_for_user, channel_invites),
1128                    )?;
1129                }
1130                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1131
1132                if let Some(incoming_call) =
1133                    self.app_state.db.incoming_call_for_user(user.id).await?
1134                {
1135                    self.peer.send(connection_id, incoming_call)?;
1136                }
1137
1138                update_user_contacts(user.id, &session).await?;
1139            }
1140            Principal::DevServer(dev_server) => {
1141                {
1142                    let mut pool = self.connection_pool.lock();
1143                    if pool.dev_server_connection_id(dev_server.id).is_some() {
1144                        return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1145                    };
1146                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1147                }
1148
1149                let projects = self
1150                    .app_state
1151                    .db
1152                    .get_projects_for_dev_server(dev_server.id)
1153                    .await?;
1154                self.peer
1155                    .send(connection_id, proto::DevServerInstructions { projects })?;
1156
1157                let status = self
1158                    .app_state
1159                    .db
1160                    .dev_server_projects_update(dev_server.user_id)
1161                    .await?;
1162                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1163            }
1164        }
1165
1166        Ok(())
1167    }
1168
1169    pub async fn invite_code_redeemed(
1170        self: &Arc<Self>,
1171        inviter_id: UserId,
1172        invitee_id: UserId,
1173    ) -> Result<()> {
1174        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1175            if let Some(code) = &user.invite_code {
1176                let pool = self.connection_pool.lock();
1177                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1178                for connection_id in pool.user_connection_ids(inviter_id) {
1179                    self.peer.send(
1180                        connection_id,
1181                        proto::UpdateContacts {
1182                            contacts: vec![invitee_contact.clone()],
1183                            ..Default::default()
1184                        },
1185                    )?;
1186                    self.peer.send(
1187                        connection_id,
1188                        proto::UpdateInviteInfo {
1189                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1190                            count: user.invite_count as u32,
1191                        },
1192                    )?;
1193                }
1194            }
1195        }
1196        Ok(())
1197    }
1198
1199    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1200        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1201            if let Some(invite_code) = &user.invite_code {
1202                let pool = self.connection_pool.lock();
1203                for connection_id in pool.user_connection_ids(user_id) {
1204                    self.peer.send(
1205                        connection_id,
1206                        proto::UpdateInviteInfo {
1207                            url: format!(
1208                                "{}{}",
1209                                self.app_state.config.invite_link_prefix, invite_code
1210                            ),
1211                            count: user.invite_count as u32,
1212                        },
1213                    )?;
1214                }
1215            }
1216        }
1217        Ok(())
1218    }
1219
1220    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1221        ServerSnapshot {
1222            connection_pool: ConnectionPoolGuard {
1223                guard: self.connection_pool.lock(),
1224                _not_send: PhantomData,
1225            },
1226            peer: &self.peer,
1227        }
1228    }
1229}
1230
1231impl<'a> Deref for ConnectionPoolGuard<'a> {
1232    type Target = ConnectionPool;
1233
1234    fn deref(&self) -> &Self::Target {
1235        &self.guard
1236    }
1237}
1238
1239impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1240    fn deref_mut(&mut self) -> &mut Self::Target {
1241        &mut self.guard
1242    }
1243}
1244
1245impl<'a> Drop for ConnectionPoolGuard<'a> {
1246    fn drop(&mut self) {
1247        #[cfg(test)]
1248        self.check_invariants();
1249    }
1250}
1251
1252fn broadcast<F>(
1253    sender_id: Option<ConnectionId>,
1254    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1255    mut f: F,
1256) where
1257    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1258{
1259    for receiver_id in receiver_ids {
1260        if Some(receiver_id) != sender_id {
1261            if let Err(error) = f(receiver_id) {
1262                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1263            }
1264        }
1265    }
1266}
1267
1268pub struct ProtocolVersion(u32);
1269
1270impl Header for ProtocolVersion {
1271    fn name() -> &'static HeaderName {
1272        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1273        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1274    }
1275
1276    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1277    where
1278        Self: Sized,
1279        I: Iterator<Item = &'i axum::http::HeaderValue>,
1280    {
1281        let version = values
1282            .next()
1283            .ok_or_else(axum::headers::Error::invalid)?
1284            .to_str()
1285            .map_err(|_| axum::headers::Error::invalid())?
1286            .parse()
1287            .map_err(|_| axum::headers::Error::invalid())?;
1288        Ok(Self(version))
1289    }
1290
1291    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1292        values.extend([self.0.to_string().parse().unwrap()]);
1293    }
1294}
1295
1296pub struct AppVersionHeader(SemanticVersion);
1297impl Header for AppVersionHeader {
1298    fn name() -> &'static HeaderName {
1299        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1300        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1301    }
1302
1303    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1304    where
1305        Self: Sized,
1306        I: Iterator<Item = &'i axum::http::HeaderValue>,
1307    {
1308        let version = values
1309            .next()
1310            .ok_or_else(axum::headers::Error::invalid)?
1311            .to_str()
1312            .map_err(|_| axum::headers::Error::invalid())?
1313            .parse()
1314            .map_err(|_| axum::headers::Error::invalid())?;
1315        Ok(Self(version))
1316    }
1317
1318    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1319        values.extend([self.0.to_string().parse().unwrap()]);
1320    }
1321}
1322
1323pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1324    Router::new()
1325        .route("/rpc", get(handle_websocket_request))
1326        .layer(
1327            ServiceBuilder::new()
1328                .layer(Extension(server.app_state.clone()))
1329                .layer(middleware::from_fn(auth::validate_header)),
1330        )
1331        .route("/metrics", get(handle_metrics))
1332        .layer(Extension(server))
1333}
1334
1335pub async fn handle_websocket_request(
1336    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1337    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1338    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1339    Extension(server): Extension<Arc<Server>>,
1340    Extension(principal): Extension<Principal>,
1341    ws: WebSocketUpgrade,
1342) -> axum::response::Response {
1343    if protocol_version != rpc::PROTOCOL_VERSION {
1344        return (
1345            StatusCode::UPGRADE_REQUIRED,
1346            "client must be upgraded".to_string(),
1347        )
1348            .into_response();
1349    }
1350
1351    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1352        return (
1353            StatusCode::UPGRADE_REQUIRED,
1354            "no version header found".to_string(),
1355        )
1356            .into_response();
1357    };
1358
1359    if !version.can_collaborate() {
1360        return (
1361            StatusCode::UPGRADE_REQUIRED,
1362            "client must be upgraded".to_string(),
1363        )
1364            .into_response();
1365    }
1366
1367    let socket_address = socket_address.to_string();
1368    ws.on_upgrade(move |socket| {
1369        let socket = socket
1370            .map_ok(to_tungstenite_message)
1371            .err_into()
1372            .with(|message| async move { Ok(to_axum_message(message)) });
1373        let connection = Connection::new(Box::pin(socket));
1374        async move {
1375            server
1376                .handle_connection(
1377                    connection,
1378                    socket_address,
1379                    principal,
1380                    version,
1381                    None,
1382                    Executor::Production,
1383                )
1384                .await;
1385        }
1386    })
1387}
1388
1389pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1390    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1391    let connections_metric = CONNECTIONS_METRIC
1392        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1393
1394    let connections = server
1395        .connection_pool
1396        .lock()
1397        .connections()
1398        .filter(|connection| !connection.admin)
1399        .count();
1400    connections_metric.set(connections as _);
1401
1402    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1404        register_int_gauge!(
1405            "shared_projects",
1406            "number of open projects with one or more guests"
1407        )
1408        .unwrap()
1409    });
1410
1411    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1412    shared_projects_metric.set(shared_projects as _);
1413
1414    let encoder = prometheus::TextEncoder::new();
1415    let metric_families = prometheus::gather();
1416    let encoded_metrics = encoder
1417        .encode_to_string(&metric_families)
1418        .map_err(|err| anyhow!("{}", err))?;
1419    Ok(encoded_metrics)
1420}
1421
1422#[instrument(err, skip(executor))]
1423async fn connection_lost(
1424    session: Session,
1425    mut teardown: watch::Receiver<bool>,
1426    executor: Executor,
1427) -> Result<()> {
1428    session.peer.disconnect(session.connection_id);
1429    session
1430        .connection_pool()
1431        .await
1432        .remove_connection(session.connection_id)?;
1433
1434    session
1435        .db()
1436        .await
1437        .connection_lost(session.connection_id)
1438        .await
1439        .trace_err();
1440
1441    futures::select_biased! {
1442        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1443            match &session.principal {
1444                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1445                    let session = session.for_user().unwrap();
1446
1447                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1448                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1449                    leave_channel_buffers_for_session(&session)
1450                        .await
1451                        .trace_err();
1452
1453                    if !session
1454                        .connection_pool()
1455                        .await
1456                        .is_user_online(session.user_id())
1457                    {
1458                        let db = session.db().await;
1459                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1460                            room_updated(&room, &session.peer);
1461                        }
1462                    }
1463
1464                    update_user_contacts(session.user_id(), &session).await?;
1465                },
1466            Principal::DevServer(_) => {
1467                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1468            },
1469        }
1470        },
1471        _ = teardown.changed().fuse() => {}
1472    }
1473
1474    Ok(())
1475}
1476
1477/// Acknowledges a ping from a client, used to keep the connection alive.
1478async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1479    response.send(proto::Ack {})?;
1480    Ok(())
1481}
1482
1483/// Creates a new room for calling (outside of channels)
1484async fn create_room(
1485    _request: proto::CreateRoom,
1486    response: Response<proto::CreateRoom>,
1487    session: UserSession,
1488) -> Result<()> {
1489    let live_kit_room = nanoid::nanoid!(30);
1490
1491    let live_kit_connection_info = util::maybe!(async {
1492        let live_kit = session.live_kit_client.as_ref();
1493        let live_kit = live_kit?;
1494        let user_id = session.user_id().to_string();
1495
1496        let token = live_kit
1497            .room_token(&live_kit_room, &user_id.to_string())
1498            .trace_err()?;
1499
1500        Some(proto::LiveKitConnectionInfo {
1501            server_url: live_kit.url().into(),
1502            token,
1503            can_publish: true,
1504        })
1505    })
1506    .await;
1507
1508    let room = session
1509        .db()
1510        .await
1511        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1512        .await?;
1513
1514    response.send(proto::CreateRoomResponse {
1515        room: Some(room.clone()),
1516        live_kit_connection_info,
1517    })?;
1518
1519    update_user_contacts(session.user_id(), &session).await?;
1520    Ok(())
1521}
1522
1523/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1524async fn join_room(
1525    request: proto::JoinRoom,
1526    response: Response<proto::JoinRoom>,
1527    session: UserSession,
1528) -> Result<()> {
1529    let room_id = RoomId::from_proto(request.id);
1530
1531    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1532
1533    if let Some(channel_id) = channel_id {
1534        return join_channel_internal(channel_id, Box::new(response), session).await;
1535    }
1536
1537    let joined_room = {
1538        let room = session
1539            .db()
1540            .await
1541            .join_room(room_id, session.user_id(), session.connection_id)
1542            .await?;
1543        room_updated(&room.room, &session.peer);
1544        room.into_inner()
1545    };
1546
1547    for connection_id in session
1548        .connection_pool()
1549        .await
1550        .user_connection_ids(session.user_id())
1551    {
1552        session
1553            .peer
1554            .send(
1555                connection_id,
1556                proto::CallCanceled {
1557                    room_id: room_id.to_proto(),
1558                },
1559            )
1560            .trace_err();
1561    }
1562
1563    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1564        if let Some(token) = live_kit
1565            .room_token(
1566                &joined_room.room.live_kit_room,
1567                &session.user_id().to_string(),
1568            )
1569            .trace_err()
1570        {
1571            Some(proto::LiveKitConnectionInfo {
1572                server_url: live_kit.url().into(),
1573                token,
1574                can_publish: true,
1575            })
1576        } else {
1577            None
1578        }
1579    } else {
1580        None
1581    };
1582
1583    response.send(proto::JoinRoomResponse {
1584        room: Some(joined_room.room),
1585        channel_id: None,
1586        live_kit_connection_info,
1587    })?;
1588
1589    update_user_contacts(session.user_id(), &session).await?;
1590    Ok(())
1591}
1592
1593/// Rejoin room is used to reconnect to a room after connection errors.
1594async fn rejoin_room(
1595    request: proto::RejoinRoom,
1596    response: Response<proto::RejoinRoom>,
1597    session: UserSession,
1598) -> Result<()> {
1599    let room;
1600    let channel;
1601    {
1602        let mut rejoined_room = session
1603            .db()
1604            .await
1605            .rejoin_room(request, session.user_id(), session.connection_id)
1606            .await?;
1607
1608        response.send(proto::RejoinRoomResponse {
1609            room: Some(rejoined_room.room.clone()),
1610            reshared_projects: rejoined_room
1611                .reshared_projects
1612                .iter()
1613                .map(|project| proto::ResharedProject {
1614                    id: project.id.to_proto(),
1615                    collaborators: project
1616                        .collaborators
1617                        .iter()
1618                        .map(|collaborator| collaborator.to_proto())
1619                        .collect(),
1620                })
1621                .collect(),
1622            rejoined_projects: rejoined_room
1623                .rejoined_projects
1624                .iter()
1625                .map(|rejoined_project| rejoined_project.to_proto())
1626                .collect(),
1627        })?;
1628        room_updated(&rejoined_room.room, &session.peer);
1629
1630        for project in &rejoined_room.reshared_projects {
1631            for collaborator in &project.collaborators {
1632                session
1633                    .peer
1634                    .send(
1635                        collaborator.connection_id,
1636                        proto::UpdateProjectCollaborator {
1637                            project_id: project.id.to_proto(),
1638                            old_peer_id: Some(project.old_connection_id.into()),
1639                            new_peer_id: Some(session.connection_id.into()),
1640                        },
1641                    )
1642                    .trace_err();
1643            }
1644
1645            broadcast(
1646                Some(session.connection_id),
1647                project
1648                    .collaborators
1649                    .iter()
1650                    .map(|collaborator| collaborator.connection_id),
1651                |connection_id| {
1652                    session.peer.forward_send(
1653                        session.connection_id,
1654                        connection_id,
1655                        proto::UpdateProject {
1656                            project_id: project.id.to_proto(),
1657                            worktrees: project.worktrees.clone(),
1658                        },
1659                    )
1660                },
1661            );
1662        }
1663
1664        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1665
1666        let rejoined_room = rejoined_room.into_inner();
1667
1668        room = rejoined_room.room;
1669        channel = rejoined_room.channel;
1670    }
1671
1672    if let Some(channel) = channel {
1673        channel_updated(
1674            &channel,
1675            &room,
1676            &session.peer,
1677            &*session.connection_pool().await,
1678        );
1679    }
1680
1681    update_user_contacts(session.user_id(), &session).await?;
1682    Ok(())
1683}
1684
1685fn notify_rejoined_projects(
1686    rejoined_projects: &mut Vec<RejoinedProject>,
1687    session: &UserSession,
1688) -> Result<()> {
1689    for project in rejoined_projects.iter() {
1690        for collaborator in &project.collaborators {
1691            session
1692                .peer
1693                .send(
1694                    collaborator.connection_id,
1695                    proto::UpdateProjectCollaborator {
1696                        project_id: project.id.to_proto(),
1697                        old_peer_id: Some(project.old_connection_id.into()),
1698                        new_peer_id: Some(session.connection_id.into()),
1699                    },
1700                )
1701                .trace_err();
1702        }
1703    }
1704
1705    for project in rejoined_projects {
1706        for worktree in mem::take(&mut project.worktrees) {
1707            #[cfg(any(test, feature = "test-support"))]
1708            const MAX_CHUNK_SIZE: usize = 2;
1709            #[cfg(not(any(test, feature = "test-support")))]
1710            const MAX_CHUNK_SIZE: usize = 256;
1711
1712            // Stream this worktree's entries.
1713            let message = proto::UpdateWorktree {
1714                project_id: project.id.to_proto(),
1715                worktree_id: worktree.id,
1716                abs_path: worktree.abs_path.clone(),
1717                root_name: worktree.root_name,
1718                updated_entries: worktree.updated_entries,
1719                removed_entries: worktree.removed_entries,
1720                scan_id: worktree.scan_id,
1721                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1722                updated_repositories: worktree.updated_repositories,
1723                removed_repositories: worktree.removed_repositories,
1724            };
1725            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1726                session.peer.send(session.connection_id, update.clone())?;
1727            }
1728
1729            // Stream this worktree's diagnostics.
1730            for summary in worktree.diagnostic_summaries {
1731                session.peer.send(
1732                    session.connection_id,
1733                    proto::UpdateDiagnosticSummary {
1734                        project_id: project.id.to_proto(),
1735                        worktree_id: worktree.id,
1736                        summary: Some(summary),
1737                    },
1738                )?;
1739            }
1740
1741            for settings_file in worktree.settings_files {
1742                session.peer.send(
1743                    session.connection_id,
1744                    proto::UpdateWorktreeSettings {
1745                        project_id: project.id.to_proto(),
1746                        worktree_id: worktree.id,
1747                        path: settings_file.path,
1748                        content: Some(settings_file.content),
1749                    },
1750                )?;
1751            }
1752        }
1753
1754        for language_server in &project.language_servers {
1755            session.peer.send(
1756                session.connection_id,
1757                proto::UpdateLanguageServer {
1758                    project_id: project.id.to_proto(),
1759                    language_server_id: language_server.id,
1760                    variant: Some(
1761                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1762                            proto::LspDiskBasedDiagnosticsUpdated {},
1763                        ),
1764                    ),
1765                },
1766            )?;
1767        }
1768    }
1769    Ok(())
1770}
1771
1772/// leave room disconnects from the room.
1773async fn leave_room(
1774    _: proto::LeaveRoom,
1775    response: Response<proto::LeaveRoom>,
1776    session: UserSession,
1777) -> Result<()> {
1778    leave_room_for_session(&session, session.connection_id).await?;
1779    response.send(proto::Ack {})?;
1780    Ok(())
1781}
1782
1783/// Updates the permissions of someone else in the room.
1784async fn set_room_participant_role(
1785    request: proto::SetRoomParticipantRole,
1786    response: Response<proto::SetRoomParticipantRole>,
1787    session: UserSession,
1788) -> Result<()> {
1789    let user_id = UserId::from_proto(request.user_id);
1790    let role = ChannelRole::from(request.role());
1791
1792    let (live_kit_room, can_publish) = {
1793        let room = session
1794            .db()
1795            .await
1796            .set_room_participant_role(
1797                session.user_id(),
1798                RoomId::from_proto(request.room_id),
1799                user_id,
1800                role,
1801            )
1802            .await?;
1803
1804        let live_kit_room = room.live_kit_room.clone();
1805        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1806        room_updated(&room, &session.peer);
1807        (live_kit_room, can_publish)
1808    };
1809
1810    if let Some(live_kit) = session.live_kit_client.as_ref() {
1811        live_kit
1812            .update_participant(
1813                live_kit_room.clone(),
1814                request.user_id.to_string(),
1815                live_kit_server::proto::ParticipantPermission {
1816                    can_subscribe: true,
1817                    can_publish,
1818                    can_publish_data: can_publish,
1819                    hidden: false,
1820                    recorder: false,
1821                },
1822            )
1823            .await
1824            .trace_err();
1825    }
1826
1827    response.send(proto::Ack {})?;
1828    Ok(())
1829}
1830
1831/// Call someone else into the current room
1832async fn call(
1833    request: proto::Call,
1834    response: Response<proto::Call>,
1835    session: UserSession,
1836) -> Result<()> {
1837    let room_id = RoomId::from_proto(request.room_id);
1838    let calling_user_id = session.user_id();
1839    let calling_connection_id = session.connection_id;
1840    let called_user_id = UserId::from_proto(request.called_user_id);
1841    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1842    if !session
1843        .db()
1844        .await
1845        .has_contact(calling_user_id, called_user_id)
1846        .await?
1847    {
1848        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1849    }
1850
1851    let incoming_call = {
1852        let (room, incoming_call) = &mut *session
1853            .db()
1854            .await
1855            .call(
1856                room_id,
1857                calling_user_id,
1858                calling_connection_id,
1859                called_user_id,
1860                initial_project_id,
1861            )
1862            .await?;
1863        room_updated(&room, &session.peer);
1864        mem::take(incoming_call)
1865    };
1866    update_user_contacts(called_user_id, &session).await?;
1867
1868    let mut calls = session
1869        .connection_pool()
1870        .await
1871        .user_connection_ids(called_user_id)
1872        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1873        .collect::<FuturesUnordered<_>>();
1874
1875    while let Some(call_response) = calls.next().await {
1876        match call_response.as_ref() {
1877            Ok(_) => {
1878                response.send(proto::Ack {})?;
1879                return Ok(());
1880            }
1881            Err(_) => {
1882                call_response.trace_err();
1883            }
1884        }
1885    }
1886
1887    {
1888        let room = session
1889            .db()
1890            .await
1891            .call_failed(room_id, called_user_id)
1892            .await?;
1893        room_updated(&room, &session.peer);
1894    }
1895    update_user_contacts(called_user_id, &session).await?;
1896
1897    Err(anyhow!("failed to ring user"))?
1898}
1899
1900/// Cancel an outgoing call.
1901async fn cancel_call(
1902    request: proto::CancelCall,
1903    response: Response<proto::CancelCall>,
1904    session: UserSession,
1905) -> Result<()> {
1906    let called_user_id = UserId::from_proto(request.called_user_id);
1907    let room_id = RoomId::from_proto(request.room_id);
1908    {
1909        let room = session
1910            .db()
1911            .await
1912            .cancel_call(room_id, session.connection_id, called_user_id)
1913            .await?;
1914        room_updated(&room, &session.peer);
1915    }
1916
1917    for connection_id in session
1918        .connection_pool()
1919        .await
1920        .user_connection_ids(called_user_id)
1921    {
1922        session
1923            .peer
1924            .send(
1925                connection_id,
1926                proto::CallCanceled {
1927                    room_id: room_id.to_proto(),
1928                },
1929            )
1930            .trace_err();
1931    }
1932    response.send(proto::Ack {})?;
1933
1934    update_user_contacts(called_user_id, &session).await?;
1935    Ok(())
1936}
1937
1938/// Decline an incoming call.
1939async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1940    let room_id = RoomId::from_proto(message.room_id);
1941    {
1942        let room = session
1943            .db()
1944            .await
1945            .decline_call(Some(room_id), session.user_id())
1946            .await?
1947            .ok_or_else(|| anyhow!("failed to decline call"))?;
1948        room_updated(&room, &session.peer);
1949    }
1950
1951    for connection_id in session
1952        .connection_pool()
1953        .await
1954        .user_connection_ids(session.user_id())
1955    {
1956        session
1957            .peer
1958            .send(
1959                connection_id,
1960                proto::CallCanceled {
1961                    room_id: room_id.to_proto(),
1962                },
1963            )
1964            .trace_err();
1965    }
1966    update_user_contacts(session.user_id(), &session).await?;
1967    Ok(())
1968}
1969
1970/// Updates other participants in the room with your current location.
1971async fn update_participant_location(
1972    request: proto::UpdateParticipantLocation,
1973    response: Response<proto::UpdateParticipantLocation>,
1974    session: UserSession,
1975) -> Result<()> {
1976    let room_id = RoomId::from_proto(request.room_id);
1977    let location = request
1978        .location
1979        .ok_or_else(|| anyhow!("invalid location"))?;
1980
1981    let db = session.db().await;
1982    let room = db
1983        .update_room_participant_location(room_id, session.connection_id, location)
1984        .await?;
1985
1986    room_updated(&room, &session.peer);
1987    response.send(proto::Ack {})?;
1988    Ok(())
1989}
1990
1991/// Share a project into the room.
1992async fn share_project(
1993    request: proto::ShareProject,
1994    response: Response<proto::ShareProject>,
1995    session: UserSession,
1996) -> Result<()> {
1997    let (project_id, room) = &*session
1998        .db()
1999        .await
2000        .share_project(
2001            RoomId::from_proto(request.room_id),
2002            session.connection_id,
2003            &request.worktrees,
2004            request
2005                .dev_server_project_id
2006                .map(|id| DevServerProjectId::from_proto(id)),
2007        )
2008        .await?;
2009    response.send(proto::ShareProjectResponse {
2010        project_id: project_id.to_proto(),
2011    })?;
2012    room_updated(&room, &session.peer);
2013
2014    Ok(())
2015}
2016
2017/// Unshare a project from the room.
2018async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2019    let project_id = ProjectId::from_proto(message.project_id);
2020    unshare_project_internal(
2021        project_id,
2022        session.connection_id,
2023        session.user_id(),
2024        &session,
2025    )
2026    .await
2027}
2028
2029async fn unshare_project_internal(
2030    project_id: ProjectId,
2031    connection_id: ConnectionId,
2032    user_id: Option<UserId>,
2033    session: &Session,
2034) -> Result<()> {
2035    let delete = {
2036        let room_guard = session
2037            .db()
2038            .await
2039            .unshare_project(project_id, connection_id, user_id)
2040            .await?;
2041
2042        let (delete, room, guest_connection_ids) = &*room_guard;
2043
2044        let message = proto::UnshareProject {
2045            project_id: project_id.to_proto(),
2046        };
2047
2048        broadcast(
2049            Some(connection_id),
2050            guest_connection_ids.iter().copied(),
2051            |conn_id| session.peer.send(conn_id, message.clone()),
2052        );
2053        if let Some(room) = room {
2054            room_updated(room, &session.peer);
2055        }
2056
2057        *delete
2058    };
2059
2060    if delete {
2061        let db = session.db().await;
2062        db.delete_project(project_id).await?;
2063    }
2064
2065    Ok(())
2066}
2067
2068/// DevServer makes a project available online
2069async fn share_dev_server_project(
2070    request: proto::ShareDevServerProject,
2071    response: Response<proto::ShareDevServerProject>,
2072    session: DevServerSession,
2073) -> Result<()> {
2074    let (dev_server_project, user_id, status) = session
2075        .db()
2076        .await
2077        .share_dev_server_project(
2078            DevServerProjectId::from_proto(request.dev_server_project_id),
2079            session.dev_server_id(),
2080            session.connection_id,
2081            &request.worktrees,
2082        )
2083        .await?;
2084    let Some(project_id) = dev_server_project.project_id else {
2085        return Err(anyhow!("failed to share remote project"))?;
2086    };
2087
2088    send_dev_server_projects_update(user_id, status, &session).await;
2089
2090    response.send(proto::ShareProjectResponse { project_id })?;
2091
2092    Ok(())
2093}
2094
2095/// Join someone elses shared project.
2096async fn join_project(
2097    request: proto::JoinProject,
2098    response: Response<proto::JoinProject>,
2099    session: UserSession,
2100) -> Result<()> {
2101    let project_id = ProjectId::from_proto(request.project_id);
2102
2103    tracing::info!(%project_id, "join project");
2104
2105    let db = session.db().await;
2106    let (project, replica_id) = &mut *db
2107        .join_project(project_id, session.connection_id, session.user_id())
2108        .await?;
2109    drop(db);
2110    tracing::info!(%project_id, "join remote project");
2111    join_project_internal(response, session, project, replica_id)
2112}
2113
2114trait JoinProjectInternalResponse {
2115    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2116}
2117impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2118    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2119        Response::<proto::JoinProject>::send(self, result)
2120    }
2121}
2122impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2123    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2124        Response::<proto::JoinHostedProject>::send(self, result)
2125    }
2126}
2127
2128fn join_project_internal(
2129    response: impl JoinProjectInternalResponse,
2130    session: UserSession,
2131    project: &mut Project,
2132    replica_id: &ReplicaId,
2133) -> Result<()> {
2134    let collaborators = project
2135        .collaborators
2136        .iter()
2137        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2138        .map(|collaborator| collaborator.to_proto())
2139        .collect::<Vec<_>>();
2140    let project_id = project.id;
2141    let guest_user_id = session.user_id();
2142
2143    let worktrees = project
2144        .worktrees
2145        .iter()
2146        .map(|(id, worktree)| proto::WorktreeMetadata {
2147            id: *id,
2148            root_name: worktree.root_name.clone(),
2149            visible: worktree.visible,
2150            abs_path: worktree.abs_path.clone(),
2151        })
2152        .collect::<Vec<_>>();
2153
2154    let add_project_collaborator = proto::AddProjectCollaborator {
2155        project_id: project_id.to_proto(),
2156        collaborator: Some(proto::Collaborator {
2157            peer_id: Some(session.connection_id.into()),
2158            replica_id: replica_id.0 as u32,
2159            user_id: guest_user_id.to_proto(),
2160        }),
2161    };
2162
2163    for collaborator in &collaborators {
2164        session
2165            .peer
2166            .send(
2167                collaborator.peer_id.unwrap().into(),
2168                add_project_collaborator.clone(),
2169            )
2170            .trace_err();
2171    }
2172
2173    // First, we send the metadata associated with each worktree.
2174    response.send(proto::JoinProjectResponse {
2175        project_id: project.id.0 as u64,
2176        worktrees: worktrees.clone(),
2177        replica_id: replica_id.0 as u32,
2178        collaborators: collaborators.clone(),
2179        language_servers: project.language_servers.clone(),
2180        role: project.role.into(),
2181        dev_server_project_id: project
2182            .dev_server_project_id
2183            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2184    })?;
2185
2186    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2187        #[cfg(any(test, feature = "test-support"))]
2188        const MAX_CHUNK_SIZE: usize = 2;
2189        #[cfg(not(any(test, feature = "test-support")))]
2190        const MAX_CHUNK_SIZE: usize = 256;
2191
2192        // Stream this worktree's entries.
2193        let message = proto::UpdateWorktree {
2194            project_id: project_id.to_proto(),
2195            worktree_id,
2196            abs_path: worktree.abs_path.clone(),
2197            root_name: worktree.root_name,
2198            updated_entries: worktree.entries,
2199            removed_entries: Default::default(),
2200            scan_id: worktree.scan_id,
2201            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2202            updated_repositories: worktree.repository_entries.into_values().collect(),
2203            removed_repositories: Default::default(),
2204        };
2205        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2206            session.peer.send(session.connection_id, update.clone())?;
2207        }
2208
2209        // Stream this worktree's diagnostics.
2210        for summary in worktree.diagnostic_summaries {
2211            session.peer.send(
2212                session.connection_id,
2213                proto::UpdateDiagnosticSummary {
2214                    project_id: project_id.to_proto(),
2215                    worktree_id: worktree.id,
2216                    summary: Some(summary),
2217                },
2218            )?;
2219        }
2220
2221        for settings_file in worktree.settings_files {
2222            session.peer.send(
2223                session.connection_id,
2224                proto::UpdateWorktreeSettings {
2225                    project_id: project_id.to_proto(),
2226                    worktree_id: worktree.id,
2227                    path: settings_file.path,
2228                    content: Some(settings_file.content),
2229                },
2230            )?;
2231        }
2232    }
2233
2234    for language_server in &project.language_servers {
2235        session.peer.send(
2236            session.connection_id,
2237            proto::UpdateLanguageServer {
2238                project_id: project_id.to_proto(),
2239                language_server_id: language_server.id,
2240                variant: Some(
2241                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2242                        proto::LspDiskBasedDiagnosticsUpdated {},
2243                    ),
2244                ),
2245            },
2246        )?;
2247    }
2248
2249    Ok(())
2250}
2251
2252/// Leave someone elses shared project.
2253async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2254    let sender_id = session.connection_id;
2255    let project_id = ProjectId::from_proto(request.project_id);
2256    let db = session.db().await;
2257    if db.is_hosted_project(project_id).await? {
2258        let project = db.leave_hosted_project(project_id, sender_id).await?;
2259        project_left(&project, &session);
2260        return Ok(());
2261    }
2262
2263    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2264    tracing::info!(
2265        %project_id,
2266        "leave project"
2267    );
2268
2269    project_left(&project, &session);
2270    if let Some(room) = room {
2271        room_updated(&room, &session.peer);
2272    }
2273
2274    Ok(())
2275}
2276
2277async fn join_hosted_project(
2278    request: proto::JoinHostedProject,
2279    response: Response<proto::JoinHostedProject>,
2280    session: UserSession,
2281) -> Result<()> {
2282    let (mut project, replica_id) = session
2283        .db()
2284        .await
2285        .join_hosted_project(
2286            ProjectId(request.project_id as i32),
2287            session.user_id(),
2288            session.connection_id,
2289        )
2290        .await?;
2291
2292    join_project_internal(response, session, &mut project, &replica_id)
2293}
2294
2295async fn create_dev_server_project(
2296    request: proto::CreateDevServerProject,
2297    response: Response<proto::CreateDevServerProject>,
2298    session: UserSession,
2299) -> Result<()> {
2300    let dev_server_id = DevServerId(request.dev_server_id as i32);
2301    let dev_server_connection_id = session
2302        .connection_pool()
2303        .await
2304        .dev_server_connection_id(dev_server_id);
2305    let Some(dev_server_connection_id) = dev_server_connection_id else {
2306        Err(ErrorCode::DevServerOffline
2307            .message("Cannot create a remote project when the dev server is offline".to_string())
2308            .anyhow())?
2309    };
2310
2311    let path = request.path.clone();
2312    //Check that the path exists on the dev server
2313    session
2314        .peer
2315        .forward_request(
2316            session.connection_id,
2317            dev_server_connection_id,
2318            proto::ValidateDevServerProjectRequest { path: path.clone() },
2319        )
2320        .await?;
2321
2322    let (dev_server_project, update) = session
2323        .db()
2324        .await
2325        .create_dev_server_project(
2326            DevServerId(request.dev_server_id as i32),
2327            &request.path,
2328            session.user_id(),
2329        )
2330        .await?;
2331
2332    let projects = session
2333        .db()
2334        .await
2335        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2336        .await?;
2337
2338    session.peer.send(
2339        dev_server_connection_id,
2340        proto::DevServerInstructions { projects },
2341    )?;
2342
2343    send_dev_server_projects_update(session.user_id(), update, &session).await;
2344
2345    response.send(proto::CreateDevServerProjectResponse {
2346        dev_server_project: Some(dev_server_project.to_proto(None)),
2347    })?;
2348    Ok(())
2349}
2350
2351async fn create_dev_server(
2352    request: proto::CreateDevServer,
2353    response: Response<proto::CreateDevServer>,
2354    session: UserSession,
2355) -> Result<()> {
2356    let access_token = auth::random_token();
2357    let hashed_access_token = auth::hash_access_token(&access_token);
2358
2359    if request.name.is_empty() {
2360        return Err(proto::ErrorCode::Forbidden
2361            .message("Dev server name cannot be empty".to_string())
2362            .anyhow())?;
2363    }
2364
2365    let (dev_server, status) = session
2366        .db()
2367        .await
2368        .create_dev_server(
2369            &request.name,
2370            request.ssh_connection_string.as_deref(),
2371            &hashed_access_token,
2372            session.user_id(),
2373        )
2374        .await?;
2375
2376    send_dev_server_projects_update(session.user_id(), status, &session).await;
2377
2378    response.send(proto::CreateDevServerResponse {
2379        dev_server_id: dev_server.id.0 as u64,
2380        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2381        name: request.name,
2382    })?;
2383    Ok(())
2384}
2385
2386async fn regenerate_dev_server_token(
2387    request: proto::RegenerateDevServerToken,
2388    response: Response<proto::RegenerateDevServerToken>,
2389    session: UserSession,
2390) -> Result<()> {
2391    let dev_server_id = DevServerId(request.dev_server_id as i32);
2392    let access_token = auth::random_token();
2393    let hashed_access_token = auth::hash_access_token(&access_token);
2394
2395    let connection_id = session
2396        .connection_pool()
2397        .await
2398        .dev_server_connection_id(dev_server_id);
2399    if let Some(connection_id) = connection_id {
2400        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2401        session
2402            .peer
2403            .send(connection_id, proto::ShutdownDevServer {})?;
2404        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2405    }
2406
2407    let status = session
2408        .db()
2409        .await
2410        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2411        .await?;
2412
2413    send_dev_server_projects_update(session.user_id(), status, &session).await;
2414
2415    response.send(proto::RegenerateDevServerTokenResponse {
2416        dev_server_id: dev_server_id.to_proto(),
2417        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2418    })?;
2419    Ok(())
2420}
2421
2422async fn rename_dev_server(
2423    request: proto::RenameDevServer,
2424    response: Response<proto::RenameDevServer>,
2425    session: UserSession,
2426) -> Result<()> {
2427    if request.name.trim().is_empty() {
2428        return Err(proto::ErrorCode::Forbidden
2429            .message("Dev server name cannot be empty".to_string())
2430            .anyhow())?;
2431    }
2432
2433    let dev_server_id = DevServerId(request.dev_server_id as i32);
2434    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2435    if dev_server.user_id != session.user_id() {
2436        return Err(anyhow!(ErrorCode::Forbidden))?;
2437    }
2438
2439    let status = session
2440        .db()
2441        .await
2442        .rename_dev_server(
2443            dev_server_id,
2444            &request.name,
2445            request.ssh_connection_string.as_deref(),
2446            session.user_id(),
2447        )
2448        .await?;
2449
2450    send_dev_server_projects_update(session.user_id(), status, &session).await;
2451
2452    response.send(proto::Ack {})?;
2453    Ok(())
2454}
2455
2456async fn delete_dev_server(
2457    request: proto::DeleteDevServer,
2458    response: Response<proto::DeleteDevServer>,
2459    session: UserSession,
2460) -> Result<()> {
2461    let dev_server_id = DevServerId(request.dev_server_id as i32);
2462    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2463    if dev_server.user_id != session.user_id() {
2464        return Err(anyhow!(ErrorCode::Forbidden))?;
2465    }
2466
2467    let connection_id = session
2468        .connection_pool()
2469        .await
2470        .dev_server_connection_id(dev_server_id);
2471    if let Some(connection_id) = connection_id {
2472        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2473        session
2474            .peer
2475            .send(connection_id, proto::ShutdownDevServer {})?;
2476        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2477    }
2478
2479    let status = session
2480        .db()
2481        .await
2482        .delete_dev_server(dev_server_id, session.user_id())
2483        .await?;
2484
2485    send_dev_server_projects_update(session.user_id(), status, &session).await;
2486
2487    response.send(proto::Ack {})?;
2488    Ok(())
2489}
2490
2491async fn delete_dev_server_project(
2492    request: proto::DeleteDevServerProject,
2493    response: Response<proto::DeleteDevServerProject>,
2494    session: UserSession,
2495) -> Result<()> {
2496    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2497    let dev_server_project = session
2498        .db()
2499        .await
2500        .get_dev_server_project(dev_server_project_id)
2501        .await?;
2502
2503    let dev_server = session
2504        .db()
2505        .await
2506        .get_dev_server(dev_server_project.dev_server_id)
2507        .await?;
2508    if dev_server.user_id != session.user_id() {
2509        return Err(anyhow!(ErrorCode::Forbidden))?;
2510    }
2511
2512    let dev_server_connection_id = session
2513        .connection_pool()
2514        .await
2515        .dev_server_connection_id(dev_server.id);
2516
2517    if let Some(dev_server_connection_id) = dev_server_connection_id {
2518        let project = session
2519            .db()
2520            .await
2521            .find_dev_server_project(dev_server_project_id)
2522            .await;
2523        if let Ok(project) = project {
2524            unshare_project_internal(
2525                project.id,
2526                dev_server_connection_id,
2527                Some(session.user_id()),
2528                &session,
2529            )
2530            .await?;
2531        }
2532    }
2533
2534    let (projects, status) = session
2535        .db()
2536        .await
2537        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2538        .await?;
2539
2540    if let Some(dev_server_connection_id) = dev_server_connection_id {
2541        session.peer.send(
2542            dev_server_connection_id,
2543            proto::DevServerInstructions { projects },
2544        )?;
2545    }
2546
2547    send_dev_server_projects_update(session.user_id(), status, &session).await;
2548
2549    response.send(proto::Ack {})?;
2550    Ok(())
2551}
2552
2553async fn rejoin_dev_server_projects(
2554    request: proto::RejoinRemoteProjects,
2555    response: Response<proto::RejoinRemoteProjects>,
2556    session: UserSession,
2557) -> Result<()> {
2558    let mut rejoined_projects = {
2559        let db = session.db().await;
2560        db.rejoin_dev_server_projects(
2561            &request.rejoined_projects,
2562            session.user_id(),
2563            session.0.connection_id,
2564        )
2565        .await?
2566    };
2567    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2568
2569    response.send(proto::RejoinRemoteProjectsResponse {
2570        rejoined_projects: rejoined_projects
2571            .into_iter()
2572            .map(|project| project.to_proto())
2573            .collect(),
2574    })
2575}
2576
2577async fn reconnect_dev_server(
2578    request: proto::ReconnectDevServer,
2579    response: Response<proto::ReconnectDevServer>,
2580    session: DevServerSession,
2581) -> Result<()> {
2582    let reshared_projects = {
2583        let db = session.db().await;
2584        db.reshare_dev_server_projects(
2585            &request.reshared_projects,
2586            session.dev_server_id(),
2587            session.0.connection_id,
2588        )
2589        .await?
2590    };
2591
2592    for project in &reshared_projects {
2593        for collaborator in &project.collaborators {
2594            session
2595                .peer
2596                .send(
2597                    collaborator.connection_id,
2598                    proto::UpdateProjectCollaborator {
2599                        project_id: project.id.to_proto(),
2600                        old_peer_id: Some(project.old_connection_id.into()),
2601                        new_peer_id: Some(session.connection_id.into()),
2602                    },
2603                )
2604                .trace_err();
2605        }
2606
2607        broadcast(
2608            Some(session.connection_id),
2609            project
2610                .collaborators
2611                .iter()
2612                .map(|collaborator| collaborator.connection_id),
2613            |connection_id| {
2614                session.peer.forward_send(
2615                    session.connection_id,
2616                    connection_id,
2617                    proto::UpdateProject {
2618                        project_id: project.id.to_proto(),
2619                        worktrees: project.worktrees.clone(),
2620                    },
2621                )
2622            },
2623        );
2624    }
2625
2626    response.send(proto::ReconnectDevServerResponse {
2627        reshared_projects: reshared_projects
2628            .iter()
2629            .map(|project| proto::ResharedProject {
2630                id: project.id.to_proto(),
2631                collaborators: project
2632                    .collaborators
2633                    .iter()
2634                    .map(|collaborator| collaborator.to_proto())
2635                    .collect(),
2636            })
2637            .collect(),
2638    })?;
2639
2640    Ok(())
2641}
2642
2643async fn shutdown_dev_server(
2644    _: proto::ShutdownDevServer,
2645    response: Response<proto::ShutdownDevServer>,
2646    session: DevServerSession,
2647) -> Result<()> {
2648    response.send(proto::Ack {})?;
2649    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2650    remove_dev_server_connection(session.dev_server_id(), &session).await
2651}
2652
2653async fn shutdown_dev_server_internal(
2654    dev_server_id: DevServerId,
2655    connection_id: ConnectionId,
2656    session: &Session,
2657) -> Result<()> {
2658    let (dev_server_projects, dev_server) = {
2659        let db = session.db().await;
2660        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2661        let dev_server = db.get_dev_server(dev_server_id).await?;
2662        (dev_server_projects, dev_server)
2663    };
2664
2665    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2666        unshare_project_internal(
2667            ProjectId::from_proto(project_id),
2668            connection_id,
2669            None,
2670            session,
2671        )
2672        .await?;
2673    }
2674
2675    session
2676        .connection_pool()
2677        .await
2678        .set_dev_server_offline(dev_server_id);
2679
2680    let status = session
2681        .db()
2682        .await
2683        .dev_server_projects_update(dev_server.user_id)
2684        .await?;
2685    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2686
2687    Ok(())
2688}
2689
2690async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2691    let dev_server_connection = session
2692        .connection_pool()
2693        .await
2694        .dev_server_connection_id(dev_server_id);
2695
2696    if let Some(dev_server_connection) = dev_server_connection {
2697        session
2698            .connection_pool()
2699            .await
2700            .remove_connection(dev_server_connection)?;
2701    }
2702    Ok(())
2703}
2704
2705/// Updates other participants with changes to the project
2706async fn update_project(
2707    request: proto::UpdateProject,
2708    response: Response<proto::UpdateProject>,
2709    session: Session,
2710) -> Result<()> {
2711    let project_id = ProjectId::from_proto(request.project_id);
2712    let (room, guest_connection_ids) = &*session
2713        .db()
2714        .await
2715        .update_project(project_id, session.connection_id, &request.worktrees)
2716        .await?;
2717    broadcast(
2718        Some(session.connection_id),
2719        guest_connection_ids.iter().copied(),
2720        |connection_id| {
2721            session
2722                .peer
2723                .forward_send(session.connection_id, connection_id, request.clone())
2724        },
2725    );
2726    if let Some(room) = room {
2727        room_updated(&room, &session.peer);
2728    }
2729    response.send(proto::Ack {})?;
2730
2731    Ok(())
2732}
2733
2734/// Updates other participants with changes to the worktree
2735async fn update_worktree(
2736    request: proto::UpdateWorktree,
2737    response: Response<proto::UpdateWorktree>,
2738    session: Session,
2739) -> Result<()> {
2740    let guest_connection_ids = session
2741        .db()
2742        .await
2743        .update_worktree(&request, session.connection_id)
2744        .await?;
2745
2746    broadcast(
2747        Some(session.connection_id),
2748        guest_connection_ids.iter().copied(),
2749        |connection_id| {
2750            session
2751                .peer
2752                .forward_send(session.connection_id, connection_id, request.clone())
2753        },
2754    );
2755    response.send(proto::Ack {})?;
2756    Ok(())
2757}
2758
2759/// Updates other participants with changes to the diagnostics
2760async fn update_diagnostic_summary(
2761    message: proto::UpdateDiagnosticSummary,
2762    session: Session,
2763) -> Result<()> {
2764    let guest_connection_ids = session
2765        .db()
2766        .await
2767        .update_diagnostic_summary(&message, session.connection_id)
2768        .await?;
2769
2770    broadcast(
2771        Some(session.connection_id),
2772        guest_connection_ids.iter().copied(),
2773        |connection_id| {
2774            session
2775                .peer
2776                .forward_send(session.connection_id, connection_id, message.clone())
2777        },
2778    );
2779
2780    Ok(())
2781}
2782
2783/// Updates other participants with changes to the worktree settings
2784async fn update_worktree_settings(
2785    message: proto::UpdateWorktreeSettings,
2786    session: Session,
2787) -> Result<()> {
2788    let guest_connection_ids = session
2789        .db()
2790        .await
2791        .update_worktree_settings(&message, session.connection_id)
2792        .await?;
2793
2794    broadcast(
2795        Some(session.connection_id),
2796        guest_connection_ids.iter().copied(),
2797        |connection_id| {
2798            session
2799                .peer
2800                .forward_send(session.connection_id, connection_id, message.clone())
2801        },
2802    );
2803
2804    Ok(())
2805}
2806
2807/// Notify other participants that a  language server has started.
2808async fn start_language_server(
2809    request: proto::StartLanguageServer,
2810    session: Session,
2811) -> Result<()> {
2812    let guest_connection_ids = session
2813        .db()
2814        .await
2815        .start_language_server(&request, session.connection_id)
2816        .await?;
2817
2818    broadcast(
2819        Some(session.connection_id),
2820        guest_connection_ids.iter().copied(),
2821        |connection_id| {
2822            session
2823                .peer
2824                .forward_send(session.connection_id, connection_id, request.clone())
2825        },
2826    );
2827    Ok(())
2828}
2829
2830/// Notify other participants that a language server has changed.
2831async fn update_language_server(
2832    request: proto::UpdateLanguageServer,
2833    session: Session,
2834) -> Result<()> {
2835    let project_id = ProjectId::from_proto(request.project_id);
2836    let project_connection_ids = session
2837        .db()
2838        .await
2839        .project_connection_ids(project_id, session.connection_id, true)
2840        .await?;
2841    broadcast(
2842        Some(session.connection_id),
2843        project_connection_ids.iter().copied(),
2844        |connection_id| {
2845            session
2846                .peer
2847                .forward_send(session.connection_id, connection_id, request.clone())
2848        },
2849    );
2850    Ok(())
2851}
2852
2853/// forward a project request to the host. These requests should be read only
2854/// as guests are allowed to send them.
2855async fn forward_read_only_project_request<T>(
2856    request: T,
2857    response: Response<T>,
2858    session: UserSession,
2859) -> Result<()>
2860where
2861    T: EntityMessage + RequestMessage,
2862{
2863    let project_id = ProjectId::from_proto(request.remote_entity_id());
2864    let host_connection_id = session
2865        .db()
2866        .await
2867        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2868        .await?;
2869    let payload = session
2870        .peer
2871        .forward_request(session.connection_id, host_connection_id, request)
2872        .await?;
2873    response.send(payload)?;
2874    Ok(())
2875}
2876
2877/// forward a project request to the host. These requests are disallowed
2878/// for guests.
2879async fn forward_mutating_project_request<T>(
2880    request: T,
2881    response: Response<T>,
2882    session: UserSession,
2883) -> Result<()>
2884where
2885    T: EntityMessage + RequestMessage,
2886{
2887    let project_id = ProjectId::from_proto(request.remote_entity_id());
2888
2889    let host_connection_id = session
2890        .db()
2891        .await
2892        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2893        .await?;
2894    let payload = session
2895        .peer
2896        .forward_request(session.connection_id, host_connection_id, request)
2897        .await?;
2898    response.send(payload)?;
2899    Ok(())
2900}
2901
2902/// forward a project request to the host. These requests are disallowed
2903/// for guests.
2904async fn forward_versioned_mutating_project_request<T>(
2905    request: T,
2906    response: Response<T>,
2907    session: UserSession,
2908) -> Result<()>
2909where
2910    T: EntityMessage + RequestMessage + VersionedMessage,
2911{
2912    let project_id = ProjectId::from_proto(request.remote_entity_id());
2913
2914    let host_connection_id = session
2915        .db()
2916        .await
2917        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2918        .await?;
2919    if let Some(host_version) = session
2920        .connection_pool()
2921        .await
2922        .connection(host_connection_id)
2923        .map(|c| c.zed_version)
2924    {
2925        if let Some(min_required_version) = request.required_host_version() {
2926            if min_required_version > host_version {
2927                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2928                    .with_tag("required", &min_required_version.to_string())))?;
2929            }
2930        }
2931    }
2932
2933    let payload = session
2934        .peer
2935        .forward_request(session.connection_id, host_connection_id, request)
2936        .await?;
2937    response.send(payload)?;
2938    Ok(())
2939}
2940
2941/// Notify other participants that a new buffer has been created
2942async fn create_buffer_for_peer(
2943    request: proto::CreateBufferForPeer,
2944    session: Session,
2945) -> Result<()> {
2946    session
2947        .db()
2948        .await
2949        .check_user_is_project_host(
2950            ProjectId::from_proto(request.project_id),
2951            session.connection_id,
2952        )
2953        .await?;
2954    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2955    session
2956        .peer
2957        .forward_send(session.connection_id, peer_id.into(), request)?;
2958    Ok(())
2959}
2960
2961/// Notify other participants that a buffer has been updated. This is
2962/// allowed for guests as long as the update is limited to selections.
2963async fn update_buffer(
2964    request: proto::UpdateBuffer,
2965    response: Response<proto::UpdateBuffer>,
2966    session: Session,
2967) -> Result<()> {
2968    let project_id = ProjectId::from_proto(request.project_id);
2969    let mut capability = Capability::ReadOnly;
2970
2971    for op in request.operations.iter() {
2972        match op.variant {
2973            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2974            Some(_) => capability = Capability::ReadWrite,
2975        }
2976    }
2977
2978    let host = {
2979        let guard = session
2980            .db()
2981            .await
2982            .connections_for_buffer_update(
2983                project_id,
2984                session.principal_id(),
2985                session.connection_id,
2986                capability,
2987            )
2988            .await?;
2989
2990        let (host, guests) = &*guard;
2991
2992        broadcast(
2993            Some(session.connection_id),
2994            guests.clone(),
2995            |connection_id| {
2996                session
2997                    .peer
2998                    .forward_send(session.connection_id, connection_id, request.clone())
2999            },
3000        );
3001
3002        *host
3003    };
3004
3005    if host != session.connection_id {
3006        session
3007            .peer
3008            .forward_request(session.connection_id, host, request.clone())
3009            .await?;
3010    }
3011
3012    response.send(proto::Ack {})?;
3013    Ok(())
3014}
3015
3016/// Notify other participants that a project has been updated.
3017async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3018    request: T,
3019    session: Session,
3020) -> Result<()> {
3021    let project_id = ProjectId::from_proto(request.remote_entity_id());
3022    let project_connection_ids = session
3023        .db()
3024        .await
3025        .project_connection_ids(project_id, session.connection_id, false)
3026        .await?;
3027
3028    broadcast(
3029        Some(session.connection_id),
3030        project_connection_ids.iter().copied(),
3031        |connection_id| {
3032            session
3033                .peer
3034                .forward_send(session.connection_id, connection_id, request.clone())
3035        },
3036    );
3037    Ok(())
3038}
3039
3040/// Start following another user in a call.
3041async fn follow(
3042    request: proto::Follow,
3043    response: Response<proto::Follow>,
3044    session: UserSession,
3045) -> Result<()> {
3046    let room_id = RoomId::from_proto(request.room_id);
3047    let project_id = request.project_id.map(ProjectId::from_proto);
3048    let leader_id = request
3049        .leader_id
3050        .ok_or_else(|| anyhow!("invalid leader id"))?
3051        .into();
3052    let follower_id = session.connection_id;
3053
3054    session
3055        .db()
3056        .await
3057        .check_room_participants(room_id, leader_id, session.connection_id)
3058        .await?;
3059
3060    let response_payload = session
3061        .peer
3062        .forward_request(session.connection_id, leader_id, request)
3063        .await?;
3064    response.send(response_payload)?;
3065
3066    if let Some(project_id) = project_id {
3067        let room = session
3068            .db()
3069            .await
3070            .follow(room_id, project_id, leader_id, follower_id)
3071            .await?;
3072        room_updated(&room, &session.peer);
3073    }
3074
3075    Ok(())
3076}
3077
3078/// Stop following another user in a call.
3079async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3080    let room_id = RoomId::from_proto(request.room_id);
3081    let project_id = request.project_id.map(ProjectId::from_proto);
3082    let leader_id = request
3083        .leader_id
3084        .ok_or_else(|| anyhow!("invalid leader id"))?
3085        .into();
3086    let follower_id = session.connection_id;
3087
3088    session
3089        .db()
3090        .await
3091        .check_room_participants(room_id, leader_id, session.connection_id)
3092        .await?;
3093
3094    session
3095        .peer
3096        .forward_send(session.connection_id, leader_id, request)?;
3097
3098    if let Some(project_id) = project_id {
3099        let room = session
3100            .db()
3101            .await
3102            .unfollow(room_id, project_id, leader_id, follower_id)
3103            .await?;
3104        room_updated(&room, &session.peer);
3105    }
3106
3107    Ok(())
3108}
3109
3110/// Notify everyone following you of your current location.
3111async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3112    let room_id = RoomId::from_proto(request.room_id);
3113    let database = session.db.lock().await;
3114
3115    let connection_ids = if let Some(project_id) = request.project_id {
3116        let project_id = ProjectId::from_proto(project_id);
3117        database
3118            .project_connection_ids(project_id, session.connection_id, true)
3119            .await?
3120    } else {
3121        database
3122            .room_connection_ids(room_id, session.connection_id)
3123            .await?
3124    };
3125
3126    // For now, don't send view update messages back to that view's current leader.
3127    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3128        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3129        _ => None,
3130    });
3131
3132    for connection_id in connection_ids.iter().cloned() {
3133        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3134            session
3135                .peer
3136                .forward_send(session.connection_id, connection_id, request.clone())?;
3137        }
3138    }
3139    Ok(())
3140}
3141
3142/// Get public data about users.
3143async fn get_users(
3144    request: proto::GetUsers,
3145    response: Response<proto::GetUsers>,
3146    session: Session,
3147) -> Result<()> {
3148    let user_ids = request
3149        .user_ids
3150        .into_iter()
3151        .map(UserId::from_proto)
3152        .collect();
3153    let users = session
3154        .db()
3155        .await
3156        .get_users_by_ids(user_ids)
3157        .await?
3158        .into_iter()
3159        .map(|user| proto::User {
3160            id: user.id.to_proto(),
3161            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3162            github_login: user.github_login,
3163        })
3164        .collect();
3165    response.send(proto::UsersResponse { users })?;
3166    Ok(())
3167}
3168
3169/// Search for users (to invite) buy Github login
3170async fn fuzzy_search_users(
3171    request: proto::FuzzySearchUsers,
3172    response: Response<proto::FuzzySearchUsers>,
3173    session: UserSession,
3174) -> Result<()> {
3175    let query = request.query;
3176    let users = match query.len() {
3177        0 => vec![],
3178        1 | 2 => session
3179            .db()
3180            .await
3181            .get_user_by_github_login(&query)
3182            .await?
3183            .into_iter()
3184            .collect(),
3185        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3186    };
3187    let users = users
3188        .into_iter()
3189        .filter(|user| user.id != session.user_id())
3190        .map(|user| proto::User {
3191            id: user.id.to_proto(),
3192            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3193            github_login: user.github_login,
3194        })
3195        .collect();
3196    response.send(proto::UsersResponse { users })?;
3197    Ok(())
3198}
3199
3200/// Send a contact request to another user.
3201async fn request_contact(
3202    request: proto::RequestContact,
3203    response: Response<proto::RequestContact>,
3204    session: UserSession,
3205) -> Result<()> {
3206    let requester_id = session.user_id();
3207    let responder_id = UserId::from_proto(request.responder_id);
3208    if requester_id == responder_id {
3209        return Err(anyhow!("cannot add yourself as a contact"))?;
3210    }
3211
3212    let notifications = session
3213        .db()
3214        .await
3215        .send_contact_request(requester_id, responder_id)
3216        .await?;
3217
3218    // Update outgoing contact requests of requester
3219    let mut update = proto::UpdateContacts::default();
3220    update.outgoing_requests.push(responder_id.to_proto());
3221    for connection_id in session
3222        .connection_pool()
3223        .await
3224        .user_connection_ids(requester_id)
3225    {
3226        session.peer.send(connection_id, update.clone())?;
3227    }
3228
3229    // Update incoming contact requests of responder
3230    let mut update = proto::UpdateContacts::default();
3231    update
3232        .incoming_requests
3233        .push(proto::IncomingContactRequest {
3234            requester_id: requester_id.to_proto(),
3235        });
3236    let connection_pool = session.connection_pool().await;
3237    for connection_id in connection_pool.user_connection_ids(responder_id) {
3238        session.peer.send(connection_id, update.clone())?;
3239    }
3240
3241    send_notifications(&connection_pool, &session.peer, notifications);
3242
3243    response.send(proto::Ack {})?;
3244    Ok(())
3245}
3246
3247/// Accept or decline a contact request
3248async fn respond_to_contact_request(
3249    request: proto::RespondToContactRequest,
3250    response: Response<proto::RespondToContactRequest>,
3251    session: UserSession,
3252) -> Result<()> {
3253    let responder_id = session.user_id();
3254    let requester_id = UserId::from_proto(request.requester_id);
3255    let db = session.db().await;
3256    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3257        db.dismiss_contact_notification(responder_id, requester_id)
3258            .await?;
3259    } else {
3260        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3261
3262        let notifications = db
3263            .respond_to_contact_request(responder_id, requester_id, accept)
3264            .await?;
3265        let requester_busy = db.is_user_busy(requester_id).await?;
3266        let responder_busy = db.is_user_busy(responder_id).await?;
3267
3268        let pool = session.connection_pool().await;
3269        // Update responder with new contact
3270        let mut update = proto::UpdateContacts::default();
3271        if accept {
3272            update
3273                .contacts
3274                .push(contact_for_user(requester_id, requester_busy, &pool));
3275        }
3276        update
3277            .remove_incoming_requests
3278            .push(requester_id.to_proto());
3279        for connection_id in pool.user_connection_ids(responder_id) {
3280            session.peer.send(connection_id, update.clone())?;
3281        }
3282
3283        // Update requester with new contact
3284        let mut update = proto::UpdateContacts::default();
3285        if accept {
3286            update
3287                .contacts
3288                .push(contact_for_user(responder_id, responder_busy, &pool));
3289        }
3290        update
3291            .remove_outgoing_requests
3292            .push(responder_id.to_proto());
3293
3294        for connection_id in pool.user_connection_ids(requester_id) {
3295            session.peer.send(connection_id, update.clone())?;
3296        }
3297
3298        send_notifications(&pool, &session.peer, notifications);
3299    }
3300
3301    response.send(proto::Ack {})?;
3302    Ok(())
3303}
3304
3305/// Remove a contact.
3306async fn remove_contact(
3307    request: proto::RemoveContact,
3308    response: Response<proto::RemoveContact>,
3309    session: UserSession,
3310) -> Result<()> {
3311    let requester_id = session.user_id();
3312    let responder_id = UserId::from_proto(request.user_id);
3313    let db = session.db().await;
3314    let (contact_accepted, deleted_notification_id) =
3315        db.remove_contact(requester_id, responder_id).await?;
3316
3317    let pool = session.connection_pool().await;
3318    // Update outgoing contact requests of requester
3319    let mut update = proto::UpdateContacts::default();
3320    if contact_accepted {
3321        update.remove_contacts.push(responder_id.to_proto());
3322    } else {
3323        update
3324            .remove_outgoing_requests
3325            .push(responder_id.to_proto());
3326    }
3327    for connection_id in pool.user_connection_ids(requester_id) {
3328        session.peer.send(connection_id, update.clone())?;
3329    }
3330
3331    // Update incoming contact requests of responder
3332    let mut update = proto::UpdateContacts::default();
3333    if contact_accepted {
3334        update.remove_contacts.push(requester_id.to_proto());
3335    } else {
3336        update
3337            .remove_incoming_requests
3338            .push(requester_id.to_proto());
3339    }
3340    for connection_id in pool.user_connection_ids(responder_id) {
3341        session.peer.send(connection_id, update.clone())?;
3342        if let Some(notification_id) = deleted_notification_id {
3343            session.peer.send(
3344                connection_id,
3345                proto::DeleteNotification {
3346                    notification_id: notification_id.to_proto(),
3347                },
3348            )?;
3349        }
3350    }
3351
3352    response.send(proto::Ack {})?;
3353    Ok(())
3354}
3355
3356/// Creates a new channel.
3357async fn create_channel(
3358    request: proto::CreateChannel,
3359    response: Response<proto::CreateChannel>,
3360    session: UserSession,
3361) -> Result<()> {
3362    let db = session.db().await;
3363
3364    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3365    let (channel, membership) = db
3366        .create_channel(&request.name, parent_id, session.user_id())
3367        .await?;
3368
3369    let root_id = channel.root_id();
3370    let channel = Channel::from_model(channel);
3371
3372    response.send(proto::CreateChannelResponse {
3373        channel: Some(channel.to_proto()),
3374        parent_id: request.parent_id,
3375    })?;
3376
3377    let mut connection_pool = session.connection_pool().await;
3378    if let Some(membership) = membership {
3379        connection_pool.subscribe_to_channel(
3380            membership.user_id,
3381            membership.channel_id,
3382            membership.role,
3383        );
3384        let update = proto::UpdateUserChannels {
3385            channel_memberships: vec![proto::ChannelMembership {
3386                channel_id: membership.channel_id.to_proto(),
3387                role: membership.role.into(),
3388            }],
3389            ..Default::default()
3390        };
3391        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3392            session.peer.send(connection_id, update.clone())?;
3393        }
3394    }
3395
3396    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3397        if !role.can_see_channel(channel.visibility) {
3398            continue;
3399        }
3400
3401        let update = proto::UpdateChannels {
3402            channels: vec![channel.to_proto()],
3403            ..Default::default()
3404        };
3405        session.peer.send(connection_id, update.clone())?;
3406    }
3407
3408    Ok(())
3409}
3410
3411/// Delete a channel
3412async fn delete_channel(
3413    request: proto::DeleteChannel,
3414    response: Response<proto::DeleteChannel>,
3415    session: UserSession,
3416) -> Result<()> {
3417    let db = session.db().await;
3418
3419    let channel_id = request.channel_id;
3420    let (root_channel, removed_channels) = db
3421        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3422        .await?;
3423    response.send(proto::Ack {})?;
3424
3425    // Notify members of removed channels
3426    let mut update = proto::UpdateChannels::default();
3427    update
3428        .delete_channels
3429        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3430
3431    let connection_pool = session.connection_pool().await;
3432    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3433        session.peer.send(connection_id, update.clone())?;
3434    }
3435
3436    Ok(())
3437}
3438
3439/// Invite someone to join a channel.
3440async fn invite_channel_member(
3441    request: proto::InviteChannelMember,
3442    response: Response<proto::InviteChannelMember>,
3443    session: UserSession,
3444) -> Result<()> {
3445    let db = session.db().await;
3446    let channel_id = ChannelId::from_proto(request.channel_id);
3447    let invitee_id = UserId::from_proto(request.user_id);
3448    let InviteMemberResult {
3449        channel,
3450        notifications,
3451    } = db
3452        .invite_channel_member(
3453            channel_id,
3454            invitee_id,
3455            session.user_id(),
3456            request.role().into(),
3457        )
3458        .await?;
3459
3460    let update = proto::UpdateChannels {
3461        channel_invitations: vec![channel.to_proto()],
3462        ..Default::default()
3463    };
3464
3465    let connection_pool = session.connection_pool().await;
3466    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3467        session.peer.send(connection_id, update.clone())?;
3468    }
3469
3470    send_notifications(&connection_pool, &session.peer, notifications);
3471
3472    response.send(proto::Ack {})?;
3473    Ok(())
3474}
3475
3476/// remove someone from a channel
3477async fn remove_channel_member(
3478    request: proto::RemoveChannelMember,
3479    response: Response<proto::RemoveChannelMember>,
3480    session: UserSession,
3481) -> Result<()> {
3482    let db = session.db().await;
3483    let channel_id = ChannelId::from_proto(request.channel_id);
3484    let member_id = UserId::from_proto(request.user_id);
3485
3486    let RemoveChannelMemberResult {
3487        membership_update,
3488        notification_id,
3489    } = db
3490        .remove_channel_member(channel_id, member_id, session.user_id())
3491        .await?;
3492
3493    let mut connection_pool = session.connection_pool().await;
3494    notify_membership_updated(
3495        &mut connection_pool,
3496        membership_update,
3497        member_id,
3498        &session.peer,
3499    );
3500    for connection_id in connection_pool.user_connection_ids(member_id) {
3501        if let Some(notification_id) = notification_id {
3502            session
3503                .peer
3504                .send(
3505                    connection_id,
3506                    proto::DeleteNotification {
3507                        notification_id: notification_id.to_proto(),
3508                    },
3509                )
3510                .trace_err();
3511        }
3512    }
3513
3514    response.send(proto::Ack {})?;
3515    Ok(())
3516}
3517
3518/// Toggle the channel between public and private.
3519/// Care is taken to maintain the invariant that public channels only descend from public channels,
3520/// (though members-only channels can appear at any point in the hierarchy).
3521async fn set_channel_visibility(
3522    request: proto::SetChannelVisibility,
3523    response: Response<proto::SetChannelVisibility>,
3524    session: UserSession,
3525) -> Result<()> {
3526    let db = session.db().await;
3527    let channel_id = ChannelId::from_proto(request.channel_id);
3528    let visibility = request.visibility().into();
3529
3530    let channel_model = db
3531        .set_channel_visibility(channel_id, visibility, session.user_id())
3532        .await?;
3533    let root_id = channel_model.root_id();
3534    let channel = Channel::from_model(channel_model);
3535
3536    let mut connection_pool = session.connection_pool().await;
3537    for (user_id, role) in connection_pool
3538        .channel_user_ids(root_id)
3539        .collect::<Vec<_>>()
3540        .into_iter()
3541    {
3542        let update = if role.can_see_channel(channel.visibility) {
3543            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3544            proto::UpdateChannels {
3545                channels: vec![channel.to_proto()],
3546                ..Default::default()
3547            }
3548        } else {
3549            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3550            proto::UpdateChannels {
3551                delete_channels: vec![channel.id.to_proto()],
3552                ..Default::default()
3553            }
3554        };
3555
3556        for connection_id in connection_pool.user_connection_ids(user_id) {
3557            session.peer.send(connection_id, update.clone())?;
3558        }
3559    }
3560
3561    response.send(proto::Ack {})?;
3562    Ok(())
3563}
3564
3565/// Alter the role for a user in the channel.
3566async fn set_channel_member_role(
3567    request: proto::SetChannelMemberRole,
3568    response: Response<proto::SetChannelMemberRole>,
3569    session: UserSession,
3570) -> Result<()> {
3571    let db = session.db().await;
3572    let channel_id = ChannelId::from_proto(request.channel_id);
3573    let member_id = UserId::from_proto(request.user_id);
3574    let result = db
3575        .set_channel_member_role(
3576            channel_id,
3577            session.user_id(),
3578            member_id,
3579            request.role().into(),
3580        )
3581        .await?;
3582
3583    match result {
3584        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3585            let mut connection_pool = session.connection_pool().await;
3586            notify_membership_updated(
3587                &mut connection_pool,
3588                membership_update,
3589                member_id,
3590                &session.peer,
3591            )
3592        }
3593        db::SetMemberRoleResult::InviteUpdated(channel) => {
3594            let update = proto::UpdateChannels {
3595                channel_invitations: vec![channel.to_proto()],
3596                ..Default::default()
3597            };
3598
3599            for connection_id in session
3600                .connection_pool()
3601                .await
3602                .user_connection_ids(member_id)
3603            {
3604                session.peer.send(connection_id, update.clone())?;
3605            }
3606        }
3607    }
3608
3609    response.send(proto::Ack {})?;
3610    Ok(())
3611}
3612
3613/// Change the name of a channel
3614async fn rename_channel(
3615    request: proto::RenameChannel,
3616    response: Response<proto::RenameChannel>,
3617    session: UserSession,
3618) -> Result<()> {
3619    let db = session.db().await;
3620    let channel_id = ChannelId::from_proto(request.channel_id);
3621    let channel_model = db
3622        .rename_channel(channel_id, session.user_id(), &request.name)
3623        .await?;
3624    let root_id = channel_model.root_id();
3625    let channel = Channel::from_model(channel_model);
3626
3627    response.send(proto::RenameChannelResponse {
3628        channel: Some(channel.to_proto()),
3629    })?;
3630
3631    let connection_pool = session.connection_pool().await;
3632    let update = proto::UpdateChannels {
3633        channels: vec![channel.to_proto()],
3634        ..Default::default()
3635    };
3636    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3637        if role.can_see_channel(channel.visibility) {
3638            session.peer.send(connection_id, update.clone())?;
3639        }
3640    }
3641
3642    Ok(())
3643}
3644
3645/// Move a channel to a new parent.
3646async fn move_channel(
3647    request: proto::MoveChannel,
3648    response: Response<proto::MoveChannel>,
3649    session: UserSession,
3650) -> Result<()> {
3651    let channel_id = ChannelId::from_proto(request.channel_id);
3652    let to = ChannelId::from_proto(request.to);
3653
3654    let (root_id, channels) = session
3655        .db()
3656        .await
3657        .move_channel(channel_id, to, session.user_id())
3658        .await?;
3659
3660    let connection_pool = session.connection_pool().await;
3661    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3662        let channels = channels
3663            .iter()
3664            .filter_map(|channel| {
3665                if role.can_see_channel(channel.visibility) {
3666                    Some(channel.to_proto())
3667                } else {
3668                    None
3669                }
3670            })
3671            .collect::<Vec<_>>();
3672        if channels.is_empty() {
3673            continue;
3674        }
3675
3676        let update = proto::UpdateChannels {
3677            channels,
3678            ..Default::default()
3679        };
3680
3681        session.peer.send(connection_id, update.clone())?;
3682    }
3683
3684    response.send(Ack {})?;
3685    Ok(())
3686}
3687
3688/// Get the list of channel members
3689async fn get_channel_members(
3690    request: proto::GetChannelMembers,
3691    response: Response<proto::GetChannelMembers>,
3692    session: UserSession,
3693) -> Result<()> {
3694    let db = session.db().await;
3695    let channel_id = ChannelId::from_proto(request.channel_id);
3696    let limit = if request.limit == 0 {
3697        u16::MAX as u64
3698    } else {
3699        request.limit
3700    };
3701    let (members, users) = db
3702        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3703        .await?;
3704    response.send(proto::GetChannelMembersResponse { members, users })?;
3705    Ok(())
3706}
3707
3708/// Accept or decline a channel invitation.
3709async fn respond_to_channel_invite(
3710    request: proto::RespondToChannelInvite,
3711    response: Response<proto::RespondToChannelInvite>,
3712    session: UserSession,
3713) -> Result<()> {
3714    let db = session.db().await;
3715    let channel_id = ChannelId::from_proto(request.channel_id);
3716    let RespondToChannelInvite {
3717        membership_update,
3718        notifications,
3719    } = db
3720        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3721        .await?;
3722
3723    let mut connection_pool = session.connection_pool().await;
3724    if let Some(membership_update) = membership_update {
3725        notify_membership_updated(
3726            &mut connection_pool,
3727            membership_update,
3728            session.user_id(),
3729            &session.peer,
3730        );
3731    } else {
3732        let update = proto::UpdateChannels {
3733            remove_channel_invitations: vec![channel_id.to_proto()],
3734            ..Default::default()
3735        };
3736
3737        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3738            session.peer.send(connection_id, update.clone())?;
3739        }
3740    };
3741
3742    send_notifications(&connection_pool, &session.peer, notifications);
3743
3744    response.send(proto::Ack {})?;
3745
3746    Ok(())
3747}
3748
3749/// Join the channels' room
3750async fn join_channel(
3751    request: proto::JoinChannel,
3752    response: Response<proto::JoinChannel>,
3753    session: UserSession,
3754) -> Result<()> {
3755    let channel_id = ChannelId::from_proto(request.channel_id);
3756    join_channel_internal(channel_id, Box::new(response), session).await
3757}
3758
3759trait JoinChannelInternalResponse {
3760    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3761}
3762impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3763    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3764        Response::<proto::JoinChannel>::send(self, result)
3765    }
3766}
3767impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3768    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3769        Response::<proto::JoinRoom>::send(self, result)
3770    }
3771}
3772
3773async fn join_channel_internal(
3774    channel_id: ChannelId,
3775    response: Box<impl JoinChannelInternalResponse>,
3776    session: UserSession,
3777) -> Result<()> {
3778    let joined_room = {
3779        let mut db = session.db().await;
3780        // If zed quits without leaving the room, and the user re-opens zed before the
3781        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3782        // room they were in.
3783        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3784            tracing::info!(
3785                stale_connection_id = %connection,
3786                "cleaning up stale connection",
3787            );
3788            drop(db);
3789            leave_room_for_session(&session, connection).await?;
3790            db = session.db().await;
3791        }
3792
3793        let (joined_room, membership_updated, role) = db
3794            .join_channel(channel_id, session.user_id(), session.connection_id)
3795            .await?;
3796
3797        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3798            let (can_publish, token) = if role == ChannelRole::Guest {
3799                (
3800                    false,
3801                    live_kit
3802                        .guest_token(
3803                            &joined_room.room.live_kit_room,
3804                            &session.user_id().to_string(),
3805                        )
3806                        .trace_err()?,
3807                )
3808            } else {
3809                (
3810                    true,
3811                    live_kit
3812                        .room_token(
3813                            &joined_room.room.live_kit_room,
3814                            &session.user_id().to_string(),
3815                        )
3816                        .trace_err()?,
3817                )
3818            };
3819
3820            Some(LiveKitConnectionInfo {
3821                server_url: live_kit.url().into(),
3822                token,
3823                can_publish,
3824            })
3825        });
3826
3827        response.send(proto::JoinRoomResponse {
3828            room: Some(joined_room.room.clone()),
3829            channel_id: joined_room
3830                .channel
3831                .as_ref()
3832                .map(|channel| channel.id.to_proto()),
3833            live_kit_connection_info,
3834        })?;
3835
3836        let mut connection_pool = session.connection_pool().await;
3837        if let Some(membership_updated) = membership_updated {
3838            notify_membership_updated(
3839                &mut connection_pool,
3840                membership_updated,
3841                session.user_id(),
3842                &session.peer,
3843            );
3844        }
3845
3846        room_updated(&joined_room.room, &session.peer);
3847
3848        joined_room
3849    };
3850
3851    channel_updated(
3852        &joined_room
3853            .channel
3854            .ok_or_else(|| anyhow!("channel not returned"))?,
3855        &joined_room.room,
3856        &session.peer,
3857        &*session.connection_pool().await,
3858    );
3859
3860    update_user_contacts(session.user_id(), &session).await?;
3861    Ok(())
3862}
3863
3864/// Start editing the channel notes
3865async fn join_channel_buffer(
3866    request: proto::JoinChannelBuffer,
3867    response: Response<proto::JoinChannelBuffer>,
3868    session: UserSession,
3869) -> Result<()> {
3870    let db = session.db().await;
3871    let channel_id = ChannelId::from_proto(request.channel_id);
3872
3873    let open_response = db
3874        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3875        .await?;
3876
3877    let collaborators = open_response.collaborators.clone();
3878    response.send(open_response)?;
3879
3880    let update = UpdateChannelBufferCollaborators {
3881        channel_id: channel_id.to_proto(),
3882        collaborators: collaborators.clone(),
3883    };
3884    channel_buffer_updated(
3885        session.connection_id,
3886        collaborators
3887            .iter()
3888            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3889        &update,
3890        &session.peer,
3891    );
3892
3893    Ok(())
3894}
3895
3896/// Edit the channel notes
3897async fn update_channel_buffer(
3898    request: proto::UpdateChannelBuffer,
3899    session: UserSession,
3900) -> Result<()> {
3901    let db = session.db().await;
3902    let channel_id = ChannelId::from_proto(request.channel_id);
3903
3904    let (collaborators, epoch, version) = db
3905        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3906        .await?;
3907
3908    channel_buffer_updated(
3909        session.connection_id,
3910        collaborators.clone(),
3911        &proto::UpdateChannelBuffer {
3912            channel_id: channel_id.to_proto(),
3913            operations: request.operations,
3914        },
3915        &session.peer,
3916    );
3917
3918    let pool = &*session.connection_pool().await;
3919
3920    let non_collaborators =
3921        pool.channel_connection_ids(channel_id)
3922            .filter_map(|(connection_id, _)| {
3923                if collaborators.contains(&connection_id) {
3924                    None
3925                } else {
3926                    Some(connection_id)
3927                }
3928            });
3929
3930    broadcast(None, non_collaborators, |peer_id| {
3931        session.peer.send(
3932            peer_id,
3933            proto::UpdateChannels {
3934                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3935                    channel_id: channel_id.to_proto(),
3936                    epoch: epoch as u64,
3937                    version: version.clone(),
3938                }],
3939                ..Default::default()
3940            },
3941        )
3942    });
3943
3944    Ok(())
3945}
3946
3947/// Rejoin the channel notes after a connection blip
3948async fn rejoin_channel_buffers(
3949    request: proto::RejoinChannelBuffers,
3950    response: Response<proto::RejoinChannelBuffers>,
3951    session: UserSession,
3952) -> Result<()> {
3953    let db = session.db().await;
3954    let buffers = db
3955        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3956        .await?;
3957
3958    for rejoined_buffer in &buffers {
3959        let collaborators_to_notify = rejoined_buffer
3960            .buffer
3961            .collaborators
3962            .iter()
3963            .filter_map(|c| Some(c.peer_id?.into()));
3964        channel_buffer_updated(
3965            session.connection_id,
3966            collaborators_to_notify,
3967            &proto::UpdateChannelBufferCollaborators {
3968                channel_id: rejoined_buffer.buffer.channel_id,
3969                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3970            },
3971            &session.peer,
3972        );
3973    }
3974
3975    response.send(proto::RejoinChannelBuffersResponse {
3976        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3977    })?;
3978
3979    Ok(())
3980}
3981
3982/// Stop editing the channel notes
3983async fn leave_channel_buffer(
3984    request: proto::LeaveChannelBuffer,
3985    response: Response<proto::LeaveChannelBuffer>,
3986    session: UserSession,
3987) -> Result<()> {
3988    let db = session.db().await;
3989    let channel_id = ChannelId::from_proto(request.channel_id);
3990
3991    let left_buffer = db
3992        .leave_channel_buffer(channel_id, session.connection_id)
3993        .await?;
3994
3995    response.send(Ack {})?;
3996
3997    channel_buffer_updated(
3998        session.connection_id,
3999        left_buffer.connections,
4000        &proto::UpdateChannelBufferCollaborators {
4001            channel_id: channel_id.to_proto(),
4002            collaborators: left_buffer.collaborators,
4003        },
4004        &session.peer,
4005    );
4006
4007    Ok(())
4008}
4009
4010fn channel_buffer_updated<T: EnvelopedMessage>(
4011    sender_id: ConnectionId,
4012    collaborators: impl IntoIterator<Item = ConnectionId>,
4013    message: &T,
4014    peer: &Peer,
4015) {
4016    broadcast(Some(sender_id), collaborators, |peer_id| {
4017        peer.send(peer_id, message.clone())
4018    });
4019}
4020
4021fn send_notifications(
4022    connection_pool: &ConnectionPool,
4023    peer: &Peer,
4024    notifications: db::NotificationBatch,
4025) {
4026    for (user_id, notification) in notifications {
4027        for connection_id in connection_pool.user_connection_ids(user_id) {
4028            if let Err(error) = peer.send(
4029                connection_id,
4030                proto::AddNotification {
4031                    notification: Some(notification.clone()),
4032                },
4033            ) {
4034                tracing::error!(
4035                    "failed to send notification to {:?} {}",
4036                    connection_id,
4037                    error
4038                );
4039            }
4040        }
4041    }
4042}
4043
4044/// Send a message to the channel
4045async fn send_channel_message(
4046    request: proto::SendChannelMessage,
4047    response: Response<proto::SendChannelMessage>,
4048    session: UserSession,
4049) -> Result<()> {
4050    // Validate the message body.
4051    let body = request.body.trim().to_string();
4052    if body.len() > MAX_MESSAGE_LEN {
4053        return Err(anyhow!("message is too long"))?;
4054    }
4055    if body.is_empty() {
4056        return Err(anyhow!("message can't be blank"))?;
4057    }
4058
4059    // TODO: adjust mentions if body is trimmed
4060
4061    let timestamp = OffsetDateTime::now_utc();
4062    let nonce = request
4063        .nonce
4064        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4065
4066    let channel_id = ChannelId::from_proto(request.channel_id);
4067    let CreatedChannelMessage {
4068        message_id,
4069        participant_connection_ids,
4070        notifications,
4071    } = session
4072        .db()
4073        .await
4074        .create_channel_message(
4075            channel_id,
4076            session.user_id(),
4077            &body,
4078            &request.mentions,
4079            timestamp,
4080            nonce.clone().into(),
4081            match request.reply_to_message_id {
4082                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4083                None => None,
4084            },
4085        )
4086        .await?;
4087
4088    let message = proto::ChannelMessage {
4089        sender_id: session.user_id().to_proto(),
4090        id: message_id.to_proto(),
4091        body,
4092        mentions: request.mentions,
4093        timestamp: timestamp.unix_timestamp() as u64,
4094        nonce: Some(nonce),
4095        reply_to_message_id: request.reply_to_message_id,
4096        edited_at: None,
4097    };
4098    broadcast(
4099        Some(session.connection_id),
4100        participant_connection_ids.clone(),
4101        |connection| {
4102            session.peer.send(
4103                connection,
4104                proto::ChannelMessageSent {
4105                    channel_id: channel_id.to_proto(),
4106                    message: Some(message.clone()),
4107                },
4108            )
4109        },
4110    );
4111    response.send(proto::SendChannelMessageResponse {
4112        message: Some(message),
4113    })?;
4114
4115    let pool = &*session.connection_pool().await;
4116    let non_participants =
4117        pool.channel_connection_ids(channel_id)
4118            .filter_map(|(connection_id, _)| {
4119                if participant_connection_ids.contains(&connection_id) {
4120                    None
4121                } else {
4122                    Some(connection_id)
4123                }
4124            });
4125    broadcast(None, non_participants, |peer_id| {
4126        session.peer.send(
4127            peer_id,
4128            proto::UpdateChannels {
4129                latest_channel_message_ids: vec![proto::ChannelMessageId {
4130                    channel_id: channel_id.to_proto(),
4131                    message_id: message_id.to_proto(),
4132                }],
4133                ..Default::default()
4134            },
4135        )
4136    });
4137    send_notifications(pool, &session.peer, notifications);
4138
4139    Ok(())
4140}
4141
4142/// Delete a channel message
4143async fn remove_channel_message(
4144    request: proto::RemoveChannelMessage,
4145    response: Response<proto::RemoveChannelMessage>,
4146    session: UserSession,
4147) -> Result<()> {
4148    let channel_id = ChannelId::from_proto(request.channel_id);
4149    let message_id = MessageId::from_proto(request.message_id);
4150    let (connection_ids, existing_notification_ids) = session
4151        .db()
4152        .await
4153        .remove_channel_message(channel_id, message_id, session.user_id())
4154        .await?;
4155
4156    broadcast(
4157        Some(session.connection_id),
4158        connection_ids,
4159        move |connection| {
4160            session.peer.send(connection, request.clone())?;
4161
4162            for notification_id in &existing_notification_ids {
4163                session.peer.send(
4164                    connection,
4165                    proto::DeleteNotification {
4166                        notification_id: (*notification_id).to_proto(),
4167                    },
4168                )?;
4169            }
4170
4171            Ok(())
4172        },
4173    );
4174    response.send(proto::Ack {})?;
4175    Ok(())
4176}
4177
4178async fn update_channel_message(
4179    request: proto::UpdateChannelMessage,
4180    response: Response<proto::UpdateChannelMessage>,
4181    session: UserSession,
4182) -> Result<()> {
4183    let channel_id = ChannelId::from_proto(request.channel_id);
4184    let message_id = MessageId::from_proto(request.message_id);
4185    let updated_at = OffsetDateTime::now_utc();
4186    let UpdatedChannelMessage {
4187        message_id,
4188        participant_connection_ids,
4189        notifications,
4190        reply_to_message_id,
4191        timestamp,
4192        deleted_mention_notification_ids,
4193        updated_mention_notifications,
4194    } = session
4195        .db()
4196        .await
4197        .update_channel_message(
4198            channel_id,
4199            message_id,
4200            session.user_id(),
4201            request.body.as_str(),
4202            &request.mentions,
4203            updated_at,
4204        )
4205        .await?;
4206
4207    let nonce = request
4208        .nonce
4209        .clone()
4210        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4211
4212    let message = proto::ChannelMessage {
4213        sender_id: session.user_id().to_proto(),
4214        id: message_id.to_proto(),
4215        body: request.body.clone(),
4216        mentions: request.mentions.clone(),
4217        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4218        nonce: Some(nonce),
4219        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4220        edited_at: Some(updated_at.unix_timestamp() as u64),
4221    };
4222
4223    response.send(proto::Ack {})?;
4224
4225    let pool = &*session.connection_pool().await;
4226    broadcast(
4227        Some(session.connection_id),
4228        participant_connection_ids,
4229        |connection| {
4230            session.peer.send(
4231                connection,
4232                proto::ChannelMessageUpdate {
4233                    channel_id: channel_id.to_proto(),
4234                    message: Some(message.clone()),
4235                },
4236            )?;
4237
4238            for notification_id in &deleted_mention_notification_ids {
4239                session.peer.send(
4240                    connection,
4241                    proto::DeleteNotification {
4242                        notification_id: (*notification_id).to_proto(),
4243                    },
4244                )?;
4245            }
4246
4247            for notification in &updated_mention_notifications {
4248                session.peer.send(
4249                    connection,
4250                    proto::UpdateNotification {
4251                        notification: Some(notification.clone()),
4252                    },
4253                )?;
4254            }
4255
4256            Ok(())
4257        },
4258    );
4259
4260    send_notifications(pool, &session.peer, notifications);
4261
4262    Ok(())
4263}
4264
4265/// Mark a channel message as read
4266async fn acknowledge_channel_message(
4267    request: proto::AckChannelMessage,
4268    session: UserSession,
4269) -> Result<()> {
4270    let channel_id = ChannelId::from_proto(request.channel_id);
4271    let message_id = MessageId::from_proto(request.message_id);
4272    let notifications = session
4273        .db()
4274        .await
4275        .observe_channel_message(channel_id, session.user_id(), message_id)
4276        .await?;
4277    send_notifications(
4278        &*session.connection_pool().await,
4279        &session.peer,
4280        notifications,
4281    );
4282    Ok(())
4283}
4284
4285/// Mark a buffer version as synced
4286async fn acknowledge_buffer_version(
4287    request: proto::AckBufferOperation,
4288    session: UserSession,
4289) -> Result<()> {
4290    let buffer_id = BufferId::from_proto(request.buffer_id);
4291    session
4292        .db()
4293        .await
4294        .observe_buffer_version(
4295            buffer_id,
4296            session.user_id(),
4297            request.epoch as i32,
4298            &request.version,
4299        )
4300        .await?;
4301    Ok(())
4302}
4303
4304struct CompleteWithLanguageModelRateLimit;
4305
4306impl RateLimit for CompleteWithLanguageModelRateLimit {
4307    fn capacity() -> usize {
4308        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4309            .ok()
4310            .and_then(|v| v.parse().ok())
4311            .unwrap_or(120) // Picked arbitrarily
4312    }
4313
4314    fn refill_duration() -> chrono::Duration {
4315        chrono::Duration::hours(1)
4316    }
4317
4318    fn db_name() -> &'static str {
4319        "complete-with-language-model"
4320    }
4321}
4322
4323async fn complete_with_language_model(
4324    request: proto::CompleteWithLanguageModel,
4325    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4326    session: Session,
4327    open_ai_api_key: Option<Arc<str>>,
4328    google_ai_api_key: Option<Arc<str>>,
4329    anthropic_api_key: Option<Arc<str>>,
4330) -> Result<()> {
4331    let Some(session) = session.for_user() else {
4332        return Err(anyhow!("user not found"))?;
4333    };
4334    authorize_access_to_language_models(&session).await?;
4335    session
4336        .rate_limiter
4337        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4338        .await?;
4339
4340    if request.model.starts_with("gpt") {
4341        let api_key =
4342            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4343        complete_with_open_ai(request, response, session, api_key).await?;
4344    } else if request.model.starts_with("gemini") {
4345        let api_key = google_ai_api_key
4346            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4347        complete_with_google_ai(request, response, session, api_key).await?;
4348    } else if request.model.starts_with("claude") {
4349        let api_key = anthropic_api_key
4350            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4351        complete_with_anthropic(request, response, session, api_key).await?;
4352    }
4353
4354    Ok(())
4355}
4356
4357async fn complete_with_open_ai(
4358    request: proto::CompleteWithLanguageModel,
4359    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4360    session: UserSession,
4361    api_key: Arc<str>,
4362) -> Result<()> {
4363    let mut completion_stream = open_ai::stream_completion(
4364        session.http_client.as_ref(),
4365        OPEN_AI_API_URL,
4366        &api_key,
4367        crate::ai::language_model_request_to_open_ai(request)?,
4368        None,
4369    )
4370    .await
4371    .context("open_ai::stream_completion request failed within collab")?;
4372
4373    while let Some(event) = completion_stream.next().await {
4374        let event = event?;
4375        response.send(proto::LanguageModelResponse {
4376            choices: event
4377                .choices
4378                .into_iter()
4379                .map(|choice| proto::LanguageModelChoiceDelta {
4380                    index: choice.index,
4381                    delta: Some(proto::LanguageModelResponseMessage {
4382                        role: choice.delta.role.map(|role| match role {
4383                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4384                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4385                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4386                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4387                        } as i32),
4388                        content: choice.delta.content,
4389                        tool_calls: choice
4390                            .delta
4391                            .tool_calls
4392                            .into_iter()
4393                            .map(|delta| proto::ToolCallDelta {
4394                                index: delta.index as u32,
4395                                id: delta.id,
4396                                variant: match delta.function {
4397                                    Some(function) => {
4398                                        let name = function.name;
4399                                        let arguments = function.arguments;
4400
4401                                        Some(proto::tool_call_delta::Variant::Function(
4402                                            proto::tool_call_delta::FunctionCallDelta {
4403                                                name,
4404                                                arguments,
4405                                            },
4406                                        ))
4407                                    }
4408                                    None => None,
4409                                },
4410                            })
4411                            .collect(),
4412                    }),
4413                    finish_reason: choice.finish_reason,
4414                })
4415                .collect(),
4416        })?;
4417    }
4418
4419    Ok(())
4420}
4421
4422async fn complete_with_google_ai(
4423    request: proto::CompleteWithLanguageModel,
4424    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4425    session: UserSession,
4426    api_key: Arc<str>,
4427) -> Result<()> {
4428    let mut stream = google_ai::stream_generate_content(
4429        session.http_client.clone(),
4430        google_ai::API_URL,
4431        api_key.as_ref(),
4432        crate::ai::language_model_request_to_google_ai(request)?,
4433    )
4434    .await
4435    .context("google_ai::stream_generate_content request failed")?;
4436
4437    while let Some(event) = stream.next().await {
4438        let event = event?;
4439        response.send(proto::LanguageModelResponse {
4440            choices: event
4441                .candidates
4442                .unwrap_or_default()
4443                .into_iter()
4444                .map(|candidate| proto::LanguageModelChoiceDelta {
4445                    index: candidate.index as u32,
4446                    delta: Some(proto::LanguageModelResponseMessage {
4447                        role: Some(match candidate.content.role {
4448                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4449                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4450                        } as i32),
4451                        content: Some(
4452                            candidate
4453                                .content
4454                                .parts
4455                                .into_iter()
4456                                .filter_map(|part| match part {
4457                                    google_ai::Part::TextPart(part) => Some(part.text),
4458                                    google_ai::Part::InlineDataPart(_) => None,
4459                                })
4460                                .collect(),
4461                        ),
4462                        // Tool calls are not supported for Google
4463                        tool_calls: Vec::new(),
4464                    }),
4465                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4466                })
4467                .collect(),
4468        })?;
4469    }
4470
4471    Ok(())
4472}
4473
4474async fn complete_with_anthropic(
4475    request: proto::CompleteWithLanguageModel,
4476    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4477    session: UserSession,
4478    api_key: Arc<str>,
4479) -> Result<()> {
4480    let model = anthropic::Model::from_id(&request.model)?;
4481
4482    let mut system_message = String::new();
4483    let messages = request
4484        .messages
4485        .into_iter()
4486        .filter_map(|message| {
4487            match message.role() {
4488                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4489                    role: anthropic::Role::User,
4490                    content: message.content,
4491                }),
4492                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4493                    role: anthropic::Role::Assistant,
4494                    content: message.content,
4495                }),
4496                // Anthropic's API breaks system instructions out as a separate field rather
4497                // than having a system message role.
4498                LanguageModelRole::LanguageModelSystem => {
4499                    if !system_message.is_empty() {
4500                        system_message.push_str("\n\n");
4501                    }
4502                    system_message.push_str(&message.content);
4503
4504                    None
4505                }
4506                // We don't yet support tool calls for Anthropic
4507                LanguageModelRole::LanguageModelTool => None,
4508            }
4509        })
4510        .collect();
4511
4512    let mut stream = anthropic::stream_completion(
4513        session.http_client.as_ref(),
4514        anthropic::ANTHROPIC_API_URL,
4515        &api_key,
4516        anthropic::Request {
4517            model,
4518            messages,
4519            stream: true,
4520            system: system_message,
4521            max_tokens: 4092,
4522        },
4523        None,
4524    )
4525    .await?;
4526
4527    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4528
4529    while let Some(event) = stream.next().await {
4530        let event = event?;
4531
4532        match event {
4533            anthropic::ResponseEvent::MessageStart { message } => {
4534                if let Some(role) = message.role {
4535                    if role == "assistant" {
4536                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4537                    } else if role == "user" {
4538                        current_role = proto::LanguageModelRole::LanguageModelUser;
4539                    }
4540                }
4541            }
4542            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4543                match content_block {
4544                    anthropic::ContentBlock::Text { text } => {
4545                        if !text.is_empty() {
4546                            response.send(proto::LanguageModelResponse {
4547                                choices: vec![proto::LanguageModelChoiceDelta {
4548                                    index: 0,
4549                                    delta: Some(proto::LanguageModelResponseMessage {
4550                                        role: Some(current_role as i32),
4551                                        content: Some(text),
4552                                        tool_calls: Vec::new(),
4553                                    }),
4554                                    finish_reason: None,
4555                                }],
4556                            })?;
4557                        }
4558                    }
4559                }
4560            }
4561            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4562                anthropic::TextDelta::TextDelta { text } => {
4563                    response.send(proto::LanguageModelResponse {
4564                        choices: vec![proto::LanguageModelChoiceDelta {
4565                            index: 0,
4566                            delta: Some(proto::LanguageModelResponseMessage {
4567                                role: Some(current_role as i32),
4568                                content: Some(text),
4569                                tool_calls: Vec::new(),
4570                            }),
4571                            finish_reason: None,
4572                        }],
4573                    })?;
4574                }
4575            },
4576            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4577                if let Some(stop_reason) = delta.stop_reason {
4578                    response.send(proto::LanguageModelResponse {
4579                        choices: vec![proto::LanguageModelChoiceDelta {
4580                            index: 0,
4581                            delta: None,
4582                            finish_reason: Some(stop_reason),
4583                        }],
4584                    })?;
4585                }
4586            }
4587            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4588            anthropic::ResponseEvent::MessageStop {} => {}
4589            anthropic::ResponseEvent::Ping {} => {}
4590        }
4591    }
4592
4593    Ok(())
4594}
4595
4596struct CountTokensWithLanguageModelRateLimit;
4597
4598impl RateLimit for CountTokensWithLanguageModelRateLimit {
4599    fn capacity() -> usize {
4600        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4601            .ok()
4602            .and_then(|v| v.parse().ok())
4603            .unwrap_or(600) // Picked arbitrarily
4604    }
4605
4606    fn refill_duration() -> chrono::Duration {
4607        chrono::Duration::hours(1)
4608    }
4609
4610    fn db_name() -> &'static str {
4611        "count-tokens-with-language-model"
4612    }
4613}
4614
4615async fn count_tokens_with_language_model(
4616    request: proto::CountTokensWithLanguageModel,
4617    response: Response<proto::CountTokensWithLanguageModel>,
4618    session: UserSession,
4619    google_ai_api_key: Option<Arc<str>>,
4620) -> Result<()> {
4621    authorize_access_to_language_models(&session).await?;
4622
4623    if !request.model.starts_with("gemini") {
4624        return Err(anyhow!(
4625            "counting tokens for model: {:?} is not supported",
4626            request.model
4627        ))?;
4628    }
4629
4630    session
4631        .rate_limiter
4632        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4633        .await?;
4634
4635    let api_key = google_ai_api_key
4636        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4637    let tokens_response = google_ai::count_tokens(
4638        session.http_client.as_ref(),
4639        google_ai::API_URL,
4640        &api_key,
4641        crate::ai::count_tokens_request_to_google_ai(request)?,
4642    )
4643    .await?;
4644    response.send(proto::CountTokensResponse {
4645        token_count: tokens_response.total_tokens as u32,
4646    })?;
4647    Ok(())
4648}
4649
4650struct ComputeEmbeddingsRateLimit;
4651
4652impl RateLimit for ComputeEmbeddingsRateLimit {
4653    fn capacity() -> usize {
4654        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4655            .ok()
4656            .and_then(|v| v.parse().ok())
4657            .unwrap_or(5000) // Picked arbitrarily
4658    }
4659
4660    fn refill_duration() -> chrono::Duration {
4661        chrono::Duration::hours(1)
4662    }
4663
4664    fn db_name() -> &'static str {
4665        "compute-embeddings"
4666    }
4667}
4668
4669async fn compute_embeddings(
4670    request: proto::ComputeEmbeddings,
4671    response: Response<proto::ComputeEmbeddings>,
4672    session: UserSession,
4673    api_key: Option<Arc<str>>,
4674) -> Result<()> {
4675    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4676    authorize_access_to_language_models(&session).await?;
4677
4678    session
4679        .rate_limiter
4680        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4681        .await?;
4682
4683    let embeddings = match request.model.as_str() {
4684        "openai/text-embedding-3-small" => {
4685            open_ai::embed(
4686                session.http_client.as_ref(),
4687                OPEN_AI_API_URL,
4688                &api_key,
4689                OpenAiEmbeddingModel::TextEmbedding3Small,
4690                request.texts.iter().map(|text| text.as_str()),
4691            )
4692            .await?
4693        }
4694        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4695    };
4696
4697    let embeddings = request
4698        .texts
4699        .iter()
4700        .map(|text| {
4701            let mut hasher = sha2::Sha256::new();
4702            hasher.update(text.as_bytes());
4703            let result = hasher.finalize();
4704            result.to_vec()
4705        })
4706        .zip(
4707            embeddings
4708                .data
4709                .into_iter()
4710                .map(|embedding| embedding.embedding),
4711        )
4712        .collect::<HashMap<_, _>>();
4713
4714    let db = session.db().await;
4715    db.save_embeddings(&request.model, &embeddings)
4716        .await
4717        .context("failed to save embeddings")
4718        .trace_err();
4719
4720    response.send(proto::ComputeEmbeddingsResponse {
4721        embeddings: embeddings
4722            .into_iter()
4723            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4724            .collect(),
4725    })?;
4726    Ok(())
4727}
4728
4729async fn get_cached_embeddings(
4730    request: proto::GetCachedEmbeddings,
4731    response: Response<proto::GetCachedEmbeddings>,
4732    session: UserSession,
4733) -> Result<()> {
4734    authorize_access_to_language_models(&session).await?;
4735
4736    let db = session.db().await;
4737    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4738
4739    response.send(proto::GetCachedEmbeddingsResponse {
4740        embeddings: embeddings
4741            .into_iter()
4742            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4743            .collect(),
4744    })?;
4745    Ok(())
4746}
4747
4748async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4749    let db = session.db().await;
4750    let flags = db.get_user_flags(session.user_id()).await?;
4751    if flags.iter().any(|flag| flag == "language-models") {
4752        Ok(())
4753    } else {
4754        Err(anyhow!("permission denied"))?
4755    }
4756}
4757
4758/// Get a Supermaven API key for the user
4759async fn get_supermaven_api_key(
4760    _request: proto::GetSupermavenApiKey,
4761    response: Response<proto::GetSupermavenApiKey>,
4762    session: UserSession,
4763) -> Result<()> {
4764    let user_id: String = session.user_id().to_string();
4765    if !session.is_staff() {
4766        return Err(anyhow!("supermaven not enabled for this account"))?;
4767    }
4768
4769    let email = session
4770        .email()
4771        .ok_or_else(|| anyhow!("user must have an email"))?;
4772
4773    let supermaven_admin_api = session
4774        .supermaven_client
4775        .as_ref()
4776        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4777
4778    let result = supermaven_admin_api
4779        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4780        .await?;
4781
4782    response.send(proto::GetSupermavenApiKeyResponse {
4783        api_key: result.api_key,
4784    })?;
4785
4786    Ok(())
4787}
4788
4789/// Start receiving chat updates for a channel
4790async fn join_channel_chat(
4791    request: proto::JoinChannelChat,
4792    response: Response<proto::JoinChannelChat>,
4793    session: UserSession,
4794) -> Result<()> {
4795    let channel_id = ChannelId::from_proto(request.channel_id);
4796
4797    let db = session.db().await;
4798    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4799        .await?;
4800    let messages = db
4801        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4802        .await?;
4803    response.send(proto::JoinChannelChatResponse {
4804        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4805        messages,
4806    })?;
4807    Ok(())
4808}
4809
4810/// Stop receiving chat updates for a channel
4811async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4812    let channel_id = ChannelId::from_proto(request.channel_id);
4813    session
4814        .db()
4815        .await
4816        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4817        .await?;
4818    Ok(())
4819}
4820
4821/// Retrieve the chat history for a channel
4822async fn get_channel_messages(
4823    request: proto::GetChannelMessages,
4824    response: Response<proto::GetChannelMessages>,
4825    session: UserSession,
4826) -> Result<()> {
4827    let channel_id = ChannelId::from_proto(request.channel_id);
4828    let messages = session
4829        .db()
4830        .await
4831        .get_channel_messages(
4832            channel_id,
4833            session.user_id(),
4834            MESSAGE_COUNT_PER_PAGE,
4835            Some(MessageId::from_proto(request.before_message_id)),
4836        )
4837        .await?;
4838    response.send(proto::GetChannelMessagesResponse {
4839        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4840        messages,
4841    })?;
4842    Ok(())
4843}
4844
4845/// Retrieve specific chat messages
4846async fn get_channel_messages_by_id(
4847    request: proto::GetChannelMessagesById,
4848    response: Response<proto::GetChannelMessagesById>,
4849    session: UserSession,
4850) -> Result<()> {
4851    let message_ids = request
4852        .message_ids
4853        .iter()
4854        .map(|id| MessageId::from_proto(*id))
4855        .collect::<Vec<_>>();
4856    let messages = session
4857        .db()
4858        .await
4859        .get_channel_messages_by_id(session.user_id(), &message_ids)
4860        .await?;
4861    response.send(proto::GetChannelMessagesResponse {
4862        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4863        messages,
4864    })?;
4865    Ok(())
4866}
4867
4868/// Retrieve the current users notifications
4869async fn get_notifications(
4870    request: proto::GetNotifications,
4871    response: Response<proto::GetNotifications>,
4872    session: UserSession,
4873) -> Result<()> {
4874    let notifications = session
4875        .db()
4876        .await
4877        .get_notifications(
4878            session.user_id(),
4879            NOTIFICATION_COUNT_PER_PAGE,
4880            request
4881                .before_id
4882                .map(|id| db::NotificationId::from_proto(id)),
4883        )
4884        .await?;
4885    response.send(proto::GetNotificationsResponse {
4886        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4887        notifications,
4888    })?;
4889    Ok(())
4890}
4891
4892/// Mark notifications as read
4893async fn mark_notification_as_read(
4894    request: proto::MarkNotificationRead,
4895    response: Response<proto::MarkNotificationRead>,
4896    session: UserSession,
4897) -> Result<()> {
4898    let database = &session.db().await;
4899    let notifications = database
4900        .mark_notification_as_read_by_id(
4901            session.user_id(),
4902            NotificationId::from_proto(request.notification_id),
4903        )
4904        .await?;
4905    send_notifications(
4906        &*session.connection_pool().await,
4907        &session.peer,
4908        notifications,
4909    );
4910    response.send(proto::Ack {})?;
4911    Ok(())
4912}
4913
4914/// Get the current users information
4915async fn get_private_user_info(
4916    _request: proto::GetPrivateUserInfo,
4917    response: Response<proto::GetPrivateUserInfo>,
4918    session: UserSession,
4919) -> Result<()> {
4920    let db = session.db().await;
4921
4922    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4923    let user = db
4924        .get_user_by_id(session.user_id())
4925        .await?
4926        .ok_or_else(|| anyhow!("user not found"))?;
4927    let flags = db.get_user_flags(session.user_id()).await?;
4928
4929    response.send(proto::GetPrivateUserInfoResponse {
4930        metrics_id,
4931        staff: user.admin,
4932        flags,
4933    })?;
4934    Ok(())
4935}
4936
4937fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4938    match message {
4939        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4940        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4941        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4942        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4943        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4944            code: frame.code.into(),
4945            reason: frame.reason,
4946        })),
4947    }
4948}
4949
4950fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4951    match message {
4952        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4953        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4954        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4955        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4956        AxumMessage::Close(frame) => {
4957            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4958                code: frame.code.into(),
4959                reason: frame.reason,
4960            }))
4961        }
4962    }
4963}
4964
4965fn notify_membership_updated(
4966    connection_pool: &mut ConnectionPool,
4967    result: MembershipUpdated,
4968    user_id: UserId,
4969    peer: &Peer,
4970) {
4971    for membership in &result.new_channels.channel_memberships {
4972        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4973    }
4974    for channel_id in &result.removed_channels {
4975        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4976    }
4977
4978    let user_channels_update = proto::UpdateUserChannels {
4979        channel_memberships: result
4980            .new_channels
4981            .channel_memberships
4982            .iter()
4983            .map(|cm| proto::ChannelMembership {
4984                channel_id: cm.channel_id.to_proto(),
4985                role: cm.role.into(),
4986            })
4987            .collect(),
4988        ..Default::default()
4989    };
4990
4991    let mut update = build_channels_update(result.new_channels, vec![]);
4992    update.delete_channels = result
4993        .removed_channels
4994        .into_iter()
4995        .map(|id| id.to_proto())
4996        .collect();
4997    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4998
4999    for connection_id in connection_pool.user_connection_ids(user_id) {
5000        peer.send(connection_id, user_channels_update.clone())
5001            .trace_err();
5002        peer.send(connection_id, update.clone()).trace_err();
5003    }
5004}
5005
5006fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5007    proto::UpdateUserChannels {
5008        channel_memberships: channels
5009            .channel_memberships
5010            .iter()
5011            .map(|m| proto::ChannelMembership {
5012                channel_id: m.channel_id.to_proto(),
5013                role: m.role.into(),
5014            })
5015            .collect(),
5016        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5017        observed_channel_message_id: channels.observed_channel_messages.clone(),
5018    }
5019}
5020
5021fn build_channels_update(
5022    channels: ChannelsForUser,
5023    channel_invites: Vec<db::Channel>,
5024) -> proto::UpdateChannels {
5025    let mut update = proto::UpdateChannels::default();
5026
5027    for channel in channels.channels {
5028        update.channels.push(channel.to_proto());
5029    }
5030
5031    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5032    update.latest_channel_message_ids = channels.latest_channel_messages;
5033
5034    for (channel_id, participants) in channels.channel_participants {
5035        update
5036            .channel_participants
5037            .push(proto::ChannelParticipants {
5038                channel_id: channel_id.to_proto(),
5039                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5040            });
5041    }
5042
5043    for channel in channel_invites {
5044        update.channel_invitations.push(channel.to_proto());
5045    }
5046
5047    update.hosted_projects = channels.hosted_projects;
5048    update
5049}
5050
5051fn build_initial_contacts_update(
5052    contacts: Vec<db::Contact>,
5053    pool: &ConnectionPool,
5054) -> proto::UpdateContacts {
5055    let mut update = proto::UpdateContacts::default();
5056
5057    for contact in contacts {
5058        match contact {
5059            db::Contact::Accepted { user_id, busy } => {
5060                update.contacts.push(contact_for_user(user_id, busy, &pool));
5061            }
5062            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5063            db::Contact::Incoming { user_id } => {
5064                update
5065                    .incoming_requests
5066                    .push(proto::IncomingContactRequest {
5067                        requester_id: user_id.to_proto(),
5068                    })
5069            }
5070        }
5071    }
5072
5073    update
5074}
5075
5076fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5077    proto::Contact {
5078        user_id: user_id.to_proto(),
5079        online: pool.is_user_online(user_id),
5080        busy,
5081    }
5082}
5083
5084fn room_updated(room: &proto::Room, peer: &Peer) {
5085    broadcast(
5086        None,
5087        room.participants
5088            .iter()
5089            .filter_map(|participant| Some(participant.peer_id?.into())),
5090        |peer_id| {
5091            peer.send(
5092                peer_id,
5093                proto::RoomUpdated {
5094                    room: Some(room.clone()),
5095                },
5096            )
5097        },
5098    );
5099}
5100
5101fn channel_updated(
5102    channel: &db::channel::Model,
5103    room: &proto::Room,
5104    peer: &Peer,
5105    pool: &ConnectionPool,
5106) {
5107    let participants = room
5108        .participants
5109        .iter()
5110        .map(|p| p.user_id)
5111        .collect::<Vec<_>>();
5112
5113    broadcast(
5114        None,
5115        pool.channel_connection_ids(channel.root_id())
5116            .filter_map(|(channel_id, role)| {
5117                role.can_see_channel(channel.visibility).then(|| channel_id)
5118            }),
5119        |peer_id| {
5120            peer.send(
5121                peer_id,
5122                proto::UpdateChannels {
5123                    channel_participants: vec![proto::ChannelParticipants {
5124                        channel_id: channel.id.to_proto(),
5125                        participant_user_ids: participants.clone(),
5126                    }],
5127                    ..Default::default()
5128                },
5129            )
5130        },
5131    );
5132}
5133
5134async fn send_dev_server_projects_update(
5135    user_id: UserId,
5136    mut status: proto::DevServerProjectsUpdate,
5137    session: &Session,
5138) {
5139    let pool = session.connection_pool().await;
5140    for dev_server in &mut status.dev_servers {
5141        dev_server.status =
5142            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5143    }
5144    let connections = pool.user_connection_ids(user_id);
5145    for connection_id in connections {
5146        session.peer.send(connection_id, status.clone()).trace_err();
5147    }
5148}
5149
5150async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5151    let db = session.db().await;
5152
5153    let contacts = db.get_contacts(user_id).await?;
5154    let busy = db.is_user_busy(user_id).await?;
5155
5156    let pool = session.connection_pool().await;
5157    let updated_contact = contact_for_user(user_id, busy, &pool);
5158    for contact in contacts {
5159        if let db::Contact::Accepted {
5160            user_id: contact_user_id,
5161            ..
5162        } = contact
5163        {
5164            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5165                session
5166                    .peer
5167                    .send(
5168                        contact_conn_id,
5169                        proto::UpdateContacts {
5170                            contacts: vec![updated_contact.clone()],
5171                            remove_contacts: Default::default(),
5172                            incoming_requests: Default::default(),
5173                            remove_incoming_requests: Default::default(),
5174                            outgoing_requests: Default::default(),
5175                            remove_outgoing_requests: Default::default(),
5176                        },
5177                    )
5178                    .trace_err();
5179            }
5180        }
5181    }
5182    Ok(())
5183}
5184
5185async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5186    log::info!("lost dev server connection, unsharing projects");
5187    let project_ids = session
5188        .db()
5189        .await
5190        .get_stale_dev_server_projects(session.connection_id)
5191        .await?;
5192
5193    for project_id in project_ids {
5194        // not unshare re-checks the connection ids match, so we get away with no transaction
5195        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5196    }
5197
5198    let user_id = session.dev_server().user_id;
5199    let update = session
5200        .db()
5201        .await
5202        .dev_server_projects_update(user_id)
5203        .await?;
5204
5205    send_dev_server_projects_update(user_id, update, session).await;
5206
5207    Ok(())
5208}
5209
5210async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5211    let mut contacts_to_update = HashSet::default();
5212
5213    let room_id;
5214    let canceled_calls_to_user_ids;
5215    let live_kit_room;
5216    let delete_live_kit_room;
5217    let room;
5218    let channel;
5219
5220    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5221        contacts_to_update.insert(session.user_id());
5222
5223        for project in left_room.left_projects.values() {
5224            project_left(project, session);
5225        }
5226
5227        room_id = RoomId::from_proto(left_room.room.id);
5228        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5229        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5230        delete_live_kit_room = left_room.deleted;
5231        room = mem::take(&mut left_room.room);
5232        channel = mem::take(&mut left_room.channel);
5233
5234        room_updated(&room, &session.peer);
5235    } else {
5236        return Ok(());
5237    }
5238
5239    if let Some(channel) = channel {
5240        channel_updated(
5241            &channel,
5242            &room,
5243            &session.peer,
5244            &*session.connection_pool().await,
5245        );
5246    }
5247
5248    {
5249        let pool = session.connection_pool().await;
5250        for canceled_user_id in canceled_calls_to_user_ids {
5251            for connection_id in pool.user_connection_ids(canceled_user_id) {
5252                session
5253                    .peer
5254                    .send(
5255                        connection_id,
5256                        proto::CallCanceled {
5257                            room_id: room_id.to_proto(),
5258                        },
5259                    )
5260                    .trace_err();
5261            }
5262            contacts_to_update.insert(canceled_user_id);
5263        }
5264    }
5265
5266    for contact_user_id in contacts_to_update {
5267        update_user_contacts(contact_user_id, &session).await?;
5268    }
5269
5270    if let Some(live_kit) = session.live_kit_client.as_ref() {
5271        live_kit
5272            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5273            .await
5274            .trace_err();
5275
5276        if delete_live_kit_room {
5277            live_kit.delete_room(live_kit_room).await.trace_err();
5278        }
5279    }
5280
5281    Ok(())
5282}
5283
5284async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5285    let left_channel_buffers = session
5286        .db()
5287        .await
5288        .leave_channel_buffers(session.connection_id)
5289        .await?;
5290
5291    for left_buffer in left_channel_buffers {
5292        channel_buffer_updated(
5293            session.connection_id,
5294            left_buffer.connections,
5295            &proto::UpdateChannelBufferCollaborators {
5296                channel_id: left_buffer.channel_id.to_proto(),
5297                collaborators: left_buffer.collaborators,
5298            },
5299            &session.peer,
5300        );
5301    }
5302
5303    Ok(())
5304}
5305
5306fn project_left(project: &db::LeftProject, session: &UserSession) {
5307    for connection_id in &project.connection_ids {
5308        if project.should_unshare {
5309            session
5310                .peer
5311                .send(
5312                    *connection_id,
5313                    proto::UnshareProject {
5314                        project_id: project.id.to_proto(),
5315                    },
5316                )
5317                .trace_err();
5318        } else {
5319            session
5320                .peer
5321                .send(
5322                    *connection_id,
5323                    proto::RemoveProjectCollaborator {
5324                        project_id: project.id.to_proto(),
5325                        peer_id: Some(session.connection_id.into()),
5326                    },
5327                )
5328                .trace_err();
5329        }
5330    }
5331}
5332
5333pub trait ResultExt {
5334    type Ok;
5335
5336    fn trace_err(self) -> Option<Self::Ok>;
5337}
5338
5339impl<T, E> ResultExt for Result<T, E>
5340where
5341    E: std::fmt::Debug,
5342{
5343    type Ok = T;
5344
5345    #[track_caller]
5346    fn trace_err(self) -> Option<T> {
5347        match self {
5348            Ok(value) => Some(value),
5349            Err(error) => {
5350                tracing::error!("{:?}", error);
5351                None
5352            }
5353        }
5354    }
5355}