rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  50        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(delete_dev_server_project))
 435            .add_request_handler(user_handler(create_dev_server))
 436            .add_request_handler(user_handler(regenerate_dev_server_token))
 437            .add_request_handler(user_handler(rename_dev_server))
 438            .add_request_handler(user_handler(delete_dev_server))
 439            .add_request_handler(dev_server_handler(share_dev_server_project))
 440            .add_request_handler(dev_server_handler(shutdown_dev_server))
 441            .add_request_handler(dev_server_handler(reconnect_dev_server))
 442            .add_message_handler(user_message_handler(leave_project))
 443            .add_request_handler(update_project)
 444            .add_request_handler(update_worktree)
 445            .add_message_handler(start_language_server)
 446            .add_message_handler(update_language_server)
 447            .add_message_handler(update_diagnostic_summary)
 448            .add_message_handler(update_worktree_settings)
 449            .add_request_handler(user_handler(
 450                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 451            ))
 452            .add_request_handler(user_handler(
 453                forward_project_request_for_owner::<proto::TaskTemplates>,
 454            ))
 455            .add_request_handler(user_handler(
 456                forward_read_only_project_request::<proto::GetHover>,
 457            ))
 458            .add_request_handler(user_handler(
 459                forward_read_only_project_request::<proto::GetDefinition>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_read_only_project_request::<proto::GetTypeDefinition>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetReferences>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::SearchProject>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetProjectSymbols>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_read_only_project_request::<proto::OpenBufferById>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::InlayHints>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::OpenBufferByPath>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_mutating_project_request::<proto::GetCompletions>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::GetCodeActions>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::ApplyCodeAction>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::PrepareRename>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::PerformRename>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::ReloadBuffers>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::FormatBuffers>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::CreateProjectEntry>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::RenameProjectEntry>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::CopyProjectEntry>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::OnTypeFormatting>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::BlameBuffer>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::MultiLspQuery>,
 547            ))
 548            .add_request_handler(user_handler(
 549                forward_mutating_project_request::<proto::RestartLanguageServers>,
 550            ))
 551            .add_request_handler(user_handler(
 552                forward_mutating_project_request::<proto::LinkedEditingRange>,
 553            ))
 554            .add_message_handler(create_buffer_for_peer)
 555            .add_request_handler(update_buffer)
 556            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 557            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 558            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 559            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 560            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 561            .add_request_handler(get_users)
 562            .add_request_handler(user_handler(fuzzy_search_users))
 563            .add_request_handler(user_handler(request_contact))
 564            .add_request_handler(user_handler(remove_contact))
 565            .add_request_handler(user_handler(respond_to_contact_request))
 566            .add_message_handler(subscribe_to_channels)
 567            .add_request_handler(user_handler(create_channel))
 568            .add_request_handler(user_handler(delete_channel))
 569            .add_request_handler(user_handler(invite_channel_member))
 570            .add_request_handler(user_handler(remove_channel_member))
 571            .add_request_handler(user_handler(set_channel_member_role))
 572            .add_request_handler(user_handler(set_channel_visibility))
 573            .add_request_handler(user_handler(rename_channel))
 574            .add_request_handler(user_handler(join_channel_buffer))
 575            .add_request_handler(user_handler(leave_channel_buffer))
 576            .add_message_handler(user_message_handler(update_channel_buffer))
 577            .add_request_handler(user_handler(rejoin_channel_buffers))
 578            .add_request_handler(user_handler(get_channel_members))
 579            .add_request_handler(user_handler(respond_to_channel_invite))
 580            .add_request_handler(user_handler(join_channel))
 581            .add_request_handler(user_handler(join_channel_chat))
 582            .add_message_handler(user_message_handler(leave_channel_chat))
 583            .add_request_handler(user_handler(send_channel_message))
 584            .add_request_handler(user_handler(remove_channel_message))
 585            .add_request_handler(user_handler(update_channel_message))
 586            .add_request_handler(user_handler(get_channel_messages))
 587            .add_request_handler(user_handler(get_channel_messages_by_id))
 588            .add_request_handler(user_handler(get_notifications))
 589            .add_request_handler(user_handler(mark_notification_as_read))
 590            .add_request_handler(user_handler(move_channel))
 591            .add_request_handler(user_handler(follow))
 592            .add_message_handler(user_message_handler(unfollow))
 593            .add_message_handler(user_message_handler(update_followers))
 594            .add_request_handler(user_handler(get_private_user_info))
 595            .add_message_handler(user_message_handler(acknowledge_channel_message))
 596            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 597            .add_request_handler(user_handler(get_supermaven_api_key))
 598            .add_request_handler(user_handler(
 599                forward_mutating_project_request::<proto::OpenContext>,
 600            ))
 601            .add_request_handler(user_handler(
 602                forward_mutating_project_request::<proto::SynchronizeContexts>,
 603            ))
 604            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 605            .add_message_handler(update_context)
 606            .add_streaming_request_handler({
 607                let app_state = app_state.clone();
 608                move |request, response, session| {
 609                    complete_with_language_model(
 610                        request,
 611                        response,
 612                        session,
 613                        app_state.config.openai_api_key.clone(),
 614                        app_state.config.google_ai_api_key.clone(),
 615                        app_state.config.anthropic_api_key.clone(),
 616                    )
 617                }
 618            })
 619            .add_request_handler({
 620                let app_state = app_state.clone();
 621                user_handler(move |request, response, session| {
 622                    count_tokens_with_language_model(
 623                        request,
 624                        response,
 625                        session,
 626                        app_state.config.google_ai_api_key.clone(),
 627                    )
 628                })
 629            })
 630            .add_request_handler({
 631                user_handler(move |request, response, session| {
 632                    get_cached_embeddings(request, response, session)
 633                })
 634            })
 635            .add_request_handler({
 636                let app_state = app_state.clone();
 637                user_handler(move |request, response, session| {
 638                    compute_embeddings(
 639                        request,
 640                        response,
 641                        session,
 642                        app_state.config.openai_api_key.clone(),
 643                    )
 644                })
 645            });
 646
 647        Arc::new(server)
 648    }
 649
 650    pub async fn start(&self) -> Result<()> {
 651        let server_id = *self.id.lock();
 652        let app_state = self.app_state.clone();
 653        let peer = self.peer.clone();
 654        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 655        let pool = self.connection_pool.clone();
 656        let live_kit_client = self.app_state.live_kit_client.clone();
 657
 658        let span = info_span!("start server");
 659        self.app_state.executor.spawn_detached(
 660            async move {
 661                tracing::info!("waiting for cleanup timeout");
 662                timeout.await;
 663                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 664                if let Some((room_ids, channel_ids)) = app_state
 665                    .db
 666                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 667                    .await
 668                    .trace_err()
 669                {
 670                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 671                    tracing::info!(
 672                        stale_channel_buffer_count = channel_ids.len(),
 673                        "retrieved stale channel buffers"
 674                    );
 675
 676                    for channel_id in channel_ids {
 677                        if let Some(refreshed_channel_buffer) = app_state
 678                            .db
 679                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 680                            .await
 681                            .trace_err()
 682                        {
 683                            for connection_id in refreshed_channel_buffer.connection_ids {
 684                                peer.send(
 685                                    connection_id,
 686                                    proto::UpdateChannelBufferCollaborators {
 687                                        channel_id: channel_id.to_proto(),
 688                                        collaborators: refreshed_channel_buffer
 689                                            .collaborators
 690                                            .clone(),
 691                                    },
 692                                )
 693                                .trace_err();
 694                            }
 695                        }
 696                    }
 697
 698                    for room_id in room_ids {
 699                        let mut contacts_to_update = HashSet::default();
 700                        let mut canceled_calls_to_user_ids = Vec::new();
 701                        let mut live_kit_room = String::new();
 702                        let mut delete_live_kit_room = false;
 703
 704                        if let Some(mut refreshed_room) = app_state
 705                            .db
 706                            .clear_stale_room_participants(room_id, server_id)
 707                            .await
 708                            .trace_err()
 709                        {
 710                            tracing::info!(
 711                                room_id = room_id.0,
 712                                new_participant_count = refreshed_room.room.participants.len(),
 713                                "refreshed room"
 714                            );
 715                            room_updated(&refreshed_room.room, &peer);
 716                            if let Some(channel) = refreshed_room.channel.as_ref() {
 717                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 718                            }
 719                            contacts_to_update
 720                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 721                            contacts_to_update
 722                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 723                            canceled_calls_to_user_ids =
 724                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 725                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 726                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 727                        }
 728
 729                        {
 730                            let pool = pool.lock();
 731                            for canceled_user_id in canceled_calls_to_user_ids {
 732                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 733                                    peer.send(
 734                                        connection_id,
 735                                        proto::CallCanceled {
 736                                            room_id: room_id.to_proto(),
 737                                        },
 738                                    )
 739                                    .trace_err();
 740                                }
 741                            }
 742                        }
 743
 744                        for user_id in contacts_to_update {
 745                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 746                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 747                            if let Some((busy, contacts)) = busy.zip(contacts) {
 748                                let pool = pool.lock();
 749                                let updated_contact = contact_for_user(user_id, busy, &pool);
 750                                for contact in contacts {
 751                                    if let db::Contact::Accepted {
 752                                        user_id: contact_user_id,
 753                                        ..
 754                                    } = contact
 755                                    {
 756                                        for contact_conn_id in
 757                                            pool.user_connection_ids(contact_user_id)
 758                                        {
 759                                            peer.send(
 760                                                contact_conn_id,
 761                                                proto::UpdateContacts {
 762                                                    contacts: vec![updated_contact.clone()],
 763                                                    remove_contacts: Default::default(),
 764                                                    incoming_requests: Default::default(),
 765                                                    remove_incoming_requests: Default::default(),
 766                                                    outgoing_requests: Default::default(),
 767                                                    remove_outgoing_requests: Default::default(),
 768                                                },
 769                                            )
 770                                            .trace_err();
 771                                        }
 772                                    }
 773                                }
 774                            }
 775                        }
 776
 777                        if let Some(live_kit) = live_kit_client.as_ref() {
 778                            if delete_live_kit_room {
 779                                live_kit.delete_room(live_kit_room).await.trace_err();
 780                            }
 781                        }
 782                    }
 783                }
 784
 785                app_state
 786                    .db
 787                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 788                    .await
 789                    .trace_err();
 790            }
 791            .instrument(span),
 792        );
 793        Ok(())
 794    }
 795
 796    pub fn teardown(&self) {
 797        self.peer.teardown();
 798        self.connection_pool.lock().reset();
 799        let _ = self.teardown.send(true);
 800    }
 801
 802    #[cfg(test)]
 803    pub fn reset(&self, id: ServerId) {
 804        self.teardown();
 805        *self.id.lock() = id;
 806        self.peer.reset(id.0 as u32);
 807        let _ = self.teardown.send(false);
 808    }
 809
 810    #[cfg(test)]
 811    pub fn id(&self) -> ServerId {
 812        *self.id.lock()
 813    }
 814
 815    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 816    where
 817        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 818        Fut: 'static + Send + Future<Output = Result<()>>,
 819        M: EnvelopedMessage,
 820    {
 821        let prev_handler = self.handlers.insert(
 822            TypeId::of::<M>(),
 823            Box::new(move |envelope, session| {
 824                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 825                let received_at = envelope.received_at;
 826                tracing::info!("message received");
 827                let start_time = Instant::now();
 828                let future = (handler)(*envelope, session);
 829                async move {
 830                    let result = future.await;
 831                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 832                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 833                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 834                    let payload_type = M::NAME;
 835
 836                    match result {
 837                        Err(error) => {
 838                            tracing::error!(
 839                                ?error,
 840                                total_duration_ms,
 841                                processing_duration_ms,
 842                                queue_duration_ms,
 843                                payload_type,
 844                                "error handling message"
 845                            )
 846                        }
 847                        Ok(()) => tracing::info!(
 848                            total_duration_ms,
 849                            processing_duration_ms,
 850                            queue_duration_ms,
 851                            "finished handling message"
 852                        ),
 853                    }
 854                }
 855                .boxed()
 856            }),
 857        );
 858        if prev_handler.is_some() {
 859            panic!("registered a handler for the same message twice");
 860        }
 861        self
 862    }
 863
 864    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 865    where
 866        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 867        Fut: 'static + Send + Future<Output = Result<()>>,
 868        M: EnvelopedMessage,
 869    {
 870        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 871        self
 872    }
 873
 874    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 875    where
 876        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 877        Fut: Send + Future<Output = Result<()>>,
 878        M: RequestMessage,
 879    {
 880        let handler = Arc::new(handler);
 881        self.add_handler(move |envelope, session| {
 882            let receipt = envelope.receipt();
 883            let handler = handler.clone();
 884            async move {
 885                let peer = session.peer.clone();
 886                let responded = Arc::new(AtomicBool::default());
 887                let response = Response {
 888                    peer: peer.clone(),
 889                    responded: responded.clone(),
 890                    receipt,
 891                };
 892                match (handler)(envelope.payload, response, session).await {
 893                    Ok(()) => {
 894                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 895                            Ok(())
 896                        } else {
 897                            Err(anyhow!("handler did not send a response"))?
 898                        }
 899                    }
 900                    Err(error) => {
 901                        let proto_err = match &error {
 902                            Error::Internal(err) => err.to_proto(),
 903                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 904                        };
 905                        peer.respond_with_error(receipt, proto_err)?;
 906                        Err(error)
 907                    }
 908                }
 909            }
 910        })
 911    }
 912
 913    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 914    where
 915        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 916        Fut: Send + Future<Output = Result<()>>,
 917        M: RequestMessage,
 918    {
 919        let handler = Arc::new(handler);
 920        self.add_handler(move |envelope, session| {
 921            let receipt = envelope.receipt();
 922            let handler = handler.clone();
 923            async move {
 924                let peer = session.peer.clone();
 925                let response = StreamingResponse {
 926                    peer: peer.clone(),
 927                    receipt,
 928                };
 929                match (handler)(envelope.payload, response, session).await {
 930                    Ok(()) => {
 931                        peer.end_stream(receipt)?;
 932                        Ok(())
 933                    }
 934                    Err(error) => {
 935                        let proto_err = match &error {
 936                            Error::Internal(err) => err.to_proto(),
 937                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 938                        };
 939                        peer.respond_with_error(receipt, proto_err)?;
 940                        Err(error)
 941                    }
 942                }
 943            }
 944        })
 945    }
 946
 947    #[allow(clippy::too_many_arguments)]
 948    pub fn handle_connection(
 949        self: &Arc<Self>,
 950        connection: Connection,
 951        address: String,
 952        principal: Principal,
 953        zed_version: ZedVersion,
 954        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 955        executor: Executor,
 956    ) -> impl Future<Output = ()> {
 957        let this = self.clone();
 958        let span = info_span!("handle connection", %address,
 959            connection_id=field::Empty,
 960            user_id=field::Empty,
 961            login=field::Empty,
 962            impersonator=field::Empty,
 963            dev_server_id=field::Empty
 964        );
 965        principal.update_span(&span);
 966
 967        let mut teardown = self.teardown.subscribe();
 968        async move {
 969            if *teardown.borrow() {
 970                tracing::error!("server is tearing down");
 971                return
 972            }
 973            let (connection_id, handle_io, mut incoming_rx) = this
 974                .peer
 975                .add_connection(connection, {
 976                    let executor = executor.clone();
 977                    move |duration| executor.sleep(duration)
 978                });
 979            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 980            tracing::info!("connection opened");
 981
 982            let http_client = match IsahcHttpClient::new() {
 983                Ok(http_client) => Arc::new(http_client),
 984                Err(error) => {
 985                    tracing::error!(?error, "failed to create HTTP client");
 986                    return;
 987                }
 988            };
 989
 990            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 991                Some(Arc::new(SupermavenAdminApi::new(
 992                    supermaven_admin_api_key.to_string(),
 993                    http_client.clone(),
 994                )))
 995            } else {
 996                None
 997            };
 998
 999            let session = Session {
1000                principal: principal.clone(),
1001                connection_id,
1002                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1003                peer: this.peer.clone(),
1004                connection_pool: this.connection_pool.clone(),
1005                live_kit_client: this.app_state.live_kit_client.clone(),
1006                http_client,
1007                rate_limiter: this.app_state.rate_limiter.clone(),
1008                _executor: executor.clone(),
1009                supermaven_client,
1010            };
1011
1012            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1013                tracing::error!(?error, "failed to send initial client update");
1014                return;
1015            }
1016
1017            let handle_io = handle_io.fuse();
1018            futures::pin_mut!(handle_io);
1019
1020            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1021            // This prevents deadlocks when e.g., client A performs a request to client B and
1022            // client B performs a request to client A. If both clients stop processing further
1023            // messages until their respective request completes, they won't have a chance to
1024            // respond to the other client's request and cause a deadlock.
1025            //
1026            // This arrangement ensures we will attempt to process earlier messages first, but fall
1027            // back to processing messages arrived later in the spirit of making progress.
1028            let mut foreground_message_handlers = FuturesUnordered::new();
1029            let concurrent_handlers = Arc::new(Semaphore::new(256));
1030            loop {
1031                let next_message = async {
1032                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1033                    let message = incoming_rx.next().await;
1034                    (permit, message)
1035                }.fuse();
1036                futures::pin_mut!(next_message);
1037                futures::select_biased! {
1038                    _ = teardown.changed().fuse() => return,
1039                    result = handle_io => {
1040                        if let Err(error) = result {
1041                            tracing::error!(?error, "error handling I/O");
1042                        }
1043                        break;
1044                    }
1045                    _ = foreground_message_handlers.next() => {}
1046                    next_message = next_message => {
1047                        let (permit, message) = next_message;
1048                        if let Some(message) = message {
1049                            let type_name = message.payload_type_name();
1050                            // note: we copy all the fields from the parent span so we can query them in the logs.
1051                            // (https://github.com/tokio-rs/tracing/issues/2670).
1052                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1053                                user_id=field::Empty,
1054                                login=field::Empty,
1055                                impersonator=field::Empty,
1056                                dev_server_id=field::Empty
1057                            );
1058                            principal.update_span(&span);
1059                            let span_enter = span.enter();
1060                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1061                                let is_background = message.is_background();
1062                                let handle_message = (handler)(message, session.clone());
1063                                drop(span_enter);
1064
1065                                let handle_message = async move {
1066                                    handle_message.await;
1067                                    drop(permit);
1068                                }.instrument(span);
1069                                if is_background {
1070                                    executor.spawn_detached(handle_message);
1071                                } else {
1072                                    foreground_message_handlers.push(handle_message);
1073                                }
1074                            } else {
1075                                tracing::error!("no message handler");
1076                            }
1077                        } else {
1078                            tracing::info!("connection closed");
1079                            break;
1080                        }
1081                    }
1082                }
1083            }
1084
1085            drop(foreground_message_handlers);
1086            tracing::info!("signing out");
1087            if let Err(error) = connection_lost(session, teardown, executor).await {
1088                tracing::error!(?error, "error signing out");
1089            }
1090
1091        }.instrument(span)
1092    }
1093
1094    async fn send_initial_client_update(
1095        &self,
1096        connection_id: ConnectionId,
1097        principal: &Principal,
1098        zed_version: ZedVersion,
1099        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1100        session: &Session,
1101    ) -> Result<()> {
1102        self.peer.send(
1103            connection_id,
1104            proto::Hello {
1105                peer_id: Some(connection_id.into()),
1106            },
1107        )?;
1108        tracing::info!("sent hello message");
1109        if let Some(send_connection_id) = send_connection_id.take() {
1110            let _ = send_connection_id.send(connection_id);
1111        }
1112
1113        match principal {
1114            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1115                if !user.connected_once {
1116                    self.peer.send(connection_id, proto::ShowContacts {})?;
1117                    self.app_state
1118                        .db
1119                        .set_user_connected_once(user.id, true)
1120                        .await?;
1121                }
1122
1123                let (contacts, dev_server_projects) = future::try_join(
1124                    self.app_state.db.get_contacts(user.id),
1125                    self.app_state.db.dev_server_projects_update(user.id),
1126                )
1127                .await?;
1128
1129                {
1130                    let mut pool = self.connection_pool.lock();
1131                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1132                    self.peer.send(
1133                        connection_id,
1134                        build_initial_contacts_update(contacts, &pool),
1135                    )?;
1136                }
1137
1138                if should_auto_subscribe_to_channels(zed_version) {
1139                    subscribe_user_to_channels(user.id, session).await?;
1140                }
1141
1142                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1143
1144                if let Some(incoming_call) =
1145                    self.app_state.db.incoming_call_for_user(user.id).await?
1146                {
1147                    self.peer.send(connection_id, incoming_call)?;
1148                }
1149
1150                update_user_contacts(user.id, &session).await?;
1151            }
1152            Principal::DevServer(dev_server) => {
1153                {
1154                    let mut pool = self.connection_pool.lock();
1155                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1156                    {
1157                        self.peer.send(
1158                            stale_connection_id,
1159                            proto::ShutdownDevServer {
1160                                reason: Some(
1161                                    "another dev server connected with the same token".to_string(),
1162                                ),
1163                            },
1164                        )?;
1165                        pool.remove_connection(stale_connection_id)?;
1166                    };
1167                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1168                }
1169
1170                let projects = self
1171                    .app_state
1172                    .db
1173                    .get_projects_for_dev_server(dev_server.id)
1174                    .await?;
1175                self.peer
1176                    .send(connection_id, proto::DevServerInstructions { projects })?;
1177
1178                let status = self
1179                    .app_state
1180                    .db
1181                    .dev_server_projects_update(dev_server.user_id)
1182                    .await?;
1183                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1184            }
1185        }
1186
1187        Ok(())
1188    }
1189
1190    pub async fn invite_code_redeemed(
1191        self: &Arc<Self>,
1192        inviter_id: UserId,
1193        invitee_id: UserId,
1194    ) -> Result<()> {
1195        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1196            if let Some(code) = &user.invite_code {
1197                let pool = self.connection_pool.lock();
1198                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1199                for connection_id in pool.user_connection_ids(inviter_id) {
1200                    self.peer.send(
1201                        connection_id,
1202                        proto::UpdateContacts {
1203                            contacts: vec![invitee_contact.clone()],
1204                            ..Default::default()
1205                        },
1206                    )?;
1207                    self.peer.send(
1208                        connection_id,
1209                        proto::UpdateInviteInfo {
1210                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1211                            count: user.invite_count as u32,
1212                        },
1213                    )?;
1214                }
1215            }
1216        }
1217        Ok(())
1218    }
1219
1220    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1221        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1222            if let Some(invite_code) = &user.invite_code {
1223                let pool = self.connection_pool.lock();
1224                for connection_id in pool.user_connection_ids(user_id) {
1225                    self.peer.send(
1226                        connection_id,
1227                        proto::UpdateInviteInfo {
1228                            url: format!(
1229                                "{}{}",
1230                                self.app_state.config.invite_link_prefix, invite_code
1231                            ),
1232                            count: user.invite_count as u32,
1233                        },
1234                    )?;
1235                }
1236            }
1237        }
1238        Ok(())
1239    }
1240
1241    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1242        ServerSnapshot {
1243            connection_pool: ConnectionPoolGuard {
1244                guard: self.connection_pool.lock(),
1245                _not_send: PhantomData,
1246            },
1247            peer: &self.peer,
1248        }
1249    }
1250}
1251
1252impl<'a> Deref for ConnectionPoolGuard<'a> {
1253    type Target = ConnectionPool;
1254
1255    fn deref(&self) -> &Self::Target {
1256        &self.guard
1257    }
1258}
1259
1260impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1261    fn deref_mut(&mut self) -> &mut Self::Target {
1262        &mut self.guard
1263    }
1264}
1265
1266impl<'a> Drop for ConnectionPoolGuard<'a> {
1267    fn drop(&mut self) {
1268        #[cfg(test)]
1269        self.check_invariants();
1270    }
1271}
1272
1273fn broadcast<F>(
1274    sender_id: Option<ConnectionId>,
1275    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1276    mut f: F,
1277) where
1278    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1279{
1280    for receiver_id in receiver_ids {
1281        if Some(receiver_id) != sender_id {
1282            if let Err(error) = f(receiver_id) {
1283                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1284            }
1285        }
1286    }
1287}
1288
1289pub struct ProtocolVersion(u32);
1290
1291impl Header for ProtocolVersion {
1292    fn name() -> &'static HeaderName {
1293        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1294        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1295    }
1296
1297    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1298    where
1299        Self: Sized,
1300        I: Iterator<Item = &'i axum::http::HeaderValue>,
1301    {
1302        let version = values
1303            .next()
1304            .ok_or_else(axum::headers::Error::invalid)?
1305            .to_str()
1306            .map_err(|_| axum::headers::Error::invalid())?
1307            .parse()
1308            .map_err(|_| axum::headers::Error::invalid())?;
1309        Ok(Self(version))
1310    }
1311
1312    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1313        values.extend([self.0.to_string().parse().unwrap()]);
1314    }
1315}
1316
1317pub struct AppVersionHeader(SemanticVersion);
1318impl Header for AppVersionHeader {
1319    fn name() -> &'static HeaderName {
1320        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1321        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1322    }
1323
1324    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1325    where
1326        Self: Sized,
1327        I: Iterator<Item = &'i axum::http::HeaderValue>,
1328    {
1329        let version = values
1330            .next()
1331            .ok_or_else(axum::headers::Error::invalid)?
1332            .to_str()
1333            .map_err(|_| axum::headers::Error::invalid())?
1334            .parse()
1335            .map_err(|_| axum::headers::Error::invalid())?;
1336        Ok(Self(version))
1337    }
1338
1339    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1340        values.extend([self.0.to_string().parse().unwrap()]);
1341    }
1342}
1343
1344pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1345    Router::new()
1346        .route("/rpc", get(handle_websocket_request))
1347        .layer(
1348            ServiceBuilder::new()
1349                .layer(Extension(server.app_state.clone()))
1350                .layer(middleware::from_fn(auth::validate_header)),
1351        )
1352        .route("/metrics", get(handle_metrics))
1353        .layer(Extension(server))
1354}
1355
1356pub async fn handle_websocket_request(
1357    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1358    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1359    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1360    Extension(server): Extension<Arc<Server>>,
1361    Extension(principal): Extension<Principal>,
1362    ws: WebSocketUpgrade,
1363) -> axum::response::Response {
1364    if protocol_version != rpc::PROTOCOL_VERSION {
1365        return (
1366            StatusCode::UPGRADE_REQUIRED,
1367            "client must be upgraded".to_string(),
1368        )
1369            .into_response();
1370    }
1371
1372    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1373        return (
1374            StatusCode::UPGRADE_REQUIRED,
1375            "no version header found".to_string(),
1376        )
1377            .into_response();
1378    };
1379
1380    if !version.can_collaborate() {
1381        return (
1382            StatusCode::UPGRADE_REQUIRED,
1383            "client must be upgraded".to_string(),
1384        )
1385            .into_response();
1386    }
1387
1388    let socket_address = socket_address.to_string();
1389    ws.on_upgrade(move |socket| {
1390        let socket = socket
1391            .map_ok(to_tungstenite_message)
1392            .err_into()
1393            .with(|message| async move { Ok(to_axum_message(message)) });
1394        let connection = Connection::new(Box::pin(socket));
1395        async move {
1396            server
1397                .handle_connection(
1398                    connection,
1399                    socket_address,
1400                    principal,
1401                    version,
1402                    None,
1403                    Executor::Production,
1404                )
1405                .await;
1406        }
1407    })
1408}
1409
1410pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1411    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1412    let connections_metric = CONNECTIONS_METRIC
1413        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1414
1415    let connections = server
1416        .connection_pool
1417        .lock()
1418        .connections()
1419        .filter(|connection| !connection.admin)
1420        .count();
1421    connections_metric.set(connections as _);
1422
1423    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1424    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1425        register_int_gauge!(
1426            "shared_projects",
1427            "number of open projects with one or more guests"
1428        )
1429        .unwrap()
1430    });
1431
1432    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1433    shared_projects_metric.set(shared_projects as _);
1434
1435    let encoder = prometheus::TextEncoder::new();
1436    let metric_families = prometheus::gather();
1437    let encoded_metrics = encoder
1438        .encode_to_string(&metric_families)
1439        .map_err(|err| anyhow!("{}", err))?;
1440    Ok(encoded_metrics)
1441}
1442
1443#[instrument(err, skip(executor))]
1444async fn connection_lost(
1445    session: Session,
1446    mut teardown: watch::Receiver<bool>,
1447    executor: Executor,
1448) -> Result<()> {
1449    session.peer.disconnect(session.connection_id);
1450    session
1451        .connection_pool()
1452        .await
1453        .remove_connection(session.connection_id)?;
1454
1455    session
1456        .db()
1457        .await
1458        .connection_lost(session.connection_id)
1459        .await
1460        .trace_err();
1461
1462    futures::select_biased! {
1463        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1464            match &session.principal {
1465                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1466                    let session = session.for_user().unwrap();
1467
1468                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1469                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1470                    leave_channel_buffers_for_session(&session)
1471                        .await
1472                        .trace_err();
1473
1474                    if !session
1475                        .connection_pool()
1476                        .await
1477                        .is_user_online(session.user_id())
1478                    {
1479                        let db = session.db().await;
1480                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1481                            room_updated(&room, &session.peer);
1482                        }
1483                    }
1484
1485                    update_user_contacts(session.user_id(), &session).await?;
1486                },
1487            Principal::DevServer(_) => {
1488                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1489            },
1490        }
1491        },
1492        _ = teardown.changed().fuse() => {}
1493    }
1494
1495    Ok(())
1496}
1497
1498/// Acknowledges a ping from a client, used to keep the connection alive.
1499async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1500    response.send(proto::Ack {})?;
1501    Ok(())
1502}
1503
1504/// Creates a new room for calling (outside of channels)
1505async fn create_room(
1506    _request: proto::CreateRoom,
1507    response: Response<proto::CreateRoom>,
1508    session: UserSession,
1509) -> Result<()> {
1510    let live_kit_room = nanoid::nanoid!(30);
1511
1512    let live_kit_connection_info = util::maybe!(async {
1513        let live_kit = session.live_kit_client.as_ref();
1514        let live_kit = live_kit?;
1515        let user_id = session.user_id().to_string();
1516
1517        let token = live_kit
1518            .room_token(&live_kit_room, &user_id.to_string())
1519            .trace_err()?;
1520
1521        Some(proto::LiveKitConnectionInfo {
1522            server_url: live_kit.url().into(),
1523            token,
1524            can_publish: true,
1525        })
1526    })
1527    .await;
1528
1529    let room = session
1530        .db()
1531        .await
1532        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1533        .await?;
1534
1535    response.send(proto::CreateRoomResponse {
1536        room: Some(room.clone()),
1537        live_kit_connection_info,
1538    })?;
1539
1540    update_user_contacts(session.user_id(), &session).await?;
1541    Ok(())
1542}
1543
1544/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1545async fn join_room(
1546    request: proto::JoinRoom,
1547    response: Response<proto::JoinRoom>,
1548    session: UserSession,
1549) -> Result<()> {
1550    let room_id = RoomId::from_proto(request.id);
1551
1552    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1553
1554    if let Some(channel_id) = channel_id {
1555        return join_channel_internal(channel_id, Box::new(response), session).await;
1556    }
1557
1558    let joined_room = {
1559        let room = session
1560            .db()
1561            .await
1562            .join_room(room_id, session.user_id(), session.connection_id)
1563            .await?;
1564        room_updated(&room.room, &session.peer);
1565        room.into_inner()
1566    };
1567
1568    for connection_id in session
1569        .connection_pool()
1570        .await
1571        .user_connection_ids(session.user_id())
1572    {
1573        session
1574            .peer
1575            .send(
1576                connection_id,
1577                proto::CallCanceled {
1578                    room_id: room_id.to_proto(),
1579                },
1580            )
1581            .trace_err();
1582    }
1583
1584    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1585        if let Some(token) = live_kit
1586            .room_token(
1587                &joined_room.room.live_kit_room,
1588                &session.user_id().to_string(),
1589            )
1590            .trace_err()
1591        {
1592            Some(proto::LiveKitConnectionInfo {
1593                server_url: live_kit.url().into(),
1594                token,
1595                can_publish: true,
1596            })
1597        } else {
1598            None
1599        }
1600    } else {
1601        None
1602    };
1603
1604    response.send(proto::JoinRoomResponse {
1605        room: Some(joined_room.room),
1606        channel_id: None,
1607        live_kit_connection_info,
1608    })?;
1609
1610    update_user_contacts(session.user_id(), &session).await?;
1611    Ok(())
1612}
1613
1614/// Rejoin room is used to reconnect to a room after connection errors.
1615async fn rejoin_room(
1616    request: proto::RejoinRoom,
1617    response: Response<proto::RejoinRoom>,
1618    session: UserSession,
1619) -> Result<()> {
1620    let room;
1621    let channel;
1622    {
1623        let mut rejoined_room = session
1624            .db()
1625            .await
1626            .rejoin_room(request, session.user_id(), session.connection_id)
1627            .await?;
1628
1629        response.send(proto::RejoinRoomResponse {
1630            room: Some(rejoined_room.room.clone()),
1631            reshared_projects: rejoined_room
1632                .reshared_projects
1633                .iter()
1634                .map(|project| proto::ResharedProject {
1635                    id: project.id.to_proto(),
1636                    collaborators: project
1637                        .collaborators
1638                        .iter()
1639                        .map(|collaborator| collaborator.to_proto())
1640                        .collect(),
1641                })
1642                .collect(),
1643            rejoined_projects: rejoined_room
1644                .rejoined_projects
1645                .iter()
1646                .map(|rejoined_project| rejoined_project.to_proto())
1647                .collect(),
1648        })?;
1649        room_updated(&rejoined_room.room, &session.peer);
1650
1651        for project in &rejoined_room.reshared_projects {
1652            for collaborator in &project.collaborators {
1653                session
1654                    .peer
1655                    .send(
1656                        collaborator.connection_id,
1657                        proto::UpdateProjectCollaborator {
1658                            project_id: project.id.to_proto(),
1659                            old_peer_id: Some(project.old_connection_id.into()),
1660                            new_peer_id: Some(session.connection_id.into()),
1661                        },
1662                    )
1663                    .trace_err();
1664            }
1665
1666            broadcast(
1667                Some(session.connection_id),
1668                project
1669                    .collaborators
1670                    .iter()
1671                    .map(|collaborator| collaborator.connection_id),
1672                |connection_id| {
1673                    session.peer.forward_send(
1674                        session.connection_id,
1675                        connection_id,
1676                        proto::UpdateProject {
1677                            project_id: project.id.to_proto(),
1678                            worktrees: project.worktrees.clone(),
1679                        },
1680                    )
1681                },
1682            );
1683        }
1684
1685        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1686
1687        let rejoined_room = rejoined_room.into_inner();
1688
1689        room = rejoined_room.room;
1690        channel = rejoined_room.channel;
1691    }
1692
1693    if let Some(channel) = channel {
1694        channel_updated(
1695            &channel,
1696            &room,
1697            &session.peer,
1698            &*session.connection_pool().await,
1699        );
1700    }
1701
1702    update_user_contacts(session.user_id(), &session).await?;
1703    Ok(())
1704}
1705
1706fn notify_rejoined_projects(
1707    rejoined_projects: &mut Vec<RejoinedProject>,
1708    session: &UserSession,
1709) -> Result<()> {
1710    for project in rejoined_projects.iter() {
1711        for collaborator in &project.collaborators {
1712            session
1713                .peer
1714                .send(
1715                    collaborator.connection_id,
1716                    proto::UpdateProjectCollaborator {
1717                        project_id: project.id.to_proto(),
1718                        old_peer_id: Some(project.old_connection_id.into()),
1719                        new_peer_id: Some(session.connection_id.into()),
1720                    },
1721                )
1722                .trace_err();
1723        }
1724    }
1725
1726    for project in rejoined_projects {
1727        for worktree in mem::take(&mut project.worktrees) {
1728            #[cfg(any(test, feature = "test-support"))]
1729            const MAX_CHUNK_SIZE: usize = 2;
1730            #[cfg(not(any(test, feature = "test-support")))]
1731            const MAX_CHUNK_SIZE: usize = 256;
1732
1733            // Stream this worktree's entries.
1734            let message = proto::UpdateWorktree {
1735                project_id: project.id.to_proto(),
1736                worktree_id: worktree.id,
1737                abs_path: worktree.abs_path.clone(),
1738                root_name: worktree.root_name,
1739                updated_entries: worktree.updated_entries,
1740                removed_entries: worktree.removed_entries,
1741                scan_id: worktree.scan_id,
1742                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1743                updated_repositories: worktree.updated_repositories,
1744                removed_repositories: worktree.removed_repositories,
1745            };
1746            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1747                session.peer.send(session.connection_id, update.clone())?;
1748            }
1749
1750            // Stream this worktree's diagnostics.
1751            for summary in worktree.diagnostic_summaries {
1752                session.peer.send(
1753                    session.connection_id,
1754                    proto::UpdateDiagnosticSummary {
1755                        project_id: project.id.to_proto(),
1756                        worktree_id: worktree.id,
1757                        summary: Some(summary),
1758                    },
1759                )?;
1760            }
1761
1762            for settings_file in worktree.settings_files {
1763                session.peer.send(
1764                    session.connection_id,
1765                    proto::UpdateWorktreeSettings {
1766                        project_id: project.id.to_proto(),
1767                        worktree_id: worktree.id,
1768                        path: settings_file.path,
1769                        content: Some(settings_file.content),
1770                    },
1771                )?;
1772            }
1773        }
1774
1775        for language_server in &project.language_servers {
1776            session.peer.send(
1777                session.connection_id,
1778                proto::UpdateLanguageServer {
1779                    project_id: project.id.to_proto(),
1780                    language_server_id: language_server.id,
1781                    variant: Some(
1782                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1783                            proto::LspDiskBasedDiagnosticsUpdated {},
1784                        ),
1785                    ),
1786                },
1787            )?;
1788        }
1789    }
1790    Ok(())
1791}
1792
1793/// leave room disconnects from the room.
1794async fn leave_room(
1795    _: proto::LeaveRoom,
1796    response: Response<proto::LeaveRoom>,
1797    session: UserSession,
1798) -> Result<()> {
1799    leave_room_for_session(&session, session.connection_id).await?;
1800    response.send(proto::Ack {})?;
1801    Ok(())
1802}
1803
1804/// Updates the permissions of someone else in the room.
1805async fn set_room_participant_role(
1806    request: proto::SetRoomParticipantRole,
1807    response: Response<proto::SetRoomParticipantRole>,
1808    session: UserSession,
1809) -> Result<()> {
1810    let user_id = UserId::from_proto(request.user_id);
1811    let role = ChannelRole::from(request.role());
1812
1813    let (live_kit_room, can_publish) = {
1814        let room = session
1815            .db()
1816            .await
1817            .set_room_participant_role(
1818                session.user_id(),
1819                RoomId::from_proto(request.room_id),
1820                user_id,
1821                role,
1822            )
1823            .await?;
1824
1825        let live_kit_room = room.live_kit_room.clone();
1826        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1827        room_updated(&room, &session.peer);
1828        (live_kit_room, can_publish)
1829    };
1830
1831    if let Some(live_kit) = session.live_kit_client.as_ref() {
1832        live_kit
1833            .update_participant(
1834                live_kit_room.clone(),
1835                request.user_id.to_string(),
1836                live_kit_server::proto::ParticipantPermission {
1837                    can_subscribe: true,
1838                    can_publish,
1839                    can_publish_data: can_publish,
1840                    hidden: false,
1841                    recorder: false,
1842                },
1843            )
1844            .await
1845            .trace_err();
1846    }
1847
1848    response.send(proto::Ack {})?;
1849    Ok(())
1850}
1851
1852/// Call someone else into the current room
1853async fn call(
1854    request: proto::Call,
1855    response: Response<proto::Call>,
1856    session: UserSession,
1857) -> Result<()> {
1858    let room_id = RoomId::from_proto(request.room_id);
1859    let calling_user_id = session.user_id();
1860    let calling_connection_id = session.connection_id;
1861    let called_user_id = UserId::from_proto(request.called_user_id);
1862    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1863    if !session
1864        .db()
1865        .await
1866        .has_contact(calling_user_id, called_user_id)
1867        .await?
1868    {
1869        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1870    }
1871
1872    let incoming_call = {
1873        let (room, incoming_call) = &mut *session
1874            .db()
1875            .await
1876            .call(
1877                room_id,
1878                calling_user_id,
1879                calling_connection_id,
1880                called_user_id,
1881                initial_project_id,
1882            )
1883            .await?;
1884        room_updated(&room, &session.peer);
1885        mem::take(incoming_call)
1886    };
1887    update_user_contacts(called_user_id, &session).await?;
1888
1889    let mut calls = session
1890        .connection_pool()
1891        .await
1892        .user_connection_ids(called_user_id)
1893        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1894        .collect::<FuturesUnordered<_>>();
1895
1896    while let Some(call_response) = calls.next().await {
1897        match call_response.as_ref() {
1898            Ok(_) => {
1899                response.send(proto::Ack {})?;
1900                return Ok(());
1901            }
1902            Err(_) => {
1903                call_response.trace_err();
1904            }
1905        }
1906    }
1907
1908    {
1909        let room = session
1910            .db()
1911            .await
1912            .call_failed(room_id, called_user_id)
1913            .await?;
1914        room_updated(&room, &session.peer);
1915    }
1916    update_user_contacts(called_user_id, &session).await?;
1917
1918    Err(anyhow!("failed to ring user"))?
1919}
1920
1921/// Cancel an outgoing call.
1922async fn cancel_call(
1923    request: proto::CancelCall,
1924    response: Response<proto::CancelCall>,
1925    session: UserSession,
1926) -> Result<()> {
1927    let called_user_id = UserId::from_proto(request.called_user_id);
1928    let room_id = RoomId::from_proto(request.room_id);
1929    {
1930        let room = session
1931            .db()
1932            .await
1933            .cancel_call(room_id, session.connection_id, called_user_id)
1934            .await?;
1935        room_updated(&room, &session.peer);
1936    }
1937
1938    for connection_id in session
1939        .connection_pool()
1940        .await
1941        .user_connection_ids(called_user_id)
1942    {
1943        session
1944            .peer
1945            .send(
1946                connection_id,
1947                proto::CallCanceled {
1948                    room_id: room_id.to_proto(),
1949                },
1950            )
1951            .trace_err();
1952    }
1953    response.send(proto::Ack {})?;
1954
1955    update_user_contacts(called_user_id, &session).await?;
1956    Ok(())
1957}
1958
1959/// Decline an incoming call.
1960async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1961    let room_id = RoomId::from_proto(message.room_id);
1962    {
1963        let room = session
1964            .db()
1965            .await
1966            .decline_call(Some(room_id), session.user_id())
1967            .await?
1968            .ok_or_else(|| anyhow!("failed to decline call"))?;
1969        room_updated(&room, &session.peer);
1970    }
1971
1972    for connection_id in session
1973        .connection_pool()
1974        .await
1975        .user_connection_ids(session.user_id())
1976    {
1977        session
1978            .peer
1979            .send(
1980                connection_id,
1981                proto::CallCanceled {
1982                    room_id: room_id.to_proto(),
1983                },
1984            )
1985            .trace_err();
1986    }
1987    update_user_contacts(session.user_id(), &session).await?;
1988    Ok(())
1989}
1990
1991/// Updates other participants in the room with your current location.
1992async fn update_participant_location(
1993    request: proto::UpdateParticipantLocation,
1994    response: Response<proto::UpdateParticipantLocation>,
1995    session: UserSession,
1996) -> Result<()> {
1997    let room_id = RoomId::from_proto(request.room_id);
1998    let location = request
1999        .location
2000        .ok_or_else(|| anyhow!("invalid location"))?;
2001
2002    let db = session.db().await;
2003    let room = db
2004        .update_room_participant_location(room_id, session.connection_id, location)
2005        .await?;
2006
2007    room_updated(&room, &session.peer);
2008    response.send(proto::Ack {})?;
2009    Ok(())
2010}
2011
2012/// Share a project into the room.
2013async fn share_project(
2014    request: proto::ShareProject,
2015    response: Response<proto::ShareProject>,
2016    session: UserSession,
2017) -> Result<()> {
2018    let (project_id, room) = &*session
2019        .db()
2020        .await
2021        .share_project(
2022            RoomId::from_proto(request.room_id),
2023            session.connection_id,
2024            &request.worktrees,
2025            request
2026                .dev_server_project_id
2027                .map(|id| DevServerProjectId::from_proto(id)),
2028        )
2029        .await?;
2030    response.send(proto::ShareProjectResponse {
2031        project_id: project_id.to_proto(),
2032    })?;
2033    room_updated(&room, &session.peer);
2034
2035    Ok(())
2036}
2037
2038/// Unshare a project from the room.
2039async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2040    let project_id = ProjectId::from_proto(message.project_id);
2041    unshare_project_internal(
2042        project_id,
2043        session.connection_id,
2044        session.user_id(),
2045        &session,
2046    )
2047    .await
2048}
2049
2050async fn unshare_project_internal(
2051    project_id: ProjectId,
2052    connection_id: ConnectionId,
2053    user_id: Option<UserId>,
2054    session: &Session,
2055) -> Result<()> {
2056    let delete = {
2057        let room_guard = session
2058            .db()
2059            .await
2060            .unshare_project(project_id, connection_id, user_id)
2061            .await?;
2062
2063        let (delete, room, guest_connection_ids) = &*room_guard;
2064
2065        let message = proto::UnshareProject {
2066            project_id: project_id.to_proto(),
2067        };
2068
2069        broadcast(
2070            Some(connection_id),
2071            guest_connection_ids.iter().copied(),
2072            |conn_id| session.peer.send(conn_id, message.clone()),
2073        );
2074        if let Some(room) = room {
2075            room_updated(room, &session.peer);
2076        }
2077
2078        *delete
2079    };
2080
2081    if delete {
2082        let db = session.db().await;
2083        db.delete_project(project_id).await?;
2084    }
2085
2086    Ok(())
2087}
2088
2089/// DevServer makes a project available online
2090async fn share_dev_server_project(
2091    request: proto::ShareDevServerProject,
2092    response: Response<proto::ShareDevServerProject>,
2093    session: DevServerSession,
2094) -> Result<()> {
2095    let (dev_server_project, user_id, status) = session
2096        .db()
2097        .await
2098        .share_dev_server_project(
2099            DevServerProjectId::from_proto(request.dev_server_project_id),
2100            session.dev_server_id(),
2101            session.connection_id,
2102            &request.worktrees,
2103        )
2104        .await?;
2105    let Some(project_id) = dev_server_project.project_id else {
2106        return Err(anyhow!("failed to share remote project"))?;
2107    };
2108
2109    send_dev_server_projects_update(user_id, status, &session).await;
2110
2111    response.send(proto::ShareProjectResponse { project_id })?;
2112
2113    Ok(())
2114}
2115
2116/// Join someone elses shared project.
2117async fn join_project(
2118    request: proto::JoinProject,
2119    response: Response<proto::JoinProject>,
2120    session: UserSession,
2121) -> Result<()> {
2122    let project_id = ProjectId::from_proto(request.project_id);
2123
2124    tracing::info!(%project_id, "join project");
2125
2126    let db = session.db().await;
2127    let (project, replica_id) = &mut *db
2128        .join_project(project_id, session.connection_id, session.user_id())
2129        .await?;
2130    drop(db);
2131    tracing::info!(%project_id, "join remote project");
2132    join_project_internal(response, session, project, replica_id)
2133}
2134
2135trait JoinProjectInternalResponse {
2136    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2137}
2138impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2139    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2140        Response::<proto::JoinProject>::send(self, result)
2141    }
2142}
2143impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2144    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2145        Response::<proto::JoinHostedProject>::send(self, result)
2146    }
2147}
2148
2149fn join_project_internal(
2150    response: impl JoinProjectInternalResponse,
2151    session: UserSession,
2152    project: &mut Project,
2153    replica_id: &ReplicaId,
2154) -> Result<()> {
2155    let collaborators = project
2156        .collaborators
2157        .iter()
2158        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2159        .map(|collaborator| collaborator.to_proto())
2160        .collect::<Vec<_>>();
2161    let project_id = project.id;
2162    let guest_user_id = session.user_id();
2163
2164    let worktrees = project
2165        .worktrees
2166        .iter()
2167        .map(|(id, worktree)| proto::WorktreeMetadata {
2168            id: *id,
2169            root_name: worktree.root_name.clone(),
2170            visible: worktree.visible,
2171            abs_path: worktree.abs_path.clone(),
2172        })
2173        .collect::<Vec<_>>();
2174
2175    let add_project_collaborator = proto::AddProjectCollaborator {
2176        project_id: project_id.to_proto(),
2177        collaborator: Some(proto::Collaborator {
2178            peer_id: Some(session.connection_id.into()),
2179            replica_id: replica_id.0 as u32,
2180            user_id: guest_user_id.to_proto(),
2181        }),
2182    };
2183
2184    for collaborator in &collaborators {
2185        session
2186            .peer
2187            .send(
2188                collaborator.peer_id.unwrap().into(),
2189                add_project_collaborator.clone(),
2190            )
2191            .trace_err();
2192    }
2193
2194    // First, we send the metadata associated with each worktree.
2195    response.send(proto::JoinProjectResponse {
2196        project_id: project.id.0 as u64,
2197        worktrees: worktrees.clone(),
2198        replica_id: replica_id.0 as u32,
2199        collaborators: collaborators.clone(),
2200        language_servers: project.language_servers.clone(),
2201        role: project.role.into(),
2202        dev_server_project_id: project
2203            .dev_server_project_id
2204            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2205    })?;
2206
2207    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2208        #[cfg(any(test, feature = "test-support"))]
2209        const MAX_CHUNK_SIZE: usize = 2;
2210        #[cfg(not(any(test, feature = "test-support")))]
2211        const MAX_CHUNK_SIZE: usize = 256;
2212
2213        // Stream this worktree's entries.
2214        let message = proto::UpdateWorktree {
2215            project_id: project_id.to_proto(),
2216            worktree_id,
2217            abs_path: worktree.abs_path.clone(),
2218            root_name: worktree.root_name,
2219            updated_entries: worktree.entries,
2220            removed_entries: Default::default(),
2221            scan_id: worktree.scan_id,
2222            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2223            updated_repositories: worktree.repository_entries.into_values().collect(),
2224            removed_repositories: Default::default(),
2225        };
2226        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2227            session.peer.send(session.connection_id, update.clone())?;
2228        }
2229
2230        // Stream this worktree's diagnostics.
2231        for summary in worktree.diagnostic_summaries {
2232            session.peer.send(
2233                session.connection_id,
2234                proto::UpdateDiagnosticSummary {
2235                    project_id: project_id.to_proto(),
2236                    worktree_id: worktree.id,
2237                    summary: Some(summary),
2238                },
2239            )?;
2240        }
2241
2242        for settings_file in worktree.settings_files {
2243            session.peer.send(
2244                session.connection_id,
2245                proto::UpdateWorktreeSettings {
2246                    project_id: project_id.to_proto(),
2247                    worktree_id: worktree.id,
2248                    path: settings_file.path,
2249                    content: Some(settings_file.content),
2250                },
2251            )?;
2252        }
2253    }
2254
2255    for language_server in &project.language_servers {
2256        session.peer.send(
2257            session.connection_id,
2258            proto::UpdateLanguageServer {
2259                project_id: project_id.to_proto(),
2260                language_server_id: language_server.id,
2261                variant: Some(
2262                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2263                        proto::LspDiskBasedDiagnosticsUpdated {},
2264                    ),
2265                ),
2266            },
2267        )?;
2268    }
2269
2270    Ok(())
2271}
2272
2273/// Leave someone elses shared project.
2274async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2275    let sender_id = session.connection_id;
2276    let project_id = ProjectId::from_proto(request.project_id);
2277    let db = session.db().await;
2278    if db.is_hosted_project(project_id).await? {
2279        let project = db.leave_hosted_project(project_id, sender_id).await?;
2280        project_left(&project, &session);
2281        return Ok(());
2282    }
2283
2284    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2285    tracing::info!(
2286        %project_id,
2287        "leave project"
2288    );
2289
2290    project_left(&project, &session);
2291    if let Some(room) = room {
2292        room_updated(&room, &session.peer);
2293    }
2294
2295    Ok(())
2296}
2297
2298async fn join_hosted_project(
2299    request: proto::JoinHostedProject,
2300    response: Response<proto::JoinHostedProject>,
2301    session: UserSession,
2302) -> Result<()> {
2303    let (mut project, replica_id) = session
2304        .db()
2305        .await
2306        .join_hosted_project(
2307            ProjectId(request.project_id as i32),
2308            session.user_id(),
2309            session.connection_id,
2310        )
2311        .await?;
2312
2313    join_project_internal(response, session, &mut project, &replica_id)
2314}
2315
2316async fn create_dev_server_project(
2317    request: proto::CreateDevServerProject,
2318    response: Response<proto::CreateDevServerProject>,
2319    session: UserSession,
2320) -> Result<()> {
2321    let dev_server_id = DevServerId(request.dev_server_id as i32);
2322    let dev_server_connection_id = session
2323        .connection_pool()
2324        .await
2325        .dev_server_connection_id(dev_server_id);
2326    let Some(dev_server_connection_id) = dev_server_connection_id else {
2327        Err(ErrorCode::DevServerOffline
2328            .message("Cannot create a remote project when the dev server is offline".to_string())
2329            .anyhow())?
2330    };
2331
2332    let path = request.path.clone();
2333    //Check that the path exists on the dev server
2334    session
2335        .peer
2336        .forward_request(
2337            session.connection_id,
2338            dev_server_connection_id,
2339            proto::ValidateDevServerProjectRequest { path: path.clone() },
2340        )
2341        .await?;
2342
2343    let (dev_server_project, update) = session
2344        .db()
2345        .await
2346        .create_dev_server_project(
2347            DevServerId(request.dev_server_id as i32),
2348            &request.path,
2349            session.user_id(),
2350        )
2351        .await?;
2352
2353    let projects = session
2354        .db()
2355        .await
2356        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2357        .await?;
2358
2359    session.peer.send(
2360        dev_server_connection_id,
2361        proto::DevServerInstructions { projects },
2362    )?;
2363
2364    send_dev_server_projects_update(session.user_id(), update, &session).await;
2365
2366    response.send(proto::CreateDevServerProjectResponse {
2367        dev_server_project: Some(dev_server_project.to_proto(None)),
2368    })?;
2369    Ok(())
2370}
2371
2372async fn create_dev_server(
2373    request: proto::CreateDevServer,
2374    response: Response<proto::CreateDevServer>,
2375    session: UserSession,
2376) -> Result<()> {
2377    let access_token = auth::random_token();
2378    let hashed_access_token = auth::hash_access_token(&access_token);
2379
2380    if request.name.is_empty() {
2381        return Err(proto::ErrorCode::Forbidden
2382            .message("Dev server name cannot be empty".to_string())
2383            .anyhow())?;
2384    }
2385
2386    let (dev_server, status) = session
2387        .db()
2388        .await
2389        .create_dev_server(
2390            &request.name,
2391            request.ssh_connection_string.as_deref(),
2392            &hashed_access_token,
2393            session.user_id(),
2394        )
2395        .await?;
2396
2397    send_dev_server_projects_update(session.user_id(), status, &session).await;
2398
2399    response.send(proto::CreateDevServerResponse {
2400        dev_server_id: dev_server.id.0 as u64,
2401        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2402        name: request.name,
2403    })?;
2404    Ok(())
2405}
2406
2407async fn regenerate_dev_server_token(
2408    request: proto::RegenerateDevServerToken,
2409    response: Response<proto::RegenerateDevServerToken>,
2410    session: UserSession,
2411) -> Result<()> {
2412    let dev_server_id = DevServerId(request.dev_server_id as i32);
2413    let access_token = auth::random_token();
2414    let hashed_access_token = auth::hash_access_token(&access_token);
2415
2416    let connection_id = session
2417        .connection_pool()
2418        .await
2419        .dev_server_connection_id(dev_server_id);
2420    if let Some(connection_id) = connection_id {
2421        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2422        session.peer.send(
2423            connection_id,
2424            proto::ShutdownDevServer {
2425                reason: Some("dev server token was regenerated".to_string()),
2426            },
2427        )?;
2428        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2429    }
2430
2431    let status = session
2432        .db()
2433        .await
2434        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2435        .await?;
2436
2437    send_dev_server_projects_update(session.user_id(), status, &session).await;
2438
2439    response.send(proto::RegenerateDevServerTokenResponse {
2440        dev_server_id: dev_server_id.to_proto(),
2441        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2442    })?;
2443    Ok(())
2444}
2445
2446async fn rename_dev_server(
2447    request: proto::RenameDevServer,
2448    response: Response<proto::RenameDevServer>,
2449    session: UserSession,
2450) -> Result<()> {
2451    if request.name.trim().is_empty() {
2452        return Err(proto::ErrorCode::Forbidden
2453            .message("Dev server name cannot be empty".to_string())
2454            .anyhow())?;
2455    }
2456
2457    let dev_server_id = DevServerId(request.dev_server_id as i32);
2458    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2459    if dev_server.user_id != session.user_id() {
2460        return Err(anyhow!(ErrorCode::Forbidden))?;
2461    }
2462
2463    let status = session
2464        .db()
2465        .await
2466        .rename_dev_server(
2467            dev_server_id,
2468            &request.name,
2469            request.ssh_connection_string.as_deref(),
2470            session.user_id(),
2471        )
2472        .await?;
2473
2474    send_dev_server_projects_update(session.user_id(), status, &session).await;
2475
2476    response.send(proto::Ack {})?;
2477    Ok(())
2478}
2479
2480async fn delete_dev_server(
2481    request: proto::DeleteDevServer,
2482    response: Response<proto::DeleteDevServer>,
2483    session: UserSession,
2484) -> Result<()> {
2485    let dev_server_id = DevServerId(request.dev_server_id as i32);
2486    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2487    if dev_server.user_id != session.user_id() {
2488        return Err(anyhow!(ErrorCode::Forbidden))?;
2489    }
2490
2491    let connection_id = session
2492        .connection_pool()
2493        .await
2494        .dev_server_connection_id(dev_server_id);
2495    if let Some(connection_id) = connection_id {
2496        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2497        session.peer.send(
2498            connection_id,
2499            proto::ShutdownDevServer {
2500                reason: Some("dev server was deleted".to_string()),
2501            },
2502        )?;
2503        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2504    }
2505
2506    let status = session
2507        .db()
2508        .await
2509        .delete_dev_server(dev_server_id, session.user_id())
2510        .await?;
2511
2512    send_dev_server_projects_update(session.user_id(), status, &session).await;
2513
2514    response.send(proto::Ack {})?;
2515    Ok(())
2516}
2517
2518async fn delete_dev_server_project(
2519    request: proto::DeleteDevServerProject,
2520    response: Response<proto::DeleteDevServerProject>,
2521    session: UserSession,
2522) -> Result<()> {
2523    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2524    let dev_server_project = session
2525        .db()
2526        .await
2527        .get_dev_server_project(dev_server_project_id)
2528        .await?;
2529
2530    let dev_server = session
2531        .db()
2532        .await
2533        .get_dev_server(dev_server_project.dev_server_id)
2534        .await?;
2535    if dev_server.user_id != session.user_id() {
2536        return Err(anyhow!(ErrorCode::Forbidden))?;
2537    }
2538
2539    let dev_server_connection_id = session
2540        .connection_pool()
2541        .await
2542        .dev_server_connection_id(dev_server.id);
2543
2544    if let Some(dev_server_connection_id) = dev_server_connection_id {
2545        let project = session
2546            .db()
2547            .await
2548            .find_dev_server_project(dev_server_project_id)
2549            .await;
2550        if let Ok(project) = project {
2551            unshare_project_internal(
2552                project.id,
2553                dev_server_connection_id,
2554                Some(session.user_id()),
2555                &session,
2556            )
2557            .await?;
2558        }
2559    }
2560
2561    let (projects, status) = session
2562        .db()
2563        .await
2564        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2565        .await?;
2566
2567    if let Some(dev_server_connection_id) = dev_server_connection_id {
2568        session.peer.send(
2569            dev_server_connection_id,
2570            proto::DevServerInstructions { projects },
2571        )?;
2572    }
2573
2574    send_dev_server_projects_update(session.user_id(), status, &session).await;
2575
2576    response.send(proto::Ack {})?;
2577    Ok(())
2578}
2579
2580async fn rejoin_dev_server_projects(
2581    request: proto::RejoinRemoteProjects,
2582    response: Response<proto::RejoinRemoteProjects>,
2583    session: UserSession,
2584) -> Result<()> {
2585    let mut rejoined_projects = {
2586        let db = session.db().await;
2587        db.rejoin_dev_server_projects(
2588            &request.rejoined_projects,
2589            session.user_id(),
2590            session.0.connection_id,
2591        )
2592        .await?
2593    };
2594    response.send(proto::RejoinRemoteProjectsResponse {
2595        rejoined_projects: rejoined_projects
2596            .iter()
2597            .map(|project| project.to_proto())
2598            .collect(),
2599    })?;
2600    notify_rejoined_projects(&mut rejoined_projects, &session)
2601}
2602
2603async fn reconnect_dev_server(
2604    request: proto::ReconnectDevServer,
2605    response: Response<proto::ReconnectDevServer>,
2606    session: DevServerSession,
2607) -> Result<()> {
2608    let reshared_projects = {
2609        let db = session.db().await;
2610        db.reshare_dev_server_projects(
2611            &request.reshared_projects,
2612            session.dev_server_id(),
2613            session.0.connection_id,
2614        )
2615        .await?
2616    };
2617
2618    for project in &reshared_projects {
2619        for collaborator in &project.collaborators {
2620            session
2621                .peer
2622                .send(
2623                    collaborator.connection_id,
2624                    proto::UpdateProjectCollaborator {
2625                        project_id: project.id.to_proto(),
2626                        old_peer_id: Some(project.old_connection_id.into()),
2627                        new_peer_id: Some(session.connection_id.into()),
2628                    },
2629                )
2630                .trace_err();
2631        }
2632
2633        broadcast(
2634            Some(session.connection_id),
2635            project
2636                .collaborators
2637                .iter()
2638                .map(|collaborator| collaborator.connection_id),
2639            |connection_id| {
2640                session.peer.forward_send(
2641                    session.connection_id,
2642                    connection_id,
2643                    proto::UpdateProject {
2644                        project_id: project.id.to_proto(),
2645                        worktrees: project.worktrees.clone(),
2646                    },
2647                )
2648            },
2649        );
2650    }
2651
2652    response.send(proto::ReconnectDevServerResponse {
2653        reshared_projects: reshared_projects
2654            .iter()
2655            .map(|project| proto::ResharedProject {
2656                id: project.id.to_proto(),
2657                collaborators: project
2658                    .collaborators
2659                    .iter()
2660                    .map(|collaborator| collaborator.to_proto())
2661                    .collect(),
2662            })
2663            .collect(),
2664    })?;
2665
2666    Ok(())
2667}
2668
2669async fn shutdown_dev_server(
2670    _: proto::ShutdownDevServer,
2671    response: Response<proto::ShutdownDevServer>,
2672    session: DevServerSession,
2673) -> Result<()> {
2674    response.send(proto::Ack {})?;
2675    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2676    remove_dev_server_connection(session.dev_server_id(), &session).await
2677}
2678
2679async fn shutdown_dev_server_internal(
2680    dev_server_id: DevServerId,
2681    connection_id: ConnectionId,
2682    session: &Session,
2683) -> Result<()> {
2684    let (dev_server_projects, dev_server) = {
2685        let db = session.db().await;
2686        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2687        let dev_server = db.get_dev_server(dev_server_id).await?;
2688        (dev_server_projects, dev_server)
2689    };
2690
2691    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2692        unshare_project_internal(
2693            ProjectId::from_proto(project_id),
2694            connection_id,
2695            None,
2696            session,
2697        )
2698        .await?;
2699    }
2700
2701    session
2702        .connection_pool()
2703        .await
2704        .set_dev_server_offline(dev_server_id);
2705
2706    let status = session
2707        .db()
2708        .await
2709        .dev_server_projects_update(dev_server.user_id)
2710        .await?;
2711    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2712
2713    Ok(())
2714}
2715
2716async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2717    let dev_server_connection = session
2718        .connection_pool()
2719        .await
2720        .dev_server_connection_id(dev_server_id);
2721
2722    if let Some(dev_server_connection) = dev_server_connection {
2723        session
2724            .connection_pool()
2725            .await
2726            .remove_connection(dev_server_connection)?;
2727    }
2728    Ok(())
2729}
2730
2731/// Updates other participants with changes to the project
2732async fn update_project(
2733    request: proto::UpdateProject,
2734    response: Response<proto::UpdateProject>,
2735    session: Session,
2736) -> Result<()> {
2737    let project_id = ProjectId::from_proto(request.project_id);
2738    let (room, guest_connection_ids) = &*session
2739        .db()
2740        .await
2741        .update_project(project_id, session.connection_id, &request.worktrees)
2742        .await?;
2743    broadcast(
2744        Some(session.connection_id),
2745        guest_connection_ids.iter().copied(),
2746        |connection_id| {
2747            session
2748                .peer
2749                .forward_send(session.connection_id, connection_id, request.clone())
2750        },
2751    );
2752    if let Some(room) = room {
2753        room_updated(&room, &session.peer);
2754    }
2755    response.send(proto::Ack {})?;
2756
2757    Ok(())
2758}
2759
2760/// Updates other participants with changes to the worktree
2761async fn update_worktree(
2762    request: proto::UpdateWorktree,
2763    response: Response<proto::UpdateWorktree>,
2764    session: Session,
2765) -> Result<()> {
2766    let guest_connection_ids = session
2767        .db()
2768        .await
2769        .update_worktree(&request, session.connection_id)
2770        .await?;
2771
2772    broadcast(
2773        Some(session.connection_id),
2774        guest_connection_ids.iter().copied(),
2775        |connection_id| {
2776            session
2777                .peer
2778                .forward_send(session.connection_id, connection_id, request.clone())
2779        },
2780    );
2781    response.send(proto::Ack {})?;
2782    Ok(())
2783}
2784
2785/// Updates other participants with changes to the diagnostics
2786async fn update_diagnostic_summary(
2787    message: proto::UpdateDiagnosticSummary,
2788    session: Session,
2789) -> Result<()> {
2790    let guest_connection_ids = session
2791        .db()
2792        .await
2793        .update_diagnostic_summary(&message, session.connection_id)
2794        .await?;
2795
2796    broadcast(
2797        Some(session.connection_id),
2798        guest_connection_ids.iter().copied(),
2799        |connection_id| {
2800            session
2801                .peer
2802                .forward_send(session.connection_id, connection_id, message.clone())
2803        },
2804    );
2805
2806    Ok(())
2807}
2808
2809/// Updates other participants with changes to the worktree settings
2810async fn update_worktree_settings(
2811    message: proto::UpdateWorktreeSettings,
2812    session: Session,
2813) -> Result<()> {
2814    let guest_connection_ids = session
2815        .db()
2816        .await
2817        .update_worktree_settings(&message, session.connection_id)
2818        .await?;
2819
2820    broadcast(
2821        Some(session.connection_id),
2822        guest_connection_ids.iter().copied(),
2823        |connection_id| {
2824            session
2825                .peer
2826                .forward_send(session.connection_id, connection_id, message.clone())
2827        },
2828    );
2829
2830    Ok(())
2831}
2832
2833/// Notify other participants that a  language server has started.
2834async fn start_language_server(
2835    request: proto::StartLanguageServer,
2836    session: Session,
2837) -> Result<()> {
2838    let guest_connection_ids = session
2839        .db()
2840        .await
2841        .start_language_server(&request, session.connection_id)
2842        .await?;
2843
2844    broadcast(
2845        Some(session.connection_id),
2846        guest_connection_ids.iter().copied(),
2847        |connection_id| {
2848            session
2849                .peer
2850                .forward_send(session.connection_id, connection_id, request.clone())
2851        },
2852    );
2853    Ok(())
2854}
2855
2856/// Notify other participants that a language server has changed.
2857async fn update_language_server(
2858    request: proto::UpdateLanguageServer,
2859    session: Session,
2860) -> Result<()> {
2861    let project_id = ProjectId::from_proto(request.project_id);
2862    let project_connection_ids = session
2863        .db()
2864        .await
2865        .project_connection_ids(project_id, session.connection_id, true)
2866        .await?;
2867    broadcast(
2868        Some(session.connection_id),
2869        project_connection_ids.iter().copied(),
2870        |connection_id| {
2871            session
2872                .peer
2873                .forward_send(session.connection_id, connection_id, request.clone())
2874        },
2875    );
2876    Ok(())
2877}
2878
2879/// forward a project request to the host. These requests should be read only
2880/// as guests are allowed to send them.
2881async fn forward_read_only_project_request<T>(
2882    request: T,
2883    response: Response<T>,
2884    session: UserSession,
2885) -> Result<()>
2886where
2887    T: EntityMessage + RequestMessage,
2888{
2889    let project_id = ProjectId::from_proto(request.remote_entity_id());
2890    let host_connection_id = session
2891        .db()
2892        .await
2893        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2894        .await?;
2895    let payload = session
2896        .peer
2897        .forward_request(session.connection_id, host_connection_id, request)
2898        .await?;
2899    response.send(payload)?;
2900    Ok(())
2901}
2902
2903/// forward a project request to the dev server. Only allowed
2904/// if it's your dev server.
2905async fn forward_project_request_for_owner<T>(
2906    request: T,
2907    response: Response<T>,
2908    session: UserSession,
2909) -> Result<()>
2910where
2911    T: EntityMessage + RequestMessage,
2912{
2913    let project_id = ProjectId::from_proto(request.remote_entity_id());
2914
2915    let host_connection_id = session
2916        .db()
2917        .await
2918        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2919        .await?;
2920    let payload = session
2921        .peer
2922        .forward_request(session.connection_id, host_connection_id, request)
2923        .await?;
2924    response.send(payload)?;
2925    Ok(())
2926}
2927
2928/// forward a project request to the host. These requests are disallowed
2929/// for guests.
2930async fn forward_mutating_project_request<T>(
2931    request: T,
2932    response: Response<T>,
2933    session: UserSession,
2934) -> Result<()>
2935where
2936    T: EntityMessage + RequestMessage,
2937{
2938    let project_id = ProjectId::from_proto(request.remote_entity_id());
2939
2940    let host_connection_id = session
2941        .db()
2942        .await
2943        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2944        .await?;
2945    let payload = session
2946        .peer
2947        .forward_request(session.connection_id, host_connection_id, request)
2948        .await?;
2949    response.send(payload)?;
2950    Ok(())
2951}
2952
2953/// forward a project request to the host. These requests are disallowed
2954/// for guests.
2955async fn forward_versioned_mutating_project_request<T>(
2956    request: T,
2957    response: Response<T>,
2958    session: UserSession,
2959) -> Result<()>
2960where
2961    T: EntityMessage + RequestMessage + VersionedMessage,
2962{
2963    let project_id = ProjectId::from_proto(request.remote_entity_id());
2964
2965    let host_connection_id = session
2966        .db()
2967        .await
2968        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2969        .await?;
2970    if let Some(host_version) = session
2971        .connection_pool()
2972        .await
2973        .connection(host_connection_id)
2974        .map(|c| c.zed_version)
2975    {
2976        if let Some(min_required_version) = request.required_host_version() {
2977            if min_required_version > host_version {
2978                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2979                    .with_tag("required", &min_required_version.to_string())))?;
2980            }
2981        }
2982    }
2983
2984    let payload = session
2985        .peer
2986        .forward_request(session.connection_id, host_connection_id, request)
2987        .await?;
2988    response.send(payload)?;
2989    Ok(())
2990}
2991
2992/// Notify other participants that a new buffer has been created
2993async fn create_buffer_for_peer(
2994    request: proto::CreateBufferForPeer,
2995    session: Session,
2996) -> Result<()> {
2997    session
2998        .db()
2999        .await
3000        .check_user_is_project_host(
3001            ProjectId::from_proto(request.project_id),
3002            session.connection_id,
3003        )
3004        .await?;
3005    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3006    session
3007        .peer
3008        .forward_send(session.connection_id, peer_id.into(), request)?;
3009    Ok(())
3010}
3011
3012/// Notify other participants that a buffer has been updated. This is
3013/// allowed for guests as long as the update is limited to selections.
3014async fn update_buffer(
3015    request: proto::UpdateBuffer,
3016    response: Response<proto::UpdateBuffer>,
3017    session: Session,
3018) -> Result<()> {
3019    let project_id = ProjectId::from_proto(request.project_id);
3020    let mut capability = Capability::ReadOnly;
3021
3022    for op in request.operations.iter() {
3023        match op.variant {
3024            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3025            Some(_) => capability = Capability::ReadWrite,
3026        }
3027    }
3028
3029    let host = {
3030        let guard = session
3031            .db()
3032            .await
3033            .connections_for_buffer_update(
3034                project_id,
3035                session.principal_id(),
3036                session.connection_id,
3037                capability,
3038            )
3039            .await?;
3040
3041        let (host, guests) = &*guard;
3042
3043        broadcast(
3044            Some(session.connection_id),
3045            guests.clone(),
3046            |connection_id| {
3047                session
3048                    .peer
3049                    .forward_send(session.connection_id, connection_id, request.clone())
3050            },
3051        );
3052
3053        *host
3054    };
3055
3056    if host != session.connection_id {
3057        session
3058            .peer
3059            .forward_request(session.connection_id, host, request.clone())
3060            .await?;
3061    }
3062
3063    response.send(proto::Ack {})?;
3064    Ok(())
3065}
3066
3067async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3068    let project_id = ProjectId::from_proto(message.project_id);
3069
3070    let operation = message.operation.as_ref().context("invalid operation")?;
3071    let capability = match operation.variant.as_ref() {
3072        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3073            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3074                match buffer_op.variant {
3075                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3076                        Capability::ReadOnly
3077                    }
3078                    _ => Capability::ReadWrite,
3079                }
3080            } else {
3081                Capability::ReadWrite
3082            }
3083        }
3084        Some(_) => Capability::ReadWrite,
3085        None => Capability::ReadOnly,
3086    };
3087
3088    let guard = session
3089        .db()
3090        .await
3091        .connections_for_buffer_update(
3092            project_id,
3093            session.principal_id(),
3094            session.connection_id,
3095            capability,
3096        )
3097        .await?;
3098
3099    let (host, guests) = &*guard;
3100
3101    broadcast(
3102        Some(session.connection_id),
3103        guests.iter().chain([host]).copied(),
3104        |connection_id| {
3105            session
3106                .peer
3107                .forward_send(session.connection_id, connection_id, message.clone())
3108        },
3109    );
3110
3111    Ok(())
3112}
3113
3114/// Notify other participants that a project has been updated.
3115async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3116    request: T,
3117    session: Session,
3118) -> Result<()> {
3119    let project_id = ProjectId::from_proto(request.remote_entity_id());
3120    let project_connection_ids = session
3121        .db()
3122        .await
3123        .project_connection_ids(project_id, session.connection_id, false)
3124        .await?;
3125
3126    broadcast(
3127        Some(session.connection_id),
3128        project_connection_ids.iter().copied(),
3129        |connection_id| {
3130            session
3131                .peer
3132                .forward_send(session.connection_id, connection_id, request.clone())
3133        },
3134    );
3135    Ok(())
3136}
3137
3138/// Start following another user in a call.
3139async fn follow(
3140    request: proto::Follow,
3141    response: Response<proto::Follow>,
3142    session: UserSession,
3143) -> Result<()> {
3144    let room_id = RoomId::from_proto(request.room_id);
3145    let project_id = request.project_id.map(ProjectId::from_proto);
3146    let leader_id = request
3147        .leader_id
3148        .ok_or_else(|| anyhow!("invalid leader id"))?
3149        .into();
3150    let follower_id = session.connection_id;
3151
3152    session
3153        .db()
3154        .await
3155        .check_room_participants(room_id, leader_id, session.connection_id)
3156        .await?;
3157
3158    let response_payload = session
3159        .peer
3160        .forward_request(session.connection_id, leader_id, request)
3161        .await?;
3162    response.send(response_payload)?;
3163
3164    if let Some(project_id) = project_id {
3165        let room = session
3166            .db()
3167            .await
3168            .follow(room_id, project_id, leader_id, follower_id)
3169            .await?;
3170        room_updated(&room, &session.peer);
3171    }
3172
3173    Ok(())
3174}
3175
3176/// Stop following another user in a call.
3177async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3178    let room_id = RoomId::from_proto(request.room_id);
3179    let project_id = request.project_id.map(ProjectId::from_proto);
3180    let leader_id = request
3181        .leader_id
3182        .ok_or_else(|| anyhow!("invalid leader id"))?
3183        .into();
3184    let follower_id = session.connection_id;
3185
3186    session
3187        .db()
3188        .await
3189        .check_room_participants(room_id, leader_id, session.connection_id)
3190        .await?;
3191
3192    session
3193        .peer
3194        .forward_send(session.connection_id, leader_id, request)?;
3195
3196    if let Some(project_id) = project_id {
3197        let room = session
3198            .db()
3199            .await
3200            .unfollow(room_id, project_id, leader_id, follower_id)
3201            .await?;
3202        room_updated(&room, &session.peer);
3203    }
3204
3205    Ok(())
3206}
3207
3208/// Notify everyone following you of your current location.
3209async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3210    let room_id = RoomId::from_proto(request.room_id);
3211    let database = session.db.lock().await;
3212
3213    let connection_ids = if let Some(project_id) = request.project_id {
3214        let project_id = ProjectId::from_proto(project_id);
3215        database
3216            .project_connection_ids(project_id, session.connection_id, true)
3217            .await?
3218    } else {
3219        database
3220            .room_connection_ids(room_id, session.connection_id)
3221            .await?
3222    };
3223
3224    // For now, don't send view update messages back to that view's current leader.
3225    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3226        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3227        _ => None,
3228    });
3229
3230    for connection_id in connection_ids.iter().cloned() {
3231        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3232            session
3233                .peer
3234                .forward_send(session.connection_id, connection_id, request.clone())?;
3235        }
3236    }
3237    Ok(())
3238}
3239
3240/// Get public data about users.
3241async fn get_users(
3242    request: proto::GetUsers,
3243    response: Response<proto::GetUsers>,
3244    session: Session,
3245) -> Result<()> {
3246    let user_ids = request
3247        .user_ids
3248        .into_iter()
3249        .map(UserId::from_proto)
3250        .collect();
3251    let users = session
3252        .db()
3253        .await
3254        .get_users_by_ids(user_ids)
3255        .await?
3256        .into_iter()
3257        .map(|user| proto::User {
3258            id: user.id.to_proto(),
3259            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3260            github_login: user.github_login,
3261        })
3262        .collect();
3263    response.send(proto::UsersResponse { users })?;
3264    Ok(())
3265}
3266
3267/// Search for users (to invite) buy Github login
3268async fn fuzzy_search_users(
3269    request: proto::FuzzySearchUsers,
3270    response: Response<proto::FuzzySearchUsers>,
3271    session: UserSession,
3272) -> Result<()> {
3273    let query = request.query;
3274    let users = match query.len() {
3275        0 => vec![],
3276        1 | 2 => session
3277            .db()
3278            .await
3279            .get_user_by_github_login(&query)
3280            .await?
3281            .into_iter()
3282            .collect(),
3283        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3284    };
3285    let users = users
3286        .into_iter()
3287        .filter(|user| user.id != session.user_id())
3288        .map(|user| proto::User {
3289            id: user.id.to_proto(),
3290            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3291            github_login: user.github_login,
3292        })
3293        .collect();
3294    response.send(proto::UsersResponse { users })?;
3295    Ok(())
3296}
3297
3298/// Send a contact request to another user.
3299async fn request_contact(
3300    request: proto::RequestContact,
3301    response: Response<proto::RequestContact>,
3302    session: UserSession,
3303) -> Result<()> {
3304    let requester_id = session.user_id();
3305    let responder_id = UserId::from_proto(request.responder_id);
3306    if requester_id == responder_id {
3307        return Err(anyhow!("cannot add yourself as a contact"))?;
3308    }
3309
3310    let notifications = session
3311        .db()
3312        .await
3313        .send_contact_request(requester_id, responder_id)
3314        .await?;
3315
3316    // Update outgoing contact requests of requester
3317    let mut update = proto::UpdateContacts::default();
3318    update.outgoing_requests.push(responder_id.to_proto());
3319    for connection_id in session
3320        .connection_pool()
3321        .await
3322        .user_connection_ids(requester_id)
3323    {
3324        session.peer.send(connection_id, update.clone())?;
3325    }
3326
3327    // Update incoming contact requests of responder
3328    let mut update = proto::UpdateContacts::default();
3329    update
3330        .incoming_requests
3331        .push(proto::IncomingContactRequest {
3332            requester_id: requester_id.to_proto(),
3333        });
3334    let connection_pool = session.connection_pool().await;
3335    for connection_id in connection_pool.user_connection_ids(responder_id) {
3336        session.peer.send(connection_id, update.clone())?;
3337    }
3338
3339    send_notifications(&connection_pool, &session.peer, notifications);
3340
3341    response.send(proto::Ack {})?;
3342    Ok(())
3343}
3344
3345/// Accept or decline a contact request
3346async fn respond_to_contact_request(
3347    request: proto::RespondToContactRequest,
3348    response: Response<proto::RespondToContactRequest>,
3349    session: UserSession,
3350) -> Result<()> {
3351    let responder_id = session.user_id();
3352    let requester_id = UserId::from_proto(request.requester_id);
3353    let db = session.db().await;
3354    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3355        db.dismiss_contact_notification(responder_id, requester_id)
3356            .await?;
3357    } else {
3358        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3359
3360        let notifications = db
3361            .respond_to_contact_request(responder_id, requester_id, accept)
3362            .await?;
3363        let requester_busy = db.is_user_busy(requester_id).await?;
3364        let responder_busy = db.is_user_busy(responder_id).await?;
3365
3366        let pool = session.connection_pool().await;
3367        // Update responder with new contact
3368        let mut update = proto::UpdateContacts::default();
3369        if accept {
3370            update
3371                .contacts
3372                .push(contact_for_user(requester_id, requester_busy, &pool));
3373        }
3374        update
3375            .remove_incoming_requests
3376            .push(requester_id.to_proto());
3377        for connection_id in pool.user_connection_ids(responder_id) {
3378            session.peer.send(connection_id, update.clone())?;
3379        }
3380
3381        // Update requester with new contact
3382        let mut update = proto::UpdateContacts::default();
3383        if accept {
3384            update
3385                .contacts
3386                .push(contact_for_user(responder_id, responder_busy, &pool));
3387        }
3388        update
3389            .remove_outgoing_requests
3390            .push(responder_id.to_proto());
3391
3392        for connection_id in pool.user_connection_ids(requester_id) {
3393            session.peer.send(connection_id, update.clone())?;
3394        }
3395
3396        send_notifications(&pool, &session.peer, notifications);
3397    }
3398
3399    response.send(proto::Ack {})?;
3400    Ok(())
3401}
3402
3403/// Remove a contact.
3404async fn remove_contact(
3405    request: proto::RemoveContact,
3406    response: Response<proto::RemoveContact>,
3407    session: UserSession,
3408) -> Result<()> {
3409    let requester_id = session.user_id();
3410    let responder_id = UserId::from_proto(request.user_id);
3411    let db = session.db().await;
3412    let (contact_accepted, deleted_notification_id) =
3413        db.remove_contact(requester_id, responder_id).await?;
3414
3415    let pool = session.connection_pool().await;
3416    // Update outgoing contact requests of requester
3417    let mut update = proto::UpdateContacts::default();
3418    if contact_accepted {
3419        update.remove_contacts.push(responder_id.to_proto());
3420    } else {
3421        update
3422            .remove_outgoing_requests
3423            .push(responder_id.to_proto());
3424    }
3425    for connection_id in pool.user_connection_ids(requester_id) {
3426        session.peer.send(connection_id, update.clone())?;
3427    }
3428
3429    // Update incoming contact requests of responder
3430    let mut update = proto::UpdateContacts::default();
3431    if contact_accepted {
3432        update.remove_contacts.push(requester_id.to_proto());
3433    } else {
3434        update
3435            .remove_incoming_requests
3436            .push(requester_id.to_proto());
3437    }
3438    for connection_id in pool.user_connection_ids(responder_id) {
3439        session.peer.send(connection_id, update.clone())?;
3440        if let Some(notification_id) = deleted_notification_id {
3441            session.peer.send(
3442                connection_id,
3443                proto::DeleteNotification {
3444                    notification_id: notification_id.to_proto(),
3445                },
3446            )?;
3447        }
3448    }
3449
3450    response.send(proto::Ack {})?;
3451    Ok(())
3452}
3453
3454fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3455    version.0.minor() < 139
3456}
3457
3458async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3459    subscribe_user_to_channels(
3460        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3461        &session,
3462    )
3463    .await?;
3464    Ok(())
3465}
3466
3467async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3468    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3469    let mut pool = session.connection_pool().await;
3470    for membership in &channels_for_user.channel_memberships {
3471        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3472    }
3473    session.peer.send(
3474        session.connection_id,
3475        build_update_user_channels(&channels_for_user),
3476    )?;
3477    session.peer.send(
3478        session.connection_id,
3479        build_channels_update(channels_for_user),
3480    )?;
3481    Ok(())
3482}
3483
3484/// Creates a new channel.
3485async fn create_channel(
3486    request: proto::CreateChannel,
3487    response: Response<proto::CreateChannel>,
3488    session: UserSession,
3489) -> Result<()> {
3490    let db = session.db().await;
3491
3492    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3493    let (channel, membership) = db
3494        .create_channel(&request.name, parent_id, session.user_id())
3495        .await?;
3496
3497    let root_id = channel.root_id();
3498    let channel = Channel::from_model(channel);
3499
3500    response.send(proto::CreateChannelResponse {
3501        channel: Some(channel.to_proto()),
3502        parent_id: request.parent_id,
3503    })?;
3504
3505    let mut connection_pool = session.connection_pool().await;
3506    if let Some(membership) = membership {
3507        connection_pool.subscribe_to_channel(
3508            membership.user_id,
3509            membership.channel_id,
3510            membership.role,
3511        );
3512        let update = proto::UpdateUserChannels {
3513            channel_memberships: vec![proto::ChannelMembership {
3514                channel_id: membership.channel_id.to_proto(),
3515                role: membership.role.into(),
3516            }],
3517            ..Default::default()
3518        };
3519        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3520            session.peer.send(connection_id, update.clone())?;
3521        }
3522    }
3523
3524    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3525        if !role.can_see_channel(channel.visibility) {
3526            continue;
3527        }
3528
3529        let update = proto::UpdateChannels {
3530            channels: vec![channel.to_proto()],
3531            ..Default::default()
3532        };
3533        session.peer.send(connection_id, update.clone())?;
3534    }
3535
3536    Ok(())
3537}
3538
3539/// Delete a channel
3540async fn delete_channel(
3541    request: proto::DeleteChannel,
3542    response: Response<proto::DeleteChannel>,
3543    session: UserSession,
3544) -> Result<()> {
3545    let db = session.db().await;
3546
3547    let channel_id = request.channel_id;
3548    let (root_channel, removed_channels) = db
3549        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3550        .await?;
3551    response.send(proto::Ack {})?;
3552
3553    // Notify members of removed channels
3554    let mut update = proto::UpdateChannels::default();
3555    update
3556        .delete_channels
3557        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3558
3559    let connection_pool = session.connection_pool().await;
3560    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3561        session.peer.send(connection_id, update.clone())?;
3562    }
3563
3564    Ok(())
3565}
3566
3567/// Invite someone to join a channel.
3568async fn invite_channel_member(
3569    request: proto::InviteChannelMember,
3570    response: Response<proto::InviteChannelMember>,
3571    session: UserSession,
3572) -> Result<()> {
3573    let db = session.db().await;
3574    let channel_id = ChannelId::from_proto(request.channel_id);
3575    let invitee_id = UserId::from_proto(request.user_id);
3576    let InviteMemberResult {
3577        channel,
3578        notifications,
3579    } = db
3580        .invite_channel_member(
3581            channel_id,
3582            invitee_id,
3583            session.user_id(),
3584            request.role().into(),
3585        )
3586        .await?;
3587
3588    let update = proto::UpdateChannels {
3589        channel_invitations: vec![channel.to_proto()],
3590        ..Default::default()
3591    };
3592
3593    let connection_pool = session.connection_pool().await;
3594    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3595        session.peer.send(connection_id, update.clone())?;
3596    }
3597
3598    send_notifications(&connection_pool, &session.peer, notifications);
3599
3600    response.send(proto::Ack {})?;
3601    Ok(())
3602}
3603
3604/// remove someone from a channel
3605async fn remove_channel_member(
3606    request: proto::RemoveChannelMember,
3607    response: Response<proto::RemoveChannelMember>,
3608    session: UserSession,
3609) -> Result<()> {
3610    let db = session.db().await;
3611    let channel_id = ChannelId::from_proto(request.channel_id);
3612    let member_id = UserId::from_proto(request.user_id);
3613
3614    let RemoveChannelMemberResult {
3615        membership_update,
3616        notification_id,
3617    } = db
3618        .remove_channel_member(channel_id, member_id, session.user_id())
3619        .await?;
3620
3621    let mut connection_pool = session.connection_pool().await;
3622    notify_membership_updated(
3623        &mut connection_pool,
3624        membership_update,
3625        member_id,
3626        &session.peer,
3627    );
3628    for connection_id in connection_pool.user_connection_ids(member_id) {
3629        if let Some(notification_id) = notification_id {
3630            session
3631                .peer
3632                .send(
3633                    connection_id,
3634                    proto::DeleteNotification {
3635                        notification_id: notification_id.to_proto(),
3636                    },
3637                )
3638                .trace_err();
3639        }
3640    }
3641
3642    response.send(proto::Ack {})?;
3643    Ok(())
3644}
3645
3646/// Toggle the channel between public and private.
3647/// Care is taken to maintain the invariant that public channels only descend from public channels,
3648/// (though members-only channels can appear at any point in the hierarchy).
3649async fn set_channel_visibility(
3650    request: proto::SetChannelVisibility,
3651    response: Response<proto::SetChannelVisibility>,
3652    session: UserSession,
3653) -> Result<()> {
3654    let db = session.db().await;
3655    let channel_id = ChannelId::from_proto(request.channel_id);
3656    let visibility = request.visibility().into();
3657
3658    let channel_model = db
3659        .set_channel_visibility(channel_id, visibility, session.user_id())
3660        .await?;
3661    let root_id = channel_model.root_id();
3662    let channel = Channel::from_model(channel_model);
3663
3664    let mut connection_pool = session.connection_pool().await;
3665    for (user_id, role) in connection_pool
3666        .channel_user_ids(root_id)
3667        .collect::<Vec<_>>()
3668        .into_iter()
3669    {
3670        let update = if role.can_see_channel(channel.visibility) {
3671            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3672            proto::UpdateChannels {
3673                channels: vec![channel.to_proto()],
3674                ..Default::default()
3675            }
3676        } else {
3677            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3678            proto::UpdateChannels {
3679                delete_channels: vec![channel.id.to_proto()],
3680                ..Default::default()
3681            }
3682        };
3683
3684        for connection_id in connection_pool.user_connection_ids(user_id) {
3685            session.peer.send(connection_id, update.clone())?;
3686        }
3687    }
3688
3689    response.send(proto::Ack {})?;
3690    Ok(())
3691}
3692
3693/// Alter the role for a user in the channel.
3694async fn set_channel_member_role(
3695    request: proto::SetChannelMemberRole,
3696    response: Response<proto::SetChannelMemberRole>,
3697    session: UserSession,
3698) -> Result<()> {
3699    let db = session.db().await;
3700    let channel_id = ChannelId::from_proto(request.channel_id);
3701    let member_id = UserId::from_proto(request.user_id);
3702    let result = db
3703        .set_channel_member_role(
3704            channel_id,
3705            session.user_id(),
3706            member_id,
3707            request.role().into(),
3708        )
3709        .await?;
3710
3711    match result {
3712        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3713            let mut connection_pool = session.connection_pool().await;
3714            notify_membership_updated(
3715                &mut connection_pool,
3716                membership_update,
3717                member_id,
3718                &session.peer,
3719            )
3720        }
3721        db::SetMemberRoleResult::InviteUpdated(channel) => {
3722            let update = proto::UpdateChannels {
3723                channel_invitations: vec![channel.to_proto()],
3724                ..Default::default()
3725            };
3726
3727            for connection_id in session
3728                .connection_pool()
3729                .await
3730                .user_connection_ids(member_id)
3731            {
3732                session.peer.send(connection_id, update.clone())?;
3733            }
3734        }
3735    }
3736
3737    response.send(proto::Ack {})?;
3738    Ok(())
3739}
3740
3741/// Change the name of a channel
3742async fn rename_channel(
3743    request: proto::RenameChannel,
3744    response: Response<proto::RenameChannel>,
3745    session: UserSession,
3746) -> Result<()> {
3747    let db = session.db().await;
3748    let channel_id = ChannelId::from_proto(request.channel_id);
3749    let channel_model = db
3750        .rename_channel(channel_id, session.user_id(), &request.name)
3751        .await?;
3752    let root_id = channel_model.root_id();
3753    let channel = Channel::from_model(channel_model);
3754
3755    response.send(proto::RenameChannelResponse {
3756        channel: Some(channel.to_proto()),
3757    })?;
3758
3759    let connection_pool = session.connection_pool().await;
3760    let update = proto::UpdateChannels {
3761        channels: vec![channel.to_proto()],
3762        ..Default::default()
3763    };
3764    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3765        if role.can_see_channel(channel.visibility) {
3766            session.peer.send(connection_id, update.clone())?;
3767        }
3768    }
3769
3770    Ok(())
3771}
3772
3773/// Move a channel to a new parent.
3774async fn move_channel(
3775    request: proto::MoveChannel,
3776    response: Response<proto::MoveChannel>,
3777    session: UserSession,
3778) -> Result<()> {
3779    let channel_id = ChannelId::from_proto(request.channel_id);
3780    let to = ChannelId::from_proto(request.to);
3781
3782    let (root_id, channels) = session
3783        .db()
3784        .await
3785        .move_channel(channel_id, to, session.user_id())
3786        .await?;
3787
3788    let connection_pool = session.connection_pool().await;
3789    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3790        let channels = channels
3791            .iter()
3792            .filter_map(|channel| {
3793                if role.can_see_channel(channel.visibility) {
3794                    Some(channel.to_proto())
3795                } else {
3796                    None
3797                }
3798            })
3799            .collect::<Vec<_>>();
3800        if channels.is_empty() {
3801            continue;
3802        }
3803
3804        let update = proto::UpdateChannels {
3805            channels,
3806            ..Default::default()
3807        };
3808
3809        session.peer.send(connection_id, update.clone())?;
3810    }
3811
3812    response.send(Ack {})?;
3813    Ok(())
3814}
3815
3816/// Get the list of channel members
3817async fn get_channel_members(
3818    request: proto::GetChannelMembers,
3819    response: Response<proto::GetChannelMembers>,
3820    session: UserSession,
3821) -> Result<()> {
3822    let db = session.db().await;
3823    let channel_id = ChannelId::from_proto(request.channel_id);
3824    let limit = if request.limit == 0 {
3825        u16::MAX as u64
3826    } else {
3827        request.limit
3828    };
3829    let (members, users) = db
3830        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3831        .await?;
3832    response.send(proto::GetChannelMembersResponse { members, users })?;
3833    Ok(())
3834}
3835
3836/// Accept or decline a channel invitation.
3837async fn respond_to_channel_invite(
3838    request: proto::RespondToChannelInvite,
3839    response: Response<proto::RespondToChannelInvite>,
3840    session: UserSession,
3841) -> Result<()> {
3842    let db = session.db().await;
3843    let channel_id = ChannelId::from_proto(request.channel_id);
3844    let RespondToChannelInvite {
3845        membership_update,
3846        notifications,
3847    } = db
3848        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3849        .await?;
3850
3851    let mut connection_pool = session.connection_pool().await;
3852    if let Some(membership_update) = membership_update {
3853        notify_membership_updated(
3854            &mut connection_pool,
3855            membership_update,
3856            session.user_id(),
3857            &session.peer,
3858        );
3859    } else {
3860        let update = proto::UpdateChannels {
3861            remove_channel_invitations: vec![channel_id.to_proto()],
3862            ..Default::default()
3863        };
3864
3865        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3866            session.peer.send(connection_id, update.clone())?;
3867        }
3868    };
3869
3870    send_notifications(&connection_pool, &session.peer, notifications);
3871
3872    response.send(proto::Ack {})?;
3873
3874    Ok(())
3875}
3876
3877/// Join the channels' room
3878async fn join_channel(
3879    request: proto::JoinChannel,
3880    response: Response<proto::JoinChannel>,
3881    session: UserSession,
3882) -> Result<()> {
3883    let channel_id = ChannelId::from_proto(request.channel_id);
3884    join_channel_internal(channel_id, Box::new(response), session).await
3885}
3886
3887trait JoinChannelInternalResponse {
3888    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3889}
3890impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3891    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3892        Response::<proto::JoinChannel>::send(self, result)
3893    }
3894}
3895impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3896    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3897        Response::<proto::JoinRoom>::send(self, result)
3898    }
3899}
3900
3901async fn join_channel_internal(
3902    channel_id: ChannelId,
3903    response: Box<impl JoinChannelInternalResponse>,
3904    session: UserSession,
3905) -> Result<()> {
3906    let joined_room = {
3907        let mut db = session.db().await;
3908        // If zed quits without leaving the room, and the user re-opens zed before the
3909        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3910        // room they were in.
3911        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3912            tracing::info!(
3913                stale_connection_id = %connection,
3914                "cleaning up stale connection",
3915            );
3916            drop(db);
3917            leave_room_for_session(&session, connection).await?;
3918            db = session.db().await;
3919        }
3920
3921        let (joined_room, membership_updated, role) = db
3922            .join_channel(channel_id, session.user_id(), session.connection_id)
3923            .await?;
3924
3925        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3926            let (can_publish, token) = if role == ChannelRole::Guest {
3927                (
3928                    false,
3929                    live_kit
3930                        .guest_token(
3931                            &joined_room.room.live_kit_room,
3932                            &session.user_id().to_string(),
3933                        )
3934                        .trace_err()?,
3935                )
3936            } else {
3937                (
3938                    true,
3939                    live_kit
3940                        .room_token(
3941                            &joined_room.room.live_kit_room,
3942                            &session.user_id().to_string(),
3943                        )
3944                        .trace_err()?,
3945                )
3946            };
3947
3948            Some(LiveKitConnectionInfo {
3949                server_url: live_kit.url().into(),
3950                token,
3951                can_publish,
3952            })
3953        });
3954
3955        response.send(proto::JoinRoomResponse {
3956            room: Some(joined_room.room.clone()),
3957            channel_id: joined_room
3958                .channel
3959                .as_ref()
3960                .map(|channel| channel.id.to_proto()),
3961            live_kit_connection_info,
3962        })?;
3963
3964        let mut connection_pool = session.connection_pool().await;
3965        if let Some(membership_updated) = membership_updated {
3966            notify_membership_updated(
3967                &mut connection_pool,
3968                membership_updated,
3969                session.user_id(),
3970                &session.peer,
3971            );
3972        }
3973
3974        room_updated(&joined_room.room, &session.peer);
3975
3976        joined_room
3977    };
3978
3979    channel_updated(
3980        &joined_room
3981            .channel
3982            .ok_or_else(|| anyhow!("channel not returned"))?,
3983        &joined_room.room,
3984        &session.peer,
3985        &*session.connection_pool().await,
3986    );
3987
3988    update_user_contacts(session.user_id(), &session).await?;
3989    Ok(())
3990}
3991
3992/// Start editing the channel notes
3993async fn join_channel_buffer(
3994    request: proto::JoinChannelBuffer,
3995    response: Response<proto::JoinChannelBuffer>,
3996    session: UserSession,
3997) -> Result<()> {
3998    let db = session.db().await;
3999    let channel_id = ChannelId::from_proto(request.channel_id);
4000
4001    let open_response = db
4002        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4003        .await?;
4004
4005    let collaborators = open_response.collaborators.clone();
4006    response.send(open_response)?;
4007
4008    let update = UpdateChannelBufferCollaborators {
4009        channel_id: channel_id.to_proto(),
4010        collaborators: collaborators.clone(),
4011    };
4012    channel_buffer_updated(
4013        session.connection_id,
4014        collaborators
4015            .iter()
4016            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4017        &update,
4018        &session.peer,
4019    );
4020
4021    Ok(())
4022}
4023
4024/// Edit the channel notes
4025async fn update_channel_buffer(
4026    request: proto::UpdateChannelBuffer,
4027    session: UserSession,
4028) -> Result<()> {
4029    let db = session.db().await;
4030    let channel_id = ChannelId::from_proto(request.channel_id);
4031
4032    let (collaborators, epoch, version) = db
4033        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4034        .await?;
4035
4036    channel_buffer_updated(
4037        session.connection_id,
4038        collaborators.clone(),
4039        &proto::UpdateChannelBuffer {
4040            channel_id: channel_id.to_proto(),
4041            operations: request.operations,
4042        },
4043        &session.peer,
4044    );
4045
4046    let pool = &*session.connection_pool().await;
4047
4048    let non_collaborators =
4049        pool.channel_connection_ids(channel_id)
4050            .filter_map(|(connection_id, _)| {
4051                if collaborators.contains(&connection_id) {
4052                    None
4053                } else {
4054                    Some(connection_id)
4055                }
4056            });
4057
4058    broadcast(None, non_collaborators, |peer_id| {
4059        session.peer.send(
4060            peer_id,
4061            proto::UpdateChannels {
4062                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4063                    channel_id: channel_id.to_proto(),
4064                    epoch: epoch as u64,
4065                    version: version.clone(),
4066                }],
4067                ..Default::default()
4068            },
4069        )
4070    });
4071
4072    Ok(())
4073}
4074
4075/// Rejoin the channel notes after a connection blip
4076async fn rejoin_channel_buffers(
4077    request: proto::RejoinChannelBuffers,
4078    response: Response<proto::RejoinChannelBuffers>,
4079    session: UserSession,
4080) -> Result<()> {
4081    let db = session.db().await;
4082    let buffers = db
4083        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4084        .await?;
4085
4086    for rejoined_buffer in &buffers {
4087        let collaborators_to_notify = rejoined_buffer
4088            .buffer
4089            .collaborators
4090            .iter()
4091            .filter_map(|c| Some(c.peer_id?.into()));
4092        channel_buffer_updated(
4093            session.connection_id,
4094            collaborators_to_notify,
4095            &proto::UpdateChannelBufferCollaborators {
4096                channel_id: rejoined_buffer.buffer.channel_id,
4097                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4098            },
4099            &session.peer,
4100        );
4101    }
4102
4103    response.send(proto::RejoinChannelBuffersResponse {
4104        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4105    })?;
4106
4107    Ok(())
4108}
4109
4110/// Stop editing the channel notes
4111async fn leave_channel_buffer(
4112    request: proto::LeaveChannelBuffer,
4113    response: Response<proto::LeaveChannelBuffer>,
4114    session: UserSession,
4115) -> Result<()> {
4116    let db = session.db().await;
4117    let channel_id = ChannelId::from_proto(request.channel_id);
4118
4119    let left_buffer = db
4120        .leave_channel_buffer(channel_id, session.connection_id)
4121        .await?;
4122
4123    response.send(Ack {})?;
4124
4125    channel_buffer_updated(
4126        session.connection_id,
4127        left_buffer.connections,
4128        &proto::UpdateChannelBufferCollaborators {
4129            channel_id: channel_id.to_proto(),
4130            collaborators: left_buffer.collaborators,
4131        },
4132        &session.peer,
4133    );
4134
4135    Ok(())
4136}
4137
4138fn channel_buffer_updated<T: EnvelopedMessage>(
4139    sender_id: ConnectionId,
4140    collaborators: impl IntoIterator<Item = ConnectionId>,
4141    message: &T,
4142    peer: &Peer,
4143) {
4144    broadcast(Some(sender_id), collaborators, |peer_id| {
4145        peer.send(peer_id, message.clone())
4146    });
4147}
4148
4149fn send_notifications(
4150    connection_pool: &ConnectionPool,
4151    peer: &Peer,
4152    notifications: db::NotificationBatch,
4153) {
4154    for (user_id, notification) in notifications {
4155        for connection_id in connection_pool.user_connection_ids(user_id) {
4156            if let Err(error) = peer.send(
4157                connection_id,
4158                proto::AddNotification {
4159                    notification: Some(notification.clone()),
4160                },
4161            ) {
4162                tracing::error!(
4163                    "failed to send notification to {:?} {}",
4164                    connection_id,
4165                    error
4166                );
4167            }
4168        }
4169    }
4170}
4171
4172/// Send a message to the channel
4173async fn send_channel_message(
4174    request: proto::SendChannelMessage,
4175    response: Response<proto::SendChannelMessage>,
4176    session: UserSession,
4177) -> Result<()> {
4178    // Validate the message body.
4179    let body = request.body.trim().to_string();
4180    if body.len() > MAX_MESSAGE_LEN {
4181        return Err(anyhow!("message is too long"))?;
4182    }
4183    if body.is_empty() {
4184        return Err(anyhow!("message can't be blank"))?;
4185    }
4186
4187    // TODO: adjust mentions if body is trimmed
4188
4189    let timestamp = OffsetDateTime::now_utc();
4190    let nonce = request
4191        .nonce
4192        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4193
4194    let channel_id = ChannelId::from_proto(request.channel_id);
4195    let CreatedChannelMessage {
4196        message_id,
4197        participant_connection_ids,
4198        notifications,
4199    } = session
4200        .db()
4201        .await
4202        .create_channel_message(
4203            channel_id,
4204            session.user_id(),
4205            &body,
4206            &request.mentions,
4207            timestamp,
4208            nonce.clone().into(),
4209            match request.reply_to_message_id {
4210                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4211                None => None,
4212            },
4213        )
4214        .await?;
4215
4216    let message = proto::ChannelMessage {
4217        sender_id: session.user_id().to_proto(),
4218        id: message_id.to_proto(),
4219        body,
4220        mentions: request.mentions,
4221        timestamp: timestamp.unix_timestamp() as u64,
4222        nonce: Some(nonce),
4223        reply_to_message_id: request.reply_to_message_id,
4224        edited_at: None,
4225    };
4226    broadcast(
4227        Some(session.connection_id),
4228        participant_connection_ids.clone(),
4229        |connection| {
4230            session.peer.send(
4231                connection,
4232                proto::ChannelMessageSent {
4233                    channel_id: channel_id.to_proto(),
4234                    message: Some(message.clone()),
4235                },
4236            )
4237        },
4238    );
4239    response.send(proto::SendChannelMessageResponse {
4240        message: Some(message),
4241    })?;
4242
4243    let pool = &*session.connection_pool().await;
4244    let non_participants =
4245        pool.channel_connection_ids(channel_id)
4246            .filter_map(|(connection_id, _)| {
4247                if participant_connection_ids.contains(&connection_id) {
4248                    None
4249                } else {
4250                    Some(connection_id)
4251                }
4252            });
4253    broadcast(None, non_participants, |peer_id| {
4254        session.peer.send(
4255            peer_id,
4256            proto::UpdateChannels {
4257                latest_channel_message_ids: vec![proto::ChannelMessageId {
4258                    channel_id: channel_id.to_proto(),
4259                    message_id: message_id.to_proto(),
4260                }],
4261                ..Default::default()
4262            },
4263        )
4264    });
4265    send_notifications(pool, &session.peer, notifications);
4266
4267    Ok(())
4268}
4269
4270/// Delete a channel message
4271async fn remove_channel_message(
4272    request: proto::RemoveChannelMessage,
4273    response: Response<proto::RemoveChannelMessage>,
4274    session: UserSession,
4275) -> Result<()> {
4276    let channel_id = ChannelId::from_proto(request.channel_id);
4277    let message_id = MessageId::from_proto(request.message_id);
4278    let (connection_ids, existing_notification_ids) = session
4279        .db()
4280        .await
4281        .remove_channel_message(channel_id, message_id, session.user_id())
4282        .await?;
4283
4284    broadcast(
4285        Some(session.connection_id),
4286        connection_ids,
4287        move |connection| {
4288            session.peer.send(connection, request.clone())?;
4289
4290            for notification_id in &existing_notification_ids {
4291                session.peer.send(
4292                    connection,
4293                    proto::DeleteNotification {
4294                        notification_id: (*notification_id).to_proto(),
4295                    },
4296                )?;
4297            }
4298
4299            Ok(())
4300        },
4301    );
4302    response.send(proto::Ack {})?;
4303    Ok(())
4304}
4305
4306async fn update_channel_message(
4307    request: proto::UpdateChannelMessage,
4308    response: Response<proto::UpdateChannelMessage>,
4309    session: UserSession,
4310) -> Result<()> {
4311    let channel_id = ChannelId::from_proto(request.channel_id);
4312    let message_id = MessageId::from_proto(request.message_id);
4313    let updated_at = OffsetDateTime::now_utc();
4314    let UpdatedChannelMessage {
4315        message_id,
4316        participant_connection_ids,
4317        notifications,
4318        reply_to_message_id,
4319        timestamp,
4320        deleted_mention_notification_ids,
4321        updated_mention_notifications,
4322    } = session
4323        .db()
4324        .await
4325        .update_channel_message(
4326            channel_id,
4327            message_id,
4328            session.user_id(),
4329            request.body.as_str(),
4330            &request.mentions,
4331            updated_at,
4332        )
4333        .await?;
4334
4335    let nonce = request
4336        .nonce
4337        .clone()
4338        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4339
4340    let message = proto::ChannelMessage {
4341        sender_id: session.user_id().to_proto(),
4342        id: message_id.to_proto(),
4343        body: request.body.clone(),
4344        mentions: request.mentions.clone(),
4345        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4346        nonce: Some(nonce),
4347        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4348        edited_at: Some(updated_at.unix_timestamp() as u64),
4349    };
4350
4351    response.send(proto::Ack {})?;
4352
4353    let pool = &*session.connection_pool().await;
4354    broadcast(
4355        Some(session.connection_id),
4356        participant_connection_ids,
4357        |connection| {
4358            session.peer.send(
4359                connection,
4360                proto::ChannelMessageUpdate {
4361                    channel_id: channel_id.to_proto(),
4362                    message: Some(message.clone()),
4363                },
4364            )?;
4365
4366            for notification_id in &deleted_mention_notification_ids {
4367                session.peer.send(
4368                    connection,
4369                    proto::DeleteNotification {
4370                        notification_id: (*notification_id).to_proto(),
4371                    },
4372                )?;
4373            }
4374
4375            for notification in &updated_mention_notifications {
4376                session.peer.send(
4377                    connection,
4378                    proto::UpdateNotification {
4379                        notification: Some(notification.clone()),
4380                    },
4381                )?;
4382            }
4383
4384            Ok(())
4385        },
4386    );
4387
4388    send_notifications(pool, &session.peer, notifications);
4389
4390    Ok(())
4391}
4392
4393/// Mark a channel message as read
4394async fn acknowledge_channel_message(
4395    request: proto::AckChannelMessage,
4396    session: UserSession,
4397) -> Result<()> {
4398    let channel_id = ChannelId::from_proto(request.channel_id);
4399    let message_id = MessageId::from_proto(request.message_id);
4400    let notifications = session
4401        .db()
4402        .await
4403        .observe_channel_message(channel_id, session.user_id(), message_id)
4404        .await?;
4405    send_notifications(
4406        &*session.connection_pool().await,
4407        &session.peer,
4408        notifications,
4409    );
4410    Ok(())
4411}
4412
4413/// Mark a buffer version as synced
4414async fn acknowledge_buffer_version(
4415    request: proto::AckBufferOperation,
4416    session: UserSession,
4417) -> Result<()> {
4418    let buffer_id = BufferId::from_proto(request.buffer_id);
4419    session
4420        .db()
4421        .await
4422        .observe_buffer_version(
4423            buffer_id,
4424            session.user_id(),
4425            request.epoch as i32,
4426            &request.version,
4427        )
4428        .await?;
4429    Ok(())
4430}
4431
4432struct CompleteWithLanguageModelRateLimit;
4433
4434impl RateLimit for CompleteWithLanguageModelRateLimit {
4435    fn capacity() -> usize {
4436        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4437            .ok()
4438            .and_then(|v| v.parse().ok())
4439            .unwrap_or(120) // Picked arbitrarily
4440    }
4441
4442    fn refill_duration() -> chrono::Duration {
4443        chrono::Duration::hours(1)
4444    }
4445
4446    fn db_name() -> &'static str {
4447        "complete-with-language-model"
4448    }
4449}
4450
4451async fn complete_with_language_model(
4452    request: proto::CompleteWithLanguageModel,
4453    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4454    session: Session,
4455    open_ai_api_key: Option<Arc<str>>,
4456    google_ai_api_key: Option<Arc<str>>,
4457    anthropic_api_key: Option<Arc<str>>,
4458) -> Result<()> {
4459    let Some(session) = session.for_user() else {
4460        return Err(anyhow!("user not found"))?;
4461    };
4462    authorize_access_to_language_models(&session).await?;
4463    session
4464        .rate_limiter
4465        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4466        .await?;
4467
4468    if request.model.starts_with("gpt") {
4469        let api_key =
4470            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4471        complete_with_open_ai(request, response, session, api_key).await?;
4472    } else if request.model.starts_with("gemini") {
4473        let api_key = google_ai_api_key
4474            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4475        complete_with_google_ai(request, response, session, api_key).await?;
4476    } else if request.model.starts_with("claude") {
4477        let api_key = anthropic_api_key
4478            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4479        complete_with_anthropic(request, response, session, api_key).await?;
4480    }
4481
4482    Ok(())
4483}
4484
4485async fn complete_with_open_ai(
4486    request: proto::CompleteWithLanguageModel,
4487    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4488    session: UserSession,
4489    api_key: Arc<str>,
4490) -> Result<()> {
4491    let mut completion_stream = open_ai::stream_completion(
4492        session.http_client.as_ref(),
4493        OPEN_AI_API_URL,
4494        &api_key,
4495        crate::ai::language_model_request_to_open_ai(request)?,
4496        None,
4497    )
4498    .await
4499    .context("open_ai::stream_completion request failed within collab")?;
4500
4501    while let Some(event) = completion_stream.next().await {
4502        let event = event?;
4503        response.send(proto::LanguageModelResponse {
4504            choices: event
4505                .choices
4506                .into_iter()
4507                .map(|choice| proto::LanguageModelChoiceDelta {
4508                    index: choice.index,
4509                    delta: Some(proto::LanguageModelResponseMessage {
4510                        role: choice.delta.role.map(|role| match role {
4511                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4512                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4513                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4514                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4515                        } as i32),
4516                        content: choice.delta.content,
4517                        tool_calls: choice
4518                            .delta
4519                            .tool_calls
4520                            .unwrap_or_default()
4521                            .into_iter()
4522                            .map(|delta| proto::ToolCallDelta {
4523                                index: delta.index as u32,
4524                                id: delta.id,
4525                                variant: match delta.function {
4526                                    Some(function) => {
4527                                        let name = function.name;
4528                                        let arguments = function.arguments;
4529
4530                                        Some(proto::tool_call_delta::Variant::Function(
4531                                            proto::tool_call_delta::FunctionCallDelta {
4532                                                name,
4533                                                arguments,
4534                                            },
4535                                        ))
4536                                    }
4537                                    None => None,
4538                                },
4539                            })
4540                            .collect(),
4541                    }),
4542                    finish_reason: choice.finish_reason,
4543                })
4544                .collect(),
4545        })?;
4546    }
4547
4548    Ok(())
4549}
4550
4551async fn complete_with_google_ai(
4552    request: proto::CompleteWithLanguageModel,
4553    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4554    session: UserSession,
4555    api_key: Arc<str>,
4556) -> Result<()> {
4557    let mut stream = google_ai::stream_generate_content(
4558        session.http_client.clone(),
4559        google_ai::API_URL,
4560        api_key.as_ref(),
4561        &request.model.clone(),
4562        crate::ai::language_model_request_to_google_ai(request)?,
4563    )
4564    .await
4565    .context("google_ai::stream_generate_content request failed")?;
4566
4567    while let Some(event) = stream.next().await {
4568        let event = event?;
4569        response.send(proto::LanguageModelResponse {
4570            choices: event
4571                .candidates
4572                .unwrap_or_default()
4573                .into_iter()
4574                .map(|candidate| proto::LanguageModelChoiceDelta {
4575                    index: candidate.index as u32,
4576                    delta: Some(proto::LanguageModelResponseMessage {
4577                        role: Some(match candidate.content.role {
4578                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4579                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4580                        } as i32),
4581                        content: Some(
4582                            candidate
4583                                .content
4584                                .parts
4585                                .into_iter()
4586                                .filter_map(|part| match part {
4587                                    google_ai::Part::TextPart(part) => Some(part.text),
4588                                    google_ai::Part::InlineDataPart(_) => None,
4589                                })
4590                                .collect(),
4591                        ),
4592                        // Tool calls are not supported for Google
4593                        tool_calls: Vec::new(),
4594                    }),
4595                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4596                })
4597                .collect(),
4598        })?;
4599    }
4600
4601    Ok(())
4602}
4603
4604async fn complete_with_anthropic(
4605    request: proto::CompleteWithLanguageModel,
4606    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4607    session: UserSession,
4608    api_key: Arc<str>,
4609) -> Result<()> {
4610    let model = anthropic::Model::from_id(&request.model)?;
4611
4612    let mut system_message = String::new();
4613    let messages = request
4614        .messages
4615        .into_iter()
4616        .filter_map(|message| {
4617            match message.role() {
4618                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4619                    role: anthropic::Role::User,
4620                    content: message.content,
4621                }),
4622                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4623                    role: anthropic::Role::Assistant,
4624                    content: message.content,
4625                }),
4626                // Anthropic's API breaks system instructions out as a separate field rather
4627                // than having a system message role.
4628                LanguageModelRole::LanguageModelSystem => {
4629                    if !system_message.is_empty() {
4630                        system_message.push_str("\n\n");
4631                    }
4632                    system_message.push_str(&message.content);
4633
4634                    None
4635                }
4636                // We don't yet support tool calls for Anthropic
4637                LanguageModelRole::LanguageModelTool => None,
4638            }
4639        })
4640        .collect();
4641
4642    let mut stream = anthropic::stream_completion(
4643        session.http_client.as_ref(),
4644        anthropic::ANTHROPIC_API_URL,
4645        &api_key,
4646        anthropic::Request {
4647            model,
4648            messages,
4649            stream: true,
4650            system: system_message,
4651            max_tokens: 4092,
4652        },
4653        None,
4654    )
4655    .await?;
4656
4657    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4658
4659    while let Some(event) = stream.next().await {
4660        let event = event?;
4661
4662        match event {
4663            anthropic::ResponseEvent::MessageStart { message } => {
4664                if let Some(role) = message.role {
4665                    if role == "assistant" {
4666                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4667                    } else if role == "user" {
4668                        current_role = proto::LanguageModelRole::LanguageModelUser;
4669                    }
4670                }
4671            }
4672            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4673                match content_block {
4674                    anthropic::ContentBlock::Text { text } => {
4675                        if !text.is_empty() {
4676                            response.send(proto::LanguageModelResponse {
4677                                choices: vec![proto::LanguageModelChoiceDelta {
4678                                    index: 0,
4679                                    delta: Some(proto::LanguageModelResponseMessage {
4680                                        role: Some(current_role as i32),
4681                                        content: Some(text),
4682                                        tool_calls: Vec::new(),
4683                                    }),
4684                                    finish_reason: None,
4685                                }],
4686                            })?;
4687                        }
4688                    }
4689                }
4690            }
4691            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4692                anthropic::TextDelta::TextDelta { text } => {
4693                    response.send(proto::LanguageModelResponse {
4694                        choices: vec![proto::LanguageModelChoiceDelta {
4695                            index: 0,
4696                            delta: Some(proto::LanguageModelResponseMessage {
4697                                role: Some(current_role as i32),
4698                                content: Some(text),
4699                                tool_calls: Vec::new(),
4700                            }),
4701                            finish_reason: None,
4702                        }],
4703                    })?;
4704                }
4705            },
4706            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4707                if let Some(stop_reason) = delta.stop_reason {
4708                    response.send(proto::LanguageModelResponse {
4709                        choices: vec![proto::LanguageModelChoiceDelta {
4710                            index: 0,
4711                            delta: None,
4712                            finish_reason: Some(stop_reason),
4713                        }],
4714                    })?;
4715                }
4716            }
4717            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4718            anthropic::ResponseEvent::MessageStop {} => {}
4719            anthropic::ResponseEvent::Ping {} => {}
4720        }
4721    }
4722
4723    Ok(())
4724}
4725
4726struct CountTokensWithLanguageModelRateLimit;
4727
4728impl RateLimit for CountTokensWithLanguageModelRateLimit {
4729    fn capacity() -> usize {
4730        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4731            .ok()
4732            .and_then(|v| v.parse().ok())
4733            .unwrap_or(600) // Picked arbitrarily
4734    }
4735
4736    fn refill_duration() -> chrono::Duration {
4737        chrono::Duration::hours(1)
4738    }
4739
4740    fn db_name() -> &'static str {
4741        "count-tokens-with-language-model"
4742    }
4743}
4744
4745async fn count_tokens_with_language_model(
4746    request: proto::CountTokensWithLanguageModel,
4747    response: Response<proto::CountTokensWithLanguageModel>,
4748    session: UserSession,
4749    google_ai_api_key: Option<Arc<str>>,
4750) -> Result<()> {
4751    authorize_access_to_language_models(&session).await?;
4752
4753    if !request.model.starts_with("gemini") {
4754        return Err(anyhow!(
4755            "counting tokens for model: {:?} is not supported",
4756            request.model
4757        ))?;
4758    }
4759
4760    session
4761        .rate_limiter
4762        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4763        .await?;
4764
4765    let api_key = google_ai_api_key
4766        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4767    let tokens_response = google_ai::count_tokens(
4768        session.http_client.as_ref(),
4769        google_ai::API_URL,
4770        &api_key,
4771        crate::ai::count_tokens_request_to_google_ai(request)?,
4772    )
4773    .await?;
4774    response.send(proto::CountTokensResponse {
4775        token_count: tokens_response.total_tokens as u32,
4776    })?;
4777    Ok(())
4778}
4779
4780struct ComputeEmbeddingsRateLimit;
4781
4782impl RateLimit for ComputeEmbeddingsRateLimit {
4783    fn capacity() -> usize {
4784        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4785            .ok()
4786            .and_then(|v| v.parse().ok())
4787            .unwrap_or(5000) // Picked arbitrarily
4788    }
4789
4790    fn refill_duration() -> chrono::Duration {
4791        chrono::Duration::hours(1)
4792    }
4793
4794    fn db_name() -> &'static str {
4795        "compute-embeddings"
4796    }
4797}
4798
4799async fn compute_embeddings(
4800    request: proto::ComputeEmbeddings,
4801    response: Response<proto::ComputeEmbeddings>,
4802    session: UserSession,
4803    api_key: Option<Arc<str>>,
4804) -> Result<()> {
4805    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4806    authorize_access_to_language_models(&session).await?;
4807
4808    session
4809        .rate_limiter
4810        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4811        .await?;
4812
4813    let embeddings = match request.model.as_str() {
4814        "openai/text-embedding-3-small" => {
4815            open_ai::embed(
4816                session.http_client.as_ref(),
4817                OPEN_AI_API_URL,
4818                &api_key,
4819                OpenAiEmbeddingModel::TextEmbedding3Small,
4820                request.texts.iter().map(|text| text.as_str()),
4821            )
4822            .await?
4823        }
4824        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4825    };
4826
4827    let embeddings = request
4828        .texts
4829        .iter()
4830        .map(|text| {
4831            let mut hasher = sha2::Sha256::new();
4832            hasher.update(text.as_bytes());
4833            let result = hasher.finalize();
4834            result.to_vec()
4835        })
4836        .zip(
4837            embeddings
4838                .data
4839                .into_iter()
4840                .map(|embedding| embedding.embedding),
4841        )
4842        .collect::<HashMap<_, _>>();
4843
4844    let db = session.db().await;
4845    db.save_embeddings(&request.model, &embeddings)
4846        .await
4847        .context("failed to save embeddings")
4848        .trace_err();
4849
4850    response.send(proto::ComputeEmbeddingsResponse {
4851        embeddings: embeddings
4852            .into_iter()
4853            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4854            .collect(),
4855    })?;
4856    Ok(())
4857}
4858
4859async fn get_cached_embeddings(
4860    request: proto::GetCachedEmbeddings,
4861    response: Response<proto::GetCachedEmbeddings>,
4862    session: UserSession,
4863) -> Result<()> {
4864    authorize_access_to_language_models(&session).await?;
4865
4866    let db = session.db().await;
4867    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4868
4869    response.send(proto::GetCachedEmbeddingsResponse {
4870        embeddings: embeddings
4871            .into_iter()
4872            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4873            .collect(),
4874    })?;
4875    Ok(())
4876}
4877
4878async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4879    let db = session.db().await;
4880    let flags = db.get_user_flags(session.user_id()).await?;
4881    if flags.iter().any(|flag| flag == "language-models") {
4882        Ok(())
4883    } else {
4884        Err(anyhow!("permission denied"))?
4885    }
4886}
4887
4888/// Get a Supermaven API key for the user
4889async fn get_supermaven_api_key(
4890    _request: proto::GetSupermavenApiKey,
4891    response: Response<proto::GetSupermavenApiKey>,
4892    session: UserSession,
4893) -> Result<()> {
4894    let user_id: String = session.user_id().to_string();
4895    if !session.is_staff() {
4896        return Err(anyhow!("supermaven not enabled for this account"))?;
4897    }
4898
4899    let email = session
4900        .email()
4901        .ok_or_else(|| anyhow!("user must have an email"))?;
4902
4903    let supermaven_admin_api = session
4904        .supermaven_client
4905        .as_ref()
4906        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4907
4908    let result = supermaven_admin_api
4909        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4910        .await?;
4911
4912    response.send(proto::GetSupermavenApiKeyResponse {
4913        api_key: result.api_key,
4914    })?;
4915
4916    Ok(())
4917}
4918
4919/// Start receiving chat updates for a channel
4920async fn join_channel_chat(
4921    request: proto::JoinChannelChat,
4922    response: Response<proto::JoinChannelChat>,
4923    session: UserSession,
4924) -> Result<()> {
4925    let channel_id = ChannelId::from_proto(request.channel_id);
4926
4927    let db = session.db().await;
4928    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4929        .await?;
4930    let messages = db
4931        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4932        .await?;
4933    response.send(proto::JoinChannelChatResponse {
4934        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4935        messages,
4936    })?;
4937    Ok(())
4938}
4939
4940/// Stop receiving chat updates for a channel
4941async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4942    let channel_id = ChannelId::from_proto(request.channel_id);
4943    session
4944        .db()
4945        .await
4946        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4947        .await?;
4948    Ok(())
4949}
4950
4951/// Retrieve the chat history for a channel
4952async fn get_channel_messages(
4953    request: proto::GetChannelMessages,
4954    response: Response<proto::GetChannelMessages>,
4955    session: UserSession,
4956) -> Result<()> {
4957    let channel_id = ChannelId::from_proto(request.channel_id);
4958    let messages = session
4959        .db()
4960        .await
4961        .get_channel_messages(
4962            channel_id,
4963            session.user_id(),
4964            MESSAGE_COUNT_PER_PAGE,
4965            Some(MessageId::from_proto(request.before_message_id)),
4966        )
4967        .await?;
4968    response.send(proto::GetChannelMessagesResponse {
4969        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4970        messages,
4971    })?;
4972    Ok(())
4973}
4974
4975/// Retrieve specific chat messages
4976async fn get_channel_messages_by_id(
4977    request: proto::GetChannelMessagesById,
4978    response: Response<proto::GetChannelMessagesById>,
4979    session: UserSession,
4980) -> Result<()> {
4981    let message_ids = request
4982        .message_ids
4983        .iter()
4984        .map(|id| MessageId::from_proto(*id))
4985        .collect::<Vec<_>>();
4986    let messages = session
4987        .db()
4988        .await
4989        .get_channel_messages_by_id(session.user_id(), &message_ids)
4990        .await?;
4991    response.send(proto::GetChannelMessagesResponse {
4992        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4993        messages,
4994    })?;
4995    Ok(())
4996}
4997
4998/// Retrieve the current users notifications
4999async fn get_notifications(
5000    request: proto::GetNotifications,
5001    response: Response<proto::GetNotifications>,
5002    session: UserSession,
5003) -> Result<()> {
5004    let notifications = session
5005        .db()
5006        .await
5007        .get_notifications(
5008            session.user_id(),
5009            NOTIFICATION_COUNT_PER_PAGE,
5010            request
5011                .before_id
5012                .map(|id| db::NotificationId::from_proto(id)),
5013        )
5014        .await?;
5015    response.send(proto::GetNotificationsResponse {
5016        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5017        notifications,
5018    })?;
5019    Ok(())
5020}
5021
5022/// Mark notifications as read
5023async fn mark_notification_as_read(
5024    request: proto::MarkNotificationRead,
5025    response: Response<proto::MarkNotificationRead>,
5026    session: UserSession,
5027) -> Result<()> {
5028    let database = &session.db().await;
5029    let notifications = database
5030        .mark_notification_as_read_by_id(
5031            session.user_id(),
5032            NotificationId::from_proto(request.notification_id),
5033        )
5034        .await?;
5035    send_notifications(
5036        &*session.connection_pool().await,
5037        &session.peer,
5038        notifications,
5039    );
5040    response.send(proto::Ack {})?;
5041    Ok(())
5042}
5043
5044/// Get the current users information
5045async fn get_private_user_info(
5046    _request: proto::GetPrivateUserInfo,
5047    response: Response<proto::GetPrivateUserInfo>,
5048    session: UserSession,
5049) -> Result<()> {
5050    let db = session.db().await;
5051
5052    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5053    let user = db
5054        .get_user_by_id(session.user_id())
5055        .await?
5056        .ok_or_else(|| anyhow!("user not found"))?;
5057    let flags = db.get_user_flags(session.user_id()).await?;
5058
5059    response.send(proto::GetPrivateUserInfoResponse {
5060        metrics_id,
5061        staff: user.admin,
5062        flags,
5063    })?;
5064    Ok(())
5065}
5066
5067fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5068    match message {
5069        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5070        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5071        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5072        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5073        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5074            code: frame.code.into(),
5075            reason: frame.reason,
5076        })),
5077    }
5078}
5079
5080fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5081    match message {
5082        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5083        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5084        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5085        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5086        AxumMessage::Close(frame) => {
5087            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5088                code: frame.code.into(),
5089                reason: frame.reason,
5090            }))
5091        }
5092    }
5093}
5094
5095fn notify_membership_updated(
5096    connection_pool: &mut ConnectionPool,
5097    result: MembershipUpdated,
5098    user_id: UserId,
5099    peer: &Peer,
5100) {
5101    for membership in &result.new_channels.channel_memberships {
5102        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5103    }
5104    for channel_id in &result.removed_channels {
5105        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5106    }
5107
5108    let user_channels_update = proto::UpdateUserChannels {
5109        channel_memberships: result
5110            .new_channels
5111            .channel_memberships
5112            .iter()
5113            .map(|cm| proto::ChannelMembership {
5114                channel_id: cm.channel_id.to_proto(),
5115                role: cm.role.into(),
5116            })
5117            .collect(),
5118        ..Default::default()
5119    };
5120
5121    let mut update = build_channels_update(result.new_channels);
5122    update.delete_channels = result
5123        .removed_channels
5124        .into_iter()
5125        .map(|id| id.to_proto())
5126        .collect();
5127    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5128
5129    for connection_id in connection_pool.user_connection_ids(user_id) {
5130        peer.send(connection_id, user_channels_update.clone())
5131            .trace_err();
5132        peer.send(connection_id, update.clone()).trace_err();
5133    }
5134}
5135
5136fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5137    proto::UpdateUserChannels {
5138        channel_memberships: channels
5139            .channel_memberships
5140            .iter()
5141            .map(|m| proto::ChannelMembership {
5142                channel_id: m.channel_id.to_proto(),
5143                role: m.role.into(),
5144            })
5145            .collect(),
5146        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5147        observed_channel_message_id: channels.observed_channel_messages.clone(),
5148    }
5149}
5150
5151fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5152    let mut update = proto::UpdateChannels::default();
5153
5154    for channel in channels.channels {
5155        update.channels.push(channel.to_proto());
5156    }
5157
5158    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5159    update.latest_channel_message_ids = channels.latest_channel_messages;
5160
5161    for (channel_id, participants) in channels.channel_participants {
5162        update
5163            .channel_participants
5164            .push(proto::ChannelParticipants {
5165                channel_id: channel_id.to_proto(),
5166                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5167            });
5168    }
5169
5170    for channel in channels.invited_channels {
5171        update.channel_invitations.push(channel.to_proto());
5172    }
5173
5174    update.hosted_projects = channels.hosted_projects;
5175    update
5176}
5177
5178fn build_initial_contacts_update(
5179    contacts: Vec<db::Contact>,
5180    pool: &ConnectionPool,
5181) -> proto::UpdateContacts {
5182    let mut update = proto::UpdateContacts::default();
5183
5184    for contact in contacts {
5185        match contact {
5186            db::Contact::Accepted { user_id, busy } => {
5187                update.contacts.push(contact_for_user(user_id, busy, &pool));
5188            }
5189            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5190            db::Contact::Incoming { user_id } => {
5191                update
5192                    .incoming_requests
5193                    .push(proto::IncomingContactRequest {
5194                        requester_id: user_id.to_proto(),
5195                    })
5196            }
5197        }
5198    }
5199
5200    update
5201}
5202
5203fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5204    proto::Contact {
5205        user_id: user_id.to_proto(),
5206        online: pool.is_user_online(user_id),
5207        busy,
5208    }
5209}
5210
5211fn room_updated(room: &proto::Room, peer: &Peer) {
5212    broadcast(
5213        None,
5214        room.participants
5215            .iter()
5216            .filter_map(|participant| Some(participant.peer_id?.into())),
5217        |peer_id| {
5218            peer.send(
5219                peer_id,
5220                proto::RoomUpdated {
5221                    room: Some(room.clone()),
5222                },
5223            )
5224        },
5225    );
5226}
5227
5228fn channel_updated(
5229    channel: &db::channel::Model,
5230    room: &proto::Room,
5231    peer: &Peer,
5232    pool: &ConnectionPool,
5233) {
5234    let participants = room
5235        .participants
5236        .iter()
5237        .map(|p| p.user_id)
5238        .collect::<Vec<_>>();
5239
5240    broadcast(
5241        None,
5242        pool.channel_connection_ids(channel.root_id())
5243            .filter_map(|(channel_id, role)| {
5244                role.can_see_channel(channel.visibility).then(|| channel_id)
5245            }),
5246        |peer_id| {
5247            peer.send(
5248                peer_id,
5249                proto::UpdateChannels {
5250                    channel_participants: vec![proto::ChannelParticipants {
5251                        channel_id: channel.id.to_proto(),
5252                        participant_user_ids: participants.clone(),
5253                    }],
5254                    ..Default::default()
5255                },
5256            )
5257        },
5258    );
5259}
5260
5261async fn send_dev_server_projects_update(
5262    user_id: UserId,
5263    mut status: proto::DevServerProjectsUpdate,
5264    session: &Session,
5265) {
5266    let pool = session.connection_pool().await;
5267    for dev_server in &mut status.dev_servers {
5268        dev_server.status =
5269            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5270    }
5271    let connections = pool.user_connection_ids(user_id);
5272    for connection_id in connections {
5273        session.peer.send(connection_id, status.clone()).trace_err();
5274    }
5275}
5276
5277async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5278    let db = session.db().await;
5279
5280    let contacts = db.get_contacts(user_id).await?;
5281    let busy = db.is_user_busy(user_id).await?;
5282
5283    let pool = session.connection_pool().await;
5284    let updated_contact = contact_for_user(user_id, busy, &pool);
5285    for contact in contacts {
5286        if let db::Contact::Accepted {
5287            user_id: contact_user_id,
5288            ..
5289        } = contact
5290        {
5291            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5292                session
5293                    .peer
5294                    .send(
5295                        contact_conn_id,
5296                        proto::UpdateContacts {
5297                            contacts: vec![updated_contact.clone()],
5298                            remove_contacts: Default::default(),
5299                            incoming_requests: Default::default(),
5300                            remove_incoming_requests: Default::default(),
5301                            outgoing_requests: Default::default(),
5302                            remove_outgoing_requests: Default::default(),
5303                        },
5304                    )
5305                    .trace_err();
5306            }
5307        }
5308    }
5309    Ok(())
5310}
5311
5312async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5313    log::info!("lost dev server connection, unsharing projects");
5314    let project_ids = session
5315        .db()
5316        .await
5317        .get_stale_dev_server_projects(session.connection_id)
5318        .await?;
5319
5320    for project_id in project_ids {
5321        // not unshare re-checks the connection ids match, so we get away with no transaction
5322        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5323    }
5324
5325    let user_id = session.dev_server().user_id;
5326    let update = session
5327        .db()
5328        .await
5329        .dev_server_projects_update(user_id)
5330        .await?;
5331
5332    send_dev_server_projects_update(user_id, update, session).await;
5333
5334    Ok(())
5335}
5336
5337async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5338    let mut contacts_to_update = HashSet::default();
5339
5340    let room_id;
5341    let canceled_calls_to_user_ids;
5342    let live_kit_room;
5343    let delete_live_kit_room;
5344    let room;
5345    let channel;
5346
5347    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5348        contacts_to_update.insert(session.user_id());
5349
5350        for project in left_room.left_projects.values() {
5351            project_left(project, session);
5352        }
5353
5354        room_id = RoomId::from_proto(left_room.room.id);
5355        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5356        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5357        delete_live_kit_room = left_room.deleted;
5358        room = mem::take(&mut left_room.room);
5359        channel = mem::take(&mut left_room.channel);
5360
5361        room_updated(&room, &session.peer);
5362    } else {
5363        return Ok(());
5364    }
5365
5366    if let Some(channel) = channel {
5367        channel_updated(
5368            &channel,
5369            &room,
5370            &session.peer,
5371            &*session.connection_pool().await,
5372        );
5373    }
5374
5375    {
5376        let pool = session.connection_pool().await;
5377        for canceled_user_id in canceled_calls_to_user_ids {
5378            for connection_id in pool.user_connection_ids(canceled_user_id) {
5379                session
5380                    .peer
5381                    .send(
5382                        connection_id,
5383                        proto::CallCanceled {
5384                            room_id: room_id.to_proto(),
5385                        },
5386                    )
5387                    .trace_err();
5388            }
5389            contacts_to_update.insert(canceled_user_id);
5390        }
5391    }
5392
5393    for contact_user_id in contacts_to_update {
5394        update_user_contacts(contact_user_id, &session).await?;
5395    }
5396
5397    if let Some(live_kit) = session.live_kit_client.as_ref() {
5398        live_kit
5399            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5400            .await
5401            .trace_err();
5402
5403        if delete_live_kit_room {
5404            live_kit.delete_room(live_kit_room).await.trace_err();
5405        }
5406    }
5407
5408    Ok(())
5409}
5410
5411async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5412    let left_channel_buffers = session
5413        .db()
5414        .await
5415        .leave_channel_buffers(session.connection_id)
5416        .await?;
5417
5418    for left_buffer in left_channel_buffers {
5419        channel_buffer_updated(
5420            session.connection_id,
5421            left_buffer.connections,
5422            &proto::UpdateChannelBufferCollaborators {
5423                channel_id: left_buffer.channel_id.to_proto(),
5424                collaborators: left_buffer.collaborators,
5425            },
5426            &session.peer,
5427        );
5428    }
5429
5430    Ok(())
5431}
5432
5433fn project_left(project: &db::LeftProject, session: &UserSession) {
5434    for connection_id in &project.connection_ids {
5435        if project.should_unshare {
5436            session
5437                .peer
5438                .send(
5439                    *connection_id,
5440                    proto::UnshareProject {
5441                        project_id: project.id.to_proto(),
5442                    },
5443                )
5444                .trace_err();
5445        } else {
5446            session
5447                .peer
5448                .send(
5449                    *connection_id,
5450                    proto::RemoveProjectCollaborator {
5451                        project_id: project.id.to_proto(),
5452                        peer_id: Some(session.connection_id.into()),
5453                    },
5454                )
5455                .trace_err();
5456        }
5457    }
5458}
5459
5460pub trait ResultExt {
5461    type Ok;
5462
5463    fn trace_err(self) -> Option<Self::Ok>;
5464}
5465
5466impl<T, E> ResultExt for Result<T, E>
5467where
5468    E: std::fmt::Debug,
5469{
5470    type Ok = T;
5471
5472    #[track_caller]
5473    fn trace_err(self) -> Option<T> {
5474        match self {
5475            Ok(value) => Some(value),
5476            Err(error) => {
5477                tracing::error!("{:?}", error);
5478                None
5479            }
5480        }
5481    }
5482}