rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  50        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(update_dev_server_project))
 435            .add_request_handler(user_handler(delete_dev_server_project))
 436            .add_request_handler(user_handler(create_dev_server))
 437            .add_request_handler(user_handler(regenerate_dev_server_token))
 438            .add_request_handler(user_handler(rename_dev_server))
 439            .add_request_handler(user_handler(delete_dev_server))
 440            .add_request_handler(user_handler(list_remote_directory))
 441            .add_request_handler(dev_server_handler(share_dev_server_project))
 442            .add_request_handler(dev_server_handler(shutdown_dev_server))
 443            .add_request_handler(dev_server_handler(reconnect_dev_server))
 444            .add_message_handler(user_message_handler(leave_project))
 445            .add_request_handler(update_project)
 446            .add_request_handler(update_worktree)
 447            .add_message_handler(start_language_server)
 448            .add_message_handler(update_language_server)
 449            .add_message_handler(update_diagnostic_summary)
 450            .add_message_handler(update_worktree_settings)
 451            .add_request_handler(user_handler(
 452                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 453            ))
 454            .add_request_handler(user_handler(
 455                forward_project_request_for_owner::<proto::TaskTemplates>,
 456            ))
 457            .add_request_handler(user_handler(
 458                forward_read_only_project_request::<proto::GetHover>,
 459            ))
 460            .add_request_handler(user_handler(
 461                forward_read_only_project_request::<proto::GetDefinition>,
 462            ))
 463            .add_request_handler(user_handler(
 464                forward_read_only_project_request::<proto::GetTypeDefinition>,
 465            ))
 466            .add_request_handler(user_handler(
 467                forward_read_only_project_request::<proto::GetReferences>,
 468            ))
 469            .add_request_handler(user_handler(
 470                forward_read_only_project_request::<proto::SearchProject>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetProjectSymbols>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::OpenBufferById>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::InlayHints>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::OpenBufferByPath>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_mutating_project_request::<proto::GetCompletions>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::GetCodeActions>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::ApplyCodeAction>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::PrepareRename>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::PerformRename>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ReloadBuffers>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::FormatBuffers>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::CreateProjectEntry>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::RenameProjectEntry>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::CopyProjectEntry>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::OnTypeFormatting>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::BlameBuffer>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::MultiLspQuery>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::RestartLanguageServers>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::LinkedEditingRange>,
 555            ))
 556            .add_message_handler(create_buffer_for_peer)
 557            .add_request_handler(update_buffer)
 558            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 559            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 560            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 561            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 562            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 563            .add_request_handler(get_users)
 564            .add_request_handler(user_handler(fuzzy_search_users))
 565            .add_request_handler(user_handler(request_contact))
 566            .add_request_handler(user_handler(remove_contact))
 567            .add_request_handler(user_handler(respond_to_contact_request))
 568            .add_message_handler(subscribe_to_channels)
 569            .add_request_handler(user_handler(create_channel))
 570            .add_request_handler(user_handler(delete_channel))
 571            .add_request_handler(user_handler(invite_channel_member))
 572            .add_request_handler(user_handler(remove_channel_member))
 573            .add_request_handler(user_handler(set_channel_member_role))
 574            .add_request_handler(user_handler(set_channel_visibility))
 575            .add_request_handler(user_handler(rename_channel))
 576            .add_request_handler(user_handler(join_channel_buffer))
 577            .add_request_handler(user_handler(leave_channel_buffer))
 578            .add_message_handler(user_message_handler(update_channel_buffer))
 579            .add_request_handler(user_handler(rejoin_channel_buffers))
 580            .add_request_handler(user_handler(get_channel_members))
 581            .add_request_handler(user_handler(respond_to_channel_invite))
 582            .add_request_handler(user_handler(join_channel))
 583            .add_request_handler(user_handler(join_channel_chat))
 584            .add_message_handler(user_message_handler(leave_channel_chat))
 585            .add_request_handler(user_handler(send_channel_message))
 586            .add_request_handler(user_handler(remove_channel_message))
 587            .add_request_handler(user_handler(update_channel_message))
 588            .add_request_handler(user_handler(get_channel_messages))
 589            .add_request_handler(user_handler(get_channel_messages_by_id))
 590            .add_request_handler(user_handler(get_notifications))
 591            .add_request_handler(user_handler(mark_notification_as_read))
 592            .add_request_handler(user_handler(move_channel))
 593            .add_request_handler(user_handler(follow))
 594            .add_message_handler(user_message_handler(unfollow))
 595            .add_message_handler(user_message_handler(update_followers))
 596            .add_request_handler(user_handler(get_private_user_info))
 597            .add_message_handler(user_message_handler(acknowledge_channel_message))
 598            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 599            .add_request_handler(user_handler(get_supermaven_api_key))
 600            .add_request_handler(user_handler(
 601                forward_mutating_project_request::<proto::OpenContext>,
 602            ))
 603            .add_request_handler(user_handler(
 604                forward_mutating_project_request::<proto::SynchronizeContexts>,
 605            ))
 606            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 607            .add_message_handler(update_context)
 608            .add_streaming_request_handler({
 609                let app_state = app_state.clone();
 610                move |request, response, session| {
 611                    complete_with_language_model(
 612                        request,
 613                        response,
 614                        session,
 615                        app_state.config.openai_api_key.clone(),
 616                        app_state.config.google_ai_api_key.clone(),
 617                        app_state.config.anthropic_api_key.clone(),
 618                    )
 619                }
 620            })
 621            .add_request_handler({
 622                let app_state = app_state.clone();
 623                user_handler(move |request, response, session| {
 624                    count_tokens_with_language_model(
 625                        request,
 626                        response,
 627                        session,
 628                        app_state.config.google_ai_api_key.clone(),
 629                    )
 630                })
 631            })
 632            .add_request_handler({
 633                user_handler(move |request, response, session| {
 634                    get_cached_embeddings(request, response, session)
 635                })
 636            })
 637            .add_request_handler({
 638                let app_state = app_state.clone();
 639                user_handler(move |request, response, session| {
 640                    compute_embeddings(
 641                        request,
 642                        response,
 643                        session,
 644                        app_state.config.openai_api_key.clone(),
 645                    )
 646                })
 647            });
 648
 649        Arc::new(server)
 650    }
 651
 652    pub async fn start(&self) -> Result<()> {
 653        let server_id = *self.id.lock();
 654        let app_state = self.app_state.clone();
 655        let peer = self.peer.clone();
 656        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 657        let pool = self.connection_pool.clone();
 658        let live_kit_client = self.app_state.live_kit_client.clone();
 659
 660        let span = info_span!("start server");
 661        self.app_state.executor.spawn_detached(
 662            async move {
 663                tracing::info!("waiting for cleanup timeout");
 664                timeout.await;
 665                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 666                if let Some((room_ids, channel_ids)) = app_state
 667                    .db
 668                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 669                    .await
 670                    .trace_err()
 671                {
 672                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 673                    tracing::info!(
 674                        stale_channel_buffer_count = channel_ids.len(),
 675                        "retrieved stale channel buffers"
 676                    );
 677
 678                    for channel_id in channel_ids {
 679                        if let Some(refreshed_channel_buffer) = app_state
 680                            .db
 681                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 682                            .await
 683                            .trace_err()
 684                        {
 685                            for connection_id in refreshed_channel_buffer.connection_ids {
 686                                peer.send(
 687                                    connection_id,
 688                                    proto::UpdateChannelBufferCollaborators {
 689                                        channel_id: channel_id.to_proto(),
 690                                        collaborators: refreshed_channel_buffer
 691                                            .collaborators
 692                                            .clone(),
 693                                    },
 694                                )
 695                                .trace_err();
 696                            }
 697                        }
 698                    }
 699
 700                    for room_id in room_ids {
 701                        let mut contacts_to_update = HashSet::default();
 702                        let mut canceled_calls_to_user_ids = Vec::new();
 703                        let mut live_kit_room = String::new();
 704                        let mut delete_live_kit_room = false;
 705
 706                        if let Some(mut refreshed_room) = app_state
 707                            .db
 708                            .clear_stale_room_participants(room_id, server_id)
 709                            .await
 710                            .trace_err()
 711                        {
 712                            tracing::info!(
 713                                room_id = room_id.0,
 714                                new_participant_count = refreshed_room.room.participants.len(),
 715                                "refreshed room"
 716                            );
 717                            room_updated(&refreshed_room.room, &peer);
 718                            if let Some(channel) = refreshed_room.channel.as_ref() {
 719                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 720                            }
 721                            contacts_to_update
 722                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 723                            contacts_to_update
 724                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 725                            canceled_calls_to_user_ids =
 726                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 727                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 728                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 729                        }
 730
 731                        {
 732                            let pool = pool.lock();
 733                            for canceled_user_id in canceled_calls_to_user_ids {
 734                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 735                                    peer.send(
 736                                        connection_id,
 737                                        proto::CallCanceled {
 738                                            room_id: room_id.to_proto(),
 739                                        },
 740                                    )
 741                                    .trace_err();
 742                                }
 743                            }
 744                        }
 745
 746                        for user_id in contacts_to_update {
 747                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 748                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 749                            if let Some((busy, contacts)) = busy.zip(contacts) {
 750                                let pool = pool.lock();
 751                                let updated_contact = contact_for_user(user_id, busy, &pool);
 752                                for contact in contacts {
 753                                    if let db::Contact::Accepted {
 754                                        user_id: contact_user_id,
 755                                        ..
 756                                    } = contact
 757                                    {
 758                                        for contact_conn_id in
 759                                            pool.user_connection_ids(contact_user_id)
 760                                        {
 761                                            peer.send(
 762                                                contact_conn_id,
 763                                                proto::UpdateContacts {
 764                                                    contacts: vec![updated_contact.clone()],
 765                                                    remove_contacts: Default::default(),
 766                                                    incoming_requests: Default::default(),
 767                                                    remove_incoming_requests: Default::default(),
 768                                                    outgoing_requests: Default::default(),
 769                                                    remove_outgoing_requests: Default::default(),
 770                                                },
 771                                            )
 772                                            .trace_err();
 773                                        }
 774                                    }
 775                                }
 776                            }
 777                        }
 778
 779                        if let Some(live_kit) = live_kit_client.as_ref() {
 780                            if delete_live_kit_room {
 781                                live_kit.delete_room(live_kit_room).await.trace_err();
 782                            }
 783                        }
 784                    }
 785                }
 786
 787                app_state
 788                    .db
 789                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 790                    .await
 791                    .trace_err();
 792            }
 793            .instrument(span),
 794        );
 795        Ok(())
 796    }
 797
 798    pub fn teardown(&self) {
 799        self.peer.teardown();
 800        self.connection_pool.lock().reset();
 801        let _ = self.teardown.send(true);
 802    }
 803
 804    #[cfg(test)]
 805    pub fn reset(&self, id: ServerId) {
 806        self.teardown();
 807        *self.id.lock() = id;
 808        self.peer.reset(id.0 as u32);
 809        let _ = self.teardown.send(false);
 810    }
 811
 812    #[cfg(test)]
 813    pub fn id(&self) -> ServerId {
 814        *self.id.lock()
 815    }
 816
 817    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 818    where
 819        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 820        Fut: 'static + Send + Future<Output = Result<()>>,
 821        M: EnvelopedMessage,
 822    {
 823        let prev_handler = self.handlers.insert(
 824            TypeId::of::<M>(),
 825            Box::new(move |envelope, session| {
 826                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 827                let received_at = envelope.received_at;
 828                tracing::info!("message received");
 829                let start_time = Instant::now();
 830                let future = (handler)(*envelope, session);
 831                async move {
 832                    let result = future.await;
 833                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 834                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 835                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 836                    let payload_type = M::NAME;
 837
 838                    match result {
 839                        Err(error) => {
 840                            tracing::error!(
 841                                ?error,
 842                                total_duration_ms,
 843                                processing_duration_ms,
 844                                queue_duration_ms,
 845                                payload_type,
 846                                "error handling message"
 847                            )
 848                        }
 849                        Ok(()) => tracing::info!(
 850                            total_duration_ms,
 851                            processing_duration_ms,
 852                            queue_duration_ms,
 853                            "finished handling message"
 854                        ),
 855                    }
 856                }
 857                .boxed()
 858            }),
 859        );
 860        if prev_handler.is_some() {
 861            panic!("registered a handler for the same message twice");
 862        }
 863        self
 864    }
 865
 866    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 867    where
 868        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 869        Fut: 'static + Send + Future<Output = Result<()>>,
 870        M: EnvelopedMessage,
 871    {
 872        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 873        self
 874    }
 875
 876    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 877    where
 878        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 879        Fut: Send + Future<Output = Result<()>>,
 880        M: RequestMessage,
 881    {
 882        let handler = Arc::new(handler);
 883        self.add_handler(move |envelope, session| {
 884            let receipt = envelope.receipt();
 885            let handler = handler.clone();
 886            async move {
 887                let peer = session.peer.clone();
 888                let responded = Arc::new(AtomicBool::default());
 889                let response = Response {
 890                    peer: peer.clone(),
 891                    responded: responded.clone(),
 892                    receipt,
 893                };
 894                match (handler)(envelope.payload, response, session).await {
 895                    Ok(()) => {
 896                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 897                            Ok(())
 898                        } else {
 899                            Err(anyhow!("handler did not send a response"))?
 900                        }
 901                    }
 902                    Err(error) => {
 903                        let proto_err = match &error {
 904                            Error::Internal(err) => err.to_proto(),
 905                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 906                        };
 907                        peer.respond_with_error(receipt, proto_err)?;
 908                        Err(error)
 909                    }
 910                }
 911            }
 912        })
 913    }
 914
 915    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 916    where
 917        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 918        Fut: Send + Future<Output = Result<()>>,
 919        M: RequestMessage,
 920    {
 921        let handler = Arc::new(handler);
 922        self.add_handler(move |envelope, session| {
 923            let receipt = envelope.receipt();
 924            let handler = handler.clone();
 925            async move {
 926                let peer = session.peer.clone();
 927                let response = StreamingResponse {
 928                    peer: peer.clone(),
 929                    receipt,
 930                };
 931                match (handler)(envelope.payload, response, session).await {
 932                    Ok(()) => {
 933                        peer.end_stream(receipt)?;
 934                        Ok(())
 935                    }
 936                    Err(error) => {
 937                        let proto_err = match &error {
 938                            Error::Internal(err) => err.to_proto(),
 939                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 940                        };
 941                        peer.respond_with_error(receipt, proto_err)?;
 942                        Err(error)
 943                    }
 944                }
 945            }
 946        })
 947    }
 948
 949    #[allow(clippy::too_many_arguments)]
 950    pub fn handle_connection(
 951        self: &Arc<Self>,
 952        connection: Connection,
 953        address: String,
 954        principal: Principal,
 955        zed_version: ZedVersion,
 956        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 957        executor: Executor,
 958    ) -> impl Future<Output = ()> {
 959        let this = self.clone();
 960        let span = info_span!("handle connection", %address,
 961            connection_id=field::Empty,
 962            user_id=field::Empty,
 963            login=field::Empty,
 964            impersonator=field::Empty,
 965            dev_server_id=field::Empty
 966        );
 967        principal.update_span(&span);
 968
 969        let mut teardown = self.teardown.subscribe();
 970        async move {
 971            if *teardown.borrow() {
 972                tracing::error!("server is tearing down");
 973                return
 974            }
 975            let (connection_id, handle_io, mut incoming_rx) = this
 976                .peer
 977                .add_connection(connection, {
 978                    let executor = executor.clone();
 979                    move |duration| executor.sleep(duration)
 980                });
 981            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 982            tracing::info!("connection opened");
 983
 984            let http_client = match IsahcHttpClient::new() {
 985                Ok(http_client) => Arc::new(http_client),
 986                Err(error) => {
 987                    tracing::error!(?error, "failed to create HTTP client");
 988                    return;
 989                }
 990            };
 991
 992            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 993                Some(Arc::new(SupermavenAdminApi::new(
 994                    supermaven_admin_api_key.to_string(),
 995                    http_client.clone(),
 996                )))
 997            } else {
 998                None
 999            };
1000
1001            let session = Session {
1002                principal: principal.clone(),
1003                connection_id,
1004                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1005                peer: this.peer.clone(),
1006                connection_pool: this.connection_pool.clone(),
1007                live_kit_client: this.app_state.live_kit_client.clone(),
1008                http_client,
1009                rate_limiter: this.app_state.rate_limiter.clone(),
1010                _executor: executor.clone(),
1011                supermaven_client,
1012            };
1013
1014            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1015                tracing::error!(?error, "failed to send initial client update");
1016                return;
1017            }
1018
1019            let handle_io = handle_io.fuse();
1020            futures::pin_mut!(handle_io);
1021
1022            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1023            // This prevents deadlocks when e.g., client A performs a request to client B and
1024            // client B performs a request to client A. If both clients stop processing further
1025            // messages until their respective request completes, they won't have a chance to
1026            // respond to the other client's request and cause a deadlock.
1027            //
1028            // This arrangement ensures we will attempt to process earlier messages first, but fall
1029            // back to processing messages arrived later in the spirit of making progress.
1030            let mut foreground_message_handlers = FuturesUnordered::new();
1031            let concurrent_handlers = Arc::new(Semaphore::new(256));
1032            loop {
1033                let next_message = async {
1034                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1035                    let message = incoming_rx.next().await;
1036                    (permit, message)
1037                }.fuse();
1038                futures::pin_mut!(next_message);
1039                futures::select_biased! {
1040                    _ = teardown.changed().fuse() => return,
1041                    result = handle_io => {
1042                        if let Err(error) = result {
1043                            tracing::error!(?error, "error handling I/O");
1044                        }
1045                        break;
1046                    }
1047                    _ = foreground_message_handlers.next() => {}
1048                    next_message = next_message => {
1049                        let (permit, message) = next_message;
1050                        if let Some(message) = message {
1051                            let type_name = message.payload_type_name();
1052                            // note: we copy all the fields from the parent span so we can query them in the logs.
1053                            // (https://github.com/tokio-rs/tracing/issues/2670).
1054                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1055                                user_id=field::Empty,
1056                                login=field::Empty,
1057                                impersonator=field::Empty,
1058                                dev_server_id=field::Empty
1059                            );
1060                            principal.update_span(&span);
1061                            let span_enter = span.enter();
1062                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1063                                let is_background = message.is_background();
1064                                let handle_message = (handler)(message, session.clone());
1065                                drop(span_enter);
1066
1067                                let handle_message = async move {
1068                                    handle_message.await;
1069                                    drop(permit);
1070                                }.instrument(span);
1071                                if is_background {
1072                                    executor.spawn_detached(handle_message);
1073                                } else {
1074                                    foreground_message_handlers.push(handle_message);
1075                                }
1076                            } else {
1077                                tracing::error!("no message handler");
1078                            }
1079                        } else {
1080                            tracing::info!("connection closed");
1081                            break;
1082                        }
1083                    }
1084                }
1085            }
1086
1087            drop(foreground_message_handlers);
1088            tracing::info!("signing out");
1089            if let Err(error) = connection_lost(session, teardown, executor).await {
1090                tracing::error!(?error, "error signing out");
1091            }
1092
1093        }.instrument(span)
1094    }
1095
1096    async fn send_initial_client_update(
1097        &self,
1098        connection_id: ConnectionId,
1099        principal: &Principal,
1100        zed_version: ZedVersion,
1101        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1102        session: &Session,
1103    ) -> Result<()> {
1104        self.peer.send(
1105            connection_id,
1106            proto::Hello {
1107                peer_id: Some(connection_id.into()),
1108            },
1109        )?;
1110        tracing::info!("sent hello message");
1111        if let Some(send_connection_id) = send_connection_id.take() {
1112            let _ = send_connection_id.send(connection_id);
1113        }
1114
1115        match principal {
1116            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1117                if !user.connected_once {
1118                    self.peer.send(connection_id, proto::ShowContacts {})?;
1119                    self.app_state
1120                        .db
1121                        .set_user_connected_once(user.id, true)
1122                        .await?;
1123                }
1124
1125                let (contacts, dev_server_projects) = future::try_join(
1126                    self.app_state.db.get_contacts(user.id),
1127                    self.app_state.db.dev_server_projects_update(user.id),
1128                )
1129                .await?;
1130
1131                {
1132                    let mut pool = self.connection_pool.lock();
1133                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1134                    self.peer.send(
1135                        connection_id,
1136                        build_initial_contacts_update(contacts, &pool),
1137                    )?;
1138                }
1139
1140                if should_auto_subscribe_to_channels(zed_version) {
1141                    subscribe_user_to_channels(user.id, session).await?;
1142                }
1143
1144                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1145
1146                if let Some(incoming_call) =
1147                    self.app_state.db.incoming_call_for_user(user.id).await?
1148                {
1149                    self.peer.send(connection_id, incoming_call)?;
1150                }
1151
1152                update_user_contacts(user.id, &session).await?;
1153            }
1154            Principal::DevServer(dev_server) => {
1155                {
1156                    let mut pool = self.connection_pool.lock();
1157                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1158                    {
1159                        self.peer.send(
1160                            stale_connection_id,
1161                            proto::ShutdownDevServer {
1162                                reason: Some(
1163                                    "another dev server connected with the same token".to_string(),
1164                                ),
1165                            },
1166                        )?;
1167                        pool.remove_connection(stale_connection_id)?;
1168                    };
1169                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1170                }
1171
1172                let projects = self
1173                    .app_state
1174                    .db
1175                    .get_projects_for_dev_server(dev_server.id)
1176                    .await?;
1177                self.peer
1178                    .send(connection_id, proto::DevServerInstructions { projects })?;
1179
1180                let status = self
1181                    .app_state
1182                    .db
1183                    .dev_server_projects_update(dev_server.user_id)
1184                    .await?;
1185                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1186            }
1187        }
1188
1189        Ok(())
1190    }
1191
1192    pub async fn invite_code_redeemed(
1193        self: &Arc<Self>,
1194        inviter_id: UserId,
1195        invitee_id: UserId,
1196    ) -> Result<()> {
1197        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1198            if let Some(code) = &user.invite_code {
1199                let pool = self.connection_pool.lock();
1200                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1201                for connection_id in pool.user_connection_ids(inviter_id) {
1202                    self.peer.send(
1203                        connection_id,
1204                        proto::UpdateContacts {
1205                            contacts: vec![invitee_contact.clone()],
1206                            ..Default::default()
1207                        },
1208                    )?;
1209                    self.peer.send(
1210                        connection_id,
1211                        proto::UpdateInviteInfo {
1212                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1213                            count: user.invite_count as u32,
1214                        },
1215                    )?;
1216                }
1217            }
1218        }
1219        Ok(())
1220    }
1221
1222    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1223        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1224            if let Some(invite_code) = &user.invite_code {
1225                let pool = self.connection_pool.lock();
1226                for connection_id in pool.user_connection_ids(user_id) {
1227                    self.peer.send(
1228                        connection_id,
1229                        proto::UpdateInviteInfo {
1230                            url: format!(
1231                                "{}{}",
1232                                self.app_state.config.invite_link_prefix, invite_code
1233                            ),
1234                            count: user.invite_count as u32,
1235                        },
1236                    )?;
1237                }
1238            }
1239        }
1240        Ok(())
1241    }
1242
1243    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1244        ServerSnapshot {
1245            connection_pool: ConnectionPoolGuard {
1246                guard: self.connection_pool.lock(),
1247                _not_send: PhantomData,
1248            },
1249            peer: &self.peer,
1250        }
1251    }
1252}
1253
1254impl<'a> Deref for ConnectionPoolGuard<'a> {
1255    type Target = ConnectionPool;
1256
1257    fn deref(&self) -> &Self::Target {
1258        &self.guard
1259    }
1260}
1261
1262impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1263    fn deref_mut(&mut self) -> &mut Self::Target {
1264        &mut self.guard
1265    }
1266}
1267
1268impl<'a> Drop for ConnectionPoolGuard<'a> {
1269    fn drop(&mut self) {
1270        #[cfg(test)]
1271        self.check_invariants();
1272    }
1273}
1274
1275fn broadcast<F>(
1276    sender_id: Option<ConnectionId>,
1277    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1278    mut f: F,
1279) where
1280    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1281{
1282    for receiver_id in receiver_ids {
1283        if Some(receiver_id) != sender_id {
1284            if let Err(error) = f(receiver_id) {
1285                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1286            }
1287        }
1288    }
1289}
1290
1291pub struct ProtocolVersion(u32);
1292
1293impl Header for ProtocolVersion {
1294    fn name() -> &'static HeaderName {
1295        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1296        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1297    }
1298
1299    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1300    where
1301        Self: Sized,
1302        I: Iterator<Item = &'i axum::http::HeaderValue>,
1303    {
1304        let version = values
1305            .next()
1306            .ok_or_else(axum::headers::Error::invalid)?
1307            .to_str()
1308            .map_err(|_| axum::headers::Error::invalid())?
1309            .parse()
1310            .map_err(|_| axum::headers::Error::invalid())?;
1311        Ok(Self(version))
1312    }
1313
1314    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1315        values.extend([self.0.to_string().parse().unwrap()]);
1316    }
1317}
1318
1319pub struct AppVersionHeader(SemanticVersion);
1320impl Header for AppVersionHeader {
1321    fn name() -> &'static HeaderName {
1322        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1323        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1324    }
1325
1326    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1327    where
1328        Self: Sized,
1329        I: Iterator<Item = &'i axum::http::HeaderValue>,
1330    {
1331        let version = values
1332            .next()
1333            .ok_or_else(axum::headers::Error::invalid)?
1334            .to_str()
1335            .map_err(|_| axum::headers::Error::invalid())?
1336            .parse()
1337            .map_err(|_| axum::headers::Error::invalid())?;
1338        Ok(Self(version))
1339    }
1340
1341    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1342        values.extend([self.0.to_string().parse().unwrap()]);
1343    }
1344}
1345
1346pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1347    Router::new()
1348        .route("/rpc", get(handle_websocket_request))
1349        .layer(
1350            ServiceBuilder::new()
1351                .layer(Extension(server.app_state.clone()))
1352                .layer(middleware::from_fn(auth::validate_header)),
1353        )
1354        .route("/metrics", get(handle_metrics))
1355        .layer(Extension(server))
1356}
1357
1358pub async fn handle_websocket_request(
1359    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1360    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1361    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1362    Extension(server): Extension<Arc<Server>>,
1363    Extension(principal): Extension<Principal>,
1364    ws: WebSocketUpgrade,
1365) -> axum::response::Response {
1366    if protocol_version != rpc::PROTOCOL_VERSION {
1367        return (
1368            StatusCode::UPGRADE_REQUIRED,
1369            "client must be upgraded".to_string(),
1370        )
1371            .into_response();
1372    }
1373
1374    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1375        return (
1376            StatusCode::UPGRADE_REQUIRED,
1377            "no version header found".to_string(),
1378        )
1379            .into_response();
1380    };
1381
1382    if !version.can_collaborate() {
1383        return (
1384            StatusCode::UPGRADE_REQUIRED,
1385            "client must be upgraded".to_string(),
1386        )
1387            .into_response();
1388    }
1389
1390    let socket_address = socket_address.to_string();
1391    ws.on_upgrade(move |socket| {
1392        let socket = socket
1393            .map_ok(to_tungstenite_message)
1394            .err_into()
1395            .with(|message| async move { Ok(to_axum_message(message)) });
1396        let connection = Connection::new(Box::pin(socket));
1397        async move {
1398            server
1399                .handle_connection(
1400                    connection,
1401                    socket_address,
1402                    principal,
1403                    version,
1404                    None,
1405                    Executor::Production,
1406                )
1407                .await;
1408        }
1409    })
1410}
1411
1412pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1413    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1414    let connections_metric = CONNECTIONS_METRIC
1415        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1416
1417    let connections = server
1418        .connection_pool
1419        .lock()
1420        .connections()
1421        .filter(|connection| !connection.admin)
1422        .count();
1423    connections_metric.set(connections as _);
1424
1425    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1426    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1427        register_int_gauge!(
1428            "shared_projects",
1429            "number of open projects with one or more guests"
1430        )
1431        .unwrap()
1432    });
1433
1434    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1435    shared_projects_metric.set(shared_projects as _);
1436
1437    let encoder = prometheus::TextEncoder::new();
1438    let metric_families = prometheus::gather();
1439    let encoded_metrics = encoder
1440        .encode_to_string(&metric_families)
1441        .map_err(|err| anyhow!("{}", err))?;
1442    Ok(encoded_metrics)
1443}
1444
1445#[instrument(err, skip(executor))]
1446async fn connection_lost(
1447    session: Session,
1448    mut teardown: watch::Receiver<bool>,
1449    executor: Executor,
1450) -> Result<()> {
1451    session.peer.disconnect(session.connection_id);
1452    session
1453        .connection_pool()
1454        .await
1455        .remove_connection(session.connection_id)?;
1456
1457    session
1458        .db()
1459        .await
1460        .connection_lost(session.connection_id)
1461        .await
1462        .trace_err();
1463
1464    futures::select_biased! {
1465        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1466            match &session.principal {
1467                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1468                    let session = session.for_user().unwrap();
1469
1470                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1471                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1472                    leave_channel_buffers_for_session(&session)
1473                        .await
1474                        .trace_err();
1475
1476                    if !session
1477                        .connection_pool()
1478                        .await
1479                        .is_user_online(session.user_id())
1480                    {
1481                        let db = session.db().await;
1482                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1483                            room_updated(&room, &session.peer);
1484                        }
1485                    }
1486
1487                    update_user_contacts(session.user_id(), &session).await?;
1488                },
1489            Principal::DevServer(_) => {
1490                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1491            },
1492        }
1493        },
1494        _ = teardown.changed().fuse() => {}
1495    }
1496
1497    Ok(())
1498}
1499
1500/// Acknowledges a ping from a client, used to keep the connection alive.
1501async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1502    response.send(proto::Ack {})?;
1503    Ok(())
1504}
1505
1506/// Creates a new room for calling (outside of channels)
1507async fn create_room(
1508    _request: proto::CreateRoom,
1509    response: Response<proto::CreateRoom>,
1510    session: UserSession,
1511) -> Result<()> {
1512    let live_kit_room = nanoid::nanoid!(30);
1513
1514    let live_kit_connection_info = util::maybe!(async {
1515        let live_kit = session.live_kit_client.as_ref();
1516        let live_kit = live_kit?;
1517        let user_id = session.user_id().to_string();
1518
1519        let token = live_kit
1520            .room_token(&live_kit_room, &user_id.to_string())
1521            .trace_err()?;
1522
1523        Some(proto::LiveKitConnectionInfo {
1524            server_url: live_kit.url().into(),
1525            token,
1526            can_publish: true,
1527        })
1528    })
1529    .await;
1530
1531    let room = session
1532        .db()
1533        .await
1534        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1535        .await?;
1536
1537    response.send(proto::CreateRoomResponse {
1538        room: Some(room.clone()),
1539        live_kit_connection_info,
1540    })?;
1541
1542    update_user_contacts(session.user_id(), &session).await?;
1543    Ok(())
1544}
1545
1546/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1547async fn join_room(
1548    request: proto::JoinRoom,
1549    response: Response<proto::JoinRoom>,
1550    session: UserSession,
1551) -> Result<()> {
1552    let room_id = RoomId::from_proto(request.id);
1553
1554    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1555
1556    if let Some(channel_id) = channel_id {
1557        return join_channel_internal(channel_id, Box::new(response), session).await;
1558    }
1559
1560    let joined_room = {
1561        let room = session
1562            .db()
1563            .await
1564            .join_room(room_id, session.user_id(), session.connection_id)
1565            .await?;
1566        room_updated(&room.room, &session.peer);
1567        room.into_inner()
1568    };
1569
1570    for connection_id in session
1571        .connection_pool()
1572        .await
1573        .user_connection_ids(session.user_id())
1574    {
1575        session
1576            .peer
1577            .send(
1578                connection_id,
1579                proto::CallCanceled {
1580                    room_id: room_id.to_proto(),
1581                },
1582            )
1583            .trace_err();
1584    }
1585
1586    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1587        if let Some(token) = live_kit
1588            .room_token(
1589                &joined_room.room.live_kit_room,
1590                &session.user_id().to_string(),
1591            )
1592            .trace_err()
1593        {
1594            Some(proto::LiveKitConnectionInfo {
1595                server_url: live_kit.url().into(),
1596                token,
1597                can_publish: true,
1598            })
1599        } else {
1600            None
1601        }
1602    } else {
1603        None
1604    };
1605
1606    response.send(proto::JoinRoomResponse {
1607        room: Some(joined_room.room),
1608        channel_id: None,
1609        live_kit_connection_info,
1610    })?;
1611
1612    update_user_contacts(session.user_id(), &session).await?;
1613    Ok(())
1614}
1615
1616/// Rejoin room is used to reconnect to a room after connection errors.
1617async fn rejoin_room(
1618    request: proto::RejoinRoom,
1619    response: Response<proto::RejoinRoom>,
1620    session: UserSession,
1621) -> Result<()> {
1622    let room;
1623    let channel;
1624    {
1625        let mut rejoined_room = session
1626            .db()
1627            .await
1628            .rejoin_room(request, session.user_id(), session.connection_id)
1629            .await?;
1630
1631        response.send(proto::RejoinRoomResponse {
1632            room: Some(rejoined_room.room.clone()),
1633            reshared_projects: rejoined_room
1634                .reshared_projects
1635                .iter()
1636                .map(|project| proto::ResharedProject {
1637                    id: project.id.to_proto(),
1638                    collaborators: project
1639                        .collaborators
1640                        .iter()
1641                        .map(|collaborator| collaborator.to_proto())
1642                        .collect(),
1643                })
1644                .collect(),
1645            rejoined_projects: rejoined_room
1646                .rejoined_projects
1647                .iter()
1648                .map(|rejoined_project| rejoined_project.to_proto())
1649                .collect(),
1650        })?;
1651        room_updated(&rejoined_room.room, &session.peer);
1652
1653        for project in &rejoined_room.reshared_projects {
1654            for collaborator in &project.collaborators {
1655                session
1656                    .peer
1657                    .send(
1658                        collaborator.connection_id,
1659                        proto::UpdateProjectCollaborator {
1660                            project_id: project.id.to_proto(),
1661                            old_peer_id: Some(project.old_connection_id.into()),
1662                            new_peer_id: Some(session.connection_id.into()),
1663                        },
1664                    )
1665                    .trace_err();
1666            }
1667
1668            broadcast(
1669                Some(session.connection_id),
1670                project
1671                    .collaborators
1672                    .iter()
1673                    .map(|collaborator| collaborator.connection_id),
1674                |connection_id| {
1675                    session.peer.forward_send(
1676                        session.connection_id,
1677                        connection_id,
1678                        proto::UpdateProject {
1679                            project_id: project.id.to_proto(),
1680                            worktrees: project.worktrees.clone(),
1681                        },
1682                    )
1683                },
1684            );
1685        }
1686
1687        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1688
1689        let rejoined_room = rejoined_room.into_inner();
1690
1691        room = rejoined_room.room;
1692        channel = rejoined_room.channel;
1693    }
1694
1695    if let Some(channel) = channel {
1696        channel_updated(
1697            &channel,
1698            &room,
1699            &session.peer,
1700            &*session.connection_pool().await,
1701        );
1702    }
1703
1704    update_user_contacts(session.user_id(), &session).await?;
1705    Ok(())
1706}
1707
1708fn notify_rejoined_projects(
1709    rejoined_projects: &mut Vec<RejoinedProject>,
1710    session: &UserSession,
1711) -> Result<()> {
1712    for project in rejoined_projects.iter() {
1713        for collaborator in &project.collaborators {
1714            session
1715                .peer
1716                .send(
1717                    collaborator.connection_id,
1718                    proto::UpdateProjectCollaborator {
1719                        project_id: project.id.to_proto(),
1720                        old_peer_id: Some(project.old_connection_id.into()),
1721                        new_peer_id: Some(session.connection_id.into()),
1722                    },
1723                )
1724                .trace_err();
1725        }
1726    }
1727
1728    for project in rejoined_projects {
1729        for worktree in mem::take(&mut project.worktrees) {
1730            #[cfg(any(test, feature = "test-support"))]
1731            const MAX_CHUNK_SIZE: usize = 2;
1732            #[cfg(not(any(test, feature = "test-support")))]
1733            const MAX_CHUNK_SIZE: usize = 256;
1734
1735            // Stream this worktree's entries.
1736            let message = proto::UpdateWorktree {
1737                project_id: project.id.to_proto(),
1738                worktree_id: worktree.id,
1739                abs_path: worktree.abs_path.clone(),
1740                root_name: worktree.root_name,
1741                updated_entries: worktree.updated_entries,
1742                removed_entries: worktree.removed_entries,
1743                scan_id: worktree.scan_id,
1744                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1745                updated_repositories: worktree.updated_repositories,
1746                removed_repositories: worktree.removed_repositories,
1747            };
1748            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1749                session.peer.send(session.connection_id, update.clone())?;
1750            }
1751
1752            // Stream this worktree's diagnostics.
1753            for summary in worktree.diagnostic_summaries {
1754                session.peer.send(
1755                    session.connection_id,
1756                    proto::UpdateDiagnosticSummary {
1757                        project_id: project.id.to_proto(),
1758                        worktree_id: worktree.id,
1759                        summary: Some(summary),
1760                    },
1761                )?;
1762            }
1763
1764            for settings_file in worktree.settings_files {
1765                session.peer.send(
1766                    session.connection_id,
1767                    proto::UpdateWorktreeSettings {
1768                        project_id: project.id.to_proto(),
1769                        worktree_id: worktree.id,
1770                        path: settings_file.path,
1771                        content: Some(settings_file.content),
1772                    },
1773                )?;
1774            }
1775        }
1776
1777        for language_server in &project.language_servers {
1778            session.peer.send(
1779                session.connection_id,
1780                proto::UpdateLanguageServer {
1781                    project_id: project.id.to_proto(),
1782                    language_server_id: language_server.id,
1783                    variant: Some(
1784                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1785                            proto::LspDiskBasedDiagnosticsUpdated {},
1786                        ),
1787                    ),
1788                },
1789            )?;
1790        }
1791    }
1792    Ok(())
1793}
1794
1795/// leave room disconnects from the room.
1796async fn leave_room(
1797    _: proto::LeaveRoom,
1798    response: Response<proto::LeaveRoom>,
1799    session: UserSession,
1800) -> Result<()> {
1801    leave_room_for_session(&session, session.connection_id).await?;
1802    response.send(proto::Ack {})?;
1803    Ok(())
1804}
1805
1806/// Updates the permissions of someone else in the room.
1807async fn set_room_participant_role(
1808    request: proto::SetRoomParticipantRole,
1809    response: Response<proto::SetRoomParticipantRole>,
1810    session: UserSession,
1811) -> Result<()> {
1812    let user_id = UserId::from_proto(request.user_id);
1813    let role = ChannelRole::from(request.role());
1814
1815    let (live_kit_room, can_publish) = {
1816        let room = session
1817            .db()
1818            .await
1819            .set_room_participant_role(
1820                session.user_id(),
1821                RoomId::from_proto(request.room_id),
1822                user_id,
1823                role,
1824            )
1825            .await?;
1826
1827        let live_kit_room = room.live_kit_room.clone();
1828        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1829        room_updated(&room, &session.peer);
1830        (live_kit_room, can_publish)
1831    };
1832
1833    if let Some(live_kit) = session.live_kit_client.as_ref() {
1834        live_kit
1835            .update_participant(
1836                live_kit_room.clone(),
1837                request.user_id.to_string(),
1838                live_kit_server::proto::ParticipantPermission {
1839                    can_subscribe: true,
1840                    can_publish,
1841                    can_publish_data: can_publish,
1842                    hidden: false,
1843                    recorder: false,
1844                },
1845            )
1846            .await
1847            .trace_err();
1848    }
1849
1850    response.send(proto::Ack {})?;
1851    Ok(())
1852}
1853
1854/// Call someone else into the current room
1855async fn call(
1856    request: proto::Call,
1857    response: Response<proto::Call>,
1858    session: UserSession,
1859) -> Result<()> {
1860    let room_id = RoomId::from_proto(request.room_id);
1861    let calling_user_id = session.user_id();
1862    let calling_connection_id = session.connection_id;
1863    let called_user_id = UserId::from_proto(request.called_user_id);
1864    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1865    if !session
1866        .db()
1867        .await
1868        .has_contact(calling_user_id, called_user_id)
1869        .await?
1870    {
1871        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1872    }
1873
1874    let incoming_call = {
1875        let (room, incoming_call) = &mut *session
1876            .db()
1877            .await
1878            .call(
1879                room_id,
1880                calling_user_id,
1881                calling_connection_id,
1882                called_user_id,
1883                initial_project_id,
1884            )
1885            .await?;
1886        room_updated(&room, &session.peer);
1887        mem::take(incoming_call)
1888    };
1889    update_user_contacts(called_user_id, &session).await?;
1890
1891    let mut calls = session
1892        .connection_pool()
1893        .await
1894        .user_connection_ids(called_user_id)
1895        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1896        .collect::<FuturesUnordered<_>>();
1897
1898    while let Some(call_response) = calls.next().await {
1899        match call_response.as_ref() {
1900            Ok(_) => {
1901                response.send(proto::Ack {})?;
1902                return Ok(());
1903            }
1904            Err(_) => {
1905                call_response.trace_err();
1906            }
1907        }
1908    }
1909
1910    {
1911        let room = session
1912            .db()
1913            .await
1914            .call_failed(room_id, called_user_id)
1915            .await?;
1916        room_updated(&room, &session.peer);
1917    }
1918    update_user_contacts(called_user_id, &session).await?;
1919
1920    Err(anyhow!("failed to ring user"))?
1921}
1922
1923/// Cancel an outgoing call.
1924async fn cancel_call(
1925    request: proto::CancelCall,
1926    response: Response<proto::CancelCall>,
1927    session: UserSession,
1928) -> Result<()> {
1929    let called_user_id = UserId::from_proto(request.called_user_id);
1930    let room_id = RoomId::from_proto(request.room_id);
1931    {
1932        let room = session
1933            .db()
1934            .await
1935            .cancel_call(room_id, session.connection_id, called_user_id)
1936            .await?;
1937        room_updated(&room, &session.peer);
1938    }
1939
1940    for connection_id in session
1941        .connection_pool()
1942        .await
1943        .user_connection_ids(called_user_id)
1944    {
1945        session
1946            .peer
1947            .send(
1948                connection_id,
1949                proto::CallCanceled {
1950                    room_id: room_id.to_proto(),
1951                },
1952            )
1953            .trace_err();
1954    }
1955    response.send(proto::Ack {})?;
1956
1957    update_user_contacts(called_user_id, &session).await?;
1958    Ok(())
1959}
1960
1961/// Decline an incoming call.
1962async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1963    let room_id = RoomId::from_proto(message.room_id);
1964    {
1965        let room = session
1966            .db()
1967            .await
1968            .decline_call(Some(room_id), session.user_id())
1969            .await?
1970            .ok_or_else(|| anyhow!("failed to decline call"))?;
1971        room_updated(&room, &session.peer);
1972    }
1973
1974    for connection_id in session
1975        .connection_pool()
1976        .await
1977        .user_connection_ids(session.user_id())
1978    {
1979        session
1980            .peer
1981            .send(
1982                connection_id,
1983                proto::CallCanceled {
1984                    room_id: room_id.to_proto(),
1985                },
1986            )
1987            .trace_err();
1988    }
1989    update_user_contacts(session.user_id(), &session).await?;
1990    Ok(())
1991}
1992
1993/// Updates other participants in the room with your current location.
1994async fn update_participant_location(
1995    request: proto::UpdateParticipantLocation,
1996    response: Response<proto::UpdateParticipantLocation>,
1997    session: UserSession,
1998) -> Result<()> {
1999    let room_id = RoomId::from_proto(request.room_id);
2000    let location = request
2001        .location
2002        .ok_or_else(|| anyhow!("invalid location"))?;
2003
2004    let db = session.db().await;
2005    let room = db
2006        .update_room_participant_location(room_id, session.connection_id, location)
2007        .await?;
2008
2009    room_updated(&room, &session.peer);
2010    response.send(proto::Ack {})?;
2011    Ok(())
2012}
2013
2014/// Share a project into the room.
2015async fn share_project(
2016    request: proto::ShareProject,
2017    response: Response<proto::ShareProject>,
2018    session: UserSession,
2019) -> Result<()> {
2020    let (project_id, room) = &*session
2021        .db()
2022        .await
2023        .share_project(
2024            RoomId::from_proto(request.room_id),
2025            session.connection_id,
2026            &request.worktrees,
2027            request
2028                .dev_server_project_id
2029                .map(|id| DevServerProjectId::from_proto(id)),
2030        )
2031        .await?;
2032    response.send(proto::ShareProjectResponse {
2033        project_id: project_id.to_proto(),
2034    })?;
2035    room_updated(&room, &session.peer);
2036
2037    Ok(())
2038}
2039
2040/// Unshare a project from the room.
2041async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2042    let project_id = ProjectId::from_proto(message.project_id);
2043    unshare_project_internal(
2044        project_id,
2045        session.connection_id,
2046        session.user_id(),
2047        &session,
2048    )
2049    .await
2050}
2051
2052async fn unshare_project_internal(
2053    project_id: ProjectId,
2054    connection_id: ConnectionId,
2055    user_id: Option<UserId>,
2056    session: &Session,
2057) -> Result<()> {
2058    let delete = {
2059        let room_guard = session
2060            .db()
2061            .await
2062            .unshare_project(project_id, connection_id, user_id)
2063            .await?;
2064
2065        let (delete, room, guest_connection_ids) = &*room_guard;
2066
2067        let message = proto::UnshareProject {
2068            project_id: project_id.to_proto(),
2069        };
2070
2071        broadcast(
2072            Some(connection_id),
2073            guest_connection_ids.iter().copied(),
2074            |conn_id| session.peer.send(conn_id, message.clone()),
2075        );
2076        if let Some(room) = room {
2077            room_updated(room, &session.peer);
2078        }
2079
2080        *delete
2081    };
2082
2083    if delete {
2084        let db = session.db().await;
2085        db.delete_project(project_id).await?;
2086    }
2087
2088    Ok(())
2089}
2090
2091/// DevServer makes a project available online
2092async fn share_dev_server_project(
2093    request: proto::ShareDevServerProject,
2094    response: Response<proto::ShareDevServerProject>,
2095    session: DevServerSession,
2096) -> Result<()> {
2097    let (dev_server_project, user_id, status) = session
2098        .db()
2099        .await
2100        .share_dev_server_project(
2101            DevServerProjectId::from_proto(request.dev_server_project_id),
2102            session.dev_server_id(),
2103            session.connection_id,
2104            &request.worktrees,
2105        )
2106        .await?;
2107    let Some(project_id) = dev_server_project.project_id else {
2108        return Err(anyhow!("failed to share remote project"))?;
2109    };
2110
2111    send_dev_server_projects_update(user_id, status, &session).await;
2112
2113    response.send(proto::ShareProjectResponse { project_id })?;
2114
2115    Ok(())
2116}
2117
2118/// Join someone elses shared project.
2119async fn join_project(
2120    request: proto::JoinProject,
2121    response: Response<proto::JoinProject>,
2122    session: UserSession,
2123) -> Result<()> {
2124    let project_id = ProjectId::from_proto(request.project_id);
2125
2126    tracing::info!(%project_id, "join project");
2127
2128    let db = session.db().await;
2129    let (project, replica_id) = &mut *db
2130        .join_project(project_id, session.connection_id, session.user_id())
2131        .await?;
2132    drop(db);
2133    tracing::info!(%project_id, "join remote project");
2134    join_project_internal(response, session, project, replica_id)
2135}
2136
2137trait JoinProjectInternalResponse {
2138    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2139}
2140impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2141    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2142        Response::<proto::JoinProject>::send(self, result)
2143    }
2144}
2145impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2146    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2147        Response::<proto::JoinHostedProject>::send(self, result)
2148    }
2149}
2150
2151fn join_project_internal(
2152    response: impl JoinProjectInternalResponse,
2153    session: UserSession,
2154    project: &mut Project,
2155    replica_id: &ReplicaId,
2156) -> Result<()> {
2157    let collaborators = project
2158        .collaborators
2159        .iter()
2160        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2161        .map(|collaborator| collaborator.to_proto())
2162        .collect::<Vec<_>>();
2163    let project_id = project.id;
2164    let guest_user_id = session.user_id();
2165
2166    let worktrees = project
2167        .worktrees
2168        .iter()
2169        .map(|(id, worktree)| proto::WorktreeMetadata {
2170            id: *id,
2171            root_name: worktree.root_name.clone(),
2172            visible: worktree.visible,
2173            abs_path: worktree.abs_path.clone(),
2174        })
2175        .collect::<Vec<_>>();
2176
2177    let add_project_collaborator = proto::AddProjectCollaborator {
2178        project_id: project_id.to_proto(),
2179        collaborator: Some(proto::Collaborator {
2180            peer_id: Some(session.connection_id.into()),
2181            replica_id: replica_id.0 as u32,
2182            user_id: guest_user_id.to_proto(),
2183        }),
2184    };
2185
2186    for collaborator in &collaborators {
2187        session
2188            .peer
2189            .send(
2190                collaborator.peer_id.unwrap().into(),
2191                add_project_collaborator.clone(),
2192            )
2193            .trace_err();
2194    }
2195
2196    // First, we send the metadata associated with each worktree.
2197    response.send(proto::JoinProjectResponse {
2198        project_id: project.id.0 as u64,
2199        worktrees: worktrees.clone(),
2200        replica_id: replica_id.0 as u32,
2201        collaborators: collaborators.clone(),
2202        language_servers: project.language_servers.clone(),
2203        role: project.role.into(),
2204        dev_server_project_id: project
2205            .dev_server_project_id
2206            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2207    })?;
2208
2209    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2210        #[cfg(any(test, feature = "test-support"))]
2211        const MAX_CHUNK_SIZE: usize = 2;
2212        #[cfg(not(any(test, feature = "test-support")))]
2213        const MAX_CHUNK_SIZE: usize = 256;
2214
2215        // Stream this worktree's entries.
2216        let message = proto::UpdateWorktree {
2217            project_id: project_id.to_proto(),
2218            worktree_id,
2219            abs_path: worktree.abs_path.clone(),
2220            root_name: worktree.root_name,
2221            updated_entries: worktree.entries,
2222            removed_entries: Default::default(),
2223            scan_id: worktree.scan_id,
2224            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2225            updated_repositories: worktree.repository_entries.into_values().collect(),
2226            removed_repositories: Default::default(),
2227        };
2228        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2229            session.peer.send(session.connection_id, update.clone())?;
2230        }
2231
2232        // Stream this worktree's diagnostics.
2233        for summary in worktree.diagnostic_summaries {
2234            session.peer.send(
2235                session.connection_id,
2236                proto::UpdateDiagnosticSummary {
2237                    project_id: project_id.to_proto(),
2238                    worktree_id: worktree.id,
2239                    summary: Some(summary),
2240                },
2241            )?;
2242        }
2243
2244        for settings_file in worktree.settings_files {
2245            session.peer.send(
2246                session.connection_id,
2247                proto::UpdateWorktreeSettings {
2248                    project_id: project_id.to_proto(),
2249                    worktree_id: worktree.id,
2250                    path: settings_file.path,
2251                    content: Some(settings_file.content),
2252                },
2253            )?;
2254        }
2255    }
2256
2257    for language_server in &project.language_servers {
2258        session.peer.send(
2259            session.connection_id,
2260            proto::UpdateLanguageServer {
2261                project_id: project_id.to_proto(),
2262                language_server_id: language_server.id,
2263                variant: Some(
2264                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2265                        proto::LspDiskBasedDiagnosticsUpdated {},
2266                    ),
2267                ),
2268            },
2269        )?;
2270    }
2271
2272    Ok(())
2273}
2274
2275/// Leave someone elses shared project.
2276async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2277    let sender_id = session.connection_id;
2278    let project_id = ProjectId::from_proto(request.project_id);
2279    let db = session.db().await;
2280    if db.is_hosted_project(project_id).await? {
2281        let project = db.leave_hosted_project(project_id, sender_id).await?;
2282        project_left(&project, &session);
2283        return Ok(());
2284    }
2285
2286    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2287    tracing::info!(
2288        %project_id,
2289        "leave project"
2290    );
2291
2292    project_left(&project, &session);
2293    if let Some(room) = room {
2294        room_updated(&room, &session.peer);
2295    }
2296
2297    Ok(())
2298}
2299
2300async fn join_hosted_project(
2301    request: proto::JoinHostedProject,
2302    response: Response<proto::JoinHostedProject>,
2303    session: UserSession,
2304) -> Result<()> {
2305    let (mut project, replica_id) = session
2306        .db()
2307        .await
2308        .join_hosted_project(
2309            ProjectId(request.project_id as i32),
2310            session.user_id(),
2311            session.connection_id,
2312        )
2313        .await?;
2314
2315    join_project_internal(response, session, &mut project, &replica_id)
2316}
2317
2318async fn list_remote_directory(
2319    request: proto::ListRemoteDirectory,
2320    response: Response<proto::ListRemoteDirectory>,
2321    session: UserSession,
2322) -> Result<()> {
2323    let dev_server_id = DevServerId(request.dev_server_id as i32);
2324    let dev_server_connection_id = session
2325        .connection_pool()
2326        .await
2327        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2328
2329    session
2330        .db()
2331        .await
2332        .get_dev_server_for_user(dev_server_id, session.user_id())
2333        .await?;
2334
2335    response.send(
2336        session
2337            .peer
2338            .forward_request(session.connection_id, dev_server_connection_id, request)
2339            .await?,
2340    )?;
2341    Ok(())
2342}
2343
2344async fn update_dev_server_project(
2345    request: proto::UpdateDevServerProject,
2346    response: Response<proto::UpdateDevServerProject>,
2347    session: UserSession,
2348) -> Result<()> {
2349    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2350
2351    let (dev_server_project, update) = session
2352        .db()
2353        .await
2354        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2355        .await?;
2356
2357    let projects = session
2358        .db()
2359        .await
2360        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2361        .await?;
2362
2363    let dev_server_connection_id = session
2364        .connection_pool()
2365        .await
2366        .dev_server_connection_id_supporting(
2367            dev_server_project.dev_server_id,
2368            ZedVersion::with_list_directory(),
2369        )?;
2370
2371    session.peer.send(
2372        dev_server_connection_id,
2373        proto::DevServerInstructions { projects },
2374    )?;
2375
2376    send_dev_server_projects_update(session.user_id(), update, &session).await;
2377
2378    response.send(proto::Ack {})
2379}
2380
2381async fn create_dev_server_project(
2382    request: proto::CreateDevServerProject,
2383    response: Response<proto::CreateDevServerProject>,
2384    session: UserSession,
2385) -> Result<()> {
2386    let dev_server_id = DevServerId(request.dev_server_id as i32);
2387    let dev_server_connection_id = session
2388        .connection_pool()
2389        .await
2390        .dev_server_connection_id(dev_server_id);
2391    let Some(dev_server_connection_id) = dev_server_connection_id else {
2392        Err(ErrorCode::DevServerOffline
2393            .message("Cannot create a remote project when the dev server is offline".to_string())
2394            .anyhow())?
2395    };
2396
2397    let path = request.path.clone();
2398    //Check that the path exists on the dev server
2399    session
2400        .peer
2401        .forward_request(
2402            session.connection_id,
2403            dev_server_connection_id,
2404            proto::ValidateDevServerProjectRequest { path: path.clone() },
2405        )
2406        .await?;
2407
2408    let (dev_server_project, update) = session
2409        .db()
2410        .await
2411        .create_dev_server_project(
2412            DevServerId(request.dev_server_id as i32),
2413            &request.path,
2414            session.user_id(),
2415        )
2416        .await?;
2417
2418    let projects = session
2419        .db()
2420        .await
2421        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2422        .await?;
2423
2424    session.peer.send(
2425        dev_server_connection_id,
2426        proto::DevServerInstructions { projects },
2427    )?;
2428
2429    send_dev_server_projects_update(session.user_id(), update, &session).await;
2430
2431    response.send(proto::CreateDevServerProjectResponse {
2432        dev_server_project: Some(dev_server_project.to_proto(None)),
2433    })?;
2434    Ok(())
2435}
2436
2437async fn create_dev_server(
2438    request: proto::CreateDevServer,
2439    response: Response<proto::CreateDevServer>,
2440    session: UserSession,
2441) -> Result<()> {
2442    let access_token = auth::random_token();
2443    let hashed_access_token = auth::hash_access_token(&access_token);
2444
2445    if request.name.is_empty() {
2446        return Err(proto::ErrorCode::Forbidden
2447            .message("Dev server name cannot be empty".to_string())
2448            .anyhow())?;
2449    }
2450
2451    let (dev_server, status) = session
2452        .db()
2453        .await
2454        .create_dev_server(
2455            &request.name,
2456            request.ssh_connection_string.as_deref(),
2457            &hashed_access_token,
2458            session.user_id(),
2459        )
2460        .await?;
2461
2462    send_dev_server_projects_update(session.user_id(), status, &session).await;
2463
2464    response.send(proto::CreateDevServerResponse {
2465        dev_server_id: dev_server.id.0 as u64,
2466        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2467        name: request.name,
2468    })?;
2469    Ok(())
2470}
2471
2472async fn regenerate_dev_server_token(
2473    request: proto::RegenerateDevServerToken,
2474    response: Response<proto::RegenerateDevServerToken>,
2475    session: UserSession,
2476) -> Result<()> {
2477    let dev_server_id = DevServerId(request.dev_server_id as i32);
2478    let access_token = auth::random_token();
2479    let hashed_access_token = auth::hash_access_token(&access_token);
2480
2481    let connection_id = session
2482        .connection_pool()
2483        .await
2484        .dev_server_connection_id(dev_server_id);
2485    if let Some(connection_id) = connection_id {
2486        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2487        session.peer.send(
2488            connection_id,
2489            proto::ShutdownDevServer {
2490                reason: Some("dev server token was regenerated".to_string()),
2491            },
2492        )?;
2493        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2494    }
2495
2496    let status = session
2497        .db()
2498        .await
2499        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2500        .await?;
2501
2502    send_dev_server_projects_update(session.user_id(), status, &session).await;
2503
2504    response.send(proto::RegenerateDevServerTokenResponse {
2505        dev_server_id: dev_server_id.to_proto(),
2506        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2507    })?;
2508    Ok(())
2509}
2510
2511async fn rename_dev_server(
2512    request: proto::RenameDevServer,
2513    response: Response<proto::RenameDevServer>,
2514    session: UserSession,
2515) -> Result<()> {
2516    if request.name.trim().is_empty() {
2517        return Err(proto::ErrorCode::Forbidden
2518            .message("Dev server name cannot be empty".to_string())
2519            .anyhow())?;
2520    }
2521
2522    let dev_server_id = DevServerId(request.dev_server_id as i32);
2523    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2524    if dev_server.user_id != session.user_id() {
2525        return Err(anyhow!(ErrorCode::Forbidden))?;
2526    }
2527
2528    let status = session
2529        .db()
2530        .await
2531        .rename_dev_server(
2532            dev_server_id,
2533            &request.name,
2534            request.ssh_connection_string.as_deref(),
2535            session.user_id(),
2536        )
2537        .await?;
2538
2539    send_dev_server_projects_update(session.user_id(), status, &session).await;
2540
2541    response.send(proto::Ack {})?;
2542    Ok(())
2543}
2544
2545async fn delete_dev_server(
2546    request: proto::DeleteDevServer,
2547    response: Response<proto::DeleteDevServer>,
2548    session: UserSession,
2549) -> Result<()> {
2550    let dev_server_id = DevServerId(request.dev_server_id as i32);
2551    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2552    if dev_server.user_id != session.user_id() {
2553        return Err(anyhow!(ErrorCode::Forbidden))?;
2554    }
2555
2556    let connection_id = session
2557        .connection_pool()
2558        .await
2559        .dev_server_connection_id(dev_server_id);
2560    if let Some(connection_id) = connection_id {
2561        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2562        session.peer.send(
2563            connection_id,
2564            proto::ShutdownDevServer {
2565                reason: Some("dev server was deleted".to_string()),
2566            },
2567        )?;
2568        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2569    }
2570
2571    let status = session
2572        .db()
2573        .await
2574        .delete_dev_server(dev_server_id, session.user_id())
2575        .await?;
2576
2577    send_dev_server_projects_update(session.user_id(), status, &session).await;
2578
2579    response.send(proto::Ack {})?;
2580    Ok(())
2581}
2582
2583async fn delete_dev_server_project(
2584    request: proto::DeleteDevServerProject,
2585    response: Response<proto::DeleteDevServerProject>,
2586    session: UserSession,
2587) -> Result<()> {
2588    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2589    let dev_server_project = session
2590        .db()
2591        .await
2592        .get_dev_server_project(dev_server_project_id)
2593        .await?;
2594
2595    let dev_server = session
2596        .db()
2597        .await
2598        .get_dev_server(dev_server_project.dev_server_id)
2599        .await?;
2600    if dev_server.user_id != session.user_id() {
2601        return Err(anyhow!(ErrorCode::Forbidden))?;
2602    }
2603
2604    let dev_server_connection_id = session
2605        .connection_pool()
2606        .await
2607        .dev_server_connection_id(dev_server.id);
2608
2609    if let Some(dev_server_connection_id) = dev_server_connection_id {
2610        let project = session
2611            .db()
2612            .await
2613            .find_dev_server_project(dev_server_project_id)
2614            .await;
2615        if let Ok(project) = project {
2616            unshare_project_internal(
2617                project.id,
2618                dev_server_connection_id,
2619                Some(session.user_id()),
2620                &session,
2621            )
2622            .await?;
2623        }
2624    }
2625
2626    let (projects, status) = session
2627        .db()
2628        .await
2629        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2630        .await?;
2631
2632    if let Some(dev_server_connection_id) = dev_server_connection_id {
2633        session.peer.send(
2634            dev_server_connection_id,
2635            proto::DevServerInstructions { projects },
2636        )?;
2637    }
2638
2639    send_dev_server_projects_update(session.user_id(), status, &session).await;
2640
2641    response.send(proto::Ack {})?;
2642    Ok(())
2643}
2644
2645async fn rejoin_dev_server_projects(
2646    request: proto::RejoinRemoteProjects,
2647    response: Response<proto::RejoinRemoteProjects>,
2648    session: UserSession,
2649) -> Result<()> {
2650    let mut rejoined_projects = {
2651        let db = session.db().await;
2652        db.rejoin_dev_server_projects(
2653            &request.rejoined_projects,
2654            session.user_id(),
2655            session.0.connection_id,
2656        )
2657        .await?
2658    };
2659    response.send(proto::RejoinRemoteProjectsResponse {
2660        rejoined_projects: rejoined_projects
2661            .iter()
2662            .map(|project| project.to_proto())
2663            .collect(),
2664    })?;
2665    notify_rejoined_projects(&mut rejoined_projects, &session)
2666}
2667
2668async fn reconnect_dev_server(
2669    request: proto::ReconnectDevServer,
2670    response: Response<proto::ReconnectDevServer>,
2671    session: DevServerSession,
2672) -> Result<()> {
2673    let reshared_projects = {
2674        let db = session.db().await;
2675        db.reshare_dev_server_projects(
2676            &request.reshared_projects,
2677            session.dev_server_id(),
2678            session.0.connection_id,
2679        )
2680        .await?
2681    };
2682
2683    for project in &reshared_projects {
2684        for collaborator in &project.collaborators {
2685            session
2686                .peer
2687                .send(
2688                    collaborator.connection_id,
2689                    proto::UpdateProjectCollaborator {
2690                        project_id: project.id.to_proto(),
2691                        old_peer_id: Some(project.old_connection_id.into()),
2692                        new_peer_id: Some(session.connection_id.into()),
2693                    },
2694                )
2695                .trace_err();
2696        }
2697
2698        broadcast(
2699            Some(session.connection_id),
2700            project
2701                .collaborators
2702                .iter()
2703                .map(|collaborator| collaborator.connection_id),
2704            |connection_id| {
2705                session.peer.forward_send(
2706                    session.connection_id,
2707                    connection_id,
2708                    proto::UpdateProject {
2709                        project_id: project.id.to_proto(),
2710                        worktrees: project.worktrees.clone(),
2711                    },
2712                )
2713            },
2714        );
2715    }
2716
2717    response.send(proto::ReconnectDevServerResponse {
2718        reshared_projects: reshared_projects
2719            .iter()
2720            .map(|project| proto::ResharedProject {
2721                id: project.id.to_proto(),
2722                collaborators: project
2723                    .collaborators
2724                    .iter()
2725                    .map(|collaborator| collaborator.to_proto())
2726                    .collect(),
2727            })
2728            .collect(),
2729    })?;
2730
2731    Ok(())
2732}
2733
2734async fn shutdown_dev_server(
2735    _: proto::ShutdownDevServer,
2736    response: Response<proto::ShutdownDevServer>,
2737    session: DevServerSession,
2738) -> Result<()> {
2739    response.send(proto::Ack {})?;
2740    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2741    remove_dev_server_connection(session.dev_server_id(), &session).await
2742}
2743
2744async fn shutdown_dev_server_internal(
2745    dev_server_id: DevServerId,
2746    connection_id: ConnectionId,
2747    session: &Session,
2748) -> Result<()> {
2749    let (dev_server_projects, dev_server) = {
2750        let db = session.db().await;
2751        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2752        let dev_server = db.get_dev_server(dev_server_id).await?;
2753        (dev_server_projects, dev_server)
2754    };
2755
2756    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2757        unshare_project_internal(
2758            ProjectId::from_proto(project_id),
2759            connection_id,
2760            None,
2761            session,
2762        )
2763        .await?;
2764    }
2765
2766    session
2767        .connection_pool()
2768        .await
2769        .set_dev_server_offline(dev_server_id);
2770
2771    let status = session
2772        .db()
2773        .await
2774        .dev_server_projects_update(dev_server.user_id)
2775        .await?;
2776    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2777
2778    Ok(())
2779}
2780
2781async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2782    let dev_server_connection = session
2783        .connection_pool()
2784        .await
2785        .dev_server_connection_id(dev_server_id);
2786
2787    if let Some(dev_server_connection) = dev_server_connection {
2788        session
2789            .connection_pool()
2790            .await
2791            .remove_connection(dev_server_connection)?;
2792    }
2793    Ok(())
2794}
2795
2796/// Updates other participants with changes to the project
2797async fn update_project(
2798    request: proto::UpdateProject,
2799    response: Response<proto::UpdateProject>,
2800    session: Session,
2801) -> Result<()> {
2802    let project_id = ProjectId::from_proto(request.project_id);
2803    let (room, guest_connection_ids) = &*session
2804        .db()
2805        .await
2806        .update_project(project_id, session.connection_id, &request.worktrees)
2807        .await?;
2808    broadcast(
2809        Some(session.connection_id),
2810        guest_connection_ids.iter().copied(),
2811        |connection_id| {
2812            session
2813                .peer
2814                .forward_send(session.connection_id, connection_id, request.clone())
2815        },
2816    );
2817    if let Some(room) = room {
2818        room_updated(&room, &session.peer);
2819    }
2820    response.send(proto::Ack {})?;
2821
2822    Ok(())
2823}
2824
2825/// Updates other participants with changes to the worktree
2826async fn update_worktree(
2827    request: proto::UpdateWorktree,
2828    response: Response<proto::UpdateWorktree>,
2829    session: Session,
2830) -> Result<()> {
2831    let guest_connection_ids = session
2832        .db()
2833        .await
2834        .update_worktree(&request, session.connection_id)
2835        .await?;
2836
2837    broadcast(
2838        Some(session.connection_id),
2839        guest_connection_ids.iter().copied(),
2840        |connection_id| {
2841            session
2842                .peer
2843                .forward_send(session.connection_id, connection_id, request.clone())
2844        },
2845    );
2846    response.send(proto::Ack {})?;
2847    Ok(())
2848}
2849
2850/// Updates other participants with changes to the diagnostics
2851async fn update_diagnostic_summary(
2852    message: proto::UpdateDiagnosticSummary,
2853    session: Session,
2854) -> Result<()> {
2855    let guest_connection_ids = session
2856        .db()
2857        .await
2858        .update_diagnostic_summary(&message, session.connection_id)
2859        .await?;
2860
2861    broadcast(
2862        Some(session.connection_id),
2863        guest_connection_ids.iter().copied(),
2864        |connection_id| {
2865            session
2866                .peer
2867                .forward_send(session.connection_id, connection_id, message.clone())
2868        },
2869    );
2870
2871    Ok(())
2872}
2873
2874/// Updates other participants with changes to the worktree settings
2875async fn update_worktree_settings(
2876    message: proto::UpdateWorktreeSettings,
2877    session: Session,
2878) -> Result<()> {
2879    let guest_connection_ids = session
2880        .db()
2881        .await
2882        .update_worktree_settings(&message, session.connection_id)
2883        .await?;
2884
2885    broadcast(
2886        Some(session.connection_id),
2887        guest_connection_ids.iter().copied(),
2888        |connection_id| {
2889            session
2890                .peer
2891                .forward_send(session.connection_id, connection_id, message.clone())
2892        },
2893    );
2894
2895    Ok(())
2896}
2897
2898/// Notify other participants that a  language server has started.
2899async fn start_language_server(
2900    request: proto::StartLanguageServer,
2901    session: Session,
2902) -> Result<()> {
2903    let guest_connection_ids = session
2904        .db()
2905        .await
2906        .start_language_server(&request, session.connection_id)
2907        .await?;
2908
2909    broadcast(
2910        Some(session.connection_id),
2911        guest_connection_ids.iter().copied(),
2912        |connection_id| {
2913            session
2914                .peer
2915                .forward_send(session.connection_id, connection_id, request.clone())
2916        },
2917    );
2918    Ok(())
2919}
2920
2921/// Notify other participants that a language server has changed.
2922async fn update_language_server(
2923    request: proto::UpdateLanguageServer,
2924    session: Session,
2925) -> Result<()> {
2926    let project_id = ProjectId::from_proto(request.project_id);
2927    let project_connection_ids = session
2928        .db()
2929        .await
2930        .project_connection_ids(project_id, session.connection_id, true)
2931        .await?;
2932    broadcast(
2933        Some(session.connection_id),
2934        project_connection_ids.iter().copied(),
2935        |connection_id| {
2936            session
2937                .peer
2938                .forward_send(session.connection_id, connection_id, request.clone())
2939        },
2940    );
2941    Ok(())
2942}
2943
2944/// forward a project request to the host. These requests should be read only
2945/// as guests are allowed to send them.
2946async fn forward_read_only_project_request<T>(
2947    request: T,
2948    response: Response<T>,
2949    session: UserSession,
2950) -> Result<()>
2951where
2952    T: EntityMessage + RequestMessage,
2953{
2954    let project_id = ProjectId::from_proto(request.remote_entity_id());
2955    let host_connection_id = session
2956        .db()
2957        .await
2958        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2959        .await?;
2960    let payload = session
2961        .peer
2962        .forward_request(session.connection_id, host_connection_id, request)
2963        .await?;
2964    response.send(payload)?;
2965    Ok(())
2966}
2967
2968/// forward a project request to the dev server. Only allowed
2969/// if it's your dev server.
2970async fn forward_project_request_for_owner<T>(
2971    request: T,
2972    response: Response<T>,
2973    session: UserSession,
2974) -> Result<()>
2975where
2976    T: EntityMessage + RequestMessage,
2977{
2978    let project_id = ProjectId::from_proto(request.remote_entity_id());
2979
2980    let host_connection_id = session
2981        .db()
2982        .await
2983        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2984        .await?;
2985    let payload = session
2986        .peer
2987        .forward_request(session.connection_id, host_connection_id, request)
2988        .await?;
2989    response.send(payload)?;
2990    Ok(())
2991}
2992
2993/// forward a project request to the host. These requests are disallowed
2994/// for guests.
2995async fn forward_mutating_project_request<T>(
2996    request: T,
2997    response: Response<T>,
2998    session: UserSession,
2999) -> Result<()>
3000where
3001    T: EntityMessage + RequestMessage,
3002{
3003    let project_id = ProjectId::from_proto(request.remote_entity_id());
3004
3005    let host_connection_id = session
3006        .db()
3007        .await
3008        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3009        .await?;
3010    let payload = session
3011        .peer
3012        .forward_request(session.connection_id, host_connection_id, request)
3013        .await?;
3014    response.send(payload)?;
3015    Ok(())
3016}
3017
3018/// forward a project request to the host. These requests are disallowed
3019/// for guests.
3020async fn forward_versioned_mutating_project_request<T>(
3021    request: T,
3022    response: Response<T>,
3023    session: UserSession,
3024) -> Result<()>
3025where
3026    T: EntityMessage + RequestMessage + VersionedMessage,
3027{
3028    let project_id = ProjectId::from_proto(request.remote_entity_id());
3029
3030    let host_connection_id = session
3031        .db()
3032        .await
3033        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3034        .await?;
3035    if let Some(host_version) = session
3036        .connection_pool()
3037        .await
3038        .connection(host_connection_id)
3039        .map(|c| c.zed_version)
3040    {
3041        if let Some(min_required_version) = request.required_host_version() {
3042            if min_required_version > host_version {
3043                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3044                    .with_tag("required", &min_required_version.to_string())))?;
3045            }
3046        }
3047    }
3048
3049    let payload = session
3050        .peer
3051        .forward_request(session.connection_id, host_connection_id, request)
3052        .await?;
3053    response.send(payload)?;
3054    Ok(())
3055}
3056
3057/// Notify other participants that a new buffer has been created
3058async fn create_buffer_for_peer(
3059    request: proto::CreateBufferForPeer,
3060    session: Session,
3061) -> Result<()> {
3062    session
3063        .db()
3064        .await
3065        .check_user_is_project_host(
3066            ProjectId::from_proto(request.project_id),
3067            session.connection_id,
3068        )
3069        .await?;
3070    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3071    session
3072        .peer
3073        .forward_send(session.connection_id, peer_id.into(), request)?;
3074    Ok(())
3075}
3076
3077/// Notify other participants that a buffer has been updated. This is
3078/// allowed for guests as long as the update is limited to selections.
3079async fn update_buffer(
3080    request: proto::UpdateBuffer,
3081    response: Response<proto::UpdateBuffer>,
3082    session: Session,
3083) -> Result<()> {
3084    let project_id = ProjectId::from_proto(request.project_id);
3085    let mut capability = Capability::ReadOnly;
3086
3087    for op in request.operations.iter() {
3088        match op.variant {
3089            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3090            Some(_) => capability = Capability::ReadWrite,
3091        }
3092    }
3093
3094    let host = {
3095        let guard = session
3096            .db()
3097            .await
3098            .connections_for_buffer_update(
3099                project_id,
3100                session.principal_id(),
3101                session.connection_id,
3102                capability,
3103            )
3104            .await?;
3105
3106        let (host, guests) = &*guard;
3107
3108        broadcast(
3109            Some(session.connection_id),
3110            guests.clone(),
3111            |connection_id| {
3112                session
3113                    .peer
3114                    .forward_send(session.connection_id, connection_id, request.clone())
3115            },
3116        );
3117
3118        *host
3119    };
3120
3121    if host != session.connection_id {
3122        session
3123            .peer
3124            .forward_request(session.connection_id, host, request.clone())
3125            .await?;
3126    }
3127
3128    response.send(proto::Ack {})?;
3129    Ok(())
3130}
3131
3132async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3133    let project_id = ProjectId::from_proto(message.project_id);
3134
3135    let operation = message.operation.as_ref().context("invalid operation")?;
3136    let capability = match operation.variant.as_ref() {
3137        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3138            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3139                match buffer_op.variant {
3140                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3141                        Capability::ReadOnly
3142                    }
3143                    _ => Capability::ReadWrite,
3144                }
3145            } else {
3146                Capability::ReadWrite
3147            }
3148        }
3149        Some(_) => Capability::ReadWrite,
3150        None => Capability::ReadOnly,
3151    };
3152
3153    let guard = session
3154        .db()
3155        .await
3156        .connections_for_buffer_update(
3157            project_id,
3158            session.principal_id(),
3159            session.connection_id,
3160            capability,
3161        )
3162        .await?;
3163
3164    let (host, guests) = &*guard;
3165
3166    broadcast(
3167        Some(session.connection_id),
3168        guests.iter().chain([host]).copied(),
3169        |connection_id| {
3170            session
3171                .peer
3172                .forward_send(session.connection_id, connection_id, message.clone())
3173        },
3174    );
3175
3176    Ok(())
3177}
3178
3179/// Notify other participants that a project has been updated.
3180async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3181    request: T,
3182    session: Session,
3183) -> Result<()> {
3184    let project_id = ProjectId::from_proto(request.remote_entity_id());
3185    let project_connection_ids = session
3186        .db()
3187        .await
3188        .project_connection_ids(project_id, session.connection_id, false)
3189        .await?;
3190
3191    broadcast(
3192        Some(session.connection_id),
3193        project_connection_ids.iter().copied(),
3194        |connection_id| {
3195            session
3196                .peer
3197                .forward_send(session.connection_id, connection_id, request.clone())
3198        },
3199    );
3200    Ok(())
3201}
3202
3203/// Start following another user in a call.
3204async fn follow(
3205    request: proto::Follow,
3206    response: Response<proto::Follow>,
3207    session: UserSession,
3208) -> Result<()> {
3209    let room_id = RoomId::from_proto(request.room_id);
3210    let project_id = request.project_id.map(ProjectId::from_proto);
3211    let leader_id = request
3212        .leader_id
3213        .ok_or_else(|| anyhow!("invalid leader id"))?
3214        .into();
3215    let follower_id = session.connection_id;
3216
3217    session
3218        .db()
3219        .await
3220        .check_room_participants(room_id, leader_id, session.connection_id)
3221        .await?;
3222
3223    let response_payload = session
3224        .peer
3225        .forward_request(session.connection_id, leader_id, request)
3226        .await?;
3227    response.send(response_payload)?;
3228
3229    if let Some(project_id) = project_id {
3230        let room = session
3231            .db()
3232            .await
3233            .follow(room_id, project_id, leader_id, follower_id)
3234            .await?;
3235        room_updated(&room, &session.peer);
3236    }
3237
3238    Ok(())
3239}
3240
3241/// Stop following another user in a call.
3242async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3243    let room_id = RoomId::from_proto(request.room_id);
3244    let project_id = request.project_id.map(ProjectId::from_proto);
3245    let leader_id = request
3246        .leader_id
3247        .ok_or_else(|| anyhow!("invalid leader id"))?
3248        .into();
3249    let follower_id = session.connection_id;
3250
3251    session
3252        .db()
3253        .await
3254        .check_room_participants(room_id, leader_id, session.connection_id)
3255        .await?;
3256
3257    session
3258        .peer
3259        .forward_send(session.connection_id, leader_id, request)?;
3260
3261    if let Some(project_id) = project_id {
3262        let room = session
3263            .db()
3264            .await
3265            .unfollow(room_id, project_id, leader_id, follower_id)
3266            .await?;
3267        room_updated(&room, &session.peer);
3268    }
3269
3270    Ok(())
3271}
3272
3273/// Notify everyone following you of your current location.
3274async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3275    let room_id = RoomId::from_proto(request.room_id);
3276    let database = session.db.lock().await;
3277
3278    let connection_ids = if let Some(project_id) = request.project_id {
3279        let project_id = ProjectId::from_proto(project_id);
3280        database
3281            .project_connection_ids(project_id, session.connection_id, true)
3282            .await?
3283    } else {
3284        database
3285            .room_connection_ids(room_id, session.connection_id)
3286            .await?
3287    };
3288
3289    // For now, don't send view update messages back to that view's current leader.
3290    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3291        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3292        _ => None,
3293    });
3294
3295    for connection_id in connection_ids.iter().cloned() {
3296        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3297            session
3298                .peer
3299                .forward_send(session.connection_id, connection_id, request.clone())?;
3300        }
3301    }
3302    Ok(())
3303}
3304
3305/// Get public data about users.
3306async fn get_users(
3307    request: proto::GetUsers,
3308    response: Response<proto::GetUsers>,
3309    session: Session,
3310) -> Result<()> {
3311    let user_ids = request
3312        .user_ids
3313        .into_iter()
3314        .map(UserId::from_proto)
3315        .collect();
3316    let users = session
3317        .db()
3318        .await
3319        .get_users_by_ids(user_ids)
3320        .await?
3321        .into_iter()
3322        .map(|user| proto::User {
3323            id: user.id.to_proto(),
3324            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3325            github_login: user.github_login,
3326        })
3327        .collect();
3328    response.send(proto::UsersResponse { users })?;
3329    Ok(())
3330}
3331
3332/// Search for users (to invite) buy Github login
3333async fn fuzzy_search_users(
3334    request: proto::FuzzySearchUsers,
3335    response: Response<proto::FuzzySearchUsers>,
3336    session: UserSession,
3337) -> Result<()> {
3338    let query = request.query;
3339    let users = match query.len() {
3340        0 => vec![],
3341        1 | 2 => session
3342            .db()
3343            .await
3344            .get_user_by_github_login(&query)
3345            .await?
3346            .into_iter()
3347            .collect(),
3348        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3349    };
3350    let users = users
3351        .into_iter()
3352        .filter(|user| user.id != session.user_id())
3353        .map(|user| proto::User {
3354            id: user.id.to_proto(),
3355            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3356            github_login: user.github_login,
3357        })
3358        .collect();
3359    response.send(proto::UsersResponse { users })?;
3360    Ok(())
3361}
3362
3363/// Send a contact request to another user.
3364async fn request_contact(
3365    request: proto::RequestContact,
3366    response: Response<proto::RequestContact>,
3367    session: UserSession,
3368) -> Result<()> {
3369    let requester_id = session.user_id();
3370    let responder_id = UserId::from_proto(request.responder_id);
3371    if requester_id == responder_id {
3372        return Err(anyhow!("cannot add yourself as a contact"))?;
3373    }
3374
3375    let notifications = session
3376        .db()
3377        .await
3378        .send_contact_request(requester_id, responder_id)
3379        .await?;
3380
3381    // Update outgoing contact requests of requester
3382    let mut update = proto::UpdateContacts::default();
3383    update.outgoing_requests.push(responder_id.to_proto());
3384    for connection_id in session
3385        .connection_pool()
3386        .await
3387        .user_connection_ids(requester_id)
3388    {
3389        session.peer.send(connection_id, update.clone())?;
3390    }
3391
3392    // Update incoming contact requests of responder
3393    let mut update = proto::UpdateContacts::default();
3394    update
3395        .incoming_requests
3396        .push(proto::IncomingContactRequest {
3397            requester_id: requester_id.to_proto(),
3398        });
3399    let connection_pool = session.connection_pool().await;
3400    for connection_id in connection_pool.user_connection_ids(responder_id) {
3401        session.peer.send(connection_id, update.clone())?;
3402    }
3403
3404    send_notifications(&connection_pool, &session.peer, notifications);
3405
3406    response.send(proto::Ack {})?;
3407    Ok(())
3408}
3409
3410/// Accept or decline a contact request
3411async fn respond_to_contact_request(
3412    request: proto::RespondToContactRequest,
3413    response: Response<proto::RespondToContactRequest>,
3414    session: UserSession,
3415) -> Result<()> {
3416    let responder_id = session.user_id();
3417    let requester_id = UserId::from_proto(request.requester_id);
3418    let db = session.db().await;
3419    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3420        db.dismiss_contact_notification(responder_id, requester_id)
3421            .await?;
3422    } else {
3423        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3424
3425        let notifications = db
3426            .respond_to_contact_request(responder_id, requester_id, accept)
3427            .await?;
3428        let requester_busy = db.is_user_busy(requester_id).await?;
3429        let responder_busy = db.is_user_busy(responder_id).await?;
3430
3431        let pool = session.connection_pool().await;
3432        // Update responder with new contact
3433        let mut update = proto::UpdateContacts::default();
3434        if accept {
3435            update
3436                .contacts
3437                .push(contact_for_user(requester_id, requester_busy, &pool));
3438        }
3439        update
3440            .remove_incoming_requests
3441            .push(requester_id.to_proto());
3442        for connection_id in pool.user_connection_ids(responder_id) {
3443            session.peer.send(connection_id, update.clone())?;
3444        }
3445
3446        // Update requester with new contact
3447        let mut update = proto::UpdateContacts::default();
3448        if accept {
3449            update
3450                .contacts
3451                .push(contact_for_user(responder_id, responder_busy, &pool));
3452        }
3453        update
3454            .remove_outgoing_requests
3455            .push(responder_id.to_proto());
3456
3457        for connection_id in pool.user_connection_ids(requester_id) {
3458            session.peer.send(connection_id, update.clone())?;
3459        }
3460
3461        send_notifications(&pool, &session.peer, notifications);
3462    }
3463
3464    response.send(proto::Ack {})?;
3465    Ok(())
3466}
3467
3468/// Remove a contact.
3469async fn remove_contact(
3470    request: proto::RemoveContact,
3471    response: Response<proto::RemoveContact>,
3472    session: UserSession,
3473) -> Result<()> {
3474    let requester_id = session.user_id();
3475    let responder_id = UserId::from_proto(request.user_id);
3476    let db = session.db().await;
3477    let (contact_accepted, deleted_notification_id) =
3478        db.remove_contact(requester_id, responder_id).await?;
3479
3480    let pool = session.connection_pool().await;
3481    // Update outgoing contact requests of requester
3482    let mut update = proto::UpdateContacts::default();
3483    if contact_accepted {
3484        update.remove_contacts.push(responder_id.to_proto());
3485    } else {
3486        update
3487            .remove_outgoing_requests
3488            .push(responder_id.to_proto());
3489    }
3490    for connection_id in pool.user_connection_ids(requester_id) {
3491        session.peer.send(connection_id, update.clone())?;
3492    }
3493
3494    // Update incoming contact requests of responder
3495    let mut update = proto::UpdateContacts::default();
3496    if contact_accepted {
3497        update.remove_contacts.push(requester_id.to_proto());
3498    } else {
3499        update
3500            .remove_incoming_requests
3501            .push(requester_id.to_proto());
3502    }
3503    for connection_id in pool.user_connection_ids(responder_id) {
3504        session.peer.send(connection_id, update.clone())?;
3505        if let Some(notification_id) = deleted_notification_id {
3506            session.peer.send(
3507                connection_id,
3508                proto::DeleteNotification {
3509                    notification_id: notification_id.to_proto(),
3510                },
3511            )?;
3512        }
3513    }
3514
3515    response.send(proto::Ack {})?;
3516    Ok(())
3517}
3518
3519fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3520    version.0.minor() < 139
3521}
3522
3523async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3524    subscribe_user_to_channels(
3525        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3526        &session,
3527    )
3528    .await?;
3529    Ok(())
3530}
3531
3532async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3533    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3534    let mut pool = session.connection_pool().await;
3535    for membership in &channels_for_user.channel_memberships {
3536        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3537    }
3538    session.peer.send(
3539        session.connection_id,
3540        build_update_user_channels(&channels_for_user),
3541    )?;
3542    session.peer.send(
3543        session.connection_id,
3544        build_channels_update(channels_for_user),
3545    )?;
3546    Ok(())
3547}
3548
3549/// Creates a new channel.
3550async fn create_channel(
3551    request: proto::CreateChannel,
3552    response: Response<proto::CreateChannel>,
3553    session: UserSession,
3554) -> Result<()> {
3555    let db = session.db().await;
3556
3557    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3558    let (channel, membership) = db
3559        .create_channel(&request.name, parent_id, session.user_id())
3560        .await?;
3561
3562    let root_id = channel.root_id();
3563    let channel = Channel::from_model(channel);
3564
3565    response.send(proto::CreateChannelResponse {
3566        channel: Some(channel.to_proto()),
3567        parent_id: request.parent_id,
3568    })?;
3569
3570    let mut connection_pool = session.connection_pool().await;
3571    if let Some(membership) = membership {
3572        connection_pool.subscribe_to_channel(
3573            membership.user_id,
3574            membership.channel_id,
3575            membership.role,
3576        );
3577        let update = proto::UpdateUserChannels {
3578            channel_memberships: vec![proto::ChannelMembership {
3579                channel_id: membership.channel_id.to_proto(),
3580                role: membership.role.into(),
3581            }],
3582            ..Default::default()
3583        };
3584        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3585            session.peer.send(connection_id, update.clone())?;
3586        }
3587    }
3588
3589    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3590        if !role.can_see_channel(channel.visibility) {
3591            continue;
3592        }
3593
3594        let update = proto::UpdateChannels {
3595            channels: vec![channel.to_proto()],
3596            ..Default::default()
3597        };
3598        session.peer.send(connection_id, update.clone())?;
3599    }
3600
3601    Ok(())
3602}
3603
3604/// Delete a channel
3605async fn delete_channel(
3606    request: proto::DeleteChannel,
3607    response: Response<proto::DeleteChannel>,
3608    session: UserSession,
3609) -> Result<()> {
3610    let db = session.db().await;
3611
3612    let channel_id = request.channel_id;
3613    let (root_channel, removed_channels) = db
3614        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3615        .await?;
3616    response.send(proto::Ack {})?;
3617
3618    // Notify members of removed channels
3619    let mut update = proto::UpdateChannels::default();
3620    update
3621        .delete_channels
3622        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3623
3624    let connection_pool = session.connection_pool().await;
3625    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3626        session.peer.send(connection_id, update.clone())?;
3627    }
3628
3629    Ok(())
3630}
3631
3632/// Invite someone to join a channel.
3633async fn invite_channel_member(
3634    request: proto::InviteChannelMember,
3635    response: Response<proto::InviteChannelMember>,
3636    session: UserSession,
3637) -> Result<()> {
3638    let db = session.db().await;
3639    let channel_id = ChannelId::from_proto(request.channel_id);
3640    let invitee_id = UserId::from_proto(request.user_id);
3641    let InviteMemberResult {
3642        channel,
3643        notifications,
3644    } = db
3645        .invite_channel_member(
3646            channel_id,
3647            invitee_id,
3648            session.user_id(),
3649            request.role().into(),
3650        )
3651        .await?;
3652
3653    let update = proto::UpdateChannels {
3654        channel_invitations: vec![channel.to_proto()],
3655        ..Default::default()
3656    };
3657
3658    let connection_pool = session.connection_pool().await;
3659    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3660        session.peer.send(connection_id, update.clone())?;
3661    }
3662
3663    send_notifications(&connection_pool, &session.peer, notifications);
3664
3665    response.send(proto::Ack {})?;
3666    Ok(())
3667}
3668
3669/// remove someone from a channel
3670async fn remove_channel_member(
3671    request: proto::RemoveChannelMember,
3672    response: Response<proto::RemoveChannelMember>,
3673    session: UserSession,
3674) -> Result<()> {
3675    let db = session.db().await;
3676    let channel_id = ChannelId::from_proto(request.channel_id);
3677    let member_id = UserId::from_proto(request.user_id);
3678
3679    let RemoveChannelMemberResult {
3680        membership_update,
3681        notification_id,
3682    } = db
3683        .remove_channel_member(channel_id, member_id, session.user_id())
3684        .await?;
3685
3686    let mut connection_pool = session.connection_pool().await;
3687    notify_membership_updated(
3688        &mut connection_pool,
3689        membership_update,
3690        member_id,
3691        &session.peer,
3692    );
3693    for connection_id in connection_pool.user_connection_ids(member_id) {
3694        if let Some(notification_id) = notification_id {
3695            session
3696                .peer
3697                .send(
3698                    connection_id,
3699                    proto::DeleteNotification {
3700                        notification_id: notification_id.to_proto(),
3701                    },
3702                )
3703                .trace_err();
3704        }
3705    }
3706
3707    response.send(proto::Ack {})?;
3708    Ok(())
3709}
3710
3711/// Toggle the channel between public and private.
3712/// Care is taken to maintain the invariant that public channels only descend from public channels,
3713/// (though members-only channels can appear at any point in the hierarchy).
3714async fn set_channel_visibility(
3715    request: proto::SetChannelVisibility,
3716    response: Response<proto::SetChannelVisibility>,
3717    session: UserSession,
3718) -> Result<()> {
3719    let db = session.db().await;
3720    let channel_id = ChannelId::from_proto(request.channel_id);
3721    let visibility = request.visibility().into();
3722
3723    let channel_model = db
3724        .set_channel_visibility(channel_id, visibility, session.user_id())
3725        .await?;
3726    let root_id = channel_model.root_id();
3727    let channel = Channel::from_model(channel_model);
3728
3729    let mut connection_pool = session.connection_pool().await;
3730    for (user_id, role) in connection_pool
3731        .channel_user_ids(root_id)
3732        .collect::<Vec<_>>()
3733        .into_iter()
3734    {
3735        let update = if role.can_see_channel(channel.visibility) {
3736            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3737            proto::UpdateChannels {
3738                channels: vec![channel.to_proto()],
3739                ..Default::default()
3740            }
3741        } else {
3742            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3743            proto::UpdateChannels {
3744                delete_channels: vec![channel.id.to_proto()],
3745                ..Default::default()
3746            }
3747        };
3748
3749        for connection_id in connection_pool.user_connection_ids(user_id) {
3750            session.peer.send(connection_id, update.clone())?;
3751        }
3752    }
3753
3754    response.send(proto::Ack {})?;
3755    Ok(())
3756}
3757
3758/// Alter the role for a user in the channel.
3759async fn set_channel_member_role(
3760    request: proto::SetChannelMemberRole,
3761    response: Response<proto::SetChannelMemberRole>,
3762    session: UserSession,
3763) -> Result<()> {
3764    let db = session.db().await;
3765    let channel_id = ChannelId::from_proto(request.channel_id);
3766    let member_id = UserId::from_proto(request.user_id);
3767    let result = db
3768        .set_channel_member_role(
3769            channel_id,
3770            session.user_id(),
3771            member_id,
3772            request.role().into(),
3773        )
3774        .await?;
3775
3776    match result {
3777        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3778            let mut connection_pool = session.connection_pool().await;
3779            notify_membership_updated(
3780                &mut connection_pool,
3781                membership_update,
3782                member_id,
3783                &session.peer,
3784            )
3785        }
3786        db::SetMemberRoleResult::InviteUpdated(channel) => {
3787            let update = proto::UpdateChannels {
3788                channel_invitations: vec![channel.to_proto()],
3789                ..Default::default()
3790            };
3791
3792            for connection_id in session
3793                .connection_pool()
3794                .await
3795                .user_connection_ids(member_id)
3796            {
3797                session.peer.send(connection_id, update.clone())?;
3798            }
3799        }
3800    }
3801
3802    response.send(proto::Ack {})?;
3803    Ok(())
3804}
3805
3806/// Change the name of a channel
3807async fn rename_channel(
3808    request: proto::RenameChannel,
3809    response: Response<proto::RenameChannel>,
3810    session: UserSession,
3811) -> Result<()> {
3812    let db = session.db().await;
3813    let channel_id = ChannelId::from_proto(request.channel_id);
3814    let channel_model = db
3815        .rename_channel(channel_id, session.user_id(), &request.name)
3816        .await?;
3817    let root_id = channel_model.root_id();
3818    let channel = Channel::from_model(channel_model);
3819
3820    response.send(proto::RenameChannelResponse {
3821        channel: Some(channel.to_proto()),
3822    })?;
3823
3824    let connection_pool = session.connection_pool().await;
3825    let update = proto::UpdateChannels {
3826        channels: vec![channel.to_proto()],
3827        ..Default::default()
3828    };
3829    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3830        if role.can_see_channel(channel.visibility) {
3831            session.peer.send(connection_id, update.clone())?;
3832        }
3833    }
3834
3835    Ok(())
3836}
3837
3838/// Move a channel to a new parent.
3839async fn move_channel(
3840    request: proto::MoveChannel,
3841    response: Response<proto::MoveChannel>,
3842    session: UserSession,
3843) -> Result<()> {
3844    let channel_id = ChannelId::from_proto(request.channel_id);
3845    let to = ChannelId::from_proto(request.to);
3846
3847    let (root_id, channels) = session
3848        .db()
3849        .await
3850        .move_channel(channel_id, to, session.user_id())
3851        .await?;
3852
3853    let connection_pool = session.connection_pool().await;
3854    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3855        let channels = channels
3856            .iter()
3857            .filter_map(|channel| {
3858                if role.can_see_channel(channel.visibility) {
3859                    Some(channel.to_proto())
3860                } else {
3861                    None
3862                }
3863            })
3864            .collect::<Vec<_>>();
3865        if channels.is_empty() {
3866            continue;
3867        }
3868
3869        let update = proto::UpdateChannels {
3870            channels,
3871            ..Default::default()
3872        };
3873
3874        session.peer.send(connection_id, update.clone())?;
3875    }
3876
3877    response.send(Ack {})?;
3878    Ok(())
3879}
3880
3881/// Get the list of channel members
3882async fn get_channel_members(
3883    request: proto::GetChannelMembers,
3884    response: Response<proto::GetChannelMembers>,
3885    session: UserSession,
3886) -> Result<()> {
3887    let db = session.db().await;
3888    let channel_id = ChannelId::from_proto(request.channel_id);
3889    let limit = if request.limit == 0 {
3890        u16::MAX as u64
3891    } else {
3892        request.limit
3893    };
3894    let (members, users) = db
3895        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3896        .await?;
3897    response.send(proto::GetChannelMembersResponse { members, users })?;
3898    Ok(())
3899}
3900
3901/// Accept or decline a channel invitation.
3902async fn respond_to_channel_invite(
3903    request: proto::RespondToChannelInvite,
3904    response: Response<proto::RespondToChannelInvite>,
3905    session: UserSession,
3906) -> Result<()> {
3907    let db = session.db().await;
3908    let channel_id = ChannelId::from_proto(request.channel_id);
3909    let RespondToChannelInvite {
3910        membership_update,
3911        notifications,
3912    } = db
3913        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3914        .await?;
3915
3916    let mut connection_pool = session.connection_pool().await;
3917    if let Some(membership_update) = membership_update {
3918        notify_membership_updated(
3919            &mut connection_pool,
3920            membership_update,
3921            session.user_id(),
3922            &session.peer,
3923        );
3924    } else {
3925        let update = proto::UpdateChannels {
3926            remove_channel_invitations: vec![channel_id.to_proto()],
3927            ..Default::default()
3928        };
3929
3930        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3931            session.peer.send(connection_id, update.clone())?;
3932        }
3933    };
3934
3935    send_notifications(&connection_pool, &session.peer, notifications);
3936
3937    response.send(proto::Ack {})?;
3938
3939    Ok(())
3940}
3941
3942/// Join the channels' room
3943async fn join_channel(
3944    request: proto::JoinChannel,
3945    response: Response<proto::JoinChannel>,
3946    session: UserSession,
3947) -> Result<()> {
3948    let channel_id = ChannelId::from_proto(request.channel_id);
3949    join_channel_internal(channel_id, Box::new(response), session).await
3950}
3951
3952trait JoinChannelInternalResponse {
3953    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3954}
3955impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3956    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3957        Response::<proto::JoinChannel>::send(self, result)
3958    }
3959}
3960impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3961    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3962        Response::<proto::JoinRoom>::send(self, result)
3963    }
3964}
3965
3966async fn join_channel_internal(
3967    channel_id: ChannelId,
3968    response: Box<impl JoinChannelInternalResponse>,
3969    session: UserSession,
3970) -> Result<()> {
3971    let joined_room = {
3972        let mut db = session.db().await;
3973        // If zed quits without leaving the room, and the user re-opens zed before the
3974        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3975        // room they were in.
3976        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3977            tracing::info!(
3978                stale_connection_id = %connection,
3979                "cleaning up stale connection",
3980            );
3981            drop(db);
3982            leave_room_for_session(&session, connection).await?;
3983            db = session.db().await;
3984        }
3985
3986        let (joined_room, membership_updated, role) = db
3987            .join_channel(channel_id, session.user_id(), session.connection_id)
3988            .await?;
3989
3990        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3991            let (can_publish, token) = if role == ChannelRole::Guest {
3992                (
3993                    false,
3994                    live_kit
3995                        .guest_token(
3996                            &joined_room.room.live_kit_room,
3997                            &session.user_id().to_string(),
3998                        )
3999                        .trace_err()?,
4000                )
4001            } else {
4002                (
4003                    true,
4004                    live_kit
4005                        .room_token(
4006                            &joined_room.room.live_kit_room,
4007                            &session.user_id().to_string(),
4008                        )
4009                        .trace_err()?,
4010                )
4011            };
4012
4013            Some(LiveKitConnectionInfo {
4014                server_url: live_kit.url().into(),
4015                token,
4016                can_publish,
4017            })
4018        });
4019
4020        response.send(proto::JoinRoomResponse {
4021            room: Some(joined_room.room.clone()),
4022            channel_id: joined_room
4023                .channel
4024                .as_ref()
4025                .map(|channel| channel.id.to_proto()),
4026            live_kit_connection_info,
4027        })?;
4028
4029        let mut connection_pool = session.connection_pool().await;
4030        if let Some(membership_updated) = membership_updated {
4031            notify_membership_updated(
4032                &mut connection_pool,
4033                membership_updated,
4034                session.user_id(),
4035                &session.peer,
4036            );
4037        }
4038
4039        room_updated(&joined_room.room, &session.peer);
4040
4041        joined_room
4042    };
4043
4044    channel_updated(
4045        &joined_room
4046            .channel
4047            .ok_or_else(|| anyhow!("channel not returned"))?,
4048        &joined_room.room,
4049        &session.peer,
4050        &*session.connection_pool().await,
4051    );
4052
4053    update_user_contacts(session.user_id(), &session).await?;
4054    Ok(())
4055}
4056
4057/// Start editing the channel notes
4058async fn join_channel_buffer(
4059    request: proto::JoinChannelBuffer,
4060    response: Response<proto::JoinChannelBuffer>,
4061    session: UserSession,
4062) -> Result<()> {
4063    let db = session.db().await;
4064    let channel_id = ChannelId::from_proto(request.channel_id);
4065
4066    let open_response = db
4067        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4068        .await?;
4069
4070    let collaborators = open_response.collaborators.clone();
4071    response.send(open_response)?;
4072
4073    let update = UpdateChannelBufferCollaborators {
4074        channel_id: channel_id.to_proto(),
4075        collaborators: collaborators.clone(),
4076    };
4077    channel_buffer_updated(
4078        session.connection_id,
4079        collaborators
4080            .iter()
4081            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4082        &update,
4083        &session.peer,
4084    );
4085
4086    Ok(())
4087}
4088
4089/// Edit the channel notes
4090async fn update_channel_buffer(
4091    request: proto::UpdateChannelBuffer,
4092    session: UserSession,
4093) -> Result<()> {
4094    let db = session.db().await;
4095    let channel_id = ChannelId::from_proto(request.channel_id);
4096
4097    let (collaborators, epoch, version) = db
4098        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4099        .await?;
4100
4101    channel_buffer_updated(
4102        session.connection_id,
4103        collaborators.clone(),
4104        &proto::UpdateChannelBuffer {
4105            channel_id: channel_id.to_proto(),
4106            operations: request.operations,
4107        },
4108        &session.peer,
4109    );
4110
4111    let pool = &*session.connection_pool().await;
4112
4113    let non_collaborators =
4114        pool.channel_connection_ids(channel_id)
4115            .filter_map(|(connection_id, _)| {
4116                if collaborators.contains(&connection_id) {
4117                    None
4118                } else {
4119                    Some(connection_id)
4120                }
4121            });
4122
4123    broadcast(None, non_collaborators, |peer_id| {
4124        session.peer.send(
4125            peer_id,
4126            proto::UpdateChannels {
4127                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4128                    channel_id: channel_id.to_proto(),
4129                    epoch: epoch as u64,
4130                    version: version.clone(),
4131                }],
4132                ..Default::default()
4133            },
4134        )
4135    });
4136
4137    Ok(())
4138}
4139
4140/// Rejoin the channel notes after a connection blip
4141async fn rejoin_channel_buffers(
4142    request: proto::RejoinChannelBuffers,
4143    response: Response<proto::RejoinChannelBuffers>,
4144    session: UserSession,
4145) -> Result<()> {
4146    let db = session.db().await;
4147    let buffers = db
4148        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4149        .await?;
4150
4151    for rejoined_buffer in &buffers {
4152        let collaborators_to_notify = rejoined_buffer
4153            .buffer
4154            .collaborators
4155            .iter()
4156            .filter_map(|c| Some(c.peer_id?.into()));
4157        channel_buffer_updated(
4158            session.connection_id,
4159            collaborators_to_notify,
4160            &proto::UpdateChannelBufferCollaborators {
4161                channel_id: rejoined_buffer.buffer.channel_id,
4162                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4163            },
4164            &session.peer,
4165        );
4166    }
4167
4168    response.send(proto::RejoinChannelBuffersResponse {
4169        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4170    })?;
4171
4172    Ok(())
4173}
4174
4175/// Stop editing the channel notes
4176async fn leave_channel_buffer(
4177    request: proto::LeaveChannelBuffer,
4178    response: Response<proto::LeaveChannelBuffer>,
4179    session: UserSession,
4180) -> Result<()> {
4181    let db = session.db().await;
4182    let channel_id = ChannelId::from_proto(request.channel_id);
4183
4184    let left_buffer = db
4185        .leave_channel_buffer(channel_id, session.connection_id)
4186        .await?;
4187
4188    response.send(Ack {})?;
4189
4190    channel_buffer_updated(
4191        session.connection_id,
4192        left_buffer.connections,
4193        &proto::UpdateChannelBufferCollaborators {
4194            channel_id: channel_id.to_proto(),
4195            collaborators: left_buffer.collaborators,
4196        },
4197        &session.peer,
4198    );
4199
4200    Ok(())
4201}
4202
4203fn channel_buffer_updated<T: EnvelopedMessage>(
4204    sender_id: ConnectionId,
4205    collaborators: impl IntoIterator<Item = ConnectionId>,
4206    message: &T,
4207    peer: &Peer,
4208) {
4209    broadcast(Some(sender_id), collaborators, |peer_id| {
4210        peer.send(peer_id, message.clone())
4211    });
4212}
4213
4214fn send_notifications(
4215    connection_pool: &ConnectionPool,
4216    peer: &Peer,
4217    notifications: db::NotificationBatch,
4218) {
4219    for (user_id, notification) in notifications {
4220        for connection_id in connection_pool.user_connection_ids(user_id) {
4221            if let Err(error) = peer.send(
4222                connection_id,
4223                proto::AddNotification {
4224                    notification: Some(notification.clone()),
4225                },
4226            ) {
4227                tracing::error!(
4228                    "failed to send notification to {:?} {}",
4229                    connection_id,
4230                    error
4231                );
4232            }
4233        }
4234    }
4235}
4236
4237/// Send a message to the channel
4238async fn send_channel_message(
4239    request: proto::SendChannelMessage,
4240    response: Response<proto::SendChannelMessage>,
4241    session: UserSession,
4242) -> Result<()> {
4243    // Validate the message body.
4244    let body = request.body.trim().to_string();
4245    if body.len() > MAX_MESSAGE_LEN {
4246        return Err(anyhow!("message is too long"))?;
4247    }
4248    if body.is_empty() {
4249        return Err(anyhow!("message can't be blank"))?;
4250    }
4251
4252    // TODO: adjust mentions if body is trimmed
4253
4254    let timestamp = OffsetDateTime::now_utc();
4255    let nonce = request
4256        .nonce
4257        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4258
4259    let channel_id = ChannelId::from_proto(request.channel_id);
4260    let CreatedChannelMessage {
4261        message_id,
4262        participant_connection_ids,
4263        notifications,
4264    } = session
4265        .db()
4266        .await
4267        .create_channel_message(
4268            channel_id,
4269            session.user_id(),
4270            &body,
4271            &request.mentions,
4272            timestamp,
4273            nonce.clone().into(),
4274            match request.reply_to_message_id {
4275                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4276                None => None,
4277            },
4278        )
4279        .await?;
4280
4281    let message = proto::ChannelMessage {
4282        sender_id: session.user_id().to_proto(),
4283        id: message_id.to_proto(),
4284        body,
4285        mentions: request.mentions,
4286        timestamp: timestamp.unix_timestamp() as u64,
4287        nonce: Some(nonce),
4288        reply_to_message_id: request.reply_to_message_id,
4289        edited_at: None,
4290    };
4291    broadcast(
4292        Some(session.connection_id),
4293        participant_connection_ids.clone(),
4294        |connection| {
4295            session.peer.send(
4296                connection,
4297                proto::ChannelMessageSent {
4298                    channel_id: channel_id.to_proto(),
4299                    message: Some(message.clone()),
4300                },
4301            )
4302        },
4303    );
4304    response.send(proto::SendChannelMessageResponse {
4305        message: Some(message),
4306    })?;
4307
4308    let pool = &*session.connection_pool().await;
4309    let non_participants =
4310        pool.channel_connection_ids(channel_id)
4311            .filter_map(|(connection_id, _)| {
4312                if participant_connection_ids.contains(&connection_id) {
4313                    None
4314                } else {
4315                    Some(connection_id)
4316                }
4317            });
4318    broadcast(None, non_participants, |peer_id| {
4319        session.peer.send(
4320            peer_id,
4321            proto::UpdateChannels {
4322                latest_channel_message_ids: vec![proto::ChannelMessageId {
4323                    channel_id: channel_id.to_proto(),
4324                    message_id: message_id.to_proto(),
4325                }],
4326                ..Default::default()
4327            },
4328        )
4329    });
4330    send_notifications(pool, &session.peer, notifications);
4331
4332    Ok(())
4333}
4334
4335/// Delete a channel message
4336async fn remove_channel_message(
4337    request: proto::RemoveChannelMessage,
4338    response: Response<proto::RemoveChannelMessage>,
4339    session: UserSession,
4340) -> Result<()> {
4341    let channel_id = ChannelId::from_proto(request.channel_id);
4342    let message_id = MessageId::from_proto(request.message_id);
4343    let (connection_ids, existing_notification_ids) = session
4344        .db()
4345        .await
4346        .remove_channel_message(channel_id, message_id, session.user_id())
4347        .await?;
4348
4349    broadcast(
4350        Some(session.connection_id),
4351        connection_ids,
4352        move |connection| {
4353            session.peer.send(connection, request.clone())?;
4354
4355            for notification_id in &existing_notification_ids {
4356                session.peer.send(
4357                    connection,
4358                    proto::DeleteNotification {
4359                        notification_id: (*notification_id).to_proto(),
4360                    },
4361                )?;
4362            }
4363
4364            Ok(())
4365        },
4366    );
4367    response.send(proto::Ack {})?;
4368    Ok(())
4369}
4370
4371async fn update_channel_message(
4372    request: proto::UpdateChannelMessage,
4373    response: Response<proto::UpdateChannelMessage>,
4374    session: UserSession,
4375) -> Result<()> {
4376    let channel_id = ChannelId::from_proto(request.channel_id);
4377    let message_id = MessageId::from_proto(request.message_id);
4378    let updated_at = OffsetDateTime::now_utc();
4379    let UpdatedChannelMessage {
4380        message_id,
4381        participant_connection_ids,
4382        notifications,
4383        reply_to_message_id,
4384        timestamp,
4385        deleted_mention_notification_ids,
4386        updated_mention_notifications,
4387    } = session
4388        .db()
4389        .await
4390        .update_channel_message(
4391            channel_id,
4392            message_id,
4393            session.user_id(),
4394            request.body.as_str(),
4395            &request.mentions,
4396            updated_at,
4397        )
4398        .await?;
4399
4400    let nonce = request
4401        .nonce
4402        .clone()
4403        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4404
4405    let message = proto::ChannelMessage {
4406        sender_id: session.user_id().to_proto(),
4407        id: message_id.to_proto(),
4408        body: request.body.clone(),
4409        mentions: request.mentions.clone(),
4410        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4411        nonce: Some(nonce),
4412        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4413        edited_at: Some(updated_at.unix_timestamp() as u64),
4414    };
4415
4416    response.send(proto::Ack {})?;
4417
4418    let pool = &*session.connection_pool().await;
4419    broadcast(
4420        Some(session.connection_id),
4421        participant_connection_ids,
4422        |connection| {
4423            session.peer.send(
4424                connection,
4425                proto::ChannelMessageUpdate {
4426                    channel_id: channel_id.to_proto(),
4427                    message: Some(message.clone()),
4428                },
4429            )?;
4430
4431            for notification_id in &deleted_mention_notification_ids {
4432                session.peer.send(
4433                    connection,
4434                    proto::DeleteNotification {
4435                        notification_id: (*notification_id).to_proto(),
4436                    },
4437                )?;
4438            }
4439
4440            for notification in &updated_mention_notifications {
4441                session.peer.send(
4442                    connection,
4443                    proto::UpdateNotification {
4444                        notification: Some(notification.clone()),
4445                    },
4446                )?;
4447            }
4448
4449            Ok(())
4450        },
4451    );
4452
4453    send_notifications(pool, &session.peer, notifications);
4454
4455    Ok(())
4456}
4457
4458/// Mark a channel message as read
4459async fn acknowledge_channel_message(
4460    request: proto::AckChannelMessage,
4461    session: UserSession,
4462) -> Result<()> {
4463    let channel_id = ChannelId::from_proto(request.channel_id);
4464    let message_id = MessageId::from_proto(request.message_id);
4465    let notifications = session
4466        .db()
4467        .await
4468        .observe_channel_message(channel_id, session.user_id(), message_id)
4469        .await?;
4470    send_notifications(
4471        &*session.connection_pool().await,
4472        &session.peer,
4473        notifications,
4474    );
4475    Ok(())
4476}
4477
4478/// Mark a buffer version as synced
4479async fn acknowledge_buffer_version(
4480    request: proto::AckBufferOperation,
4481    session: UserSession,
4482) -> Result<()> {
4483    let buffer_id = BufferId::from_proto(request.buffer_id);
4484    session
4485        .db()
4486        .await
4487        .observe_buffer_version(
4488            buffer_id,
4489            session.user_id(),
4490            request.epoch as i32,
4491            &request.version,
4492        )
4493        .await?;
4494    Ok(())
4495}
4496
4497struct CompleteWithLanguageModelRateLimit;
4498
4499impl RateLimit for CompleteWithLanguageModelRateLimit {
4500    fn capacity() -> usize {
4501        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4502            .ok()
4503            .and_then(|v| v.parse().ok())
4504            .unwrap_or(120) // Picked arbitrarily
4505    }
4506
4507    fn refill_duration() -> chrono::Duration {
4508        chrono::Duration::hours(1)
4509    }
4510
4511    fn db_name() -> &'static str {
4512        "complete-with-language-model"
4513    }
4514}
4515
4516async fn complete_with_language_model(
4517    request: proto::CompleteWithLanguageModel,
4518    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4519    session: Session,
4520    open_ai_api_key: Option<Arc<str>>,
4521    google_ai_api_key: Option<Arc<str>>,
4522    anthropic_api_key: Option<Arc<str>>,
4523) -> Result<()> {
4524    let Some(session) = session.for_user() else {
4525        return Err(anyhow!("user not found"))?;
4526    };
4527    authorize_access_to_language_models(&session).await?;
4528    session
4529        .rate_limiter
4530        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4531        .await?;
4532
4533    if request.model.starts_with("gpt") {
4534        let api_key =
4535            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4536        complete_with_open_ai(request, response, session, api_key).await?;
4537    } else if request.model.starts_with("gemini") {
4538        let api_key = google_ai_api_key
4539            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4540        complete_with_google_ai(request, response, session, api_key).await?;
4541    } else if request.model.starts_with("claude") {
4542        let api_key = anthropic_api_key
4543            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4544        complete_with_anthropic(request, response, session, api_key).await?;
4545    }
4546
4547    Ok(())
4548}
4549
4550async fn complete_with_open_ai(
4551    request: proto::CompleteWithLanguageModel,
4552    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4553    session: UserSession,
4554    api_key: Arc<str>,
4555) -> Result<()> {
4556    let mut completion_stream = open_ai::stream_completion(
4557        session.http_client.as_ref(),
4558        OPEN_AI_API_URL,
4559        &api_key,
4560        crate::ai::language_model_request_to_open_ai(request)?,
4561        None,
4562    )
4563    .await
4564    .context("open_ai::stream_completion request failed within collab")?;
4565
4566    while let Some(event) = completion_stream.next().await {
4567        let event = event?;
4568        response.send(proto::LanguageModelResponse {
4569            choices: event
4570                .choices
4571                .into_iter()
4572                .map(|choice| proto::LanguageModelChoiceDelta {
4573                    index: choice.index,
4574                    delta: Some(proto::LanguageModelResponseMessage {
4575                        role: choice.delta.role.map(|role| match role {
4576                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4577                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4578                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4579                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4580                        } as i32),
4581                        content: choice.delta.content,
4582                        tool_calls: choice
4583                            .delta
4584                            .tool_calls
4585                            .unwrap_or_default()
4586                            .into_iter()
4587                            .map(|delta| proto::ToolCallDelta {
4588                                index: delta.index as u32,
4589                                id: delta.id,
4590                                variant: match delta.function {
4591                                    Some(function) => {
4592                                        let name = function.name;
4593                                        let arguments = function.arguments;
4594
4595                                        Some(proto::tool_call_delta::Variant::Function(
4596                                            proto::tool_call_delta::FunctionCallDelta {
4597                                                name,
4598                                                arguments,
4599                                            },
4600                                        ))
4601                                    }
4602                                    None => None,
4603                                },
4604                            })
4605                            .collect(),
4606                    }),
4607                    finish_reason: choice.finish_reason,
4608                })
4609                .collect(),
4610        })?;
4611    }
4612
4613    Ok(())
4614}
4615
4616async fn complete_with_google_ai(
4617    request: proto::CompleteWithLanguageModel,
4618    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4619    session: UserSession,
4620    api_key: Arc<str>,
4621) -> Result<()> {
4622    let mut stream = google_ai::stream_generate_content(
4623        session.http_client.clone(),
4624        google_ai::API_URL,
4625        api_key.as_ref(),
4626        &request.model.clone(),
4627        crate::ai::language_model_request_to_google_ai(request)?,
4628    )
4629    .await
4630    .context("google_ai::stream_generate_content request failed")?;
4631
4632    while let Some(event) = stream.next().await {
4633        let event = event?;
4634        response.send(proto::LanguageModelResponse {
4635            choices: event
4636                .candidates
4637                .unwrap_or_default()
4638                .into_iter()
4639                .map(|candidate| proto::LanguageModelChoiceDelta {
4640                    index: candidate.index as u32,
4641                    delta: Some(proto::LanguageModelResponseMessage {
4642                        role: Some(match candidate.content.role {
4643                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4644                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4645                        } as i32),
4646                        content: Some(
4647                            candidate
4648                                .content
4649                                .parts
4650                                .into_iter()
4651                                .filter_map(|part| match part {
4652                                    google_ai::Part::TextPart(part) => Some(part.text),
4653                                    google_ai::Part::InlineDataPart(_) => None,
4654                                })
4655                                .collect(),
4656                        ),
4657                        // Tool calls are not supported for Google
4658                        tool_calls: Vec::new(),
4659                    }),
4660                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4661                })
4662                .collect(),
4663        })?;
4664    }
4665
4666    Ok(())
4667}
4668
4669async fn complete_with_anthropic(
4670    request: proto::CompleteWithLanguageModel,
4671    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4672    session: UserSession,
4673    api_key: Arc<str>,
4674) -> Result<()> {
4675    let model = anthropic::Model::from_id(&request.model)?;
4676
4677    let mut system_message = String::new();
4678    let messages = request
4679        .messages
4680        .into_iter()
4681        .filter_map(|message| {
4682            match message.role() {
4683                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4684                    role: anthropic::Role::User,
4685                    content: message.content,
4686                }),
4687                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4688                    role: anthropic::Role::Assistant,
4689                    content: message.content,
4690                }),
4691                // Anthropic's API breaks system instructions out as a separate field rather
4692                // than having a system message role.
4693                LanguageModelRole::LanguageModelSystem => {
4694                    if !system_message.is_empty() {
4695                        system_message.push_str("\n\n");
4696                    }
4697                    system_message.push_str(&message.content);
4698
4699                    None
4700                }
4701                // We don't yet support tool calls for Anthropic
4702                LanguageModelRole::LanguageModelTool => None,
4703            }
4704        })
4705        .collect();
4706
4707    let mut stream = anthropic::stream_completion(
4708        session.http_client.as_ref(),
4709        anthropic::ANTHROPIC_API_URL,
4710        &api_key,
4711        anthropic::Request {
4712            model,
4713            messages,
4714            stream: true,
4715            system: system_message,
4716            max_tokens: 4092,
4717        },
4718        None,
4719    )
4720    .await?;
4721
4722    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4723
4724    while let Some(event) = stream.next().await {
4725        let event = event?;
4726
4727        match event {
4728            anthropic::ResponseEvent::MessageStart { message } => {
4729                if let Some(role) = message.role {
4730                    if role == "assistant" {
4731                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4732                    } else if role == "user" {
4733                        current_role = proto::LanguageModelRole::LanguageModelUser;
4734                    }
4735                }
4736            }
4737            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4738                match content_block {
4739                    anthropic::ContentBlock::Text { text } => {
4740                        if !text.is_empty() {
4741                            response.send(proto::LanguageModelResponse {
4742                                choices: vec![proto::LanguageModelChoiceDelta {
4743                                    index: 0,
4744                                    delta: Some(proto::LanguageModelResponseMessage {
4745                                        role: Some(current_role as i32),
4746                                        content: Some(text),
4747                                        tool_calls: Vec::new(),
4748                                    }),
4749                                    finish_reason: None,
4750                                }],
4751                            })?;
4752                        }
4753                    }
4754                }
4755            }
4756            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4757                anthropic::TextDelta::TextDelta { text } => {
4758                    response.send(proto::LanguageModelResponse {
4759                        choices: vec![proto::LanguageModelChoiceDelta {
4760                            index: 0,
4761                            delta: Some(proto::LanguageModelResponseMessage {
4762                                role: Some(current_role as i32),
4763                                content: Some(text),
4764                                tool_calls: Vec::new(),
4765                            }),
4766                            finish_reason: None,
4767                        }],
4768                    })?;
4769                }
4770            },
4771            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4772                if let Some(stop_reason) = delta.stop_reason {
4773                    response.send(proto::LanguageModelResponse {
4774                        choices: vec![proto::LanguageModelChoiceDelta {
4775                            index: 0,
4776                            delta: None,
4777                            finish_reason: Some(stop_reason),
4778                        }],
4779                    })?;
4780                }
4781            }
4782            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4783            anthropic::ResponseEvent::MessageStop {} => {}
4784            anthropic::ResponseEvent::Ping {} => {}
4785        }
4786    }
4787
4788    Ok(())
4789}
4790
4791struct CountTokensWithLanguageModelRateLimit;
4792
4793impl RateLimit for CountTokensWithLanguageModelRateLimit {
4794    fn capacity() -> usize {
4795        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4796            .ok()
4797            .and_then(|v| v.parse().ok())
4798            .unwrap_or(600) // Picked arbitrarily
4799    }
4800
4801    fn refill_duration() -> chrono::Duration {
4802        chrono::Duration::hours(1)
4803    }
4804
4805    fn db_name() -> &'static str {
4806        "count-tokens-with-language-model"
4807    }
4808}
4809
4810async fn count_tokens_with_language_model(
4811    request: proto::CountTokensWithLanguageModel,
4812    response: Response<proto::CountTokensWithLanguageModel>,
4813    session: UserSession,
4814    google_ai_api_key: Option<Arc<str>>,
4815) -> Result<()> {
4816    authorize_access_to_language_models(&session).await?;
4817
4818    if !request.model.starts_with("gemini") {
4819        return Err(anyhow!(
4820            "counting tokens for model: {:?} is not supported",
4821            request.model
4822        ))?;
4823    }
4824
4825    session
4826        .rate_limiter
4827        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4828        .await?;
4829
4830    let api_key = google_ai_api_key
4831        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4832    let tokens_response = google_ai::count_tokens(
4833        session.http_client.as_ref(),
4834        google_ai::API_URL,
4835        &api_key,
4836        crate::ai::count_tokens_request_to_google_ai(request)?,
4837    )
4838    .await?;
4839    response.send(proto::CountTokensResponse {
4840        token_count: tokens_response.total_tokens as u32,
4841    })?;
4842    Ok(())
4843}
4844
4845struct ComputeEmbeddingsRateLimit;
4846
4847impl RateLimit for ComputeEmbeddingsRateLimit {
4848    fn capacity() -> usize {
4849        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4850            .ok()
4851            .and_then(|v| v.parse().ok())
4852            .unwrap_or(5000) // Picked arbitrarily
4853    }
4854
4855    fn refill_duration() -> chrono::Duration {
4856        chrono::Duration::hours(1)
4857    }
4858
4859    fn db_name() -> &'static str {
4860        "compute-embeddings"
4861    }
4862}
4863
4864async fn compute_embeddings(
4865    request: proto::ComputeEmbeddings,
4866    response: Response<proto::ComputeEmbeddings>,
4867    session: UserSession,
4868    api_key: Option<Arc<str>>,
4869) -> Result<()> {
4870    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4871    authorize_access_to_language_models(&session).await?;
4872
4873    session
4874        .rate_limiter
4875        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4876        .await?;
4877
4878    let embeddings = match request.model.as_str() {
4879        "openai/text-embedding-3-small" => {
4880            open_ai::embed(
4881                session.http_client.as_ref(),
4882                OPEN_AI_API_URL,
4883                &api_key,
4884                OpenAiEmbeddingModel::TextEmbedding3Small,
4885                request.texts.iter().map(|text| text.as_str()),
4886            )
4887            .await?
4888        }
4889        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4890    };
4891
4892    let embeddings = request
4893        .texts
4894        .iter()
4895        .map(|text| {
4896            let mut hasher = sha2::Sha256::new();
4897            hasher.update(text.as_bytes());
4898            let result = hasher.finalize();
4899            result.to_vec()
4900        })
4901        .zip(
4902            embeddings
4903                .data
4904                .into_iter()
4905                .map(|embedding| embedding.embedding),
4906        )
4907        .collect::<HashMap<_, _>>();
4908
4909    let db = session.db().await;
4910    db.save_embeddings(&request.model, &embeddings)
4911        .await
4912        .context("failed to save embeddings")
4913        .trace_err();
4914
4915    response.send(proto::ComputeEmbeddingsResponse {
4916        embeddings: embeddings
4917            .into_iter()
4918            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4919            .collect(),
4920    })?;
4921    Ok(())
4922}
4923
4924async fn get_cached_embeddings(
4925    request: proto::GetCachedEmbeddings,
4926    response: Response<proto::GetCachedEmbeddings>,
4927    session: UserSession,
4928) -> Result<()> {
4929    authorize_access_to_language_models(&session).await?;
4930
4931    let db = session.db().await;
4932    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4933
4934    response.send(proto::GetCachedEmbeddingsResponse {
4935        embeddings: embeddings
4936            .into_iter()
4937            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4938            .collect(),
4939    })?;
4940    Ok(())
4941}
4942
4943async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4944    let db = session.db().await;
4945    let flags = db.get_user_flags(session.user_id()).await?;
4946    if flags.iter().any(|flag| flag == "language-models") {
4947        Ok(())
4948    } else {
4949        Err(anyhow!("permission denied"))?
4950    }
4951}
4952
4953/// Get a Supermaven API key for the user
4954async fn get_supermaven_api_key(
4955    _request: proto::GetSupermavenApiKey,
4956    response: Response<proto::GetSupermavenApiKey>,
4957    session: UserSession,
4958) -> Result<()> {
4959    let user_id: String = session.user_id().to_string();
4960    if !session.is_staff() {
4961        return Err(anyhow!("supermaven not enabled for this account"))?;
4962    }
4963
4964    let email = session
4965        .email()
4966        .ok_or_else(|| anyhow!("user must have an email"))?;
4967
4968    let supermaven_admin_api = session
4969        .supermaven_client
4970        .as_ref()
4971        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4972
4973    let result = supermaven_admin_api
4974        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4975        .await?;
4976
4977    response.send(proto::GetSupermavenApiKeyResponse {
4978        api_key: result.api_key,
4979    })?;
4980
4981    Ok(())
4982}
4983
4984/// Start receiving chat updates for a channel
4985async fn join_channel_chat(
4986    request: proto::JoinChannelChat,
4987    response: Response<proto::JoinChannelChat>,
4988    session: UserSession,
4989) -> Result<()> {
4990    let channel_id = ChannelId::from_proto(request.channel_id);
4991
4992    let db = session.db().await;
4993    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4994        .await?;
4995    let messages = db
4996        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4997        .await?;
4998    response.send(proto::JoinChannelChatResponse {
4999        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5000        messages,
5001    })?;
5002    Ok(())
5003}
5004
5005/// Stop receiving chat updates for a channel
5006async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5007    let channel_id = ChannelId::from_proto(request.channel_id);
5008    session
5009        .db()
5010        .await
5011        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5012        .await?;
5013    Ok(())
5014}
5015
5016/// Retrieve the chat history for a channel
5017async fn get_channel_messages(
5018    request: proto::GetChannelMessages,
5019    response: Response<proto::GetChannelMessages>,
5020    session: UserSession,
5021) -> Result<()> {
5022    let channel_id = ChannelId::from_proto(request.channel_id);
5023    let messages = session
5024        .db()
5025        .await
5026        .get_channel_messages(
5027            channel_id,
5028            session.user_id(),
5029            MESSAGE_COUNT_PER_PAGE,
5030            Some(MessageId::from_proto(request.before_message_id)),
5031        )
5032        .await?;
5033    response.send(proto::GetChannelMessagesResponse {
5034        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5035        messages,
5036    })?;
5037    Ok(())
5038}
5039
5040/// Retrieve specific chat messages
5041async fn get_channel_messages_by_id(
5042    request: proto::GetChannelMessagesById,
5043    response: Response<proto::GetChannelMessagesById>,
5044    session: UserSession,
5045) -> Result<()> {
5046    let message_ids = request
5047        .message_ids
5048        .iter()
5049        .map(|id| MessageId::from_proto(*id))
5050        .collect::<Vec<_>>();
5051    let messages = session
5052        .db()
5053        .await
5054        .get_channel_messages_by_id(session.user_id(), &message_ids)
5055        .await?;
5056    response.send(proto::GetChannelMessagesResponse {
5057        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5058        messages,
5059    })?;
5060    Ok(())
5061}
5062
5063/// Retrieve the current users notifications
5064async fn get_notifications(
5065    request: proto::GetNotifications,
5066    response: Response<proto::GetNotifications>,
5067    session: UserSession,
5068) -> Result<()> {
5069    let notifications = session
5070        .db()
5071        .await
5072        .get_notifications(
5073            session.user_id(),
5074            NOTIFICATION_COUNT_PER_PAGE,
5075            request
5076                .before_id
5077                .map(|id| db::NotificationId::from_proto(id)),
5078        )
5079        .await?;
5080    response.send(proto::GetNotificationsResponse {
5081        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5082        notifications,
5083    })?;
5084    Ok(())
5085}
5086
5087/// Mark notifications as read
5088async fn mark_notification_as_read(
5089    request: proto::MarkNotificationRead,
5090    response: Response<proto::MarkNotificationRead>,
5091    session: UserSession,
5092) -> Result<()> {
5093    let database = &session.db().await;
5094    let notifications = database
5095        .mark_notification_as_read_by_id(
5096            session.user_id(),
5097            NotificationId::from_proto(request.notification_id),
5098        )
5099        .await?;
5100    send_notifications(
5101        &*session.connection_pool().await,
5102        &session.peer,
5103        notifications,
5104    );
5105    response.send(proto::Ack {})?;
5106    Ok(())
5107}
5108
5109/// Get the current users information
5110async fn get_private_user_info(
5111    _request: proto::GetPrivateUserInfo,
5112    response: Response<proto::GetPrivateUserInfo>,
5113    session: UserSession,
5114) -> Result<()> {
5115    let db = session.db().await;
5116
5117    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5118    let user = db
5119        .get_user_by_id(session.user_id())
5120        .await?
5121        .ok_or_else(|| anyhow!("user not found"))?;
5122    let flags = db.get_user_flags(session.user_id()).await?;
5123
5124    response.send(proto::GetPrivateUserInfoResponse {
5125        metrics_id,
5126        staff: user.admin,
5127        flags,
5128    })?;
5129    Ok(())
5130}
5131
5132fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5133    match message {
5134        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5135        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5136        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5137        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5138        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5139            code: frame.code.into(),
5140            reason: frame.reason,
5141        })),
5142    }
5143}
5144
5145fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5146    match message {
5147        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5148        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5149        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5150        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5151        AxumMessage::Close(frame) => {
5152            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5153                code: frame.code.into(),
5154                reason: frame.reason,
5155            }))
5156        }
5157    }
5158}
5159
5160fn notify_membership_updated(
5161    connection_pool: &mut ConnectionPool,
5162    result: MembershipUpdated,
5163    user_id: UserId,
5164    peer: &Peer,
5165) {
5166    for membership in &result.new_channels.channel_memberships {
5167        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5168    }
5169    for channel_id in &result.removed_channels {
5170        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5171    }
5172
5173    let user_channels_update = proto::UpdateUserChannels {
5174        channel_memberships: result
5175            .new_channels
5176            .channel_memberships
5177            .iter()
5178            .map(|cm| proto::ChannelMembership {
5179                channel_id: cm.channel_id.to_proto(),
5180                role: cm.role.into(),
5181            })
5182            .collect(),
5183        ..Default::default()
5184    };
5185
5186    let mut update = build_channels_update(result.new_channels);
5187    update.delete_channels = result
5188        .removed_channels
5189        .into_iter()
5190        .map(|id| id.to_proto())
5191        .collect();
5192    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5193
5194    for connection_id in connection_pool.user_connection_ids(user_id) {
5195        peer.send(connection_id, user_channels_update.clone())
5196            .trace_err();
5197        peer.send(connection_id, update.clone()).trace_err();
5198    }
5199}
5200
5201fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5202    proto::UpdateUserChannels {
5203        channel_memberships: channels
5204            .channel_memberships
5205            .iter()
5206            .map(|m| proto::ChannelMembership {
5207                channel_id: m.channel_id.to_proto(),
5208                role: m.role.into(),
5209            })
5210            .collect(),
5211        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5212        observed_channel_message_id: channels.observed_channel_messages.clone(),
5213    }
5214}
5215
5216fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5217    let mut update = proto::UpdateChannels::default();
5218
5219    for channel in channels.channels {
5220        update.channels.push(channel.to_proto());
5221    }
5222
5223    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5224    update.latest_channel_message_ids = channels.latest_channel_messages;
5225
5226    for (channel_id, participants) in channels.channel_participants {
5227        update
5228            .channel_participants
5229            .push(proto::ChannelParticipants {
5230                channel_id: channel_id.to_proto(),
5231                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5232            });
5233    }
5234
5235    for channel in channels.invited_channels {
5236        update.channel_invitations.push(channel.to_proto());
5237    }
5238
5239    update.hosted_projects = channels.hosted_projects;
5240    update
5241}
5242
5243fn build_initial_contacts_update(
5244    contacts: Vec<db::Contact>,
5245    pool: &ConnectionPool,
5246) -> proto::UpdateContacts {
5247    let mut update = proto::UpdateContacts::default();
5248
5249    for contact in contacts {
5250        match contact {
5251            db::Contact::Accepted { user_id, busy } => {
5252                update.contacts.push(contact_for_user(user_id, busy, &pool));
5253            }
5254            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5255            db::Contact::Incoming { user_id } => {
5256                update
5257                    .incoming_requests
5258                    .push(proto::IncomingContactRequest {
5259                        requester_id: user_id.to_proto(),
5260                    })
5261            }
5262        }
5263    }
5264
5265    update
5266}
5267
5268fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5269    proto::Contact {
5270        user_id: user_id.to_proto(),
5271        online: pool.is_user_online(user_id),
5272        busy,
5273    }
5274}
5275
5276fn room_updated(room: &proto::Room, peer: &Peer) {
5277    broadcast(
5278        None,
5279        room.participants
5280            .iter()
5281            .filter_map(|participant| Some(participant.peer_id?.into())),
5282        |peer_id| {
5283            peer.send(
5284                peer_id,
5285                proto::RoomUpdated {
5286                    room: Some(room.clone()),
5287                },
5288            )
5289        },
5290    );
5291}
5292
5293fn channel_updated(
5294    channel: &db::channel::Model,
5295    room: &proto::Room,
5296    peer: &Peer,
5297    pool: &ConnectionPool,
5298) {
5299    let participants = room
5300        .participants
5301        .iter()
5302        .map(|p| p.user_id)
5303        .collect::<Vec<_>>();
5304
5305    broadcast(
5306        None,
5307        pool.channel_connection_ids(channel.root_id())
5308            .filter_map(|(channel_id, role)| {
5309                role.can_see_channel(channel.visibility).then(|| channel_id)
5310            }),
5311        |peer_id| {
5312            peer.send(
5313                peer_id,
5314                proto::UpdateChannels {
5315                    channel_participants: vec![proto::ChannelParticipants {
5316                        channel_id: channel.id.to_proto(),
5317                        participant_user_ids: participants.clone(),
5318                    }],
5319                    ..Default::default()
5320                },
5321            )
5322        },
5323    );
5324}
5325
5326async fn send_dev_server_projects_update(
5327    user_id: UserId,
5328    mut status: proto::DevServerProjectsUpdate,
5329    session: &Session,
5330) {
5331    let pool = session.connection_pool().await;
5332    for dev_server in &mut status.dev_servers {
5333        dev_server.status =
5334            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5335    }
5336    let connections = pool.user_connection_ids(user_id);
5337    for connection_id in connections {
5338        session.peer.send(connection_id, status.clone()).trace_err();
5339    }
5340}
5341
5342async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5343    let db = session.db().await;
5344
5345    let contacts = db.get_contacts(user_id).await?;
5346    let busy = db.is_user_busy(user_id).await?;
5347
5348    let pool = session.connection_pool().await;
5349    let updated_contact = contact_for_user(user_id, busy, &pool);
5350    for contact in contacts {
5351        if let db::Contact::Accepted {
5352            user_id: contact_user_id,
5353            ..
5354        } = contact
5355        {
5356            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5357                session
5358                    .peer
5359                    .send(
5360                        contact_conn_id,
5361                        proto::UpdateContacts {
5362                            contacts: vec![updated_contact.clone()],
5363                            remove_contacts: Default::default(),
5364                            incoming_requests: Default::default(),
5365                            remove_incoming_requests: Default::default(),
5366                            outgoing_requests: Default::default(),
5367                            remove_outgoing_requests: Default::default(),
5368                        },
5369                    )
5370                    .trace_err();
5371            }
5372        }
5373    }
5374    Ok(())
5375}
5376
5377async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5378    log::info!("lost dev server connection, unsharing projects");
5379    let project_ids = session
5380        .db()
5381        .await
5382        .get_stale_dev_server_projects(session.connection_id)
5383        .await?;
5384
5385    for project_id in project_ids {
5386        // not unshare re-checks the connection ids match, so we get away with no transaction
5387        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5388    }
5389
5390    let user_id = session.dev_server().user_id;
5391    let update = session
5392        .db()
5393        .await
5394        .dev_server_projects_update(user_id)
5395        .await?;
5396
5397    send_dev_server_projects_update(user_id, update, session).await;
5398
5399    Ok(())
5400}
5401
5402async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5403    let mut contacts_to_update = HashSet::default();
5404
5405    let room_id;
5406    let canceled_calls_to_user_ids;
5407    let live_kit_room;
5408    let delete_live_kit_room;
5409    let room;
5410    let channel;
5411
5412    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5413        contacts_to_update.insert(session.user_id());
5414
5415        for project in left_room.left_projects.values() {
5416            project_left(project, session);
5417        }
5418
5419        room_id = RoomId::from_proto(left_room.room.id);
5420        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5421        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5422        delete_live_kit_room = left_room.deleted;
5423        room = mem::take(&mut left_room.room);
5424        channel = mem::take(&mut left_room.channel);
5425
5426        room_updated(&room, &session.peer);
5427    } else {
5428        return Ok(());
5429    }
5430
5431    if let Some(channel) = channel {
5432        channel_updated(
5433            &channel,
5434            &room,
5435            &session.peer,
5436            &*session.connection_pool().await,
5437        );
5438    }
5439
5440    {
5441        let pool = session.connection_pool().await;
5442        for canceled_user_id in canceled_calls_to_user_ids {
5443            for connection_id in pool.user_connection_ids(canceled_user_id) {
5444                session
5445                    .peer
5446                    .send(
5447                        connection_id,
5448                        proto::CallCanceled {
5449                            room_id: room_id.to_proto(),
5450                        },
5451                    )
5452                    .trace_err();
5453            }
5454            contacts_to_update.insert(canceled_user_id);
5455        }
5456    }
5457
5458    for contact_user_id in contacts_to_update {
5459        update_user_contacts(contact_user_id, &session).await?;
5460    }
5461
5462    if let Some(live_kit) = session.live_kit_client.as_ref() {
5463        live_kit
5464            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5465            .await
5466            .trace_err();
5467
5468        if delete_live_kit_room {
5469            live_kit.delete_room(live_kit_room).await.trace_err();
5470        }
5471    }
5472
5473    Ok(())
5474}
5475
5476async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5477    let left_channel_buffers = session
5478        .db()
5479        .await
5480        .leave_channel_buffers(session.connection_id)
5481        .await?;
5482
5483    for left_buffer in left_channel_buffers {
5484        channel_buffer_updated(
5485            session.connection_id,
5486            left_buffer.connections,
5487            &proto::UpdateChannelBufferCollaborators {
5488                channel_id: left_buffer.channel_id.to_proto(),
5489                collaborators: left_buffer.collaborators,
5490            },
5491            &session.peer,
5492        );
5493    }
5494
5495    Ok(())
5496}
5497
5498fn project_left(project: &db::LeftProject, session: &UserSession) {
5499    for connection_id in &project.connection_ids {
5500        if project.should_unshare {
5501            session
5502                .peer
5503                .send(
5504                    *connection_id,
5505                    proto::UnshareProject {
5506                        project_id: project.id.to_proto(),
5507                    },
5508                )
5509                .trace_err();
5510        } else {
5511            session
5512                .peer
5513                .send(
5514                    *connection_id,
5515                    proto::RemoveProjectCollaborator {
5516                        project_id: project.id.to_proto(),
5517                        peer_id: Some(session.connection_id.into()),
5518                    },
5519                )
5520                .trace_err();
5521        }
5522    }
5523}
5524
5525pub trait ResultExt {
5526    type Ok;
5527
5528    fn trace_err(self) -> Option<Self::Ok>;
5529}
5530
5531impl<T, E> ResultExt for Result<T, E>
5532where
5533    E: std::fmt::Debug,
5534{
5535    type Ok = T;
5536
5537    #[track_caller]
5538    fn trace_err(self) -> Option<T> {
5539        match self {
5540            Ok(value) => Some(value),
5541            Err(error) => {
5542                tracing::error!("{:?}", error);
5543                None
5544            }
5545        }
5546    }
5547}