rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  50        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(delete_dev_server_project))
 435            .add_request_handler(user_handler(create_dev_server))
 436            .add_request_handler(user_handler(regenerate_dev_server_token))
 437            .add_request_handler(user_handler(rename_dev_server))
 438            .add_request_handler(user_handler(delete_dev_server))
 439            .add_request_handler(dev_server_handler(share_dev_server_project))
 440            .add_request_handler(dev_server_handler(shutdown_dev_server))
 441            .add_request_handler(dev_server_handler(reconnect_dev_server))
 442            .add_message_handler(user_message_handler(leave_project))
 443            .add_request_handler(update_project)
 444            .add_request_handler(update_worktree)
 445            .add_message_handler(start_language_server)
 446            .add_message_handler(update_language_server)
 447            .add_message_handler(update_diagnostic_summary)
 448            .add_message_handler(update_worktree_settings)
 449            .add_request_handler(user_handler(
 450                forward_read_only_project_request::<proto::GetHover>,
 451            ))
 452            .add_request_handler(user_handler(
 453                forward_read_only_project_request::<proto::GetDefinition>,
 454            ))
 455            .add_request_handler(user_handler(
 456                forward_read_only_project_request::<proto::GetTypeDefinition>,
 457            ))
 458            .add_request_handler(user_handler(
 459                forward_read_only_project_request::<proto::GetReferences>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_read_only_project_request::<proto::SearchProject>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::GetProjectSymbols>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::OpenBufferById>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_read_only_project_request::<proto::InlayHints>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_read_only_project_request::<proto::OpenBufferByPath>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_mutating_project_request::<proto::GetCompletions>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_mutating_project_request::<proto::GetCodeActions>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_mutating_project_request::<proto::ApplyCodeAction>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::PrepareRename>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::PerformRename>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::ReloadBuffers>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::FormatBuffers>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::CreateProjectEntry>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::RenameProjectEntry>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::CopyProjectEntry>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::OnTypeFormatting>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::BlameBuffer>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_mutating_project_request::<proto::MultiLspQuery>,
 541            ))
 542            .add_message_handler(create_buffer_for_peer)
 543            .add_request_handler(update_buffer)
 544            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 545            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 546            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 547            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 548            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 549            .add_request_handler(get_users)
 550            .add_request_handler(user_handler(fuzzy_search_users))
 551            .add_request_handler(user_handler(request_contact))
 552            .add_request_handler(user_handler(remove_contact))
 553            .add_request_handler(user_handler(respond_to_contact_request))
 554            .add_request_handler(user_handler(create_channel))
 555            .add_request_handler(user_handler(delete_channel))
 556            .add_request_handler(user_handler(invite_channel_member))
 557            .add_request_handler(user_handler(remove_channel_member))
 558            .add_request_handler(user_handler(set_channel_member_role))
 559            .add_request_handler(user_handler(set_channel_visibility))
 560            .add_request_handler(user_handler(rename_channel))
 561            .add_request_handler(user_handler(join_channel_buffer))
 562            .add_request_handler(user_handler(leave_channel_buffer))
 563            .add_message_handler(user_message_handler(update_channel_buffer))
 564            .add_request_handler(user_handler(rejoin_channel_buffers))
 565            .add_request_handler(user_handler(get_channel_members))
 566            .add_request_handler(user_handler(respond_to_channel_invite))
 567            .add_request_handler(user_handler(join_channel))
 568            .add_request_handler(user_handler(join_channel_chat))
 569            .add_message_handler(user_message_handler(leave_channel_chat))
 570            .add_request_handler(user_handler(send_channel_message))
 571            .add_request_handler(user_handler(remove_channel_message))
 572            .add_request_handler(user_handler(update_channel_message))
 573            .add_request_handler(user_handler(get_channel_messages))
 574            .add_request_handler(user_handler(get_channel_messages_by_id))
 575            .add_request_handler(user_handler(get_notifications))
 576            .add_request_handler(user_handler(mark_notification_as_read))
 577            .add_request_handler(user_handler(move_channel))
 578            .add_request_handler(user_handler(follow))
 579            .add_message_handler(user_message_handler(unfollow))
 580            .add_message_handler(user_message_handler(update_followers))
 581            .add_request_handler(user_handler(get_private_user_info))
 582            .add_message_handler(user_message_handler(acknowledge_channel_message))
 583            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 584            .add_request_handler(user_handler(get_supermaven_api_key))
 585            .add_streaming_request_handler({
 586                let app_state = app_state.clone();
 587                move |request, response, session| {
 588                    complete_with_language_model(
 589                        request,
 590                        response,
 591                        session,
 592                        app_state.config.openai_api_key.clone(),
 593                        app_state.config.google_ai_api_key.clone(),
 594                        app_state.config.anthropic_api_key.clone(),
 595                    )
 596                }
 597            })
 598            .add_request_handler({
 599                let app_state = app_state.clone();
 600                user_handler(move |request, response, session| {
 601                    count_tokens_with_language_model(
 602                        request,
 603                        response,
 604                        session,
 605                        app_state.config.google_ai_api_key.clone(),
 606                    )
 607                })
 608            })
 609            .add_request_handler({
 610                user_handler(move |request, response, session| {
 611                    get_cached_embeddings(request, response, session)
 612                })
 613            })
 614            .add_request_handler({
 615                let app_state = app_state.clone();
 616                user_handler(move |request, response, session| {
 617                    compute_embeddings(
 618                        request,
 619                        response,
 620                        session,
 621                        app_state.config.openai_api_key.clone(),
 622                    )
 623                })
 624            });
 625
 626        Arc::new(server)
 627    }
 628
 629    pub async fn start(&self) -> Result<()> {
 630        let server_id = *self.id.lock();
 631        let app_state = self.app_state.clone();
 632        let peer = self.peer.clone();
 633        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 634        let pool = self.connection_pool.clone();
 635        let live_kit_client = self.app_state.live_kit_client.clone();
 636
 637        let span = info_span!("start server");
 638        self.app_state.executor.spawn_detached(
 639            async move {
 640                tracing::info!("waiting for cleanup timeout");
 641                timeout.await;
 642                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 643                if let Some((room_ids, channel_ids)) = app_state
 644                    .db
 645                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 646                    .await
 647                    .trace_err()
 648                {
 649                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 650                    tracing::info!(
 651                        stale_channel_buffer_count = channel_ids.len(),
 652                        "retrieved stale channel buffers"
 653                    );
 654
 655                    for channel_id in channel_ids {
 656                        if let Some(refreshed_channel_buffer) = app_state
 657                            .db
 658                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 659                            .await
 660                            .trace_err()
 661                        {
 662                            for connection_id in refreshed_channel_buffer.connection_ids {
 663                                peer.send(
 664                                    connection_id,
 665                                    proto::UpdateChannelBufferCollaborators {
 666                                        channel_id: channel_id.to_proto(),
 667                                        collaborators: refreshed_channel_buffer
 668                                            .collaborators
 669                                            .clone(),
 670                                    },
 671                                )
 672                                .trace_err();
 673                            }
 674                        }
 675                    }
 676
 677                    for room_id in room_ids {
 678                        let mut contacts_to_update = HashSet::default();
 679                        let mut canceled_calls_to_user_ids = Vec::new();
 680                        let mut live_kit_room = String::new();
 681                        let mut delete_live_kit_room = false;
 682
 683                        if let Some(mut refreshed_room) = app_state
 684                            .db
 685                            .clear_stale_room_participants(room_id, server_id)
 686                            .await
 687                            .trace_err()
 688                        {
 689                            tracing::info!(
 690                                room_id = room_id.0,
 691                                new_participant_count = refreshed_room.room.participants.len(),
 692                                "refreshed room"
 693                            );
 694                            room_updated(&refreshed_room.room, &peer);
 695                            if let Some(channel) = refreshed_room.channel.as_ref() {
 696                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 697                            }
 698                            contacts_to_update
 699                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 700                            contacts_to_update
 701                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 702                            canceled_calls_to_user_ids =
 703                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 704                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 705                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 706                        }
 707
 708                        {
 709                            let pool = pool.lock();
 710                            for canceled_user_id in canceled_calls_to_user_ids {
 711                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 712                                    peer.send(
 713                                        connection_id,
 714                                        proto::CallCanceled {
 715                                            room_id: room_id.to_proto(),
 716                                        },
 717                                    )
 718                                    .trace_err();
 719                                }
 720                            }
 721                        }
 722
 723                        for user_id in contacts_to_update {
 724                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 725                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 726                            if let Some((busy, contacts)) = busy.zip(contacts) {
 727                                let pool = pool.lock();
 728                                let updated_contact = contact_for_user(user_id, busy, &pool);
 729                                for contact in contacts {
 730                                    if let db::Contact::Accepted {
 731                                        user_id: contact_user_id,
 732                                        ..
 733                                    } = contact
 734                                    {
 735                                        for contact_conn_id in
 736                                            pool.user_connection_ids(contact_user_id)
 737                                        {
 738                                            peer.send(
 739                                                contact_conn_id,
 740                                                proto::UpdateContacts {
 741                                                    contacts: vec![updated_contact.clone()],
 742                                                    remove_contacts: Default::default(),
 743                                                    incoming_requests: Default::default(),
 744                                                    remove_incoming_requests: Default::default(),
 745                                                    outgoing_requests: Default::default(),
 746                                                    remove_outgoing_requests: Default::default(),
 747                                                },
 748                                            )
 749                                            .trace_err();
 750                                        }
 751                                    }
 752                                }
 753                            }
 754                        }
 755
 756                        if let Some(live_kit) = live_kit_client.as_ref() {
 757                            if delete_live_kit_room {
 758                                live_kit.delete_room(live_kit_room).await.trace_err();
 759                            }
 760                        }
 761                    }
 762                }
 763
 764                app_state
 765                    .db
 766                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 767                    .await
 768                    .trace_err();
 769            }
 770            .instrument(span),
 771        );
 772        Ok(())
 773    }
 774
 775    pub fn teardown(&self) {
 776        self.peer.teardown();
 777        self.connection_pool.lock().reset();
 778        let _ = self.teardown.send(true);
 779    }
 780
 781    #[cfg(test)]
 782    pub fn reset(&self, id: ServerId) {
 783        self.teardown();
 784        *self.id.lock() = id;
 785        self.peer.reset(id.0 as u32);
 786        let _ = self.teardown.send(false);
 787    }
 788
 789    #[cfg(test)]
 790    pub fn id(&self) -> ServerId {
 791        *self.id.lock()
 792    }
 793
 794    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 795    where
 796        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 797        Fut: 'static + Send + Future<Output = Result<()>>,
 798        M: EnvelopedMessage,
 799    {
 800        let prev_handler = self.handlers.insert(
 801            TypeId::of::<M>(),
 802            Box::new(move |envelope, session| {
 803                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 804                let received_at = envelope.received_at;
 805                tracing::info!("message received");
 806                let start_time = Instant::now();
 807                let future = (handler)(*envelope, session);
 808                async move {
 809                    let result = future.await;
 810                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 811                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 812                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 813                    let payload_type = M::NAME;
 814
 815                    match result {
 816                        Err(error) => {
 817                            tracing::error!(
 818                                ?error,
 819                                total_duration_ms,
 820                                processing_duration_ms,
 821                                queue_duration_ms,
 822                                payload_type,
 823                                "error handling message"
 824                            )
 825                        }
 826                        Ok(()) => tracing::info!(
 827                            total_duration_ms,
 828                            processing_duration_ms,
 829                            queue_duration_ms,
 830                            "finished handling message"
 831                        ),
 832                    }
 833                }
 834                .boxed()
 835            }),
 836        );
 837        if prev_handler.is_some() {
 838            panic!("registered a handler for the same message twice");
 839        }
 840        self
 841    }
 842
 843    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 844    where
 845        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 846        Fut: 'static + Send + Future<Output = Result<()>>,
 847        M: EnvelopedMessage,
 848    {
 849        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 850        self
 851    }
 852
 853    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 854    where
 855        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 856        Fut: Send + Future<Output = Result<()>>,
 857        M: RequestMessage,
 858    {
 859        let handler = Arc::new(handler);
 860        self.add_handler(move |envelope, session| {
 861            let receipt = envelope.receipt();
 862            let handler = handler.clone();
 863            async move {
 864                let peer = session.peer.clone();
 865                let responded = Arc::new(AtomicBool::default());
 866                let response = Response {
 867                    peer: peer.clone(),
 868                    responded: responded.clone(),
 869                    receipt,
 870                };
 871                match (handler)(envelope.payload, response, session).await {
 872                    Ok(()) => {
 873                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 874                            Ok(())
 875                        } else {
 876                            Err(anyhow!("handler did not send a response"))?
 877                        }
 878                    }
 879                    Err(error) => {
 880                        let proto_err = match &error {
 881                            Error::Internal(err) => err.to_proto(),
 882                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 883                        };
 884                        peer.respond_with_error(receipt, proto_err)?;
 885                        Err(error)
 886                    }
 887                }
 888            }
 889        })
 890    }
 891
 892    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 893    where
 894        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 895        Fut: Send + Future<Output = Result<()>>,
 896        M: RequestMessage,
 897    {
 898        let handler = Arc::new(handler);
 899        self.add_handler(move |envelope, session| {
 900            let receipt = envelope.receipt();
 901            let handler = handler.clone();
 902            async move {
 903                let peer = session.peer.clone();
 904                let response = StreamingResponse {
 905                    peer: peer.clone(),
 906                    receipt,
 907                };
 908                match (handler)(envelope.payload, response, session).await {
 909                    Ok(()) => {
 910                        peer.end_stream(receipt)?;
 911                        Ok(())
 912                    }
 913                    Err(error) => {
 914                        let proto_err = match &error {
 915                            Error::Internal(err) => err.to_proto(),
 916                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 917                        };
 918                        peer.respond_with_error(receipt, proto_err)?;
 919                        Err(error)
 920                    }
 921                }
 922            }
 923        })
 924    }
 925
 926    #[allow(clippy::too_many_arguments)]
 927    pub fn handle_connection(
 928        self: &Arc<Self>,
 929        connection: Connection,
 930        address: String,
 931        principal: Principal,
 932        zed_version: ZedVersion,
 933        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 934        executor: Executor,
 935    ) -> impl Future<Output = ()> {
 936        let this = self.clone();
 937        let span = info_span!("handle connection", %address,
 938            connection_id=field::Empty,
 939            user_id=field::Empty,
 940            login=field::Empty,
 941            impersonator=field::Empty,
 942            dev_server_id=field::Empty
 943        );
 944        principal.update_span(&span);
 945
 946        let mut teardown = self.teardown.subscribe();
 947        async move {
 948            if *teardown.borrow() {
 949                tracing::error!("server is tearing down");
 950                return
 951            }
 952            let (connection_id, handle_io, mut incoming_rx) = this
 953                .peer
 954                .add_connection(connection, {
 955                    let executor = executor.clone();
 956                    move |duration| executor.sleep(duration)
 957                });
 958            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 959            tracing::info!("connection opened");
 960
 961            let http_client = match IsahcHttpClient::new() {
 962                Ok(http_client) => Arc::new(http_client),
 963                Err(error) => {
 964                    tracing::error!(?error, "failed to create HTTP client");
 965                    return;
 966                }
 967            };
 968
 969            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 970                Some(Arc::new(SupermavenAdminApi::new(
 971                    supermaven_admin_api_key.to_string(),
 972                    http_client.clone(),
 973                )))
 974            } else {
 975                None
 976            };
 977
 978            let session = Session {
 979                principal: principal.clone(),
 980                connection_id,
 981                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 982                peer: this.peer.clone(),
 983                connection_pool: this.connection_pool.clone(),
 984                live_kit_client: this.app_state.live_kit_client.clone(),
 985                http_client,
 986                rate_limiter: this.app_state.rate_limiter.clone(),
 987                _executor: executor.clone(),
 988                supermaven_client,
 989            };
 990
 991            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 992                tracing::error!(?error, "failed to send initial client update");
 993                return;
 994            }
 995
 996            let handle_io = handle_io.fuse();
 997            futures::pin_mut!(handle_io);
 998
 999            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1000            // This prevents deadlocks when e.g., client A performs a request to client B and
1001            // client B performs a request to client A. If both clients stop processing further
1002            // messages until their respective request completes, they won't have a chance to
1003            // respond to the other client's request and cause a deadlock.
1004            //
1005            // This arrangement ensures we will attempt to process earlier messages first, but fall
1006            // back to processing messages arrived later in the spirit of making progress.
1007            let mut foreground_message_handlers = FuturesUnordered::new();
1008            let concurrent_handlers = Arc::new(Semaphore::new(256));
1009            loop {
1010                let next_message = async {
1011                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1012                    let message = incoming_rx.next().await;
1013                    (permit, message)
1014                }.fuse();
1015                futures::pin_mut!(next_message);
1016                futures::select_biased! {
1017                    _ = teardown.changed().fuse() => return,
1018                    result = handle_io => {
1019                        if let Err(error) = result {
1020                            tracing::error!(?error, "error handling I/O");
1021                        }
1022                        break;
1023                    }
1024                    _ = foreground_message_handlers.next() => {}
1025                    next_message = next_message => {
1026                        let (permit, message) = next_message;
1027                        if let Some(message) = message {
1028                            let type_name = message.payload_type_name();
1029                            // note: we copy all the fields from the parent span so we can query them in the logs.
1030                            // (https://github.com/tokio-rs/tracing/issues/2670).
1031                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1032                                user_id=field::Empty,
1033                                login=field::Empty,
1034                                impersonator=field::Empty,
1035                                dev_server_id=field::Empty
1036                            );
1037                            principal.update_span(&span);
1038                            let span_enter = span.enter();
1039                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1040                                let is_background = message.is_background();
1041                                let handle_message = (handler)(message, session.clone());
1042                                drop(span_enter);
1043
1044                                let handle_message = async move {
1045                                    handle_message.await;
1046                                    drop(permit);
1047                                }.instrument(span);
1048                                if is_background {
1049                                    executor.spawn_detached(handle_message);
1050                                } else {
1051                                    foreground_message_handlers.push(handle_message);
1052                                }
1053                            } else {
1054                                tracing::error!("no message handler");
1055                            }
1056                        } else {
1057                            tracing::info!("connection closed");
1058                            break;
1059                        }
1060                    }
1061                }
1062            }
1063
1064            drop(foreground_message_handlers);
1065            tracing::info!("signing out");
1066            if let Err(error) = connection_lost(session, teardown, executor).await {
1067                tracing::error!(?error, "error signing out");
1068            }
1069
1070        }.instrument(span)
1071    }
1072
1073    async fn send_initial_client_update(
1074        &self,
1075        connection_id: ConnectionId,
1076        principal: &Principal,
1077        zed_version: ZedVersion,
1078        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1079        session: &Session,
1080    ) -> Result<()> {
1081        self.peer.send(
1082            connection_id,
1083            proto::Hello {
1084                peer_id: Some(connection_id.into()),
1085            },
1086        )?;
1087        tracing::info!("sent hello message");
1088        if let Some(send_connection_id) = send_connection_id.take() {
1089            let _ = send_connection_id.send(connection_id);
1090        }
1091
1092        match principal {
1093            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1094                if !user.connected_once {
1095                    self.peer.send(connection_id, proto::ShowContacts {})?;
1096                    self.app_state
1097                        .db
1098                        .set_user_connected_once(user.id, true)
1099                        .await?;
1100                }
1101
1102                let (contacts, channels_for_user, channel_invites, dev_server_projects) =
1103                    future::try_join4(
1104                        self.app_state.db.get_contacts(user.id),
1105                        self.app_state.db.get_channels_for_user(user.id),
1106                        self.app_state.db.get_channel_invites_for_user(user.id),
1107                        self.app_state.db.dev_server_projects_update(user.id),
1108                    )
1109                    .await?;
1110
1111                {
1112                    let mut pool = self.connection_pool.lock();
1113                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1114                    for membership in &channels_for_user.channel_memberships {
1115                        pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1116                    }
1117                    self.peer.send(
1118                        connection_id,
1119                        build_initial_contacts_update(contacts, &pool),
1120                    )?;
1121                    self.peer.send(
1122                        connection_id,
1123                        build_update_user_channels(&channels_for_user),
1124                    )?;
1125                    self.peer.send(
1126                        connection_id,
1127                        build_channels_update(channels_for_user, channel_invites),
1128                    )?;
1129                }
1130                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1131
1132                if let Some(incoming_call) =
1133                    self.app_state.db.incoming_call_for_user(user.id).await?
1134                {
1135                    self.peer.send(connection_id, incoming_call)?;
1136                }
1137
1138                update_user_contacts(user.id, &session).await?;
1139            }
1140            Principal::DevServer(dev_server) => {
1141                {
1142                    let mut pool = self.connection_pool.lock();
1143                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1144                    {
1145                        self.peer.send(
1146                            stale_connection_id,
1147                            proto::ShutdownDevServer {
1148                                reason: Some(
1149                                    "another dev server connected with the same token".to_string(),
1150                                ),
1151                            },
1152                        )?;
1153                        pool.remove_connection(stale_connection_id)?;
1154                    };
1155                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1156                }
1157
1158                let projects = self
1159                    .app_state
1160                    .db
1161                    .get_projects_for_dev_server(dev_server.id)
1162                    .await?;
1163                self.peer
1164                    .send(connection_id, proto::DevServerInstructions { projects })?;
1165
1166                let status = self
1167                    .app_state
1168                    .db
1169                    .dev_server_projects_update(dev_server.user_id)
1170                    .await?;
1171                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1172            }
1173        }
1174
1175        Ok(())
1176    }
1177
1178    pub async fn invite_code_redeemed(
1179        self: &Arc<Self>,
1180        inviter_id: UserId,
1181        invitee_id: UserId,
1182    ) -> Result<()> {
1183        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1184            if let Some(code) = &user.invite_code {
1185                let pool = self.connection_pool.lock();
1186                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1187                for connection_id in pool.user_connection_ids(inviter_id) {
1188                    self.peer.send(
1189                        connection_id,
1190                        proto::UpdateContacts {
1191                            contacts: vec![invitee_contact.clone()],
1192                            ..Default::default()
1193                        },
1194                    )?;
1195                    self.peer.send(
1196                        connection_id,
1197                        proto::UpdateInviteInfo {
1198                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1199                            count: user.invite_count as u32,
1200                        },
1201                    )?;
1202                }
1203            }
1204        }
1205        Ok(())
1206    }
1207
1208    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1209        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1210            if let Some(invite_code) = &user.invite_code {
1211                let pool = self.connection_pool.lock();
1212                for connection_id in pool.user_connection_ids(user_id) {
1213                    self.peer.send(
1214                        connection_id,
1215                        proto::UpdateInviteInfo {
1216                            url: format!(
1217                                "{}{}",
1218                                self.app_state.config.invite_link_prefix, invite_code
1219                            ),
1220                            count: user.invite_count as u32,
1221                        },
1222                    )?;
1223                }
1224            }
1225        }
1226        Ok(())
1227    }
1228
1229    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1230        ServerSnapshot {
1231            connection_pool: ConnectionPoolGuard {
1232                guard: self.connection_pool.lock(),
1233                _not_send: PhantomData,
1234            },
1235            peer: &self.peer,
1236        }
1237    }
1238}
1239
1240impl<'a> Deref for ConnectionPoolGuard<'a> {
1241    type Target = ConnectionPool;
1242
1243    fn deref(&self) -> &Self::Target {
1244        &self.guard
1245    }
1246}
1247
1248impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1249    fn deref_mut(&mut self) -> &mut Self::Target {
1250        &mut self.guard
1251    }
1252}
1253
1254impl<'a> Drop for ConnectionPoolGuard<'a> {
1255    fn drop(&mut self) {
1256        #[cfg(test)]
1257        self.check_invariants();
1258    }
1259}
1260
1261fn broadcast<F>(
1262    sender_id: Option<ConnectionId>,
1263    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1264    mut f: F,
1265) where
1266    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1267{
1268    for receiver_id in receiver_ids {
1269        if Some(receiver_id) != sender_id {
1270            if let Err(error) = f(receiver_id) {
1271                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1272            }
1273        }
1274    }
1275}
1276
1277pub struct ProtocolVersion(u32);
1278
1279impl Header for ProtocolVersion {
1280    fn name() -> &'static HeaderName {
1281        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1282        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1283    }
1284
1285    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1286    where
1287        Self: Sized,
1288        I: Iterator<Item = &'i axum::http::HeaderValue>,
1289    {
1290        let version = values
1291            .next()
1292            .ok_or_else(axum::headers::Error::invalid)?
1293            .to_str()
1294            .map_err(|_| axum::headers::Error::invalid())?
1295            .parse()
1296            .map_err(|_| axum::headers::Error::invalid())?;
1297        Ok(Self(version))
1298    }
1299
1300    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1301        values.extend([self.0.to_string().parse().unwrap()]);
1302    }
1303}
1304
1305pub struct AppVersionHeader(SemanticVersion);
1306impl Header for AppVersionHeader {
1307    fn name() -> &'static HeaderName {
1308        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1309        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1310    }
1311
1312    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1313    where
1314        Self: Sized,
1315        I: Iterator<Item = &'i axum::http::HeaderValue>,
1316    {
1317        let version = values
1318            .next()
1319            .ok_or_else(axum::headers::Error::invalid)?
1320            .to_str()
1321            .map_err(|_| axum::headers::Error::invalid())?
1322            .parse()
1323            .map_err(|_| axum::headers::Error::invalid())?;
1324        Ok(Self(version))
1325    }
1326
1327    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1328        values.extend([self.0.to_string().parse().unwrap()]);
1329    }
1330}
1331
1332pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1333    Router::new()
1334        .route("/rpc", get(handle_websocket_request))
1335        .layer(
1336            ServiceBuilder::new()
1337                .layer(Extension(server.app_state.clone()))
1338                .layer(middleware::from_fn(auth::validate_header)),
1339        )
1340        .route("/metrics", get(handle_metrics))
1341        .layer(Extension(server))
1342}
1343
1344pub async fn handle_websocket_request(
1345    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1346    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1347    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1348    Extension(server): Extension<Arc<Server>>,
1349    Extension(principal): Extension<Principal>,
1350    ws: WebSocketUpgrade,
1351) -> axum::response::Response {
1352    if protocol_version != rpc::PROTOCOL_VERSION {
1353        return (
1354            StatusCode::UPGRADE_REQUIRED,
1355            "client must be upgraded".to_string(),
1356        )
1357            .into_response();
1358    }
1359
1360    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1361        return (
1362            StatusCode::UPGRADE_REQUIRED,
1363            "no version header found".to_string(),
1364        )
1365            .into_response();
1366    };
1367
1368    if !version.can_collaborate() {
1369        return (
1370            StatusCode::UPGRADE_REQUIRED,
1371            "client must be upgraded".to_string(),
1372        )
1373            .into_response();
1374    }
1375
1376    let socket_address = socket_address.to_string();
1377    ws.on_upgrade(move |socket| {
1378        let socket = socket
1379            .map_ok(to_tungstenite_message)
1380            .err_into()
1381            .with(|message| async move { Ok(to_axum_message(message)) });
1382        let connection = Connection::new(Box::pin(socket));
1383        async move {
1384            server
1385                .handle_connection(
1386                    connection,
1387                    socket_address,
1388                    principal,
1389                    version,
1390                    None,
1391                    Executor::Production,
1392                )
1393                .await;
1394        }
1395    })
1396}
1397
1398pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1399    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1400    let connections_metric = CONNECTIONS_METRIC
1401        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1402
1403    let connections = server
1404        .connection_pool
1405        .lock()
1406        .connections()
1407        .filter(|connection| !connection.admin)
1408        .count();
1409    connections_metric.set(connections as _);
1410
1411    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1412    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1413        register_int_gauge!(
1414            "shared_projects",
1415            "number of open projects with one or more guests"
1416        )
1417        .unwrap()
1418    });
1419
1420    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1421    shared_projects_metric.set(shared_projects as _);
1422
1423    let encoder = prometheus::TextEncoder::new();
1424    let metric_families = prometheus::gather();
1425    let encoded_metrics = encoder
1426        .encode_to_string(&metric_families)
1427        .map_err(|err| anyhow!("{}", err))?;
1428    Ok(encoded_metrics)
1429}
1430
1431#[instrument(err, skip(executor))]
1432async fn connection_lost(
1433    session: Session,
1434    mut teardown: watch::Receiver<bool>,
1435    executor: Executor,
1436) -> Result<()> {
1437    session.peer.disconnect(session.connection_id);
1438    session
1439        .connection_pool()
1440        .await
1441        .remove_connection(session.connection_id)?;
1442
1443    session
1444        .db()
1445        .await
1446        .connection_lost(session.connection_id)
1447        .await
1448        .trace_err();
1449
1450    futures::select_biased! {
1451        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1452            match &session.principal {
1453                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1454                    let session = session.for_user().unwrap();
1455
1456                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1457                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1458                    leave_channel_buffers_for_session(&session)
1459                        .await
1460                        .trace_err();
1461
1462                    if !session
1463                        .connection_pool()
1464                        .await
1465                        .is_user_online(session.user_id())
1466                    {
1467                        let db = session.db().await;
1468                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1469                            room_updated(&room, &session.peer);
1470                        }
1471                    }
1472
1473                    update_user_contacts(session.user_id(), &session).await?;
1474                },
1475            Principal::DevServer(_) => {
1476                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1477            },
1478        }
1479        },
1480        _ = teardown.changed().fuse() => {}
1481    }
1482
1483    Ok(())
1484}
1485
1486/// Acknowledges a ping from a client, used to keep the connection alive.
1487async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1488    response.send(proto::Ack {})?;
1489    Ok(())
1490}
1491
1492/// Creates a new room for calling (outside of channels)
1493async fn create_room(
1494    _request: proto::CreateRoom,
1495    response: Response<proto::CreateRoom>,
1496    session: UserSession,
1497) -> Result<()> {
1498    let live_kit_room = nanoid::nanoid!(30);
1499
1500    let live_kit_connection_info = util::maybe!(async {
1501        let live_kit = session.live_kit_client.as_ref();
1502        let live_kit = live_kit?;
1503        let user_id = session.user_id().to_string();
1504
1505        let token = live_kit
1506            .room_token(&live_kit_room, &user_id.to_string())
1507            .trace_err()?;
1508
1509        Some(proto::LiveKitConnectionInfo {
1510            server_url: live_kit.url().into(),
1511            token,
1512            can_publish: true,
1513        })
1514    })
1515    .await;
1516
1517    let room = session
1518        .db()
1519        .await
1520        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1521        .await?;
1522
1523    response.send(proto::CreateRoomResponse {
1524        room: Some(room.clone()),
1525        live_kit_connection_info,
1526    })?;
1527
1528    update_user_contacts(session.user_id(), &session).await?;
1529    Ok(())
1530}
1531
1532/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1533async fn join_room(
1534    request: proto::JoinRoom,
1535    response: Response<proto::JoinRoom>,
1536    session: UserSession,
1537) -> Result<()> {
1538    let room_id = RoomId::from_proto(request.id);
1539
1540    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1541
1542    if let Some(channel_id) = channel_id {
1543        return join_channel_internal(channel_id, Box::new(response), session).await;
1544    }
1545
1546    let joined_room = {
1547        let room = session
1548            .db()
1549            .await
1550            .join_room(room_id, session.user_id(), session.connection_id)
1551            .await?;
1552        room_updated(&room.room, &session.peer);
1553        room.into_inner()
1554    };
1555
1556    for connection_id in session
1557        .connection_pool()
1558        .await
1559        .user_connection_ids(session.user_id())
1560    {
1561        session
1562            .peer
1563            .send(
1564                connection_id,
1565                proto::CallCanceled {
1566                    room_id: room_id.to_proto(),
1567                },
1568            )
1569            .trace_err();
1570    }
1571
1572    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1573        if let Some(token) = live_kit
1574            .room_token(
1575                &joined_room.room.live_kit_room,
1576                &session.user_id().to_string(),
1577            )
1578            .trace_err()
1579        {
1580            Some(proto::LiveKitConnectionInfo {
1581                server_url: live_kit.url().into(),
1582                token,
1583                can_publish: true,
1584            })
1585        } else {
1586            None
1587        }
1588    } else {
1589        None
1590    };
1591
1592    response.send(proto::JoinRoomResponse {
1593        room: Some(joined_room.room),
1594        channel_id: None,
1595        live_kit_connection_info,
1596    })?;
1597
1598    update_user_contacts(session.user_id(), &session).await?;
1599    Ok(())
1600}
1601
1602/// Rejoin room is used to reconnect to a room after connection errors.
1603async fn rejoin_room(
1604    request: proto::RejoinRoom,
1605    response: Response<proto::RejoinRoom>,
1606    session: UserSession,
1607) -> Result<()> {
1608    let room;
1609    let channel;
1610    {
1611        let mut rejoined_room = session
1612            .db()
1613            .await
1614            .rejoin_room(request, session.user_id(), session.connection_id)
1615            .await?;
1616
1617        response.send(proto::RejoinRoomResponse {
1618            room: Some(rejoined_room.room.clone()),
1619            reshared_projects: rejoined_room
1620                .reshared_projects
1621                .iter()
1622                .map(|project| proto::ResharedProject {
1623                    id: project.id.to_proto(),
1624                    collaborators: project
1625                        .collaborators
1626                        .iter()
1627                        .map(|collaborator| collaborator.to_proto())
1628                        .collect(),
1629                })
1630                .collect(),
1631            rejoined_projects: rejoined_room
1632                .rejoined_projects
1633                .iter()
1634                .map(|rejoined_project| rejoined_project.to_proto())
1635                .collect(),
1636        })?;
1637        room_updated(&rejoined_room.room, &session.peer);
1638
1639        for project in &rejoined_room.reshared_projects {
1640            for collaborator in &project.collaborators {
1641                session
1642                    .peer
1643                    .send(
1644                        collaborator.connection_id,
1645                        proto::UpdateProjectCollaborator {
1646                            project_id: project.id.to_proto(),
1647                            old_peer_id: Some(project.old_connection_id.into()),
1648                            new_peer_id: Some(session.connection_id.into()),
1649                        },
1650                    )
1651                    .trace_err();
1652            }
1653
1654            broadcast(
1655                Some(session.connection_id),
1656                project
1657                    .collaborators
1658                    .iter()
1659                    .map(|collaborator| collaborator.connection_id),
1660                |connection_id| {
1661                    session.peer.forward_send(
1662                        session.connection_id,
1663                        connection_id,
1664                        proto::UpdateProject {
1665                            project_id: project.id.to_proto(),
1666                            worktrees: project.worktrees.clone(),
1667                        },
1668                    )
1669                },
1670            );
1671        }
1672
1673        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1674
1675        let rejoined_room = rejoined_room.into_inner();
1676
1677        room = rejoined_room.room;
1678        channel = rejoined_room.channel;
1679    }
1680
1681    if let Some(channel) = channel {
1682        channel_updated(
1683            &channel,
1684            &room,
1685            &session.peer,
1686            &*session.connection_pool().await,
1687        );
1688    }
1689
1690    update_user_contacts(session.user_id(), &session).await?;
1691    Ok(())
1692}
1693
1694fn notify_rejoined_projects(
1695    rejoined_projects: &mut Vec<RejoinedProject>,
1696    session: &UserSession,
1697) -> Result<()> {
1698    for project in rejoined_projects.iter() {
1699        for collaborator in &project.collaborators {
1700            session
1701                .peer
1702                .send(
1703                    collaborator.connection_id,
1704                    proto::UpdateProjectCollaborator {
1705                        project_id: project.id.to_proto(),
1706                        old_peer_id: Some(project.old_connection_id.into()),
1707                        new_peer_id: Some(session.connection_id.into()),
1708                    },
1709                )
1710                .trace_err();
1711        }
1712    }
1713
1714    for project in rejoined_projects {
1715        for worktree in mem::take(&mut project.worktrees) {
1716            #[cfg(any(test, feature = "test-support"))]
1717            const MAX_CHUNK_SIZE: usize = 2;
1718            #[cfg(not(any(test, feature = "test-support")))]
1719            const MAX_CHUNK_SIZE: usize = 256;
1720
1721            // Stream this worktree's entries.
1722            let message = proto::UpdateWorktree {
1723                project_id: project.id.to_proto(),
1724                worktree_id: worktree.id,
1725                abs_path: worktree.abs_path.clone(),
1726                root_name: worktree.root_name,
1727                updated_entries: worktree.updated_entries,
1728                removed_entries: worktree.removed_entries,
1729                scan_id: worktree.scan_id,
1730                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1731                updated_repositories: worktree.updated_repositories,
1732                removed_repositories: worktree.removed_repositories,
1733            };
1734            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1735                session.peer.send(session.connection_id, update.clone())?;
1736            }
1737
1738            // Stream this worktree's diagnostics.
1739            for summary in worktree.diagnostic_summaries {
1740                session.peer.send(
1741                    session.connection_id,
1742                    proto::UpdateDiagnosticSummary {
1743                        project_id: project.id.to_proto(),
1744                        worktree_id: worktree.id,
1745                        summary: Some(summary),
1746                    },
1747                )?;
1748            }
1749
1750            for settings_file in worktree.settings_files {
1751                session.peer.send(
1752                    session.connection_id,
1753                    proto::UpdateWorktreeSettings {
1754                        project_id: project.id.to_proto(),
1755                        worktree_id: worktree.id,
1756                        path: settings_file.path,
1757                        content: Some(settings_file.content),
1758                    },
1759                )?;
1760            }
1761        }
1762
1763        for language_server in &project.language_servers {
1764            session.peer.send(
1765                session.connection_id,
1766                proto::UpdateLanguageServer {
1767                    project_id: project.id.to_proto(),
1768                    language_server_id: language_server.id,
1769                    variant: Some(
1770                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1771                            proto::LspDiskBasedDiagnosticsUpdated {},
1772                        ),
1773                    ),
1774                },
1775            )?;
1776        }
1777    }
1778    Ok(())
1779}
1780
1781/// leave room disconnects from the room.
1782async fn leave_room(
1783    _: proto::LeaveRoom,
1784    response: Response<proto::LeaveRoom>,
1785    session: UserSession,
1786) -> Result<()> {
1787    leave_room_for_session(&session, session.connection_id).await?;
1788    response.send(proto::Ack {})?;
1789    Ok(())
1790}
1791
1792/// Updates the permissions of someone else in the room.
1793async fn set_room_participant_role(
1794    request: proto::SetRoomParticipantRole,
1795    response: Response<proto::SetRoomParticipantRole>,
1796    session: UserSession,
1797) -> Result<()> {
1798    let user_id = UserId::from_proto(request.user_id);
1799    let role = ChannelRole::from(request.role());
1800
1801    let (live_kit_room, can_publish) = {
1802        let room = session
1803            .db()
1804            .await
1805            .set_room_participant_role(
1806                session.user_id(),
1807                RoomId::from_proto(request.room_id),
1808                user_id,
1809                role,
1810            )
1811            .await?;
1812
1813        let live_kit_room = room.live_kit_room.clone();
1814        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1815        room_updated(&room, &session.peer);
1816        (live_kit_room, can_publish)
1817    };
1818
1819    if let Some(live_kit) = session.live_kit_client.as_ref() {
1820        live_kit
1821            .update_participant(
1822                live_kit_room.clone(),
1823                request.user_id.to_string(),
1824                live_kit_server::proto::ParticipantPermission {
1825                    can_subscribe: true,
1826                    can_publish,
1827                    can_publish_data: can_publish,
1828                    hidden: false,
1829                    recorder: false,
1830                },
1831            )
1832            .await
1833            .trace_err();
1834    }
1835
1836    response.send(proto::Ack {})?;
1837    Ok(())
1838}
1839
1840/// Call someone else into the current room
1841async fn call(
1842    request: proto::Call,
1843    response: Response<proto::Call>,
1844    session: UserSession,
1845) -> Result<()> {
1846    let room_id = RoomId::from_proto(request.room_id);
1847    let calling_user_id = session.user_id();
1848    let calling_connection_id = session.connection_id;
1849    let called_user_id = UserId::from_proto(request.called_user_id);
1850    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1851    if !session
1852        .db()
1853        .await
1854        .has_contact(calling_user_id, called_user_id)
1855        .await?
1856    {
1857        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1858    }
1859
1860    let incoming_call = {
1861        let (room, incoming_call) = &mut *session
1862            .db()
1863            .await
1864            .call(
1865                room_id,
1866                calling_user_id,
1867                calling_connection_id,
1868                called_user_id,
1869                initial_project_id,
1870            )
1871            .await?;
1872        room_updated(&room, &session.peer);
1873        mem::take(incoming_call)
1874    };
1875    update_user_contacts(called_user_id, &session).await?;
1876
1877    let mut calls = session
1878        .connection_pool()
1879        .await
1880        .user_connection_ids(called_user_id)
1881        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1882        .collect::<FuturesUnordered<_>>();
1883
1884    while let Some(call_response) = calls.next().await {
1885        match call_response.as_ref() {
1886            Ok(_) => {
1887                response.send(proto::Ack {})?;
1888                return Ok(());
1889            }
1890            Err(_) => {
1891                call_response.trace_err();
1892            }
1893        }
1894    }
1895
1896    {
1897        let room = session
1898            .db()
1899            .await
1900            .call_failed(room_id, called_user_id)
1901            .await?;
1902        room_updated(&room, &session.peer);
1903    }
1904    update_user_contacts(called_user_id, &session).await?;
1905
1906    Err(anyhow!("failed to ring user"))?
1907}
1908
1909/// Cancel an outgoing call.
1910async fn cancel_call(
1911    request: proto::CancelCall,
1912    response: Response<proto::CancelCall>,
1913    session: UserSession,
1914) -> Result<()> {
1915    let called_user_id = UserId::from_proto(request.called_user_id);
1916    let room_id = RoomId::from_proto(request.room_id);
1917    {
1918        let room = session
1919            .db()
1920            .await
1921            .cancel_call(room_id, session.connection_id, called_user_id)
1922            .await?;
1923        room_updated(&room, &session.peer);
1924    }
1925
1926    for connection_id in session
1927        .connection_pool()
1928        .await
1929        .user_connection_ids(called_user_id)
1930    {
1931        session
1932            .peer
1933            .send(
1934                connection_id,
1935                proto::CallCanceled {
1936                    room_id: room_id.to_proto(),
1937                },
1938            )
1939            .trace_err();
1940    }
1941    response.send(proto::Ack {})?;
1942
1943    update_user_contacts(called_user_id, &session).await?;
1944    Ok(())
1945}
1946
1947/// Decline an incoming call.
1948async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1949    let room_id = RoomId::from_proto(message.room_id);
1950    {
1951        let room = session
1952            .db()
1953            .await
1954            .decline_call(Some(room_id), session.user_id())
1955            .await?
1956            .ok_or_else(|| anyhow!("failed to decline call"))?;
1957        room_updated(&room, &session.peer);
1958    }
1959
1960    for connection_id in session
1961        .connection_pool()
1962        .await
1963        .user_connection_ids(session.user_id())
1964    {
1965        session
1966            .peer
1967            .send(
1968                connection_id,
1969                proto::CallCanceled {
1970                    room_id: room_id.to_proto(),
1971                },
1972            )
1973            .trace_err();
1974    }
1975    update_user_contacts(session.user_id(), &session).await?;
1976    Ok(())
1977}
1978
1979/// Updates other participants in the room with your current location.
1980async fn update_participant_location(
1981    request: proto::UpdateParticipantLocation,
1982    response: Response<proto::UpdateParticipantLocation>,
1983    session: UserSession,
1984) -> Result<()> {
1985    let room_id = RoomId::from_proto(request.room_id);
1986    let location = request
1987        .location
1988        .ok_or_else(|| anyhow!("invalid location"))?;
1989
1990    let db = session.db().await;
1991    let room = db
1992        .update_room_participant_location(room_id, session.connection_id, location)
1993        .await?;
1994
1995    room_updated(&room, &session.peer);
1996    response.send(proto::Ack {})?;
1997    Ok(())
1998}
1999
2000/// Share a project into the room.
2001async fn share_project(
2002    request: proto::ShareProject,
2003    response: Response<proto::ShareProject>,
2004    session: UserSession,
2005) -> Result<()> {
2006    let (project_id, room) = &*session
2007        .db()
2008        .await
2009        .share_project(
2010            RoomId::from_proto(request.room_id),
2011            session.connection_id,
2012            &request.worktrees,
2013            request
2014                .dev_server_project_id
2015                .map(|id| DevServerProjectId::from_proto(id)),
2016        )
2017        .await?;
2018    response.send(proto::ShareProjectResponse {
2019        project_id: project_id.to_proto(),
2020    })?;
2021    room_updated(&room, &session.peer);
2022
2023    Ok(())
2024}
2025
2026/// Unshare a project from the room.
2027async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2028    let project_id = ProjectId::from_proto(message.project_id);
2029    unshare_project_internal(
2030        project_id,
2031        session.connection_id,
2032        session.user_id(),
2033        &session,
2034    )
2035    .await
2036}
2037
2038async fn unshare_project_internal(
2039    project_id: ProjectId,
2040    connection_id: ConnectionId,
2041    user_id: Option<UserId>,
2042    session: &Session,
2043) -> Result<()> {
2044    let delete = {
2045        let room_guard = session
2046            .db()
2047            .await
2048            .unshare_project(project_id, connection_id, user_id)
2049            .await?;
2050
2051        let (delete, room, guest_connection_ids) = &*room_guard;
2052
2053        let message = proto::UnshareProject {
2054            project_id: project_id.to_proto(),
2055        };
2056
2057        broadcast(
2058            Some(connection_id),
2059            guest_connection_ids.iter().copied(),
2060            |conn_id| session.peer.send(conn_id, message.clone()),
2061        );
2062        if let Some(room) = room {
2063            room_updated(room, &session.peer);
2064        }
2065
2066        *delete
2067    };
2068
2069    if delete {
2070        let db = session.db().await;
2071        db.delete_project(project_id).await?;
2072    }
2073
2074    Ok(())
2075}
2076
2077/// DevServer makes a project available online
2078async fn share_dev_server_project(
2079    request: proto::ShareDevServerProject,
2080    response: Response<proto::ShareDevServerProject>,
2081    session: DevServerSession,
2082) -> Result<()> {
2083    let (dev_server_project, user_id, status) = session
2084        .db()
2085        .await
2086        .share_dev_server_project(
2087            DevServerProjectId::from_proto(request.dev_server_project_id),
2088            session.dev_server_id(),
2089            session.connection_id,
2090            &request.worktrees,
2091        )
2092        .await?;
2093    let Some(project_id) = dev_server_project.project_id else {
2094        return Err(anyhow!("failed to share remote project"))?;
2095    };
2096
2097    send_dev_server_projects_update(user_id, status, &session).await;
2098
2099    response.send(proto::ShareProjectResponse { project_id })?;
2100
2101    Ok(())
2102}
2103
2104/// Join someone elses shared project.
2105async fn join_project(
2106    request: proto::JoinProject,
2107    response: Response<proto::JoinProject>,
2108    session: UserSession,
2109) -> Result<()> {
2110    let project_id = ProjectId::from_proto(request.project_id);
2111
2112    tracing::info!(%project_id, "join project");
2113
2114    let db = session.db().await;
2115    let (project, replica_id) = &mut *db
2116        .join_project(project_id, session.connection_id, session.user_id())
2117        .await?;
2118    drop(db);
2119    tracing::info!(%project_id, "join remote project");
2120    join_project_internal(response, session, project, replica_id)
2121}
2122
2123trait JoinProjectInternalResponse {
2124    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2125}
2126impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2127    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2128        Response::<proto::JoinProject>::send(self, result)
2129    }
2130}
2131impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2132    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2133        Response::<proto::JoinHostedProject>::send(self, result)
2134    }
2135}
2136
2137fn join_project_internal(
2138    response: impl JoinProjectInternalResponse,
2139    session: UserSession,
2140    project: &mut Project,
2141    replica_id: &ReplicaId,
2142) -> Result<()> {
2143    let collaborators = project
2144        .collaborators
2145        .iter()
2146        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2147        .map(|collaborator| collaborator.to_proto())
2148        .collect::<Vec<_>>();
2149    let project_id = project.id;
2150    let guest_user_id = session.user_id();
2151
2152    let worktrees = project
2153        .worktrees
2154        .iter()
2155        .map(|(id, worktree)| proto::WorktreeMetadata {
2156            id: *id,
2157            root_name: worktree.root_name.clone(),
2158            visible: worktree.visible,
2159            abs_path: worktree.abs_path.clone(),
2160        })
2161        .collect::<Vec<_>>();
2162
2163    let add_project_collaborator = proto::AddProjectCollaborator {
2164        project_id: project_id.to_proto(),
2165        collaborator: Some(proto::Collaborator {
2166            peer_id: Some(session.connection_id.into()),
2167            replica_id: replica_id.0 as u32,
2168            user_id: guest_user_id.to_proto(),
2169        }),
2170    };
2171
2172    for collaborator in &collaborators {
2173        session
2174            .peer
2175            .send(
2176                collaborator.peer_id.unwrap().into(),
2177                add_project_collaborator.clone(),
2178            )
2179            .trace_err();
2180    }
2181
2182    // First, we send the metadata associated with each worktree.
2183    response.send(proto::JoinProjectResponse {
2184        project_id: project.id.0 as u64,
2185        worktrees: worktrees.clone(),
2186        replica_id: replica_id.0 as u32,
2187        collaborators: collaborators.clone(),
2188        language_servers: project.language_servers.clone(),
2189        role: project.role.into(),
2190        dev_server_project_id: project
2191            .dev_server_project_id
2192            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2193    })?;
2194
2195    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2196        #[cfg(any(test, feature = "test-support"))]
2197        const MAX_CHUNK_SIZE: usize = 2;
2198        #[cfg(not(any(test, feature = "test-support")))]
2199        const MAX_CHUNK_SIZE: usize = 256;
2200
2201        // Stream this worktree's entries.
2202        let message = proto::UpdateWorktree {
2203            project_id: project_id.to_proto(),
2204            worktree_id,
2205            abs_path: worktree.abs_path.clone(),
2206            root_name: worktree.root_name,
2207            updated_entries: worktree.entries,
2208            removed_entries: Default::default(),
2209            scan_id: worktree.scan_id,
2210            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2211            updated_repositories: worktree.repository_entries.into_values().collect(),
2212            removed_repositories: Default::default(),
2213        };
2214        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2215            session.peer.send(session.connection_id, update.clone())?;
2216        }
2217
2218        // Stream this worktree's diagnostics.
2219        for summary in worktree.diagnostic_summaries {
2220            session.peer.send(
2221                session.connection_id,
2222                proto::UpdateDiagnosticSummary {
2223                    project_id: project_id.to_proto(),
2224                    worktree_id: worktree.id,
2225                    summary: Some(summary),
2226                },
2227            )?;
2228        }
2229
2230        for settings_file in worktree.settings_files {
2231            session.peer.send(
2232                session.connection_id,
2233                proto::UpdateWorktreeSettings {
2234                    project_id: project_id.to_proto(),
2235                    worktree_id: worktree.id,
2236                    path: settings_file.path,
2237                    content: Some(settings_file.content),
2238                },
2239            )?;
2240        }
2241    }
2242
2243    for language_server in &project.language_servers {
2244        session.peer.send(
2245            session.connection_id,
2246            proto::UpdateLanguageServer {
2247                project_id: project_id.to_proto(),
2248                language_server_id: language_server.id,
2249                variant: Some(
2250                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2251                        proto::LspDiskBasedDiagnosticsUpdated {},
2252                    ),
2253                ),
2254            },
2255        )?;
2256    }
2257
2258    Ok(())
2259}
2260
2261/// Leave someone elses shared project.
2262async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2263    let sender_id = session.connection_id;
2264    let project_id = ProjectId::from_proto(request.project_id);
2265    let db = session.db().await;
2266    if db.is_hosted_project(project_id).await? {
2267        let project = db.leave_hosted_project(project_id, sender_id).await?;
2268        project_left(&project, &session);
2269        return Ok(());
2270    }
2271
2272    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2273    tracing::info!(
2274        %project_id,
2275        "leave project"
2276    );
2277
2278    project_left(&project, &session);
2279    if let Some(room) = room {
2280        room_updated(&room, &session.peer);
2281    }
2282
2283    Ok(())
2284}
2285
2286async fn join_hosted_project(
2287    request: proto::JoinHostedProject,
2288    response: Response<proto::JoinHostedProject>,
2289    session: UserSession,
2290) -> Result<()> {
2291    let (mut project, replica_id) = session
2292        .db()
2293        .await
2294        .join_hosted_project(
2295            ProjectId(request.project_id as i32),
2296            session.user_id(),
2297            session.connection_id,
2298        )
2299        .await?;
2300
2301    join_project_internal(response, session, &mut project, &replica_id)
2302}
2303
2304async fn create_dev_server_project(
2305    request: proto::CreateDevServerProject,
2306    response: Response<proto::CreateDevServerProject>,
2307    session: UserSession,
2308) -> Result<()> {
2309    let dev_server_id = DevServerId(request.dev_server_id as i32);
2310    let dev_server_connection_id = session
2311        .connection_pool()
2312        .await
2313        .dev_server_connection_id(dev_server_id);
2314    let Some(dev_server_connection_id) = dev_server_connection_id else {
2315        Err(ErrorCode::DevServerOffline
2316            .message("Cannot create a remote project when the dev server is offline".to_string())
2317            .anyhow())?
2318    };
2319
2320    let path = request.path.clone();
2321    //Check that the path exists on the dev server
2322    session
2323        .peer
2324        .forward_request(
2325            session.connection_id,
2326            dev_server_connection_id,
2327            proto::ValidateDevServerProjectRequest { path: path.clone() },
2328        )
2329        .await?;
2330
2331    let (dev_server_project, update) = session
2332        .db()
2333        .await
2334        .create_dev_server_project(
2335            DevServerId(request.dev_server_id as i32),
2336            &request.path,
2337            session.user_id(),
2338        )
2339        .await?;
2340
2341    let projects = session
2342        .db()
2343        .await
2344        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2345        .await?;
2346
2347    session.peer.send(
2348        dev_server_connection_id,
2349        proto::DevServerInstructions { projects },
2350    )?;
2351
2352    send_dev_server_projects_update(session.user_id(), update, &session).await;
2353
2354    response.send(proto::CreateDevServerProjectResponse {
2355        dev_server_project: Some(dev_server_project.to_proto(None)),
2356    })?;
2357    Ok(())
2358}
2359
2360async fn create_dev_server(
2361    request: proto::CreateDevServer,
2362    response: Response<proto::CreateDevServer>,
2363    session: UserSession,
2364) -> Result<()> {
2365    let access_token = auth::random_token();
2366    let hashed_access_token = auth::hash_access_token(&access_token);
2367
2368    if request.name.is_empty() {
2369        return Err(proto::ErrorCode::Forbidden
2370            .message("Dev server name cannot be empty".to_string())
2371            .anyhow())?;
2372    }
2373
2374    let (dev_server, status) = session
2375        .db()
2376        .await
2377        .create_dev_server(
2378            &request.name,
2379            request.ssh_connection_string.as_deref(),
2380            &hashed_access_token,
2381            session.user_id(),
2382        )
2383        .await?;
2384
2385    send_dev_server_projects_update(session.user_id(), status, &session).await;
2386
2387    response.send(proto::CreateDevServerResponse {
2388        dev_server_id: dev_server.id.0 as u64,
2389        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2390        name: request.name,
2391    })?;
2392    Ok(())
2393}
2394
2395async fn regenerate_dev_server_token(
2396    request: proto::RegenerateDevServerToken,
2397    response: Response<proto::RegenerateDevServerToken>,
2398    session: UserSession,
2399) -> Result<()> {
2400    let dev_server_id = DevServerId(request.dev_server_id as i32);
2401    let access_token = auth::random_token();
2402    let hashed_access_token = auth::hash_access_token(&access_token);
2403
2404    let connection_id = session
2405        .connection_pool()
2406        .await
2407        .dev_server_connection_id(dev_server_id);
2408    if let Some(connection_id) = connection_id {
2409        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2410        session.peer.send(
2411            connection_id,
2412            proto::ShutdownDevServer {
2413                reason: Some("dev server token was regenerated".to_string()),
2414            },
2415        )?;
2416        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2417    }
2418
2419    let status = session
2420        .db()
2421        .await
2422        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2423        .await?;
2424
2425    send_dev_server_projects_update(session.user_id(), status, &session).await;
2426
2427    response.send(proto::RegenerateDevServerTokenResponse {
2428        dev_server_id: dev_server_id.to_proto(),
2429        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2430    })?;
2431    Ok(())
2432}
2433
2434async fn rename_dev_server(
2435    request: proto::RenameDevServer,
2436    response: Response<proto::RenameDevServer>,
2437    session: UserSession,
2438) -> Result<()> {
2439    if request.name.trim().is_empty() {
2440        return Err(proto::ErrorCode::Forbidden
2441            .message("Dev server name cannot be empty".to_string())
2442            .anyhow())?;
2443    }
2444
2445    let dev_server_id = DevServerId(request.dev_server_id as i32);
2446    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2447    if dev_server.user_id != session.user_id() {
2448        return Err(anyhow!(ErrorCode::Forbidden))?;
2449    }
2450
2451    let status = session
2452        .db()
2453        .await
2454        .rename_dev_server(
2455            dev_server_id,
2456            &request.name,
2457            request.ssh_connection_string.as_deref(),
2458            session.user_id(),
2459        )
2460        .await?;
2461
2462    send_dev_server_projects_update(session.user_id(), status, &session).await;
2463
2464    response.send(proto::Ack {})?;
2465    Ok(())
2466}
2467
2468async fn delete_dev_server(
2469    request: proto::DeleteDevServer,
2470    response: Response<proto::DeleteDevServer>,
2471    session: UserSession,
2472) -> Result<()> {
2473    let dev_server_id = DevServerId(request.dev_server_id as i32);
2474    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2475    if dev_server.user_id != session.user_id() {
2476        return Err(anyhow!(ErrorCode::Forbidden))?;
2477    }
2478
2479    let connection_id = session
2480        .connection_pool()
2481        .await
2482        .dev_server_connection_id(dev_server_id);
2483    if let Some(connection_id) = connection_id {
2484        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2485        session.peer.send(
2486            connection_id,
2487            proto::ShutdownDevServer {
2488                reason: Some("dev server was deleted".to_string()),
2489            },
2490        )?;
2491        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2492    }
2493
2494    let status = session
2495        .db()
2496        .await
2497        .delete_dev_server(dev_server_id, session.user_id())
2498        .await?;
2499
2500    send_dev_server_projects_update(session.user_id(), status, &session).await;
2501
2502    response.send(proto::Ack {})?;
2503    Ok(())
2504}
2505
2506async fn delete_dev_server_project(
2507    request: proto::DeleteDevServerProject,
2508    response: Response<proto::DeleteDevServerProject>,
2509    session: UserSession,
2510) -> Result<()> {
2511    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2512    let dev_server_project = session
2513        .db()
2514        .await
2515        .get_dev_server_project(dev_server_project_id)
2516        .await?;
2517
2518    let dev_server = session
2519        .db()
2520        .await
2521        .get_dev_server(dev_server_project.dev_server_id)
2522        .await?;
2523    if dev_server.user_id != session.user_id() {
2524        return Err(anyhow!(ErrorCode::Forbidden))?;
2525    }
2526
2527    let dev_server_connection_id = session
2528        .connection_pool()
2529        .await
2530        .dev_server_connection_id(dev_server.id);
2531
2532    if let Some(dev_server_connection_id) = dev_server_connection_id {
2533        let project = session
2534            .db()
2535            .await
2536            .find_dev_server_project(dev_server_project_id)
2537            .await;
2538        if let Ok(project) = project {
2539            unshare_project_internal(
2540                project.id,
2541                dev_server_connection_id,
2542                Some(session.user_id()),
2543                &session,
2544            )
2545            .await?;
2546        }
2547    }
2548
2549    let (projects, status) = session
2550        .db()
2551        .await
2552        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2553        .await?;
2554
2555    if let Some(dev_server_connection_id) = dev_server_connection_id {
2556        session.peer.send(
2557            dev_server_connection_id,
2558            proto::DevServerInstructions { projects },
2559        )?;
2560    }
2561
2562    send_dev_server_projects_update(session.user_id(), status, &session).await;
2563
2564    response.send(proto::Ack {})?;
2565    Ok(())
2566}
2567
2568async fn rejoin_dev_server_projects(
2569    request: proto::RejoinRemoteProjects,
2570    response: Response<proto::RejoinRemoteProjects>,
2571    session: UserSession,
2572) -> Result<()> {
2573    let mut rejoined_projects = {
2574        let db = session.db().await;
2575        db.rejoin_dev_server_projects(
2576            &request.rejoined_projects,
2577            session.user_id(),
2578            session.0.connection_id,
2579        )
2580        .await?
2581    };
2582    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2583
2584    response.send(proto::RejoinRemoteProjectsResponse {
2585        rejoined_projects: rejoined_projects
2586            .into_iter()
2587            .map(|project| project.to_proto())
2588            .collect(),
2589    })
2590}
2591
2592async fn reconnect_dev_server(
2593    request: proto::ReconnectDevServer,
2594    response: Response<proto::ReconnectDevServer>,
2595    session: DevServerSession,
2596) -> Result<()> {
2597    let reshared_projects = {
2598        let db = session.db().await;
2599        db.reshare_dev_server_projects(
2600            &request.reshared_projects,
2601            session.dev_server_id(),
2602            session.0.connection_id,
2603        )
2604        .await?
2605    };
2606
2607    for project in &reshared_projects {
2608        for collaborator in &project.collaborators {
2609            session
2610                .peer
2611                .send(
2612                    collaborator.connection_id,
2613                    proto::UpdateProjectCollaborator {
2614                        project_id: project.id.to_proto(),
2615                        old_peer_id: Some(project.old_connection_id.into()),
2616                        new_peer_id: Some(session.connection_id.into()),
2617                    },
2618                )
2619                .trace_err();
2620        }
2621
2622        broadcast(
2623            Some(session.connection_id),
2624            project
2625                .collaborators
2626                .iter()
2627                .map(|collaborator| collaborator.connection_id),
2628            |connection_id| {
2629                session.peer.forward_send(
2630                    session.connection_id,
2631                    connection_id,
2632                    proto::UpdateProject {
2633                        project_id: project.id.to_proto(),
2634                        worktrees: project.worktrees.clone(),
2635                    },
2636                )
2637            },
2638        );
2639    }
2640
2641    response.send(proto::ReconnectDevServerResponse {
2642        reshared_projects: reshared_projects
2643            .iter()
2644            .map(|project| proto::ResharedProject {
2645                id: project.id.to_proto(),
2646                collaborators: project
2647                    .collaborators
2648                    .iter()
2649                    .map(|collaborator| collaborator.to_proto())
2650                    .collect(),
2651            })
2652            .collect(),
2653    })?;
2654
2655    Ok(())
2656}
2657
2658async fn shutdown_dev_server(
2659    _: proto::ShutdownDevServer,
2660    response: Response<proto::ShutdownDevServer>,
2661    session: DevServerSession,
2662) -> Result<()> {
2663    response.send(proto::Ack {})?;
2664    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2665    remove_dev_server_connection(session.dev_server_id(), &session).await
2666}
2667
2668async fn shutdown_dev_server_internal(
2669    dev_server_id: DevServerId,
2670    connection_id: ConnectionId,
2671    session: &Session,
2672) -> Result<()> {
2673    let (dev_server_projects, dev_server) = {
2674        let db = session.db().await;
2675        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2676        let dev_server = db.get_dev_server(dev_server_id).await?;
2677        (dev_server_projects, dev_server)
2678    };
2679
2680    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2681        unshare_project_internal(
2682            ProjectId::from_proto(project_id),
2683            connection_id,
2684            None,
2685            session,
2686        )
2687        .await?;
2688    }
2689
2690    session
2691        .connection_pool()
2692        .await
2693        .set_dev_server_offline(dev_server_id);
2694
2695    let status = session
2696        .db()
2697        .await
2698        .dev_server_projects_update(dev_server.user_id)
2699        .await?;
2700    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2701
2702    Ok(())
2703}
2704
2705async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2706    let dev_server_connection = session
2707        .connection_pool()
2708        .await
2709        .dev_server_connection_id(dev_server_id);
2710
2711    if let Some(dev_server_connection) = dev_server_connection {
2712        session
2713            .connection_pool()
2714            .await
2715            .remove_connection(dev_server_connection)?;
2716    }
2717    Ok(())
2718}
2719
2720/// Updates other participants with changes to the project
2721async fn update_project(
2722    request: proto::UpdateProject,
2723    response: Response<proto::UpdateProject>,
2724    session: Session,
2725) -> Result<()> {
2726    let project_id = ProjectId::from_proto(request.project_id);
2727    let (room, guest_connection_ids) = &*session
2728        .db()
2729        .await
2730        .update_project(project_id, session.connection_id, &request.worktrees)
2731        .await?;
2732    broadcast(
2733        Some(session.connection_id),
2734        guest_connection_ids.iter().copied(),
2735        |connection_id| {
2736            session
2737                .peer
2738                .forward_send(session.connection_id, connection_id, request.clone())
2739        },
2740    );
2741    if let Some(room) = room {
2742        room_updated(&room, &session.peer);
2743    }
2744    response.send(proto::Ack {})?;
2745
2746    Ok(())
2747}
2748
2749/// Updates other participants with changes to the worktree
2750async fn update_worktree(
2751    request: proto::UpdateWorktree,
2752    response: Response<proto::UpdateWorktree>,
2753    session: Session,
2754) -> Result<()> {
2755    let guest_connection_ids = session
2756        .db()
2757        .await
2758        .update_worktree(&request, session.connection_id)
2759        .await?;
2760
2761    broadcast(
2762        Some(session.connection_id),
2763        guest_connection_ids.iter().copied(),
2764        |connection_id| {
2765            session
2766                .peer
2767                .forward_send(session.connection_id, connection_id, request.clone())
2768        },
2769    );
2770    response.send(proto::Ack {})?;
2771    Ok(())
2772}
2773
2774/// Updates other participants with changes to the diagnostics
2775async fn update_diagnostic_summary(
2776    message: proto::UpdateDiagnosticSummary,
2777    session: Session,
2778) -> Result<()> {
2779    let guest_connection_ids = session
2780        .db()
2781        .await
2782        .update_diagnostic_summary(&message, session.connection_id)
2783        .await?;
2784
2785    broadcast(
2786        Some(session.connection_id),
2787        guest_connection_ids.iter().copied(),
2788        |connection_id| {
2789            session
2790                .peer
2791                .forward_send(session.connection_id, connection_id, message.clone())
2792        },
2793    );
2794
2795    Ok(())
2796}
2797
2798/// Updates other participants with changes to the worktree settings
2799async fn update_worktree_settings(
2800    message: proto::UpdateWorktreeSettings,
2801    session: Session,
2802) -> Result<()> {
2803    let guest_connection_ids = session
2804        .db()
2805        .await
2806        .update_worktree_settings(&message, session.connection_id)
2807        .await?;
2808
2809    broadcast(
2810        Some(session.connection_id),
2811        guest_connection_ids.iter().copied(),
2812        |connection_id| {
2813            session
2814                .peer
2815                .forward_send(session.connection_id, connection_id, message.clone())
2816        },
2817    );
2818
2819    Ok(())
2820}
2821
2822/// Notify other participants that a  language server has started.
2823async fn start_language_server(
2824    request: proto::StartLanguageServer,
2825    session: Session,
2826) -> Result<()> {
2827    let guest_connection_ids = session
2828        .db()
2829        .await
2830        .start_language_server(&request, session.connection_id)
2831        .await?;
2832
2833    broadcast(
2834        Some(session.connection_id),
2835        guest_connection_ids.iter().copied(),
2836        |connection_id| {
2837            session
2838                .peer
2839                .forward_send(session.connection_id, connection_id, request.clone())
2840        },
2841    );
2842    Ok(())
2843}
2844
2845/// Notify other participants that a language server has changed.
2846async fn update_language_server(
2847    request: proto::UpdateLanguageServer,
2848    session: Session,
2849) -> Result<()> {
2850    let project_id = ProjectId::from_proto(request.project_id);
2851    let project_connection_ids = session
2852        .db()
2853        .await
2854        .project_connection_ids(project_id, session.connection_id, true)
2855        .await?;
2856    broadcast(
2857        Some(session.connection_id),
2858        project_connection_ids.iter().copied(),
2859        |connection_id| {
2860            session
2861                .peer
2862                .forward_send(session.connection_id, connection_id, request.clone())
2863        },
2864    );
2865    Ok(())
2866}
2867
2868/// forward a project request to the host. These requests should be read only
2869/// as guests are allowed to send them.
2870async fn forward_read_only_project_request<T>(
2871    request: T,
2872    response: Response<T>,
2873    session: UserSession,
2874) -> Result<()>
2875where
2876    T: EntityMessage + RequestMessage,
2877{
2878    let project_id = ProjectId::from_proto(request.remote_entity_id());
2879    let host_connection_id = session
2880        .db()
2881        .await
2882        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2883        .await?;
2884    let payload = session
2885        .peer
2886        .forward_request(session.connection_id, host_connection_id, request)
2887        .await?;
2888    response.send(payload)?;
2889    Ok(())
2890}
2891
2892/// forward a project request to the host. These requests are disallowed
2893/// for guests.
2894async fn forward_mutating_project_request<T>(
2895    request: T,
2896    response: Response<T>,
2897    session: UserSession,
2898) -> Result<()>
2899where
2900    T: EntityMessage + RequestMessage,
2901{
2902    let project_id = ProjectId::from_proto(request.remote_entity_id());
2903
2904    let host_connection_id = session
2905        .db()
2906        .await
2907        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2908        .await?;
2909    let payload = session
2910        .peer
2911        .forward_request(session.connection_id, host_connection_id, request)
2912        .await?;
2913    response.send(payload)?;
2914    Ok(())
2915}
2916
2917/// forward a project request to the host. These requests are disallowed
2918/// for guests.
2919async fn forward_versioned_mutating_project_request<T>(
2920    request: T,
2921    response: Response<T>,
2922    session: UserSession,
2923) -> Result<()>
2924where
2925    T: EntityMessage + RequestMessage + VersionedMessage,
2926{
2927    let project_id = ProjectId::from_proto(request.remote_entity_id());
2928
2929    let host_connection_id = session
2930        .db()
2931        .await
2932        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2933        .await?;
2934    if let Some(host_version) = session
2935        .connection_pool()
2936        .await
2937        .connection(host_connection_id)
2938        .map(|c| c.zed_version)
2939    {
2940        if let Some(min_required_version) = request.required_host_version() {
2941            if min_required_version > host_version {
2942                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2943                    .with_tag("required", &min_required_version.to_string())))?;
2944            }
2945        }
2946    }
2947
2948    let payload = session
2949        .peer
2950        .forward_request(session.connection_id, host_connection_id, request)
2951        .await?;
2952    response.send(payload)?;
2953    Ok(())
2954}
2955
2956/// Notify other participants that a new buffer has been created
2957async fn create_buffer_for_peer(
2958    request: proto::CreateBufferForPeer,
2959    session: Session,
2960) -> Result<()> {
2961    session
2962        .db()
2963        .await
2964        .check_user_is_project_host(
2965            ProjectId::from_proto(request.project_id),
2966            session.connection_id,
2967        )
2968        .await?;
2969    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2970    session
2971        .peer
2972        .forward_send(session.connection_id, peer_id.into(), request)?;
2973    Ok(())
2974}
2975
2976/// Notify other participants that a buffer has been updated. This is
2977/// allowed for guests as long as the update is limited to selections.
2978async fn update_buffer(
2979    request: proto::UpdateBuffer,
2980    response: Response<proto::UpdateBuffer>,
2981    session: Session,
2982) -> Result<()> {
2983    let project_id = ProjectId::from_proto(request.project_id);
2984    let mut capability = Capability::ReadOnly;
2985
2986    for op in request.operations.iter() {
2987        match op.variant {
2988            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2989            Some(_) => capability = Capability::ReadWrite,
2990        }
2991    }
2992
2993    let host = {
2994        let guard = session
2995            .db()
2996            .await
2997            .connections_for_buffer_update(
2998                project_id,
2999                session.principal_id(),
3000                session.connection_id,
3001                capability,
3002            )
3003            .await?;
3004
3005        let (host, guests) = &*guard;
3006
3007        broadcast(
3008            Some(session.connection_id),
3009            guests.clone(),
3010            |connection_id| {
3011                session
3012                    .peer
3013                    .forward_send(session.connection_id, connection_id, request.clone())
3014            },
3015        );
3016
3017        *host
3018    };
3019
3020    if host != session.connection_id {
3021        session
3022            .peer
3023            .forward_request(session.connection_id, host, request.clone())
3024            .await?;
3025    }
3026
3027    response.send(proto::Ack {})?;
3028    Ok(())
3029}
3030
3031/// Notify other participants that a project has been updated.
3032async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3033    request: T,
3034    session: Session,
3035) -> Result<()> {
3036    let project_id = ProjectId::from_proto(request.remote_entity_id());
3037    let project_connection_ids = session
3038        .db()
3039        .await
3040        .project_connection_ids(project_id, session.connection_id, false)
3041        .await?;
3042
3043    broadcast(
3044        Some(session.connection_id),
3045        project_connection_ids.iter().copied(),
3046        |connection_id| {
3047            session
3048                .peer
3049                .forward_send(session.connection_id, connection_id, request.clone())
3050        },
3051    );
3052    Ok(())
3053}
3054
3055/// Start following another user in a call.
3056async fn follow(
3057    request: proto::Follow,
3058    response: Response<proto::Follow>,
3059    session: UserSession,
3060) -> Result<()> {
3061    let room_id = RoomId::from_proto(request.room_id);
3062    let project_id = request.project_id.map(ProjectId::from_proto);
3063    let leader_id = request
3064        .leader_id
3065        .ok_or_else(|| anyhow!("invalid leader id"))?
3066        .into();
3067    let follower_id = session.connection_id;
3068
3069    session
3070        .db()
3071        .await
3072        .check_room_participants(room_id, leader_id, session.connection_id)
3073        .await?;
3074
3075    let response_payload = session
3076        .peer
3077        .forward_request(session.connection_id, leader_id, request)
3078        .await?;
3079    response.send(response_payload)?;
3080
3081    if let Some(project_id) = project_id {
3082        let room = session
3083            .db()
3084            .await
3085            .follow(room_id, project_id, leader_id, follower_id)
3086            .await?;
3087        room_updated(&room, &session.peer);
3088    }
3089
3090    Ok(())
3091}
3092
3093/// Stop following another user in a call.
3094async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3095    let room_id = RoomId::from_proto(request.room_id);
3096    let project_id = request.project_id.map(ProjectId::from_proto);
3097    let leader_id = request
3098        .leader_id
3099        .ok_or_else(|| anyhow!("invalid leader id"))?
3100        .into();
3101    let follower_id = session.connection_id;
3102
3103    session
3104        .db()
3105        .await
3106        .check_room_participants(room_id, leader_id, session.connection_id)
3107        .await?;
3108
3109    session
3110        .peer
3111        .forward_send(session.connection_id, leader_id, request)?;
3112
3113    if let Some(project_id) = project_id {
3114        let room = session
3115            .db()
3116            .await
3117            .unfollow(room_id, project_id, leader_id, follower_id)
3118            .await?;
3119        room_updated(&room, &session.peer);
3120    }
3121
3122    Ok(())
3123}
3124
3125/// Notify everyone following you of your current location.
3126async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3127    let room_id = RoomId::from_proto(request.room_id);
3128    let database = session.db.lock().await;
3129
3130    let connection_ids = if let Some(project_id) = request.project_id {
3131        let project_id = ProjectId::from_proto(project_id);
3132        database
3133            .project_connection_ids(project_id, session.connection_id, true)
3134            .await?
3135    } else {
3136        database
3137            .room_connection_ids(room_id, session.connection_id)
3138            .await?
3139    };
3140
3141    // For now, don't send view update messages back to that view's current leader.
3142    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3143        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3144        _ => None,
3145    });
3146
3147    for connection_id in connection_ids.iter().cloned() {
3148        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3149            session
3150                .peer
3151                .forward_send(session.connection_id, connection_id, request.clone())?;
3152        }
3153    }
3154    Ok(())
3155}
3156
3157/// Get public data about users.
3158async fn get_users(
3159    request: proto::GetUsers,
3160    response: Response<proto::GetUsers>,
3161    session: Session,
3162) -> Result<()> {
3163    let user_ids = request
3164        .user_ids
3165        .into_iter()
3166        .map(UserId::from_proto)
3167        .collect();
3168    let users = session
3169        .db()
3170        .await
3171        .get_users_by_ids(user_ids)
3172        .await?
3173        .into_iter()
3174        .map(|user| proto::User {
3175            id: user.id.to_proto(),
3176            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3177            github_login: user.github_login,
3178        })
3179        .collect();
3180    response.send(proto::UsersResponse { users })?;
3181    Ok(())
3182}
3183
3184/// Search for users (to invite) buy Github login
3185async fn fuzzy_search_users(
3186    request: proto::FuzzySearchUsers,
3187    response: Response<proto::FuzzySearchUsers>,
3188    session: UserSession,
3189) -> Result<()> {
3190    let query = request.query;
3191    let users = match query.len() {
3192        0 => vec![],
3193        1 | 2 => session
3194            .db()
3195            .await
3196            .get_user_by_github_login(&query)
3197            .await?
3198            .into_iter()
3199            .collect(),
3200        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3201    };
3202    let users = users
3203        .into_iter()
3204        .filter(|user| user.id != session.user_id())
3205        .map(|user| proto::User {
3206            id: user.id.to_proto(),
3207            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3208            github_login: user.github_login,
3209        })
3210        .collect();
3211    response.send(proto::UsersResponse { users })?;
3212    Ok(())
3213}
3214
3215/// Send a contact request to another user.
3216async fn request_contact(
3217    request: proto::RequestContact,
3218    response: Response<proto::RequestContact>,
3219    session: UserSession,
3220) -> Result<()> {
3221    let requester_id = session.user_id();
3222    let responder_id = UserId::from_proto(request.responder_id);
3223    if requester_id == responder_id {
3224        return Err(anyhow!("cannot add yourself as a contact"))?;
3225    }
3226
3227    let notifications = session
3228        .db()
3229        .await
3230        .send_contact_request(requester_id, responder_id)
3231        .await?;
3232
3233    // Update outgoing contact requests of requester
3234    let mut update = proto::UpdateContacts::default();
3235    update.outgoing_requests.push(responder_id.to_proto());
3236    for connection_id in session
3237        .connection_pool()
3238        .await
3239        .user_connection_ids(requester_id)
3240    {
3241        session.peer.send(connection_id, update.clone())?;
3242    }
3243
3244    // Update incoming contact requests of responder
3245    let mut update = proto::UpdateContacts::default();
3246    update
3247        .incoming_requests
3248        .push(proto::IncomingContactRequest {
3249            requester_id: requester_id.to_proto(),
3250        });
3251    let connection_pool = session.connection_pool().await;
3252    for connection_id in connection_pool.user_connection_ids(responder_id) {
3253        session.peer.send(connection_id, update.clone())?;
3254    }
3255
3256    send_notifications(&connection_pool, &session.peer, notifications);
3257
3258    response.send(proto::Ack {})?;
3259    Ok(())
3260}
3261
3262/// Accept or decline a contact request
3263async fn respond_to_contact_request(
3264    request: proto::RespondToContactRequest,
3265    response: Response<proto::RespondToContactRequest>,
3266    session: UserSession,
3267) -> Result<()> {
3268    let responder_id = session.user_id();
3269    let requester_id = UserId::from_proto(request.requester_id);
3270    let db = session.db().await;
3271    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3272        db.dismiss_contact_notification(responder_id, requester_id)
3273            .await?;
3274    } else {
3275        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3276
3277        let notifications = db
3278            .respond_to_contact_request(responder_id, requester_id, accept)
3279            .await?;
3280        let requester_busy = db.is_user_busy(requester_id).await?;
3281        let responder_busy = db.is_user_busy(responder_id).await?;
3282
3283        let pool = session.connection_pool().await;
3284        // Update responder with new contact
3285        let mut update = proto::UpdateContacts::default();
3286        if accept {
3287            update
3288                .contacts
3289                .push(contact_for_user(requester_id, requester_busy, &pool));
3290        }
3291        update
3292            .remove_incoming_requests
3293            .push(requester_id.to_proto());
3294        for connection_id in pool.user_connection_ids(responder_id) {
3295            session.peer.send(connection_id, update.clone())?;
3296        }
3297
3298        // Update requester with new contact
3299        let mut update = proto::UpdateContacts::default();
3300        if accept {
3301            update
3302                .contacts
3303                .push(contact_for_user(responder_id, responder_busy, &pool));
3304        }
3305        update
3306            .remove_outgoing_requests
3307            .push(responder_id.to_proto());
3308
3309        for connection_id in pool.user_connection_ids(requester_id) {
3310            session.peer.send(connection_id, update.clone())?;
3311        }
3312
3313        send_notifications(&pool, &session.peer, notifications);
3314    }
3315
3316    response.send(proto::Ack {})?;
3317    Ok(())
3318}
3319
3320/// Remove a contact.
3321async fn remove_contact(
3322    request: proto::RemoveContact,
3323    response: Response<proto::RemoveContact>,
3324    session: UserSession,
3325) -> Result<()> {
3326    let requester_id = session.user_id();
3327    let responder_id = UserId::from_proto(request.user_id);
3328    let db = session.db().await;
3329    let (contact_accepted, deleted_notification_id) =
3330        db.remove_contact(requester_id, responder_id).await?;
3331
3332    let pool = session.connection_pool().await;
3333    // Update outgoing contact requests of requester
3334    let mut update = proto::UpdateContacts::default();
3335    if contact_accepted {
3336        update.remove_contacts.push(responder_id.to_proto());
3337    } else {
3338        update
3339            .remove_outgoing_requests
3340            .push(responder_id.to_proto());
3341    }
3342    for connection_id in pool.user_connection_ids(requester_id) {
3343        session.peer.send(connection_id, update.clone())?;
3344    }
3345
3346    // Update incoming contact requests of responder
3347    let mut update = proto::UpdateContacts::default();
3348    if contact_accepted {
3349        update.remove_contacts.push(requester_id.to_proto());
3350    } else {
3351        update
3352            .remove_incoming_requests
3353            .push(requester_id.to_proto());
3354    }
3355    for connection_id in pool.user_connection_ids(responder_id) {
3356        session.peer.send(connection_id, update.clone())?;
3357        if let Some(notification_id) = deleted_notification_id {
3358            session.peer.send(
3359                connection_id,
3360                proto::DeleteNotification {
3361                    notification_id: notification_id.to_proto(),
3362                },
3363            )?;
3364        }
3365    }
3366
3367    response.send(proto::Ack {})?;
3368    Ok(())
3369}
3370
3371/// Creates a new channel.
3372async fn create_channel(
3373    request: proto::CreateChannel,
3374    response: Response<proto::CreateChannel>,
3375    session: UserSession,
3376) -> Result<()> {
3377    let db = session.db().await;
3378
3379    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3380    let (channel, membership) = db
3381        .create_channel(&request.name, parent_id, session.user_id())
3382        .await?;
3383
3384    let root_id = channel.root_id();
3385    let channel = Channel::from_model(channel);
3386
3387    response.send(proto::CreateChannelResponse {
3388        channel: Some(channel.to_proto()),
3389        parent_id: request.parent_id,
3390    })?;
3391
3392    let mut connection_pool = session.connection_pool().await;
3393    if let Some(membership) = membership {
3394        connection_pool.subscribe_to_channel(
3395            membership.user_id,
3396            membership.channel_id,
3397            membership.role,
3398        );
3399        let update = proto::UpdateUserChannels {
3400            channel_memberships: vec![proto::ChannelMembership {
3401                channel_id: membership.channel_id.to_proto(),
3402                role: membership.role.into(),
3403            }],
3404            ..Default::default()
3405        };
3406        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3407            session.peer.send(connection_id, update.clone())?;
3408        }
3409    }
3410
3411    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3412        if !role.can_see_channel(channel.visibility) {
3413            continue;
3414        }
3415
3416        let update = proto::UpdateChannels {
3417            channels: vec![channel.to_proto()],
3418            ..Default::default()
3419        };
3420        session.peer.send(connection_id, update.clone())?;
3421    }
3422
3423    Ok(())
3424}
3425
3426/// Delete a channel
3427async fn delete_channel(
3428    request: proto::DeleteChannel,
3429    response: Response<proto::DeleteChannel>,
3430    session: UserSession,
3431) -> Result<()> {
3432    let db = session.db().await;
3433
3434    let channel_id = request.channel_id;
3435    let (root_channel, removed_channels) = db
3436        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3437        .await?;
3438    response.send(proto::Ack {})?;
3439
3440    // Notify members of removed channels
3441    let mut update = proto::UpdateChannels::default();
3442    update
3443        .delete_channels
3444        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3445
3446    let connection_pool = session.connection_pool().await;
3447    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3448        session.peer.send(connection_id, update.clone())?;
3449    }
3450
3451    Ok(())
3452}
3453
3454/// Invite someone to join a channel.
3455async fn invite_channel_member(
3456    request: proto::InviteChannelMember,
3457    response: Response<proto::InviteChannelMember>,
3458    session: UserSession,
3459) -> Result<()> {
3460    let db = session.db().await;
3461    let channel_id = ChannelId::from_proto(request.channel_id);
3462    let invitee_id = UserId::from_proto(request.user_id);
3463    let InviteMemberResult {
3464        channel,
3465        notifications,
3466    } = db
3467        .invite_channel_member(
3468            channel_id,
3469            invitee_id,
3470            session.user_id(),
3471            request.role().into(),
3472        )
3473        .await?;
3474
3475    let update = proto::UpdateChannels {
3476        channel_invitations: vec![channel.to_proto()],
3477        ..Default::default()
3478    };
3479
3480    let connection_pool = session.connection_pool().await;
3481    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3482        session.peer.send(connection_id, update.clone())?;
3483    }
3484
3485    send_notifications(&connection_pool, &session.peer, notifications);
3486
3487    response.send(proto::Ack {})?;
3488    Ok(())
3489}
3490
3491/// remove someone from a channel
3492async fn remove_channel_member(
3493    request: proto::RemoveChannelMember,
3494    response: Response<proto::RemoveChannelMember>,
3495    session: UserSession,
3496) -> Result<()> {
3497    let db = session.db().await;
3498    let channel_id = ChannelId::from_proto(request.channel_id);
3499    let member_id = UserId::from_proto(request.user_id);
3500
3501    let RemoveChannelMemberResult {
3502        membership_update,
3503        notification_id,
3504    } = db
3505        .remove_channel_member(channel_id, member_id, session.user_id())
3506        .await?;
3507
3508    let mut connection_pool = session.connection_pool().await;
3509    notify_membership_updated(
3510        &mut connection_pool,
3511        membership_update,
3512        member_id,
3513        &session.peer,
3514    );
3515    for connection_id in connection_pool.user_connection_ids(member_id) {
3516        if let Some(notification_id) = notification_id {
3517            session
3518                .peer
3519                .send(
3520                    connection_id,
3521                    proto::DeleteNotification {
3522                        notification_id: notification_id.to_proto(),
3523                    },
3524                )
3525                .trace_err();
3526        }
3527    }
3528
3529    response.send(proto::Ack {})?;
3530    Ok(())
3531}
3532
3533/// Toggle the channel between public and private.
3534/// Care is taken to maintain the invariant that public channels only descend from public channels,
3535/// (though members-only channels can appear at any point in the hierarchy).
3536async fn set_channel_visibility(
3537    request: proto::SetChannelVisibility,
3538    response: Response<proto::SetChannelVisibility>,
3539    session: UserSession,
3540) -> Result<()> {
3541    let db = session.db().await;
3542    let channel_id = ChannelId::from_proto(request.channel_id);
3543    let visibility = request.visibility().into();
3544
3545    let channel_model = db
3546        .set_channel_visibility(channel_id, visibility, session.user_id())
3547        .await?;
3548    let root_id = channel_model.root_id();
3549    let channel = Channel::from_model(channel_model);
3550
3551    let mut connection_pool = session.connection_pool().await;
3552    for (user_id, role) in connection_pool
3553        .channel_user_ids(root_id)
3554        .collect::<Vec<_>>()
3555        .into_iter()
3556    {
3557        let update = if role.can_see_channel(channel.visibility) {
3558            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3559            proto::UpdateChannels {
3560                channels: vec![channel.to_proto()],
3561                ..Default::default()
3562            }
3563        } else {
3564            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3565            proto::UpdateChannels {
3566                delete_channels: vec![channel.id.to_proto()],
3567                ..Default::default()
3568            }
3569        };
3570
3571        for connection_id in connection_pool.user_connection_ids(user_id) {
3572            session.peer.send(connection_id, update.clone())?;
3573        }
3574    }
3575
3576    response.send(proto::Ack {})?;
3577    Ok(())
3578}
3579
3580/// Alter the role for a user in the channel.
3581async fn set_channel_member_role(
3582    request: proto::SetChannelMemberRole,
3583    response: Response<proto::SetChannelMemberRole>,
3584    session: UserSession,
3585) -> Result<()> {
3586    let db = session.db().await;
3587    let channel_id = ChannelId::from_proto(request.channel_id);
3588    let member_id = UserId::from_proto(request.user_id);
3589    let result = db
3590        .set_channel_member_role(
3591            channel_id,
3592            session.user_id(),
3593            member_id,
3594            request.role().into(),
3595        )
3596        .await?;
3597
3598    match result {
3599        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3600            let mut connection_pool = session.connection_pool().await;
3601            notify_membership_updated(
3602                &mut connection_pool,
3603                membership_update,
3604                member_id,
3605                &session.peer,
3606            )
3607        }
3608        db::SetMemberRoleResult::InviteUpdated(channel) => {
3609            let update = proto::UpdateChannels {
3610                channel_invitations: vec![channel.to_proto()],
3611                ..Default::default()
3612            };
3613
3614            for connection_id in session
3615                .connection_pool()
3616                .await
3617                .user_connection_ids(member_id)
3618            {
3619                session.peer.send(connection_id, update.clone())?;
3620            }
3621        }
3622    }
3623
3624    response.send(proto::Ack {})?;
3625    Ok(())
3626}
3627
3628/// Change the name of a channel
3629async fn rename_channel(
3630    request: proto::RenameChannel,
3631    response: Response<proto::RenameChannel>,
3632    session: UserSession,
3633) -> Result<()> {
3634    let db = session.db().await;
3635    let channel_id = ChannelId::from_proto(request.channel_id);
3636    let channel_model = db
3637        .rename_channel(channel_id, session.user_id(), &request.name)
3638        .await?;
3639    let root_id = channel_model.root_id();
3640    let channel = Channel::from_model(channel_model);
3641
3642    response.send(proto::RenameChannelResponse {
3643        channel: Some(channel.to_proto()),
3644    })?;
3645
3646    let connection_pool = session.connection_pool().await;
3647    let update = proto::UpdateChannels {
3648        channels: vec![channel.to_proto()],
3649        ..Default::default()
3650    };
3651    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3652        if role.can_see_channel(channel.visibility) {
3653            session.peer.send(connection_id, update.clone())?;
3654        }
3655    }
3656
3657    Ok(())
3658}
3659
3660/// Move a channel to a new parent.
3661async fn move_channel(
3662    request: proto::MoveChannel,
3663    response: Response<proto::MoveChannel>,
3664    session: UserSession,
3665) -> Result<()> {
3666    let channel_id = ChannelId::from_proto(request.channel_id);
3667    let to = ChannelId::from_proto(request.to);
3668
3669    let (root_id, channels) = session
3670        .db()
3671        .await
3672        .move_channel(channel_id, to, session.user_id())
3673        .await?;
3674
3675    let connection_pool = session.connection_pool().await;
3676    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3677        let channels = channels
3678            .iter()
3679            .filter_map(|channel| {
3680                if role.can_see_channel(channel.visibility) {
3681                    Some(channel.to_proto())
3682                } else {
3683                    None
3684                }
3685            })
3686            .collect::<Vec<_>>();
3687        if channels.is_empty() {
3688            continue;
3689        }
3690
3691        let update = proto::UpdateChannels {
3692            channels,
3693            ..Default::default()
3694        };
3695
3696        session.peer.send(connection_id, update.clone())?;
3697    }
3698
3699    response.send(Ack {})?;
3700    Ok(())
3701}
3702
3703/// Get the list of channel members
3704async fn get_channel_members(
3705    request: proto::GetChannelMembers,
3706    response: Response<proto::GetChannelMembers>,
3707    session: UserSession,
3708) -> Result<()> {
3709    let db = session.db().await;
3710    let channel_id = ChannelId::from_proto(request.channel_id);
3711    let limit = if request.limit == 0 {
3712        u16::MAX as u64
3713    } else {
3714        request.limit
3715    };
3716    let (members, users) = db
3717        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3718        .await?;
3719    response.send(proto::GetChannelMembersResponse { members, users })?;
3720    Ok(())
3721}
3722
3723/// Accept or decline a channel invitation.
3724async fn respond_to_channel_invite(
3725    request: proto::RespondToChannelInvite,
3726    response: Response<proto::RespondToChannelInvite>,
3727    session: UserSession,
3728) -> Result<()> {
3729    let db = session.db().await;
3730    let channel_id = ChannelId::from_proto(request.channel_id);
3731    let RespondToChannelInvite {
3732        membership_update,
3733        notifications,
3734    } = db
3735        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3736        .await?;
3737
3738    let mut connection_pool = session.connection_pool().await;
3739    if let Some(membership_update) = membership_update {
3740        notify_membership_updated(
3741            &mut connection_pool,
3742            membership_update,
3743            session.user_id(),
3744            &session.peer,
3745        );
3746    } else {
3747        let update = proto::UpdateChannels {
3748            remove_channel_invitations: vec![channel_id.to_proto()],
3749            ..Default::default()
3750        };
3751
3752        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3753            session.peer.send(connection_id, update.clone())?;
3754        }
3755    };
3756
3757    send_notifications(&connection_pool, &session.peer, notifications);
3758
3759    response.send(proto::Ack {})?;
3760
3761    Ok(())
3762}
3763
3764/// Join the channels' room
3765async fn join_channel(
3766    request: proto::JoinChannel,
3767    response: Response<proto::JoinChannel>,
3768    session: UserSession,
3769) -> Result<()> {
3770    let channel_id = ChannelId::from_proto(request.channel_id);
3771    join_channel_internal(channel_id, Box::new(response), session).await
3772}
3773
3774trait JoinChannelInternalResponse {
3775    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3776}
3777impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3778    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3779        Response::<proto::JoinChannel>::send(self, result)
3780    }
3781}
3782impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3783    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3784        Response::<proto::JoinRoom>::send(self, result)
3785    }
3786}
3787
3788async fn join_channel_internal(
3789    channel_id: ChannelId,
3790    response: Box<impl JoinChannelInternalResponse>,
3791    session: UserSession,
3792) -> Result<()> {
3793    let joined_room = {
3794        let mut db = session.db().await;
3795        // If zed quits without leaving the room, and the user re-opens zed before the
3796        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3797        // room they were in.
3798        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3799            tracing::info!(
3800                stale_connection_id = %connection,
3801                "cleaning up stale connection",
3802            );
3803            drop(db);
3804            leave_room_for_session(&session, connection).await?;
3805            db = session.db().await;
3806        }
3807
3808        let (joined_room, membership_updated, role) = db
3809            .join_channel(channel_id, session.user_id(), session.connection_id)
3810            .await?;
3811
3812        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3813            let (can_publish, token) = if role == ChannelRole::Guest {
3814                (
3815                    false,
3816                    live_kit
3817                        .guest_token(
3818                            &joined_room.room.live_kit_room,
3819                            &session.user_id().to_string(),
3820                        )
3821                        .trace_err()?,
3822                )
3823            } else {
3824                (
3825                    true,
3826                    live_kit
3827                        .room_token(
3828                            &joined_room.room.live_kit_room,
3829                            &session.user_id().to_string(),
3830                        )
3831                        .trace_err()?,
3832                )
3833            };
3834
3835            Some(LiveKitConnectionInfo {
3836                server_url: live_kit.url().into(),
3837                token,
3838                can_publish,
3839            })
3840        });
3841
3842        response.send(proto::JoinRoomResponse {
3843            room: Some(joined_room.room.clone()),
3844            channel_id: joined_room
3845                .channel
3846                .as_ref()
3847                .map(|channel| channel.id.to_proto()),
3848            live_kit_connection_info,
3849        })?;
3850
3851        let mut connection_pool = session.connection_pool().await;
3852        if let Some(membership_updated) = membership_updated {
3853            notify_membership_updated(
3854                &mut connection_pool,
3855                membership_updated,
3856                session.user_id(),
3857                &session.peer,
3858            );
3859        }
3860
3861        room_updated(&joined_room.room, &session.peer);
3862
3863        joined_room
3864    };
3865
3866    channel_updated(
3867        &joined_room
3868            .channel
3869            .ok_or_else(|| anyhow!("channel not returned"))?,
3870        &joined_room.room,
3871        &session.peer,
3872        &*session.connection_pool().await,
3873    );
3874
3875    update_user_contacts(session.user_id(), &session).await?;
3876    Ok(())
3877}
3878
3879/// Start editing the channel notes
3880async fn join_channel_buffer(
3881    request: proto::JoinChannelBuffer,
3882    response: Response<proto::JoinChannelBuffer>,
3883    session: UserSession,
3884) -> Result<()> {
3885    let db = session.db().await;
3886    let channel_id = ChannelId::from_proto(request.channel_id);
3887
3888    let open_response = db
3889        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3890        .await?;
3891
3892    let collaborators = open_response.collaborators.clone();
3893    response.send(open_response)?;
3894
3895    let update = UpdateChannelBufferCollaborators {
3896        channel_id: channel_id.to_proto(),
3897        collaborators: collaborators.clone(),
3898    };
3899    channel_buffer_updated(
3900        session.connection_id,
3901        collaborators
3902            .iter()
3903            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3904        &update,
3905        &session.peer,
3906    );
3907
3908    Ok(())
3909}
3910
3911/// Edit the channel notes
3912async fn update_channel_buffer(
3913    request: proto::UpdateChannelBuffer,
3914    session: UserSession,
3915) -> Result<()> {
3916    let db = session.db().await;
3917    let channel_id = ChannelId::from_proto(request.channel_id);
3918
3919    let (collaborators, epoch, version) = db
3920        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3921        .await?;
3922
3923    channel_buffer_updated(
3924        session.connection_id,
3925        collaborators.clone(),
3926        &proto::UpdateChannelBuffer {
3927            channel_id: channel_id.to_proto(),
3928            operations: request.operations,
3929        },
3930        &session.peer,
3931    );
3932
3933    let pool = &*session.connection_pool().await;
3934
3935    let non_collaborators =
3936        pool.channel_connection_ids(channel_id)
3937            .filter_map(|(connection_id, _)| {
3938                if collaborators.contains(&connection_id) {
3939                    None
3940                } else {
3941                    Some(connection_id)
3942                }
3943            });
3944
3945    broadcast(None, non_collaborators, |peer_id| {
3946        session.peer.send(
3947            peer_id,
3948            proto::UpdateChannels {
3949                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3950                    channel_id: channel_id.to_proto(),
3951                    epoch: epoch as u64,
3952                    version: version.clone(),
3953                }],
3954                ..Default::default()
3955            },
3956        )
3957    });
3958
3959    Ok(())
3960}
3961
3962/// Rejoin the channel notes after a connection blip
3963async fn rejoin_channel_buffers(
3964    request: proto::RejoinChannelBuffers,
3965    response: Response<proto::RejoinChannelBuffers>,
3966    session: UserSession,
3967) -> Result<()> {
3968    let db = session.db().await;
3969    let buffers = db
3970        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3971        .await?;
3972
3973    for rejoined_buffer in &buffers {
3974        let collaborators_to_notify = rejoined_buffer
3975            .buffer
3976            .collaborators
3977            .iter()
3978            .filter_map(|c| Some(c.peer_id?.into()));
3979        channel_buffer_updated(
3980            session.connection_id,
3981            collaborators_to_notify,
3982            &proto::UpdateChannelBufferCollaborators {
3983                channel_id: rejoined_buffer.buffer.channel_id,
3984                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3985            },
3986            &session.peer,
3987        );
3988    }
3989
3990    response.send(proto::RejoinChannelBuffersResponse {
3991        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3992    })?;
3993
3994    Ok(())
3995}
3996
3997/// Stop editing the channel notes
3998async fn leave_channel_buffer(
3999    request: proto::LeaveChannelBuffer,
4000    response: Response<proto::LeaveChannelBuffer>,
4001    session: UserSession,
4002) -> Result<()> {
4003    let db = session.db().await;
4004    let channel_id = ChannelId::from_proto(request.channel_id);
4005
4006    let left_buffer = db
4007        .leave_channel_buffer(channel_id, session.connection_id)
4008        .await?;
4009
4010    response.send(Ack {})?;
4011
4012    channel_buffer_updated(
4013        session.connection_id,
4014        left_buffer.connections,
4015        &proto::UpdateChannelBufferCollaborators {
4016            channel_id: channel_id.to_proto(),
4017            collaborators: left_buffer.collaborators,
4018        },
4019        &session.peer,
4020    );
4021
4022    Ok(())
4023}
4024
4025fn channel_buffer_updated<T: EnvelopedMessage>(
4026    sender_id: ConnectionId,
4027    collaborators: impl IntoIterator<Item = ConnectionId>,
4028    message: &T,
4029    peer: &Peer,
4030) {
4031    broadcast(Some(sender_id), collaborators, |peer_id| {
4032        peer.send(peer_id, message.clone())
4033    });
4034}
4035
4036fn send_notifications(
4037    connection_pool: &ConnectionPool,
4038    peer: &Peer,
4039    notifications: db::NotificationBatch,
4040) {
4041    for (user_id, notification) in notifications {
4042        for connection_id in connection_pool.user_connection_ids(user_id) {
4043            if let Err(error) = peer.send(
4044                connection_id,
4045                proto::AddNotification {
4046                    notification: Some(notification.clone()),
4047                },
4048            ) {
4049                tracing::error!(
4050                    "failed to send notification to {:?} {}",
4051                    connection_id,
4052                    error
4053                );
4054            }
4055        }
4056    }
4057}
4058
4059/// Send a message to the channel
4060async fn send_channel_message(
4061    request: proto::SendChannelMessage,
4062    response: Response<proto::SendChannelMessage>,
4063    session: UserSession,
4064) -> Result<()> {
4065    // Validate the message body.
4066    let body = request.body.trim().to_string();
4067    if body.len() > MAX_MESSAGE_LEN {
4068        return Err(anyhow!("message is too long"))?;
4069    }
4070    if body.is_empty() {
4071        return Err(anyhow!("message can't be blank"))?;
4072    }
4073
4074    // TODO: adjust mentions if body is trimmed
4075
4076    let timestamp = OffsetDateTime::now_utc();
4077    let nonce = request
4078        .nonce
4079        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4080
4081    let channel_id = ChannelId::from_proto(request.channel_id);
4082    let CreatedChannelMessage {
4083        message_id,
4084        participant_connection_ids,
4085        notifications,
4086    } = session
4087        .db()
4088        .await
4089        .create_channel_message(
4090            channel_id,
4091            session.user_id(),
4092            &body,
4093            &request.mentions,
4094            timestamp,
4095            nonce.clone().into(),
4096            match request.reply_to_message_id {
4097                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4098                None => None,
4099            },
4100        )
4101        .await?;
4102
4103    let message = proto::ChannelMessage {
4104        sender_id: session.user_id().to_proto(),
4105        id: message_id.to_proto(),
4106        body,
4107        mentions: request.mentions,
4108        timestamp: timestamp.unix_timestamp() as u64,
4109        nonce: Some(nonce),
4110        reply_to_message_id: request.reply_to_message_id,
4111        edited_at: None,
4112    };
4113    broadcast(
4114        Some(session.connection_id),
4115        participant_connection_ids.clone(),
4116        |connection| {
4117            session.peer.send(
4118                connection,
4119                proto::ChannelMessageSent {
4120                    channel_id: channel_id.to_proto(),
4121                    message: Some(message.clone()),
4122                },
4123            )
4124        },
4125    );
4126    response.send(proto::SendChannelMessageResponse {
4127        message: Some(message),
4128    })?;
4129
4130    let pool = &*session.connection_pool().await;
4131    let non_participants =
4132        pool.channel_connection_ids(channel_id)
4133            .filter_map(|(connection_id, _)| {
4134                if participant_connection_ids.contains(&connection_id) {
4135                    None
4136                } else {
4137                    Some(connection_id)
4138                }
4139            });
4140    broadcast(None, non_participants, |peer_id| {
4141        session.peer.send(
4142            peer_id,
4143            proto::UpdateChannels {
4144                latest_channel_message_ids: vec![proto::ChannelMessageId {
4145                    channel_id: channel_id.to_proto(),
4146                    message_id: message_id.to_proto(),
4147                }],
4148                ..Default::default()
4149            },
4150        )
4151    });
4152    send_notifications(pool, &session.peer, notifications);
4153
4154    Ok(())
4155}
4156
4157/// Delete a channel message
4158async fn remove_channel_message(
4159    request: proto::RemoveChannelMessage,
4160    response: Response<proto::RemoveChannelMessage>,
4161    session: UserSession,
4162) -> Result<()> {
4163    let channel_id = ChannelId::from_proto(request.channel_id);
4164    let message_id = MessageId::from_proto(request.message_id);
4165    let (connection_ids, existing_notification_ids) = session
4166        .db()
4167        .await
4168        .remove_channel_message(channel_id, message_id, session.user_id())
4169        .await?;
4170
4171    broadcast(
4172        Some(session.connection_id),
4173        connection_ids,
4174        move |connection| {
4175            session.peer.send(connection, request.clone())?;
4176
4177            for notification_id in &existing_notification_ids {
4178                session.peer.send(
4179                    connection,
4180                    proto::DeleteNotification {
4181                        notification_id: (*notification_id).to_proto(),
4182                    },
4183                )?;
4184            }
4185
4186            Ok(())
4187        },
4188    );
4189    response.send(proto::Ack {})?;
4190    Ok(())
4191}
4192
4193async fn update_channel_message(
4194    request: proto::UpdateChannelMessage,
4195    response: Response<proto::UpdateChannelMessage>,
4196    session: UserSession,
4197) -> Result<()> {
4198    let channel_id = ChannelId::from_proto(request.channel_id);
4199    let message_id = MessageId::from_proto(request.message_id);
4200    let updated_at = OffsetDateTime::now_utc();
4201    let UpdatedChannelMessage {
4202        message_id,
4203        participant_connection_ids,
4204        notifications,
4205        reply_to_message_id,
4206        timestamp,
4207        deleted_mention_notification_ids,
4208        updated_mention_notifications,
4209    } = session
4210        .db()
4211        .await
4212        .update_channel_message(
4213            channel_id,
4214            message_id,
4215            session.user_id(),
4216            request.body.as_str(),
4217            &request.mentions,
4218            updated_at,
4219        )
4220        .await?;
4221
4222    let nonce = request
4223        .nonce
4224        .clone()
4225        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4226
4227    let message = proto::ChannelMessage {
4228        sender_id: session.user_id().to_proto(),
4229        id: message_id.to_proto(),
4230        body: request.body.clone(),
4231        mentions: request.mentions.clone(),
4232        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4233        nonce: Some(nonce),
4234        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4235        edited_at: Some(updated_at.unix_timestamp() as u64),
4236    };
4237
4238    response.send(proto::Ack {})?;
4239
4240    let pool = &*session.connection_pool().await;
4241    broadcast(
4242        Some(session.connection_id),
4243        participant_connection_ids,
4244        |connection| {
4245            session.peer.send(
4246                connection,
4247                proto::ChannelMessageUpdate {
4248                    channel_id: channel_id.to_proto(),
4249                    message: Some(message.clone()),
4250                },
4251            )?;
4252
4253            for notification_id in &deleted_mention_notification_ids {
4254                session.peer.send(
4255                    connection,
4256                    proto::DeleteNotification {
4257                        notification_id: (*notification_id).to_proto(),
4258                    },
4259                )?;
4260            }
4261
4262            for notification in &updated_mention_notifications {
4263                session.peer.send(
4264                    connection,
4265                    proto::UpdateNotification {
4266                        notification: Some(notification.clone()),
4267                    },
4268                )?;
4269            }
4270
4271            Ok(())
4272        },
4273    );
4274
4275    send_notifications(pool, &session.peer, notifications);
4276
4277    Ok(())
4278}
4279
4280/// Mark a channel message as read
4281async fn acknowledge_channel_message(
4282    request: proto::AckChannelMessage,
4283    session: UserSession,
4284) -> Result<()> {
4285    let channel_id = ChannelId::from_proto(request.channel_id);
4286    let message_id = MessageId::from_proto(request.message_id);
4287    let notifications = session
4288        .db()
4289        .await
4290        .observe_channel_message(channel_id, session.user_id(), message_id)
4291        .await?;
4292    send_notifications(
4293        &*session.connection_pool().await,
4294        &session.peer,
4295        notifications,
4296    );
4297    Ok(())
4298}
4299
4300/// Mark a buffer version as synced
4301async fn acknowledge_buffer_version(
4302    request: proto::AckBufferOperation,
4303    session: UserSession,
4304) -> Result<()> {
4305    let buffer_id = BufferId::from_proto(request.buffer_id);
4306    session
4307        .db()
4308        .await
4309        .observe_buffer_version(
4310            buffer_id,
4311            session.user_id(),
4312            request.epoch as i32,
4313            &request.version,
4314        )
4315        .await?;
4316    Ok(())
4317}
4318
4319struct CompleteWithLanguageModelRateLimit;
4320
4321impl RateLimit for CompleteWithLanguageModelRateLimit {
4322    fn capacity() -> usize {
4323        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4324            .ok()
4325            .and_then(|v| v.parse().ok())
4326            .unwrap_or(120) // Picked arbitrarily
4327    }
4328
4329    fn refill_duration() -> chrono::Duration {
4330        chrono::Duration::hours(1)
4331    }
4332
4333    fn db_name() -> &'static str {
4334        "complete-with-language-model"
4335    }
4336}
4337
4338async fn complete_with_language_model(
4339    request: proto::CompleteWithLanguageModel,
4340    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4341    session: Session,
4342    open_ai_api_key: Option<Arc<str>>,
4343    google_ai_api_key: Option<Arc<str>>,
4344    anthropic_api_key: Option<Arc<str>>,
4345) -> Result<()> {
4346    let Some(session) = session.for_user() else {
4347        return Err(anyhow!("user not found"))?;
4348    };
4349    authorize_access_to_language_models(&session).await?;
4350    session
4351        .rate_limiter
4352        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4353        .await?;
4354
4355    if request.model.starts_with("gpt") {
4356        let api_key =
4357            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4358        complete_with_open_ai(request, response, session, api_key).await?;
4359    } else if request.model.starts_with("gemini") {
4360        let api_key = google_ai_api_key
4361            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4362        complete_with_google_ai(request, response, session, api_key).await?;
4363    } else if request.model.starts_with("claude") {
4364        let api_key = anthropic_api_key
4365            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4366        complete_with_anthropic(request, response, session, api_key).await?;
4367    }
4368
4369    Ok(())
4370}
4371
4372async fn complete_with_open_ai(
4373    request: proto::CompleteWithLanguageModel,
4374    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4375    session: UserSession,
4376    api_key: Arc<str>,
4377) -> Result<()> {
4378    let mut completion_stream = open_ai::stream_completion(
4379        session.http_client.as_ref(),
4380        OPEN_AI_API_URL,
4381        &api_key,
4382        crate::ai::language_model_request_to_open_ai(request)?,
4383        None,
4384    )
4385    .await
4386    .context("open_ai::stream_completion request failed within collab")?;
4387
4388    while let Some(event) = completion_stream.next().await {
4389        let event = event?;
4390        response.send(proto::LanguageModelResponse {
4391            choices: event
4392                .choices
4393                .into_iter()
4394                .map(|choice| proto::LanguageModelChoiceDelta {
4395                    index: choice.index,
4396                    delta: Some(proto::LanguageModelResponseMessage {
4397                        role: choice.delta.role.map(|role| match role {
4398                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4399                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4400                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4401                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4402                        } as i32),
4403                        content: choice.delta.content,
4404                        tool_calls: choice
4405                            .delta
4406                            .tool_calls
4407                            .into_iter()
4408                            .map(|delta| proto::ToolCallDelta {
4409                                index: delta.index as u32,
4410                                id: delta.id,
4411                                variant: match delta.function {
4412                                    Some(function) => {
4413                                        let name = function.name;
4414                                        let arguments = function.arguments;
4415
4416                                        Some(proto::tool_call_delta::Variant::Function(
4417                                            proto::tool_call_delta::FunctionCallDelta {
4418                                                name,
4419                                                arguments,
4420                                            },
4421                                        ))
4422                                    }
4423                                    None => None,
4424                                },
4425                            })
4426                            .collect(),
4427                    }),
4428                    finish_reason: choice.finish_reason,
4429                })
4430                .collect(),
4431        })?;
4432    }
4433
4434    Ok(())
4435}
4436
4437async fn complete_with_google_ai(
4438    request: proto::CompleteWithLanguageModel,
4439    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4440    session: UserSession,
4441    api_key: Arc<str>,
4442) -> Result<()> {
4443    let mut stream = google_ai::stream_generate_content(
4444        session.http_client.clone(),
4445        google_ai::API_URL,
4446        api_key.as_ref(),
4447        crate::ai::language_model_request_to_google_ai(request)?,
4448    )
4449    .await
4450    .context("google_ai::stream_generate_content request failed")?;
4451
4452    while let Some(event) = stream.next().await {
4453        let event = event?;
4454        response.send(proto::LanguageModelResponse {
4455            choices: event
4456                .candidates
4457                .unwrap_or_default()
4458                .into_iter()
4459                .map(|candidate| proto::LanguageModelChoiceDelta {
4460                    index: candidate.index as u32,
4461                    delta: Some(proto::LanguageModelResponseMessage {
4462                        role: Some(match candidate.content.role {
4463                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4464                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4465                        } as i32),
4466                        content: Some(
4467                            candidate
4468                                .content
4469                                .parts
4470                                .into_iter()
4471                                .filter_map(|part| match part {
4472                                    google_ai::Part::TextPart(part) => Some(part.text),
4473                                    google_ai::Part::InlineDataPart(_) => None,
4474                                })
4475                                .collect(),
4476                        ),
4477                        // Tool calls are not supported for Google
4478                        tool_calls: Vec::new(),
4479                    }),
4480                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4481                })
4482                .collect(),
4483        })?;
4484    }
4485
4486    Ok(())
4487}
4488
4489async fn complete_with_anthropic(
4490    request: proto::CompleteWithLanguageModel,
4491    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4492    session: UserSession,
4493    api_key: Arc<str>,
4494) -> Result<()> {
4495    let model = anthropic::Model::from_id(&request.model)?;
4496
4497    let mut system_message = String::new();
4498    let messages = request
4499        .messages
4500        .into_iter()
4501        .filter_map(|message| {
4502            match message.role() {
4503                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4504                    role: anthropic::Role::User,
4505                    content: message.content,
4506                }),
4507                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4508                    role: anthropic::Role::Assistant,
4509                    content: message.content,
4510                }),
4511                // Anthropic's API breaks system instructions out as a separate field rather
4512                // than having a system message role.
4513                LanguageModelRole::LanguageModelSystem => {
4514                    if !system_message.is_empty() {
4515                        system_message.push_str("\n\n");
4516                    }
4517                    system_message.push_str(&message.content);
4518
4519                    None
4520                }
4521                // We don't yet support tool calls for Anthropic
4522                LanguageModelRole::LanguageModelTool => None,
4523            }
4524        })
4525        .collect();
4526
4527    let mut stream = anthropic::stream_completion(
4528        session.http_client.as_ref(),
4529        anthropic::ANTHROPIC_API_URL,
4530        &api_key,
4531        anthropic::Request {
4532            model,
4533            messages,
4534            stream: true,
4535            system: system_message,
4536            max_tokens: 4092,
4537        },
4538        None,
4539    )
4540    .await?;
4541
4542    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4543
4544    while let Some(event) = stream.next().await {
4545        let event = event?;
4546
4547        match event {
4548            anthropic::ResponseEvent::MessageStart { message } => {
4549                if let Some(role) = message.role {
4550                    if role == "assistant" {
4551                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4552                    } else if role == "user" {
4553                        current_role = proto::LanguageModelRole::LanguageModelUser;
4554                    }
4555                }
4556            }
4557            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4558                match content_block {
4559                    anthropic::ContentBlock::Text { text } => {
4560                        if !text.is_empty() {
4561                            response.send(proto::LanguageModelResponse {
4562                                choices: vec![proto::LanguageModelChoiceDelta {
4563                                    index: 0,
4564                                    delta: Some(proto::LanguageModelResponseMessage {
4565                                        role: Some(current_role as i32),
4566                                        content: Some(text),
4567                                        tool_calls: Vec::new(),
4568                                    }),
4569                                    finish_reason: None,
4570                                }],
4571                            })?;
4572                        }
4573                    }
4574                }
4575            }
4576            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4577                anthropic::TextDelta::TextDelta { text } => {
4578                    response.send(proto::LanguageModelResponse {
4579                        choices: vec![proto::LanguageModelChoiceDelta {
4580                            index: 0,
4581                            delta: Some(proto::LanguageModelResponseMessage {
4582                                role: Some(current_role as i32),
4583                                content: Some(text),
4584                                tool_calls: Vec::new(),
4585                            }),
4586                            finish_reason: None,
4587                        }],
4588                    })?;
4589                }
4590            },
4591            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4592                if let Some(stop_reason) = delta.stop_reason {
4593                    response.send(proto::LanguageModelResponse {
4594                        choices: vec![proto::LanguageModelChoiceDelta {
4595                            index: 0,
4596                            delta: None,
4597                            finish_reason: Some(stop_reason),
4598                        }],
4599                    })?;
4600                }
4601            }
4602            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4603            anthropic::ResponseEvent::MessageStop {} => {}
4604            anthropic::ResponseEvent::Ping {} => {}
4605        }
4606    }
4607
4608    Ok(())
4609}
4610
4611struct CountTokensWithLanguageModelRateLimit;
4612
4613impl RateLimit for CountTokensWithLanguageModelRateLimit {
4614    fn capacity() -> usize {
4615        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4616            .ok()
4617            .and_then(|v| v.parse().ok())
4618            .unwrap_or(600) // Picked arbitrarily
4619    }
4620
4621    fn refill_duration() -> chrono::Duration {
4622        chrono::Duration::hours(1)
4623    }
4624
4625    fn db_name() -> &'static str {
4626        "count-tokens-with-language-model"
4627    }
4628}
4629
4630async fn count_tokens_with_language_model(
4631    request: proto::CountTokensWithLanguageModel,
4632    response: Response<proto::CountTokensWithLanguageModel>,
4633    session: UserSession,
4634    google_ai_api_key: Option<Arc<str>>,
4635) -> Result<()> {
4636    authorize_access_to_language_models(&session).await?;
4637
4638    if !request.model.starts_with("gemini") {
4639        return Err(anyhow!(
4640            "counting tokens for model: {:?} is not supported",
4641            request.model
4642        ))?;
4643    }
4644
4645    session
4646        .rate_limiter
4647        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4648        .await?;
4649
4650    let api_key = google_ai_api_key
4651        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4652    let tokens_response = google_ai::count_tokens(
4653        session.http_client.as_ref(),
4654        google_ai::API_URL,
4655        &api_key,
4656        crate::ai::count_tokens_request_to_google_ai(request)?,
4657    )
4658    .await?;
4659    response.send(proto::CountTokensResponse {
4660        token_count: tokens_response.total_tokens as u32,
4661    })?;
4662    Ok(())
4663}
4664
4665struct ComputeEmbeddingsRateLimit;
4666
4667impl RateLimit for ComputeEmbeddingsRateLimit {
4668    fn capacity() -> usize {
4669        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4670            .ok()
4671            .and_then(|v| v.parse().ok())
4672            .unwrap_or(5000) // Picked arbitrarily
4673    }
4674
4675    fn refill_duration() -> chrono::Duration {
4676        chrono::Duration::hours(1)
4677    }
4678
4679    fn db_name() -> &'static str {
4680        "compute-embeddings"
4681    }
4682}
4683
4684async fn compute_embeddings(
4685    request: proto::ComputeEmbeddings,
4686    response: Response<proto::ComputeEmbeddings>,
4687    session: UserSession,
4688    api_key: Option<Arc<str>>,
4689) -> Result<()> {
4690    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4691    authorize_access_to_language_models(&session).await?;
4692
4693    session
4694        .rate_limiter
4695        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4696        .await?;
4697
4698    let embeddings = match request.model.as_str() {
4699        "openai/text-embedding-3-small" => {
4700            open_ai::embed(
4701                session.http_client.as_ref(),
4702                OPEN_AI_API_URL,
4703                &api_key,
4704                OpenAiEmbeddingModel::TextEmbedding3Small,
4705                request.texts.iter().map(|text| text.as_str()),
4706            )
4707            .await?
4708        }
4709        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4710    };
4711
4712    let embeddings = request
4713        .texts
4714        .iter()
4715        .map(|text| {
4716            let mut hasher = sha2::Sha256::new();
4717            hasher.update(text.as_bytes());
4718            let result = hasher.finalize();
4719            result.to_vec()
4720        })
4721        .zip(
4722            embeddings
4723                .data
4724                .into_iter()
4725                .map(|embedding| embedding.embedding),
4726        )
4727        .collect::<HashMap<_, _>>();
4728
4729    let db = session.db().await;
4730    db.save_embeddings(&request.model, &embeddings)
4731        .await
4732        .context("failed to save embeddings")
4733        .trace_err();
4734
4735    response.send(proto::ComputeEmbeddingsResponse {
4736        embeddings: embeddings
4737            .into_iter()
4738            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4739            .collect(),
4740    })?;
4741    Ok(())
4742}
4743
4744async fn get_cached_embeddings(
4745    request: proto::GetCachedEmbeddings,
4746    response: Response<proto::GetCachedEmbeddings>,
4747    session: UserSession,
4748) -> Result<()> {
4749    authorize_access_to_language_models(&session).await?;
4750
4751    let db = session.db().await;
4752    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4753
4754    response.send(proto::GetCachedEmbeddingsResponse {
4755        embeddings: embeddings
4756            .into_iter()
4757            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4758            .collect(),
4759    })?;
4760    Ok(())
4761}
4762
4763async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4764    let db = session.db().await;
4765    let flags = db.get_user_flags(session.user_id()).await?;
4766    if flags.iter().any(|flag| flag == "language-models") {
4767        Ok(())
4768    } else {
4769        Err(anyhow!("permission denied"))?
4770    }
4771}
4772
4773/// Get a Supermaven API key for the user
4774async fn get_supermaven_api_key(
4775    _request: proto::GetSupermavenApiKey,
4776    response: Response<proto::GetSupermavenApiKey>,
4777    session: UserSession,
4778) -> Result<()> {
4779    let user_id: String = session.user_id().to_string();
4780    if !session.is_staff() {
4781        return Err(anyhow!("supermaven not enabled for this account"))?;
4782    }
4783
4784    let email = session
4785        .email()
4786        .ok_or_else(|| anyhow!("user must have an email"))?;
4787
4788    let supermaven_admin_api = session
4789        .supermaven_client
4790        .as_ref()
4791        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4792
4793    let result = supermaven_admin_api
4794        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4795        .await?;
4796
4797    response.send(proto::GetSupermavenApiKeyResponse {
4798        api_key: result.api_key,
4799    })?;
4800
4801    Ok(())
4802}
4803
4804/// Start receiving chat updates for a channel
4805async fn join_channel_chat(
4806    request: proto::JoinChannelChat,
4807    response: Response<proto::JoinChannelChat>,
4808    session: UserSession,
4809) -> Result<()> {
4810    let channel_id = ChannelId::from_proto(request.channel_id);
4811
4812    let db = session.db().await;
4813    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4814        .await?;
4815    let messages = db
4816        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4817        .await?;
4818    response.send(proto::JoinChannelChatResponse {
4819        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4820        messages,
4821    })?;
4822    Ok(())
4823}
4824
4825/// Stop receiving chat updates for a channel
4826async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4827    let channel_id = ChannelId::from_proto(request.channel_id);
4828    session
4829        .db()
4830        .await
4831        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4832        .await?;
4833    Ok(())
4834}
4835
4836/// Retrieve the chat history for a channel
4837async fn get_channel_messages(
4838    request: proto::GetChannelMessages,
4839    response: Response<proto::GetChannelMessages>,
4840    session: UserSession,
4841) -> Result<()> {
4842    let channel_id = ChannelId::from_proto(request.channel_id);
4843    let messages = session
4844        .db()
4845        .await
4846        .get_channel_messages(
4847            channel_id,
4848            session.user_id(),
4849            MESSAGE_COUNT_PER_PAGE,
4850            Some(MessageId::from_proto(request.before_message_id)),
4851        )
4852        .await?;
4853    response.send(proto::GetChannelMessagesResponse {
4854        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4855        messages,
4856    })?;
4857    Ok(())
4858}
4859
4860/// Retrieve specific chat messages
4861async fn get_channel_messages_by_id(
4862    request: proto::GetChannelMessagesById,
4863    response: Response<proto::GetChannelMessagesById>,
4864    session: UserSession,
4865) -> Result<()> {
4866    let message_ids = request
4867        .message_ids
4868        .iter()
4869        .map(|id| MessageId::from_proto(*id))
4870        .collect::<Vec<_>>();
4871    let messages = session
4872        .db()
4873        .await
4874        .get_channel_messages_by_id(session.user_id(), &message_ids)
4875        .await?;
4876    response.send(proto::GetChannelMessagesResponse {
4877        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4878        messages,
4879    })?;
4880    Ok(())
4881}
4882
4883/// Retrieve the current users notifications
4884async fn get_notifications(
4885    request: proto::GetNotifications,
4886    response: Response<proto::GetNotifications>,
4887    session: UserSession,
4888) -> Result<()> {
4889    let notifications = session
4890        .db()
4891        .await
4892        .get_notifications(
4893            session.user_id(),
4894            NOTIFICATION_COUNT_PER_PAGE,
4895            request
4896                .before_id
4897                .map(|id| db::NotificationId::from_proto(id)),
4898        )
4899        .await?;
4900    response.send(proto::GetNotificationsResponse {
4901        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4902        notifications,
4903    })?;
4904    Ok(())
4905}
4906
4907/// Mark notifications as read
4908async fn mark_notification_as_read(
4909    request: proto::MarkNotificationRead,
4910    response: Response<proto::MarkNotificationRead>,
4911    session: UserSession,
4912) -> Result<()> {
4913    let database = &session.db().await;
4914    let notifications = database
4915        .mark_notification_as_read_by_id(
4916            session.user_id(),
4917            NotificationId::from_proto(request.notification_id),
4918        )
4919        .await?;
4920    send_notifications(
4921        &*session.connection_pool().await,
4922        &session.peer,
4923        notifications,
4924    );
4925    response.send(proto::Ack {})?;
4926    Ok(())
4927}
4928
4929/// Get the current users information
4930async fn get_private_user_info(
4931    _request: proto::GetPrivateUserInfo,
4932    response: Response<proto::GetPrivateUserInfo>,
4933    session: UserSession,
4934) -> Result<()> {
4935    let db = session.db().await;
4936
4937    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4938    let user = db
4939        .get_user_by_id(session.user_id())
4940        .await?
4941        .ok_or_else(|| anyhow!("user not found"))?;
4942    let flags = db.get_user_flags(session.user_id()).await?;
4943
4944    response.send(proto::GetPrivateUserInfoResponse {
4945        metrics_id,
4946        staff: user.admin,
4947        flags,
4948    })?;
4949    Ok(())
4950}
4951
4952fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4953    match message {
4954        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4955        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4956        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4957        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4958        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4959            code: frame.code.into(),
4960            reason: frame.reason,
4961        })),
4962    }
4963}
4964
4965fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4966    match message {
4967        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4968        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4969        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4970        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4971        AxumMessage::Close(frame) => {
4972            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4973                code: frame.code.into(),
4974                reason: frame.reason,
4975            }))
4976        }
4977    }
4978}
4979
4980fn notify_membership_updated(
4981    connection_pool: &mut ConnectionPool,
4982    result: MembershipUpdated,
4983    user_id: UserId,
4984    peer: &Peer,
4985) {
4986    for membership in &result.new_channels.channel_memberships {
4987        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4988    }
4989    for channel_id in &result.removed_channels {
4990        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4991    }
4992
4993    let user_channels_update = proto::UpdateUserChannels {
4994        channel_memberships: result
4995            .new_channels
4996            .channel_memberships
4997            .iter()
4998            .map(|cm| proto::ChannelMembership {
4999                channel_id: cm.channel_id.to_proto(),
5000                role: cm.role.into(),
5001            })
5002            .collect(),
5003        ..Default::default()
5004    };
5005
5006    let mut update = build_channels_update(result.new_channels, vec![]);
5007    update.delete_channels = result
5008        .removed_channels
5009        .into_iter()
5010        .map(|id| id.to_proto())
5011        .collect();
5012    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5013
5014    for connection_id in connection_pool.user_connection_ids(user_id) {
5015        peer.send(connection_id, user_channels_update.clone())
5016            .trace_err();
5017        peer.send(connection_id, update.clone()).trace_err();
5018    }
5019}
5020
5021fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5022    proto::UpdateUserChannels {
5023        channel_memberships: channels
5024            .channel_memberships
5025            .iter()
5026            .map(|m| proto::ChannelMembership {
5027                channel_id: m.channel_id.to_proto(),
5028                role: m.role.into(),
5029            })
5030            .collect(),
5031        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5032        observed_channel_message_id: channels.observed_channel_messages.clone(),
5033    }
5034}
5035
5036fn build_channels_update(
5037    channels: ChannelsForUser,
5038    channel_invites: Vec<db::Channel>,
5039) -> proto::UpdateChannels {
5040    let mut update = proto::UpdateChannels::default();
5041
5042    for channel in channels.channels {
5043        update.channels.push(channel.to_proto());
5044    }
5045
5046    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5047    update.latest_channel_message_ids = channels.latest_channel_messages;
5048
5049    for (channel_id, participants) in channels.channel_participants {
5050        update
5051            .channel_participants
5052            .push(proto::ChannelParticipants {
5053                channel_id: channel_id.to_proto(),
5054                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5055            });
5056    }
5057
5058    for channel in channel_invites {
5059        update.channel_invitations.push(channel.to_proto());
5060    }
5061
5062    update.hosted_projects = channels.hosted_projects;
5063    update
5064}
5065
5066fn build_initial_contacts_update(
5067    contacts: Vec<db::Contact>,
5068    pool: &ConnectionPool,
5069) -> proto::UpdateContacts {
5070    let mut update = proto::UpdateContacts::default();
5071
5072    for contact in contacts {
5073        match contact {
5074            db::Contact::Accepted { user_id, busy } => {
5075                update.contacts.push(contact_for_user(user_id, busy, &pool));
5076            }
5077            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5078            db::Contact::Incoming { user_id } => {
5079                update
5080                    .incoming_requests
5081                    .push(proto::IncomingContactRequest {
5082                        requester_id: user_id.to_proto(),
5083                    })
5084            }
5085        }
5086    }
5087
5088    update
5089}
5090
5091fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5092    proto::Contact {
5093        user_id: user_id.to_proto(),
5094        online: pool.is_user_online(user_id),
5095        busy,
5096    }
5097}
5098
5099fn room_updated(room: &proto::Room, peer: &Peer) {
5100    broadcast(
5101        None,
5102        room.participants
5103            .iter()
5104            .filter_map(|participant| Some(participant.peer_id?.into())),
5105        |peer_id| {
5106            peer.send(
5107                peer_id,
5108                proto::RoomUpdated {
5109                    room: Some(room.clone()),
5110                },
5111            )
5112        },
5113    );
5114}
5115
5116fn channel_updated(
5117    channel: &db::channel::Model,
5118    room: &proto::Room,
5119    peer: &Peer,
5120    pool: &ConnectionPool,
5121) {
5122    let participants = room
5123        .participants
5124        .iter()
5125        .map(|p| p.user_id)
5126        .collect::<Vec<_>>();
5127
5128    broadcast(
5129        None,
5130        pool.channel_connection_ids(channel.root_id())
5131            .filter_map(|(channel_id, role)| {
5132                role.can_see_channel(channel.visibility).then(|| channel_id)
5133            }),
5134        |peer_id| {
5135            peer.send(
5136                peer_id,
5137                proto::UpdateChannels {
5138                    channel_participants: vec![proto::ChannelParticipants {
5139                        channel_id: channel.id.to_proto(),
5140                        participant_user_ids: participants.clone(),
5141                    }],
5142                    ..Default::default()
5143                },
5144            )
5145        },
5146    );
5147}
5148
5149async fn send_dev_server_projects_update(
5150    user_id: UserId,
5151    mut status: proto::DevServerProjectsUpdate,
5152    session: &Session,
5153) {
5154    let pool = session.connection_pool().await;
5155    for dev_server in &mut status.dev_servers {
5156        dev_server.status =
5157            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5158    }
5159    let connections = pool.user_connection_ids(user_id);
5160    for connection_id in connections {
5161        session.peer.send(connection_id, status.clone()).trace_err();
5162    }
5163}
5164
5165async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5166    let db = session.db().await;
5167
5168    let contacts = db.get_contacts(user_id).await?;
5169    let busy = db.is_user_busy(user_id).await?;
5170
5171    let pool = session.connection_pool().await;
5172    let updated_contact = contact_for_user(user_id, busy, &pool);
5173    for contact in contacts {
5174        if let db::Contact::Accepted {
5175            user_id: contact_user_id,
5176            ..
5177        } = contact
5178        {
5179            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5180                session
5181                    .peer
5182                    .send(
5183                        contact_conn_id,
5184                        proto::UpdateContacts {
5185                            contacts: vec![updated_contact.clone()],
5186                            remove_contacts: Default::default(),
5187                            incoming_requests: Default::default(),
5188                            remove_incoming_requests: Default::default(),
5189                            outgoing_requests: Default::default(),
5190                            remove_outgoing_requests: Default::default(),
5191                        },
5192                    )
5193                    .trace_err();
5194            }
5195        }
5196    }
5197    Ok(())
5198}
5199
5200async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5201    log::info!("lost dev server connection, unsharing projects");
5202    let project_ids = session
5203        .db()
5204        .await
5205        .get_stale_dev_server_projects(session.connection_id)
5206        .await?;
5207
5208    for project_id in project_ids {
5209        // not unshare re-checks the connection ids match, so we get away with no transaction
5210        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5211    }
5212
5213    let user_id = session.dev_server().user_id;
5214    let update = session
5215        .db()
5216        .await
5217        .dev_server_projects_update(user_id)
5218        .await?;
5219
5220    send_dev_server_projects_update(user_id, update, session).await;
5221
5222    Ok(())
5223}
5224
5225async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5226    let mut contacts_to_update = HashSet::default();
5227
5228    let room_id;
5229    let canceled_calls_to_user_ids;
5230    let live_kit_room;
5231    let delete_live_kit_room;
5232    let room;
5233    let channel;
5234
5235    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5236        contacts_to_update.insert(session.user_id());
5237
5238        for project in left_room.left_projects.values() {
5239            project_left(project, session);
5240        }
5241
5242        room_id = RoomId::from_proto(left_room.room.id);
5243        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5244        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5245        delete_live_kit_room = left_room.deleted;
5246        room = mem::take(&mut left_room.room);
5247        channel = mem::take(&mut left_room.channel);
5248
5249        room_updated(&room, &session.peer);
5250    } else {
5251        return Ok(());
5252    }
5253
5254    if let Some(channel) = channel {
5255        channel_updated(
5256            &channel,
5257            &room,
5258            &session.peer,
5259            &*session.connection_pool().await,
5260        );
5261    }
5262
5263    {
5264        let pool = session.connection_pool().await;
5265        for canceled_user_id in canceled_calls_to_user_ids {
5266            for connection_id in pool.user_connection_ids(canceled_user_id) {
5267                session
5268                    .peer
5269                    .send(
5270                        connection_id,
5271                        proto::CallCanceled {
5272                            room_id: room_id.to_proto(),
5273                        },
5274                    )
5275                    .trace_err();
5276            }
5277            contacts_to_update.insert(canceled_user_id);
5278        }
5279    }
5280
5281    for contact_user_id in contacts_to_update {
5282        update_user_contacts(contact_user_id, &session).await?;
5283    }
5284
5285    if let Some(live_kit) = session.live_kit_client.as_ref() {
5286        live_kit
5287            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5288            .await
5289            .trace_err();
5290
5291        if delete_live_kit_room {
5292            live_kit.delete_room(live_kit_room).await.trace_err();
5293        }
5294    }
5295
5296    Ok(())
5297}
5298
5299async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5300    let left_channel_buffers = session
5301        .db()
5302        .await
5303        .leave_channel_buffers(session.connection_id)
5304        .await?;
5305
5306    for left_buffer in left_channel_buffers {
5307        channel_buffer_updated(
5308            session.connection_id,
5309            left_buffer.connections,
5310            &proto::UpdateChannelBufferCollaborators {
5311                channel_id: left_buffer.channel_id.to_proto(),
5312                collaborators: left_buffer.collaborators,
5313            },
5314            &session.peer,
5315        );
5316    }
5317
5318    Ok(())
5319}
5320
5321fn project_left(project: &db::LeftProject, session: &UserSession) {
5322    for connection_id in &project.connection_ids {
5323        if project.should_unshare {
5324            session
5325                .peer
5326                .send(
5327                    *connection_id,
5328                    proto::UnshareProject {
5329                        project_id: project.id.to_proto(),
5330                    },
5331                )
5332                .trace_err();
5333        } else {
5334            session
5335                .peer
5336                .send(
5337                    *connection_id,
5338                    proto::RemoveProjectCollaborator {
5339                        project_id: project.id.to_proto(),
5340                        peer_id: Some(session.connection_id.into()),
5341                    },
5342                )
5343                .trace_err();
5344        }
5345    }
5346}
5347
5348pub trait ResultExt {
5349    type Ok;
5350
5351    fn trace_err(self) -> Option<Self::Ok>;
5352}
5353
5354impl<T, E> ResultExt for Result<T, E>
5355where
5356    E: std::fmt::Debug,
5357{
5358    type Ok = T;
5359
5360    #[track_caller]
5361    fn trace_err(self) -> Option<T> {
5362        match self {
5363            Ok(value) => Some(value),
5364            Err(error) => {
5365                tracing::error!("{:?}", error);
5366                None
5367            }
5368        }
5369    }
5370}