rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, bail, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http_client::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  50        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(update_dev_server_project))
 435            .add_request_handler(user_handler(delete_dev_server_project))
 436            .add_request_handler(user_handler(create_dev_server))
 437            .add_request_handler(user_handler(regenerate_dev_server_token))
 438            .add_request_handler(user_handler(rename_dev_server))
 439            .add_request_handler(user_handler(delete_dev_server))
 440            .add_request_handler(user_handler(list_remote_directory))
 441            .add_request_handler(dev_server_handler(share_dev_server_project))
 442            .add_request_handler(dev_server_handler(shutdown_dev_server))
 443            .add_request_handler(dev_server_handler(reconnect_dev_server))
 444            .add_message_handler(user_message_handler(leave_project))
 445            .add_request_handler(update_project)
 446            .add_request_handler(update_worktree)
 447            .add_message_handler(start_language_server)
 448            .add_message_handler(update_language_server)
 449            .add_message_handler(update_diagnostic_summary)
 450            .add_message_handler(update_worktree_settings)
 451            .add_request_handler(user_handler(
 452                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 453            ))
 454            .add_request_handler(user_handler(
 455                forward_project_request_for_owner::<proto::TaskTemplates>,
 456            ))
 457            .add_request_handler(user_handler(
 458                forward_read_only_project_request::<proto::GetHover>,
 459            ))
 460            .add_request_handler(user_handler(
 461                forward_read_only_project_request::<proto::GetDefinition>,
 462            ))
 463            .add_request_handler(user_handler(
 464                forward_read_only_project_request::<proto::GetTypeDefinition>,
 465            ))
 466            .add_request_handler(user_handler(
 467                forward_read_only_project_request::<proto::GetReferences>,
 468            ))
 469            .add_request_handler(user_handler(
 470                forward_read_only_project_request::<proto::SearchProject>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_read_only_project_request::<proto::GetProjectSymbols>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_read_only_project_request::<proto::OpenBufferById>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_read_only_project_request::<proto::InlayHints>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_read_only_project_request::<proto::OpenBufferByPath>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_mutating_project_request::<proto::GetCompletions>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::GetCodeActions>,
 507            ))
 508            .add_request_handler(user_handler(
 509                forward_mutating_project_request::<proto::ApplyCodeAction>,
 510            ))
 511            .add_request_handler(user_handler(
 512                forward_mutating_project_request::<proto::PrepareRename>,
 513            ))
 514            .add_request_handler(user_handler(
 515                forward_mutating_project_request::<proto::PerformRename>,
 516            ))
 517            .add_request_handler(user_handler(
 518                forward_mutating_project_request::<proto::ReloadBuffers>,
 519            ))
 520            .add_request_handler(user_handler(
 521                forward_mutating_project_request::<proto::FormatBuffers>,
 522            ))
 523            .add_request_handler(user_handler(
 524                forward_mutating_project_request::<proto::CreateProjectEntry>,
 525            ))
 526            .add_request_handler(user_handler(
 527                forward_mutating_project_request::<proto::RenameProjectEntry>,
 528            ))
 529            .add_request_handler(user_handler(
 530                forward_mutating_project_request::<proto::CopyProjectEntry>,
 531            ))
 532            .add_request_handler(user_handler(
 533                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 534            ))
 535            .add_request_handler(user_handler(
 536                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 537            ))
 538            .add_request_handler(user_handler(
 539                forward_mutating_project_request::<proto::OnTypeFormatting>,
 540            ))
 541            .add_request_handler(user_handler(
 542                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 543            ))
 544            .add_request_handler(user_handler(
 545                forward_mutating_project_request::<proto::BlameBuffer>,
 546            ))
 547            .add_request_handler(user_handler(
 548                forward_mutating_project_request::<proto::MultiLspQuery>,
 549            ))
 550            .add_request_handler(user_handler(
 551                forward_mutating_project_request::<proto::RestartLanguageServers>,
 552            ))
 553            .add_request_handler(user_handler(
 554                forward_mutating_project_request::<proto::LinkedEditingRange>,
 555            ))
 556            .add_message_handler(create_buffer_for_peer)
 557            .add_request_handler(update_buffer)
 558            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 559            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 560            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 561            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 562            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 563            .add_request_handler(get_users)
 564            .add_request_handler(user_handler(fuzzy_search_users))
 565            .add_request_handler(user_handler(request_contact))
 566            .add_request_handler(user_handler(remove_contact))
 567            .add_request_handler(user_handler(respond_to_contact_request))
 568            .add_message_handler(subscribe_to_channels)
 569            .add_request_handler(user_handler(create_channel))
 570            .add_request_handler(user_handler(delete_channel))
 571            .add_request_handler(user_handler(invite_channel_member))
 572            .add_request_handler(user_handler(remove_channel_member))
 573            .add_request_handler(user_handler(set_channel_member_role))
 574            .add_request_handler(user_handler(set_channel_visibility))
 575            .add_request_handler(user_handler(rename_channel))
 576            .add_request_handler(user_handler(join_channel_buffer))
 577            .add_request_handler(user_handler(leave_channel_buffer))
 578            .add_message_handler(user_message_handler(update_channel_buffer))
 579            .add_request_handler(user_handler(rejoin_channel_buffers))
 580            .add_request_handler(user_handler(get_channel_members))
 581            .add_request_handler(user_handler(respond_to_channel_invite))
 582            .add_request_handler(user_handler(join_channel))
 583            .add_request_handler(user_handler(join_channel_chat))
 584            .add_message_handler(user_message_handler(leave_channel_chat))
 585            .add_request_handler(user_handler(send_channel_message))
 586            .add_request_handler(user_handler(remove_channel_message))
 587            .add_request_handler(user_handler(update_channel_message))
 588            .add_request_handler(user_handler(get_channel_messages))
 589            .add_request_handler(user_handler(get_channel_messages_by_id))
 590            .add_request_handler(user_handler(get_notifications))
 591            .add_request_handler(user_handler(mark_notification_as_read))
 592            .add_request_handler(user_handler(move_channel))
 593            .add_request_handler(user_handler(follow))
 594            .add_message_handler(user_message_handler(unfollow))
 595            .add_message_handler(user_message_handler(update_followers))
 596            .add_request_handler(user_handler(get_private_user_info))
 597            .add_message_handler(user_message_handler(acknowledge_channel_message))
 598            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 599            .add_request_handler(user_handler(get_supermaven_api_key))
 600            .add_request_handler(user_handler(
 601                forward_mutating_project_request::<proto::OpenContext>,
 602            ))
 603            .add_request_handler(user_handler(
 604                forward_mutating_project_request::<proto::SynchronizeContexts>,
 605            ))
 606            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 607            .add_message_handler(update_context)
 608            .add_streaming_request_handler({
 609                let app_state = app_state.clone();
 610                move |request, response, session| {
 611                    complete_with_language_model(
 612                        request,
 613                        response,
 614                        session,
 615                        app_state.config.openai_api_key.clone(),
 616                        app_state.config.google_ai_api_key.clone(),
 617                        app_state.config.anthropic_api_key.clone(),
 618                    )
 619                }
 620            })
 621            .add_request_handler({
 622                let app_state = app_state.clone();
 623                user_handler(move |request, response, session| {
 624                    count_tokens_with_language_model(
 625                        request,
 626                        response,
 627                        session,
 628                        app_state.config.google_ai_api_key.clone(),
 629                    )
 630                })
 631            })
 632            .add_request_handler({
 633                user_handler(move |request, response, session| {
 634                    get_cached_embeddings(request, response, session)
 635                })
 636            })
 637            .add_request_handler({
 638                let app_state = app_state.clone();
 639                user_handler(move |request, response, session| {
 640                    compute_embeddings(
 641                        request,
 642                        response,
 643                        session,
 644                        app_state.config.openai_api_key.clone(),
 645                    )
 646                })
 647            });
 648
 649        Arc::new(server)
 650    }
 651
 652    pub async fn start(&self) -> Result<()> {
 653        let server_id = *self.id.lock();
 654        let app_state = self.app_state.clone();
 655        let peer = self.peer.clone();
 656        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 657        let pool = self.connection_pool.clone();
 658        let live_kit_client = self.app_state.live_kit_client.clone();
 659
 660        let span = info_span!("start server");
 661        self.app_state.executor.spawn_detached(
 662            async move {
 663                tracing::info!("waiting for cleanup timeout");
 664                timeout.await;
 665                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 666                if let Some((room_ids, channel_ids)) = app_state
 667                    .db
 668                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 669                    .await
 670                    .trace_err()
 671                {
 672                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 673                    tracing::info!(
 674                        stale_channel_buffer_count = channel_ids.len(),
 675                        "retrieved stale channel buffers"
 676                    );
 677
 678                    for channel_id in channel_ids {
 679                        if let Some(refreshed_channel_buffer) = app_state
 680                            .db
 681                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 682                            .await
 683                            .trace_err()
 684                        {
 685                            for connection_id in refreshed_channel_buffer.connection_ids {
 686                                peer.send(
 687                                    connection_id,
 688                                    proto::UpdateChannelBufferCollaborators {
 689                                        channel_id: channel_id.to_proto(),
 690                                        collaborators: refreshed_channel_buffer
 691                                            .collaborators
 692                                            .clone(),
 693                                    },
 694                                )
 695                                .trace_err();
 696                            }
 697                        }
 698                    }
 699
 700                    for room_id in room_ids {
 701                        let mut contacts_to_update = HashSet::default();
 702                        let mut canceled_calls_to_user_ids = Vec::new();
 703                        let mut live_kit_room = String::new();
 704                        let mut delete_live_kit_room = false;
 705
 706                        if let Some(mut refreshed_room) = app_state
 707                            .db
 708                            .clear_stale_room_participants(room_id, server_id)
 709                            .await
 710                            .trace_err()
 711                        {
 712                            tracing::info!(
 713                                room_id = room_id.0,
 714                                new_participant_count = refreshed_room.room.participants.len(),
 715                                "refreshed room"
 716                            );
 717                            room_updated(&refreshed_room.room, &peer);
 718                            if let Some(channel) = refreshed_room.channel.as_ref() {
 719                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 720                            }
 721                            contacts_to_update
 722                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 723                            contacts_to_update
 724                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 725                            canceled_calls_to_user_ids =
 726                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 727                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 728                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 729                        }
 730
 731                        {
 732                            let pool = pool.lock();
 733                            for canceled_user_id in canceled_calls_to_user_ids {
 734                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 735                                    peer.send(
 736                                        connection_id,
 737                                        proto::CallCanceled {
 738                                            room_id: room_id.to_proto(),
 739                                        },
 740                                    )
 741                                    .trace_err();
 742                                }
 743                            }
 744                        }
 745
 746                        for user_id in contacts_to_update {
 747                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 748                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 749                            if let Some((busy, contacts)) = busy.zip(contacts) {
 750                                let pool = pool.lock();
 751                                let updated_contact = contact_for_user(user_id, busy, &pool);
 752                                for contact in contacts {
 753                                    if let db::Contact::Accepted {
 754                                        user_id: contact_user_id,
 755                                        ..
 756                                    } = contact
 757                                    {
 758                                        for contact_conn_id in
 759                                            pool.user_connection_ids(contact_user_id)
 760                                        {
 761                                            peer.send(
 762                                                contact_conn_id,
 763                                                proto::UpdateContacts {
 764                                                    contacts: vec![updated_contact.clone()],
 765                                                    remove_contacts: Default::default(),
 766                                                    incoming_requests: Default::default(),
 767                                                    remove_incoming_requests: Default::default(),
 768                                                    outgoing_requests: Default::default(),
 769                                                    remove_outgoing_requests: Default::default(),
 770                                                },
 771                                            )
 772                                            .trace_err();
 773                                        }
 774                                    }
 775                                }
 776                            }
 777                        }
 778
 779                        if let Some(live_kit) = live_kit_client.as_ref() {
 780                            if delete_live_kit_room {
 781                                live_kit.delete_room(live_kit_room).await.trace_err();
 782                            }
 783                        }
 784                    }
 785                }
 786
 787                app_state
 788                    .db
 789                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 790                    .await
 791                    .trace_err();
 792            }
 793            .instrument(span),
 794        );
 795        Ok(())
 796    }
 797
 798    pub fn teardown(&self) {
 799        self.peer.teardown();
 800        self.connection_pool.lock().reset();
 801        let _ = self.teardown.send(true);
 802    }
 803
 804    #[cfg(test)]
 805    pub fn reset(&self, id: ServerId) {
 806        self.teardown();
 807        *self.id.lock() = id;
 808        self.peer.reset(id.0 as u32);
 809        let _ = self.teardown.send(false);
 810    }
 811
 812    #[cfg(test)]
 813    pub fn id(&self) -> ServerId {
 814        *self.id.lock()
 815    }
 816
 817    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 818    where
 819        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 820        Fut: 'static + Send + Future<Output = Result<()>>,
 821        M: EnvelopedMessage,
 822    {
 823        let prev_handler = self.handlers.insert(
 824            TypeId::of::<M>(),
 825            Box::new(move |envelope, session| {
 826                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 827                let received_at = envelope.received_at;
 828                tracing::info!("message received");
 829                let start_time = Instant::now();
 830                let future = (handler)(*envelope, session);
 831                async move {
 832                    let result = future.await;
 833                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 834                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 835                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 836                    let payload_type = M::NAME;
 837
 838                    match result {
 839                        Err(error) => {
 840                            tracing::error!(
 841                                ?error,
 842                                total_duration_ms,
 843                                processing_duration_ms,
 844                                queue_duration_ms,
 845                                payload_type,
 846                                "error handling message"
 847                            )
 848                        }
 849                        Ok(()) => tracing::info!(
 850                            total_duration_ms,
 851                            processing_duration_ms,
 852                            queue_duration_ms,
 853                            "finished handling message"
 854                        ),
 855                    }
 856                }
 857                .boxed()
 858            }),
 859        );
 860        if prev_handler.is_some() {
 861            panic!("registered a handler for the same message twice");
 862        }
 863        self
 864    }
 865
 866    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 867    where
 868        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 869        Fut: 'static + Send + Future<Output = Result<()>>,
 870        M: EnvelopedMessage,
 871    {
 872        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 873        self
 874    }
 875
 876    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 877    where
 878        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 879        Fut: Send + Future<Output = Result<()>>,
 880        M: RequestMessage,
 881    {
 882        let handler = Arc::new(handler);
 883        self.add_handler(move |envelope, session| {
 884            let receipt = envelope.receipt();
 885            let handler = handler.clone();
 886            async move {
 887                let peer = session.peer.clone();
 888                let responded = Arc::new(AtomicBool::default());
 889                let response = Response {
 890                    peer: peer.clone(),
 891                    responded: responded.clone(),
 892                    receipt,
 893                };
 894                match (handler)(envelope.payload, response, session).await {
 895                    Ok(()) => {
 896                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 897                            Ok(())
 898                        } else {
 899                            Err(anyhow!("handler did not send a response"))?
 900                        }
 901                    }
 902                    Err(error) => {
 903                        let proto_err = match &error {
 904                            Error::Internal(err) => err.to_proto(),
 905                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 906                        };
 907                        peer.respond_with_error(receipt, proto_err)?;
 908                        Err(error)
 909                    }
 910                }
 911            }
 912        })
 913    }
 914
 915    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 916    where
 917        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 918        Fut: Send + Future<Output = Result<()>>,
 919        M: RequestMessage,
 920    {
 921        let handler = Arc::new(handler);
 922        self.add_handler(move |envelope, session| {
 923            let receipt = envelope.receipt();
 924            let handler = handler.clone();
 925            async move {
 926                let peer = session.peer.clone();
 927                let response = StreamingResponse {
 928                    peer: peer.clone(),
 929                    receipt,
 930                };
 931                match (handler)(envelope.payload, response, session).await {
 932                    Ok(()) => {
 933                        peer.end_stream(receipt)?;
 934                        Ok(())
 935                    }
 936                    Err(error) => {
 937                        let proto_err = match &error {
 938                            Error::Internal(err) => err.to_proto(),
 939                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 940                        };
 941                        peer.respond_with_error(receipt, proto_err)?;
 942                        Err(error)
 943                    }
 944                }
 945            }
 946        })
 947    }
 948
 949    #[allow(clippy::too_many_arguments)]
 950    pub fn handle_connection(
 951        self: &Arc<Self>,
 952        connection: Connection,
 953        address: String,
 954        principal: Principal,
 955        zed_version: ZedVersion,
 956        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 957        executor: Executor,
 958    ) -> impl Future<Output = ()> {
 959        let this = self.clone();
 960        let span = info_span!("handle connection", %address,
 961            connection_id=field::Empty,
 962            user_id=field::Empty,
 963            login=field::Empty,
 964            impersonator=field::Empty,
 965            dev_server_id=field::Empty
 966        );
 967        principal.update_span(&span);
 968
 969        let mut teardown = self.teardown.subscribe();
 970        async move {
 971            if *teardown.borrow() {
 972                tracing::error!("server is tearing down");
 973                return
 974            }
 975            let (connection_id, handle_io, mut incoming_rx) = this
 976                .peer
 977                .add_connection(connection, {
 978                    let executor = executor.clone();
 979                    move |duration| executor.sleep(duration)
 980                });
 981            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 982            tracing::info!("connection opened");
 983
 984            let http_client = match IsahcHttpClient::new() {
 985                Ok(http_client) => Arc::new(http_client),
 986                Err(error) => {
 987                    tracing::error!(?error, "failed to create HTTP client");
 988                    return;
 989                }
 990            };
 991
 992            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 993                Some(Arc::new(SupermavenAdminApi::new(
 994                    supermaven_admin_api_key.to_string(),
 995                    http_client.clone(),
 996                )))
 997            } else {
 998                None
 999            };
1000
1001            let session = Session {
1002                principal: principal.clone(),
1003                connection_id,
1004                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1005                peer: this.peer.clone(),
1006                connection_pool: this.connection_pool.clone(),
1007                live_kit_client: this.app_state.live_kit_client.clone(),
1008                http_client,
1009                rate_limiter: this.app_state.rate_limiter.clone(),
1010                _executor: executor.clone(),
1011                supermaven_client,
1012            };
1013
1014            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1015                tracing::error!(?error, "failed to send initial client update");
1016                return;
1017            }
1018
1019            let handle_io = handle_io.fuse();
1020            futures::pin_mut!(handle_io);
1021
1022            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1023            // This prevents deadlocks when e.g., client A performs a request to client B and
1024            // client B performs a request to client A. If both clients stop processing further
1025            // messages until their respective request completes, they won't have a chance to
1026            // respond to the other client's request and cause a deadlock.
1027            //
1028            // This arrangement ensures we will attempt to process earlier messages first, but fall
1029            // back to processing messages arrived later in the spirit of making progress.
1030            let mut foreground_message_handlers = FuturesUnordered::new();
1031            let concurrent_handlers = Arc::new(Semaphore::new(256));
1032            loop {
1033                let next_message = async {
1034                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1035                    let message = incoming_rx.next().await;
1036                    (permit, message)
1037                }.fuse();
1038                futures::pin_mut!(next_message);
1039                futures::select_biased! {
1040                    _ = teardown.changed().fuse() => return,
1041                    result = handle_io => {
1042                        if let Err(error) = result {
1043                            tracing::error!(?error, "error handling I/O");
1044                        }
1045                        break;
1046                    }
1047                    _ = foreground_message_handlers.next() => {}
1048                    next_message = next_message => {
1049                        let (permit, message) = next_message;
1050                        if let Some(message) = message {
1051                            let type_name = message.payload_type_name();
1052                            // note: we copy all the fields from the parent span so we can query them in the logs.
1053                            // (https://github.com/tokio-rs/tracing/issues/2670).
1054                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1055                                user_id=field::Empty,
1056                                login=field::Empty,
1057                                impersonator=field::Empty,
1058                                dev_server_id=field::Empty
1059                            );
1060                            principal.update_span(&span);
1061                            let span_enter = span.enter();
1062                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1063                                let is_background = message.is_background();
1064                                let handle_message = (handler)(message, session.clone());
1065                                drop(span_enter);
1066
1067                                let handle_message = async move {
1068                                    handle_message.await;
1069                                    drop(permit);
1070                                }.instrument(span);
1071                                if is_background {
1072                                    executor.spawn_detached(handle_message);
1073                                } else {
1074                                    foreground_message_handlers.push(handle_message);
1075                                }
1076                            } else {
1077                                tracing::error!("no message handler");
1078                            }
1079                        } else {
1080                            tracing::info!("connection closed");
1081                            break;
1082                        }
1083                    }
1084                }
1085            }
1086
1087            drop(foreground_message_handlers);
1088            tracing::info!("signing out");
1089            if let Err(error) = connection_lost(session, teardown, executor).await {
1090                tracing::error!(?error, "error signing out");
1091            }
1092
1093        }.instrument(span)
1094    }
1095
1096    async fn send_initial_client_update(
1097        &self,
1098        connection_id: ConnectionId,
1099        principal: &Principal,
1100        zed_version: ZedVersion,
1101        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1102        session: &Session,
1103    ) -> Result<()> {
1104        self.peer.send(
1105            connection_id,
1106            proto::Hello {
1107                peer_id: Some(connection_id.into()),
1108            },
1109        )?;
1110        tracing::info!("sent hello message");
1111        if let Some(send_connection_id) = send_connection_id.take() {
1112            let _ = send_connection_id.send(connection_id);
1113        }
1114
1115        match principal {
1116            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1117                if !user.connected_once {
1118                    self.peer.send(connection_id, proto::ShowContacts {})?;
1119                    self.app_state
1120                        .db
1121                        .set_user_connected_once(user.id, true)
1122                        .await?;
1123                }
1124
1125                let (contacts, dev_server_projects) = future::try_join(
1126                    self.app_state.db.get_contacts(user.id),
1127                    self.app_state.db.dev_server_projects_update(user.id),
1128                )
1129                .await?;
1130
1131                {
1132                    let mut pool = self.connection_pool.lock();
1133                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1134                    self.peer.send(
1135                        connection_id,
1136                        build_initial_contacts_update(contacts, &pool),
1137                    )?;
1138                }
1139
1140                if should_auto_subscribe_to_channels(zed_version) {
1141                    subscribe_user_to_channels(user.id, session).await?;
1142                }
1143
1144                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1145
1146                if let Some(incoming_call) =
1147                    self.app_state.db.incoming_call_for_user(user.id).await?
1148                {
1149                    self.peer.send(connection_id, incoming_call)?;
1150                }
1151
1152                update_user_contacts(user.id, &session).await?;
1153            }
1154            Principal::DevServer(dev_server) => {
1155                {
1156                    let mut pool = self.connection_pool.lock();
1157                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1158                    {
1159                        self.peer.send(
1160                            stale_connection_id,
1161                            proto::ShutdownDevServer {
1162                                reason: Some(
1163                                    "another dev server connected with the same token".to_string(),
1164                                ),
1165                            },
1166                        )?;
1167                        pool.remove_connection(stale_connection_id)?;
1168                    };
1169                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1170                }
1171
1172                let projects = self
1173                    .app_state
1174                    .db
1175                    .get_projects_for_dev_server(dev_server.id)
1176                    .await?;
1177                self.peer
1178                    .send(connection_id, proto::DevServerInstructions { projects })?;
1179
1180                let status = self
1181                    .app_state
1182                    .db
1183                    .dev_server_projects_update(dev_server.user_id)
1184                    .await?;
1185                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1186            }
1187        }
1188
1189        Ok(())
1190    }
1191
1192    pub async fn invite_code_redeemed(
1193        self: &Arc<Self>,
1194        inviter_id: UserId,
1195        invitee_id: UserId,
1196    ) -> Result<()> {
1197        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1198            if let Some(code) = &user.invite_code {
1199                let pool = self.connection_pool.lock();
1200                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1201                for connection_id in pool.user_connection_ids(inviter_id) {
1202                    self.peer.send(
1203                        connection_id,
1204                        proto::UpdateContacts {
1205                            contacts: vec![invitee_contact.clone()],
1206                            ..Default::default()
1207                        },
1208                    )?;
1209                    self.peer.send(
1210                        connection_id,
1211                        proto::UpdateInviteInfo {
1212                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1213                            count: user.invite_count as u32,
1214                        },
1215                    )?;
1216                }
1217            }
1218        }
1219        Ok(())
1220    }
1221
1222    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1223        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1224            if let Some(invite_code) = &user.invite_code {
1225                let pool = self.connection_pool.lock();
1226                for connection_id in pool.user_connection_ids(user_id) {
1227                    self.peer.send(
1228                        connection_id,
1229                        proto::UpdateInviteInfo {
1230                            url: format!(
1231                                "{}{}",
1232                                self.app_state.config.invite_link_prefix, invite_code
1233                            ),
1234                            count: user.invite_count as u32,
1235                        },
1236                    )?;
1237                }
1238            }
1239        }
1240        Ok(())
1241    }
1242
1243    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1244        ServerSnapshot {
1245            connection_pool: ConnectionPoolGuard {
1246                guard: self.connection_pool.lock(),
1247                _not_send: PhantomData,
1248            },
1249            peer: &self.peer,
1250        }
1251    }
1252}
1253
1254impl<'a> Deref for ConnectionPoolGuard<'a> {
1255    type Target = ConnectionPool;
1256
1257    fn deref(&self) -> &Self::Target {
1258        &self.guard
1259    }
1260}
1261
1262impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1263    fn deref_mut(&mut self) -> &mut Self::Target {
1264        &mut self.guard
1265    }
1266}
1267
1268impl<'a> Drop for ConnectionPoolGuard<'a> {
1269    fn drop(&mut self) {
1270        #[cfg(test)]
1271        self.check_invariants();
1272    }
1273}
1274
1275fn broadcast<F>(
1276    sender_id: Option<ConnectionId>,
1277    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1278    mut f: F,
1279) where
1280    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1281{
1282    for receiver_id in receiver_ids {
1283        if Some(receiver_id) != sender_id {
1284            if let Err(error) = f(receiver_id) {
1285                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1286            }
1287        }
1288    }
1289}
1290
1291pub struct ProtocolVersion(u32);
1292
1293impl Header for ProtocolVersion {
1294    fn name() -> &'static HeaderName {
1295        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1296        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1297    }
1298
1299    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1300    where
1301        Self: Sized,
1302        I: Iterator<Item = &'i axum::http::HeaderValue>,
1303    {
1304        let version = values
1305            .next()
1306            .ok_or_else(axum::headers::Error::invalid)?
1307            .to_str()
1308            .map_err(|_| axum::headers::Error::invalid())?
1309            .parse()
1310            .map_err(|_| axum::headers::Error::invalid())?;
1311        Ok(Self(version))
1312    }
1313
1314    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1315        values.extend([self.0.to_string().parse().unwrap()]);
1316    }
1317}
1318
1319pub struct AppVersionHeader(SemanticVersion);
1320impl Header for AppVersionHeader {
1321    fn name() -> &'static HeaderName {
1322        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1323        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1324    }
1325
1326    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1327    where
1328        Self: Sized,
1329        I: Iterator<Item = &'i axum::http::HeaderValue>,
1330    {
1331        let version = values
1332            .next()
1333            .ok_or_else(axum::headers::Error::invalid)?
1334            .to_str()
1335            .map_err(|_| axum::headers::Error::invalid())?
1336            .parse()
1337            .map_err(|_| axum::headers::Error::invalid())?;
1338        Ok(Self(version))
1339    }
1340
1341    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1342        values.extend([self.0.to_string().parse().unwrap()]);
1343    }
1344}
1345
1346pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1347    Router::new()
1348        .route("/rpc", get(handle_websocket_request))
1349        .layer(
1350            ServiceBuilder::new()
1351                .layer(Extension(server.app_state.clone()))
1352                .layer(middleware::from_fn(auth::validate_header)),
1353        )
1354        .route("/metrics", get(handle_metrics))
1355        .layer(Extension(server))
1356}
1357
1358pub async fn handle_websocket_request(
1359    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1360    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1361    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1362    Extension(server): Extension<Arc<Server>>,
1363    Extension(principal): Extension<Principal>,
1364    ws: WebSocketUpgrade,
1365) -> axum::response::Response {
1366    if protocol_version != rpc::PROTOCOL_VERSION {
1367        return (
1368            StatusCode::UPGRADE_REQUIRED,
1369            "client must be upgraded".to_string(),
1370        )
1371            .into_response();
1372    }
1373
1374    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1375        return (
1376            StatusCode::UPGRADE_REQUIRED,
1377            "no version header found".to_string(),
1378        )
1379            .into_response();
1380    };
1381
1382    if !version.can_collaborate() {
1383        return (
1384            StatusCode::UPGRADE_REQUIRED,
1385            "client must be upgraded".to_string(),
1386        )
1387            .into_response();
1388    }
1389
1390    let socket_address = socket_address.to_string();
1391    ws.on_upgrade(move |socket| {
1392        let socket = socket
1393            .map_ok(to_tungstenite_message)
1394            .err_into()
1395            .with(|message| async move {
1396                if let Some(message) = to_axum_message(message) {
1397                    Ok(message)
1398                } else {
1399                    bail!("Could not convert a tungstenite message to axum message");
1400                }
1401            });
1402        let connection = Connection::new(Box::pin(socket));
1403        async move {
1404            server
1405                .handle_connection(
1406                    connection,
1407                    socket_address,
1408                    principal,
1409                    version,
1410                    None,
1411                    Executor::Production,
1412                )
1413                .await;
1414        }
1415    })
1416}
1417
1418pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1419    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1420    let connections_metric = CONNECTIONS_METRIC
1421        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1422
1423    let connections = server
1424        .connection_pool
1425        .lock()
1426        .connections()
1427        .filter(|connection| !connection.admin)
1428        .count();
1429    connections_metric.set(connections as _);
1430
1431    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1432    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1433        register_int_gauge!(
1434            "shared_projects",
1435            "number of open projects with one or more guests"
1436        )
1437        .unwrap()
1438    });
1439
1440    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1441    shared_projects_metric.set(shared_projects as _);
1442
1443    let encoder = prometheus::TextEncoder::new();
1444    let metric_families = prometheus::gather();
1445    let encoded_metrics = encoder
1446        .encode_to_string(&metric_families)
1447        .map_err(|err| anyhow!("{}", err))?;
1448    Ok(encoded_metrics)
1449}
1450
1451#[instrument(err, skip(executor))]
1452async fn connection_lost(
1453    session: Session,
1454    mut teardown: watch::Receiver<bool>,
1455    executor: Executor,
1456) -> Result<()> {
1457    session.peer.disconnect(session.connection_id);
1458    session
1459        .connection_pool()
1460        .await
1461        .remove_connection(session.connection_id)?;
1462
1463    session
1464        .db()
1465        .await
1466        .connection_lost(session.connection_id)
1467        .await
1468        .trace_err();
1469
1470    futures::select_biased! {
1471        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1472            match &session.principal {
1473                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1474                    let session = session.for_user().unwrap();
1475
1476                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1477                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1478                    leave_channel_buffers_for_session(&session)
1479                        .await
1480                        .trace_err();
1481
1482                    if !session
1483                        .connection_pool()
1484                        .await
1485                        .is_user_online(session.user_id())
1486                    {
1487                        let db = session.db().await;
1488                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1489                            room_updated(&room, &session.peer);
1490                        }
1491                    }
1492
1493                    update_user_contacts(session.user_id(), &session).await?;
1494                },
1495            Principal::DevServer(_) => {
1496                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1497            },
1498        }
1499        },
1500        _ = teardown.changed().fuse() => {}
1501    }
1502
1503    Ok(())
1504}
1505
1506/// Acknowledges a ping from a client, used to keep the connection alive.
1507async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1508    response.send(proto::Ack {})?;
1509    Ok(())
1510}
1511
1512/// Creates a new room for calling (outside of channels)
1513async fn create_room(
1514    _request: proto::CreateRoom,
1515    response: Response<proto::CreateRoom>,
1516    session: UserSession,
1517) -> Result<()> {
1518    let live_kit_room = nanoid::nanoid!(30);
1519
1520    let live_kit_connection_info = util::maybe!(async {
1521        let live_kit = session.live_kit_client.as_ref();
1522        let live_kit = live_kit?;
1523        let user_id = session.user_id().to_string();
1524
1525        let token = live_kit
1526            .room_token(&live_kit_room, &user_id.to_string())
1527            .trace_err()?;
1528
1529        Some(proto::LiveKitConnectionInfo {
1530            server_url: live_kit.url().into(),
1531            token,
1532            can_publish: true,
1533        })
1534    })
1535    .await;
1536
1537    let room = session
1538        .db()
1539        .await
1540        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1541        .await?;
1542
1543    response.send(proto::CreateRoomResponse {
1544        room: Some(room.clone()),
1545        live_kit_connection_info,
1546    })?;
1547
1548    update_user_contacts(session.user_id(), &session).await?;
1549    Ok(())
1550}
1551
1552/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1553async fn join_room(
1554    request: proto::JoinRoom,
1555    response: Response<proto::JoinRoom>,
1556    session: UserSession,
1557) -> Result<()> {
1558    let room_id = RoomId::from_proto(request.id);
1559
1560    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1561
1562    if let Some(channel_id) = channel_id {
1563        return join_channel_internal(channel_id, Box::new(response), session).await;
1564    }
1565
1566    let joined_room = {
1567        let room = session
1568            .db()
1569            .await
1570            .join_room(room_id, session.user_id(), session.connection_id)
1571            .await?;
1572        room_updated(&room.room, &session.peer);
1573        room.into_inner()
1574    };
1575
1576    for connection_id in session
1577        .connection_pool()
1578        .await
1579        .user_connection_ids(session.user_id())
1580    {
1581        session
1582            .peer
1583            .send(
1584                connection_id,
1585                proto::CallCanceled {
1586                    room_id: room_id.to_proto(),
1587                },
1588            )
1589            .trace_err();
1590    }
1591
1592    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1593        if let Some(token) = live_kit
1594            .room_token(
1595                &joined_room.room.live_kit_room,
1596                &session.user_id().to_string(),
1597            )
1598            .trace_err()
1599        {
1600            Some(proto::LiveKitConnectionInfo {
1601                server_url: live_kit.url().into(),
1602                token,
1603                can_publish: true,
1604            })
1605        } else {
1606            None
1607        }
1608    } else {
1609        None
1610    };
1611
1612    response.send(proto::JoinRoomResponse {
1613        room: Some(joined_room.room),
1614        channel_id: None,
1615        live_kit_connection_info,
1616    })?;
1617
1618    update_user_contacts(session.user_id(), &session).await?;
1619    Ok(())
1620}
1621
1622/// Rejoin room is used to reconnect to a room after connection errors.
1623async fn rejoin_room(
1624    request: proto::RejoinRoom,
1625    response: Response<proto::RejoinRoom>,
1626    session: UserSession,
1627) -> Result<()> {
1628    let room;
1629    let channel;
1630    {
1631        let mut rejoined_room = session
1632            .db()
1633            .await
1634            .rejoin_room(request, session.user_id(), session.connection_id)
1635            .await?;
1636
1637        response.send(proto::RejoinRoomResponse {
1638            room: Some(rejoined_room.room.clone()),
1639            reshared_projects: rejoined_room
1640                .reshared_projects
1641                .iter()
1642                .map(|project| proto::ResharedProject {
1643                    id: project.id.to_proto(),
1644                    collaborators: project
1645                        .collaborators
1646                        .iter()
1647                        .map(|collaborator| collaborator.to_proto())
1648                        .collect(),
1649                })
1650                .collect(),
1651            rejoined_projects: rejoined_room
1652                .rejoined_projects
1653                .iter()
1654                .map(|rejoined_project| rejoined_project.to_proto())
1655                .collect(),
1656        })?;
1657        room_updated(&rejoined_room.room, &session.peer);
1658
1659        for project in &rejoined_room.reshared_projects {
1660            for collaborator in &project.collaborators {
1661                session
1662                    .peer
1663                    .send(
1664                        collaborator.connection_id,
1665                        proto::UpdateProjectCollaborator {
1666                            project_id: project.id.to_proto(),
1667                            old_peer_id: Some(project.old_connection_id.into()),
1668                            new_peer_id: Some(session.connection_id.into()),
1669                        },
1670                    )
1671                    .trace_err();
1672            }
1673
1674            broadcast(
1675                Some(session.connection_id),
1676                project
1677                    .collaborators
1678                    .iter()
1679                    .map(|collaborator| collaborator.connection_id),
1680                |connection_id| {
1681                    session.peer.forward_send(
1682                        session.connection_id,
1683                        connection_id,
1684                        proto::UpdateProject {
1685                            project_id: project.id.to_proto(),
1686                            worktrees: project.worktrees.clone(),
1687                        },
1688                    )
1689                },
1690            );
1691        }
1692
1693        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1694
1695        let rejoined_room = rejoined_room.into_inner();
1696
1697        room = rejoined_room.room;
1698        channel = rejoined_room.channel;
1699    }
1700
1701    if let Some(channel) = channel {
1702        channel_updated(
1703            &channel,
1704            &room,
1705            &session.peer,
1706            &*session.connection_pool().await,
1707        );
1708    }
1709
1710    update_user_contacts(session.user_id(), &session).await?;
1711    Ok(())
1712}
1713
1714fn notify_rejoined_projects(
1715    rejoined_projects: &mut Vec<RejoinedProject>,
1716    session: &UserSession,
1717) -> Result<()> {
1718    for project in rejoined_projects.iter() {
1719        for collaborator in &project.collaborators {
1720            session
1721                .peer
1722                .send(
1723                    collaborator.connection_id,
1724                    proto::UpdateProjectCollaborator {
1725                        project_id: project.id.to_proto(),
1726                        old_peer_id: Some(project.old_connection_id.into()),
1727                        new_peer_id: Some(session.connection_id.into()),
1728                    },
1729                )
1730                .trace_err();
1731        }
1732    }
1733
1734    for project in rejoined_projects {
1735        for worktree in mem::take(&mut project.worktrees) {
1736            #[cfg(any(test, feature = "test-support"))]
1737            const MAX_CHUNK_SIZE: usize = 2;
1738            #[cfg(not(any(test, feature = "test-support")))]
1739            const MAX_CHUNK_SIZE: usize = 256;
1740
1741            // Stream this worktree's entries.
1742            let message = proto::UpdateWorktree {
1743                project_id: project.id.to_proto(),
1744                worktree_id: worktree.id,
1745                abs_path: worktree.abs_path.clone(),
1746                root_name: worktree.root_name,
1747                updated_entries: worktree.updated_entries,
1748                removed_entries: worktree.removed_entries,
1749                scan_id: worktree.scan_id,
1750                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1751                updated_repositories: worktree.updated_repositories,
1752                removed_repositories: worktree.removed_repositories,
1753            };
1754            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1755                session.peer.send(session.connection_id, update.clone())?;
1756            }
1757
1758            // Stream this worktree's diagnostics.
1759            for summary in worktree.diagnostic_summaries {
1760                session.peer.send(
1761                    session.connection_id,
1762                    proto::UpdateDiagnosticSummary {
1763                        project_id: project.id.to_proto(),
1764                        worktree_id: worktree.id,
1765                        summary: Some(summary),
1766                    },
1767                )?;
1768            }
1769
1770            for settings_file in worktree.settings_files {
1771                session.peer.send(
1772                    session.connection_id,
1773                    proto::UpdateWorktreeSettings {
1774                        project_id: project.id.to_proto(),
1775                        worktree_id: worktree.id,
1776                        path: settings_file.path,
1777                        content: Some(settings_file.content),
1778                    },
1779                )?;
1780            }
1781        }
1782
1783        for language_server in &project.language_servers {
1784            session.peer.send(
1785                session.connection_id,
1786                proto::UpdateLanguageServer {
1787                    project_id: project.id.to_proto(),
1788                    language_server_id: language_server.id,
1789                    variant: Some(
1790                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1791                            proto::LspDiskBasedDiagnosticsUpdated {},
1792                        ),
1793                    ),
1794                },
1795            )?;
1796        }
1797    }
1798    Ok(())
1799}
1800
1801/// leave room disconnects from the room.
1802async fn leave_room(
1803    _: proto::LeaveRoom,
1804    response: Response<proto::LeaveRoom>,
1805    session: UserSession,
1806) -> Result<()> {
1807    leave_room_for_session(&session, session.connection_id).await?;
1808    response.send(proto::Ack {})?;
1809    Ok(())
1810}
1811
1812/// Updates the permissions of someone else in the room.
1813async fn set_room_participant_role(
1814    request: proto::SetRoomParticipantRole,
1815    response: Response<proto::SetRoomParticipantRole>,
1816    session: UserSession,
1817) -> Result<()> {
1818    let user_id = UserId::from_proto(request.user_id);
1819    let role = ChannelRole::from(request.role());
1820
1821    let (live_kit_room, can_publish) = {
1822        let room = session
1823            .db()
1824            .await
1825            .set_room_participant_role(
1826                session.user_id(),
1827                RoomId::from_proto(request.room_id),
1828                user_id,
1829                role,
1830            )
1831            .await?;
1832
1833        let live_kit_room = room.live_kit_room.clone();
1834        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1835        room_updated(&room, &session.peer);
1836        (live_kit_room, can_publish)
1837    };
1838
1839    if let Some(live_kit) = session.live_kit_client.as_ref() {
1840        live_kit
1841            .update_participant(
1842                live_kit_room.clone(),
1843                request.user_id.to_string(),
1844                live_kit_server::proto::ParticipantPermission {
1845                    can_subscribe: true,
1846                    can_publish,
1847                    can_publish_data: can_publish,
1848                    hidden: false,
1849                    recorder: false,
1850                },
1851            )
1852            .await
1853            .trace_err();
1854    }
1855
1856    response.send(proto::Ack {})?;
1857    Ok(())
1858}
1859
1860/// Call someone else into the current room
1861async fn call(
1862    request: proto::Call,
1863    response: Response<proto::Call>,
1864    session: UserSession,
1865) -> Result<()> {
1866    let room_id = RoomId::from_proto(request.room_id);
1867    let calling_user_id = session.user_id();
1868    let calling_connection_id = session.connection_id;
1869    let called_user_id = UserId::from_proto(request.called_user_id);
1870    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1871    if !session
1872        .db()
1873        .await
1874        .has_contact(calling_user_id, called_user_id)
1875        .await?
1876    {
1877        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1878    }
1879
1880    let incoming_call = {
1881        let (room, incoming_call) = &mut *session
1882            .db()
1883            .await
1884            .call(
1885                room_id,
1886                calling_user_id,
1887                calling_connection_id,
1888                called_user_id,
1889                initial_project_id,
1890            )
1891            .await?;
1892        room_updated(&room, &session.peer);
1893        mem::take(incoming_call)
1894    };
1895    update_user_contacts(called_user_id, &session).await?;
1896
1897    let mut calls = session
1898        .connection_pool()
1899        .await
1900        .user_connection_ids(called_user_id)
1901        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1902        .collect::<FuturesUnordered<_>>();
1903
1904    while let Some(call_response) = calls.next().await {
1905        match call_response.as_ref() {
1906            Ok(_) => {
1907                response.send(proto::Ack {})?;
1908                return Ok(());
1909            }
1910            Err(_) => {
1911                call_response.trace_err();
1912            }
1913        }
1914    }
1915
1916    {
1917        let room = session
1918            .db()
1919            .await
1920            .call_failed(room_id, called_user_id)
1921            .await?;
1922        room_updated(&room, &session.peer);
1923    }
1924    update_user_contacts(called_user_id, &session).await?;
1925
1926    Err(anyhow!("failed to ring user"))?
1927}
1928
1929/// Cancel an outgoing call.
1930async fn cancel_call(
1931    request: proto::CancelCall,
1932    response: Response<proto::CancelCall>,
1933    session: UserSession,
1934) -> Result<()> {
1935    let called_user_id = UserId::from_proto(request.called_user_id);
1936    let room_id = RoomId::from_proto(request.room_id);
1937    {
1938        let room = session
1939            .db()
1940            .await
1941            .cancel_call(room_id, session.connection_id, called_user_id)
1942            .await?;
1943        room_updated(&room, &session.peer);
1944    }
1945
1946    for connection_id in session
1947        .connection_pool()
1948        .await
1949        .user_connection_ids(called_user_id)
1950    {
1951        session
1952            .peer
1953            .send(
1954                connection_id,
1955                proto::CallCanceled {
1956                    room_id: room_id.to_proto(),
1957                },
1958            )
1959            .trace_err();
1960    }
1961    response.send(proto::Ack {})?;
1962
1963    update_user_contacts(called_user_id, &session).await?;
1964    Ok(())
1965}
1966
1967/// Decline an incoming call.
1968async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1969    let room_id = RoomId::from_proto(message.room_id);
1970    {
1971        let room = session
1972            .db()
1973            .await
1974            .decline_call(Some(room_id), session.user_id())
1975            .await?
1976            .ok_or_else(|| anyhow!("failed to decline call"))?;
1977        room_updated(&room, &session.peer);
1978    }
1979
1980    for connection_id in session
1981        .connection_pool()
1982        .await
1983        .user_connection_ids(session.user_id())
1984    {
1985        session
1986            .peer
1987            .send(
1988                connection_id,
1989                proto::CallCanceled {
1990                    room_id: room_id.to_proto(),
1991                },
1992            )
1993            .trace_err();
1994    }
1995    update_user_contacts(session.user_id(), &session).await?;
1996    Ok(())
1997}
1998
1999/// Updates other participants in the room with your current location.
2000async fn update_participant_location(
2001    request: proto::UpdateParticipantLocation,
2002    response: Response<proto::UpdateParticipantLocation>,
2003    session: UserSession,
2004) -> Result<()> {
2005    let room_id = RoomId::from_proto(request.room_id);
2006    let location = request
2007        .location
2008        .ok_or_else(|| anyhow!("invalid location"))?;
2009
2010    let db = session.db().await;
2011    let room = db
2012        .update_room_participant_location(room_id, session.connection_id, location)
2013        .await?;
2014
2015    room_updated(&room, &session.peer);
2016    response.send(proto::Ack {})?;
2017    Ok(())
2018}
2019
2020/// Share a project into the room.
2021async fn share_project(
2022    request: proto::ShareProject,
2023    response: Response<proto::ShareProject>,
2024    session: UserSession,
2025) -> Result<()> {
2026    let (project_id, room) = &*session
2027        .db()
2028        .await
2029        .share_project(
2030            RoomId::from_proto(request.room_id),
2031            session.connection_id,
2032            &request.worktrees,
2033            request
2034                .dev_server_project_id
2035                .map(|id| DevServerProjectId::from_proto(id)),
2036        )
2037        .await?;
2038    response.send(proto::ShareProjectResponse {
2039        project_id: project_id.to_proto(),
2040    })?;
2041    room_updated(&room, &session.peer);
2042
2043    Ok(())
2044}
2045
2046/// Unshare a project from the room.
2047async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2048    let project_id = ProjectId::from_proto(message.project_id);
2049    unshare_project_internal(
2050        project_id,
2051        session.connection_id,
2052        session.user_id(),
2053        &session,
2054    )
2055    .await
2056}
2057
2058async fn unshare_project_internal(
2059    project_id: ProjectId,
2060    connection_id: ConnectionId,
2061    user_id: Option<UserId>,
2062    session: &Session,
2063) -> Result<()> {
2064    let delete = {
2065        let room_guard = session
2066            .db()
2067            .await
2068            .unshare_project(project_id, connection_id, user_id)
2069            .await?;
2070
2071        let (delete, room, guest_connection_ids) = &*room_guard;
2072
2073        let message = proto::UnshareProject {
2074            project_id: project_id.to_proto(),
2075        };
2076
2077        broadcast(
2078            Some(connection_id),
2079            guest_connection_ids.iter().copied(),
2080            |conn_id| session.peer.send(conn_id, message.clone()),
2081        );
2082        if let Some(room) = room {
2083            room_updated(room, &session.peer);
2084        }
2085
2086        *delete
2087    };
2088
2089    if delete {
2090        let db = session.db().await;
2091        db.delete_project(project_id).await?;
2092    }
2093
2094    Ok(())
2095}
2096
2097/// DevServer makes a project available online
2098async fn share_dev_server_project(
2099    request: proto::ShareDevServerProject,
2100    response: Response<proto::ShareDevServerProject>,
2101    session: DevServerSession,
2102) -> Result<()> {
2103    let (dev_server_project, user_id, status) = session
2104        .db()
2105        .await
2106        .share_dev_server_project(
2107            DevServerProjectId::from_proto(request.dev_server_project_id),
2108            session.dev_server_id(),
2109            session.connection_id,
2110            &request.worktrees,
2111        )
2112        .await?;
2113    let Some(project_id) = dev_server_project.project_id else {
2114        return Err(anyhow!("failed to share remote project"))?;
2115    };
2116
2117    send_dev_server_projects_update(user_id, status, &session).await;
2118
2119    response.send(proto::ShareProjectResponse { project_id })?;
2120
2121    Ok(())
2122}
2123
2124/// Join someone elses shared project.
2125async fn join_project(
2126    request: proto::JoinProject,
2127    response: Response<proto::JoinProject>,
2128    session: UserSession,
2129) -> Result<()> {
2130    let project_id = ProjectId::from_proto(request.project_id);
2131
2132    tracing::info!(%project_id, "join project");
2133
2134    let db = session.db().await;
2135    let (project, replica_id) = &mut *db
2136        .join_project(project_id, session.connection_id, session.user_id())
2137        .await?;
2138    drop(db);
2139    tracing::info!(%project_id, "join remote project");
2140    join_project_internal(response, session, project, replica_id)
2141}
2142
2143trait JoinProjectInternalResponse {
2144    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2145}
2146impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2147    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2148        Response::<proto::JoinProject>::send(self, result)
2149    }
2150}
2151impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2152    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2153        Response::<proto::JoinHostedProject>::send(self, result)
2154    }
2155}
2156
2157fn join_project_internal(
2158    response: impl JoinProjectInternalResponse,
2159    session: UserSession,
2160    project: &mut Project,
2161    replica_id: &ReplicaId,
2162) -> Result<()> {
2163    let collaborators = project
2164        .collaborators
2165        .iter()
2166        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2167        .map(|collaborator| collaborator.to_proto())
2168        .collect::<Vec<_>>();
2169    let project_id = project.id;
2170    let guest_user_id = session.user_id();
2171
2172    let worktrees = project
2173        .worktrees
2174        .iter()
2175        .map(|(id, worktree)| proto::WorktreeMetadata {
2176            id: *id,
2177            root_name: worktree.root_name.clone(),
2178            visible: worktree.visible,
2179            abs_path: worktree.abs_path.clone(),
2180        })
2181        .collect::<Vec<_>>();
2182
2183    let add_project_collaborator = proto::AddProjectCollaborator {
2184        project_id: project_id.to_proto(),
2185        collaborator: Some(proto::Collaborator {
2186            peer_id: Some(session.connection_id.into()),
2187            replica_id: replica_id.0 as u32,
2188            user_id: guest_user_id.to_proto(),
2189        }),
2190    };
2191
2192    for collaborator in &collaborators {
2193        session
2194            .peer
2195            .send(
2196                collaborator.peer_id.unwrap().into(),
2197                add_project_collaborator.clone(),
2198            )
2199            .trace_err();
2200    }
2201
2202    // First, we send the metadata associated with each worktree.
2203    response.send(proto::JoinProjectResponse {
2204        project_id: project.id.0 as u64,
2205        worktrees: worktrees.clone(),
2206        replica_id: replica_id.0 as u32,
2207        collaborators: collaborators.clone(),
2208        language_servers: project.language_servers.clone(),
2209        role: project.role.into(),
2210        dev_server_project_id: project
2211            .dev_server_project_id
2212            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2213    })?;
2214
2215    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2216        #[cfg(any(test, feature = "test-support"))]
2217        const MAX_CHUNK_SIZE: usize = 2;
2218        #[cfg(not(any(test, feature = "test-support")))]
2219        const MAX_CHUNK_SIZE: usize = 256;
2220
2221        // Stream this worktree's entries.
2222        let message = proto::UpdateWorktree {
2223            project_id: project_id.to_proto(),
2224            worktree_id,
2225            abs_path: worktree.abs_path.clone(),
2226            root_name: worktree.root_name,
2227            updated_entries: worktree.entries,
2228            removed_entries: Default::default(),
2229            scan_id: worktree.scan_id,
2230            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2231            updated_repositories: worktree.repository_entries.into_values().collect(),
2232            removed_repositories: Default::default(),
2233        };
2234        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2235            session.peer.send(session.connection_id, update.clone())?;
2236        }
2237
2238        // Stream this worktree's diagnostics.
2239        for summary in worktree.diagnostic_summaries {
2240            session.peer.send(
2241                session.connection_id,
2242                proto::UpdateDiagnosticSummary {
2243                    project_id: project_id.to_proto(),
2244                    worktree_id: worktree.id,
2245                    summary: Some(summary),
2246                },
2247            )?;
2248        }
2249
2250        for settings_file in worktree.settings_files {
2251            session.peer.send(
2252                session.connection_id,
2253                proto::UpdateWorktreeSettings {
2254                    project_id: project_id.to_proto(),
2255                    worktree_id: worktree.id,
2256                    path: settings_file.path,
2257                    content: Some(settings_file.content),
2258                },
2259            )?;
2260        }
2261    }
2262
2263    for language_server in &project.language_servers {
2264        session.peer.send(
2265            session.connection_id,
2266            proto::UpdateLanguageServer {
2267                project_id: project_id.to_proto(),
2268                language_server_id: language_server.id,
2269                variant: Some(
2270                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2271                        proto::LspDiskBasedDiagnosticsUpdated {},
2272                    ),
2273                ),
2274            },
2275        )?;
2276    }
2277
2278    Ok(())
2279}
2280
2281/// Leave someone elses shared project.
2282async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2283    let sender_id = session.connection_id;
2284    let project_id = ProjectId::from_proto(request.project_id);
2285    let db = session.db().await;
2286    if db.is_hosted_project(project_id).await? {
2287        let project = db.leave_hosted_project(project_id, sender_id).await?;
2288        project_left(&project, &session);
2289        return Ok(());
2290    }
2291
2292    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2293    tracing::info!(
2294        %project_id,
2295        "leave project"
2296    );
2297
2298    project_left(&project, &session);
2299    if let Some(room) = room {
2300        room_updated(&room, &session.peer);
2301    }
2302
2303    Ok(())
2304}
2305
2306async fn join_hosted_project(
2307    request: proto::JoinHostedProject,
2308    response: Response<proto::JoinHostedProject>,
2309    session: UserSession,
2310) -> Result<()> {
2311    let (mut project, replica_id) = session
2312        .db()
2313        .await
2314        .join_hosted_project(
2315            ProjectId(request.project_id as i32),
2316            session.user_id(),
2317            session.connection_id,
2318        )
2319        .await?;
2320
2321    join_project_internal(response, session, &mut project, &replica_id)
2322}
2323
2324async fn list_remote_directory(
2325    request: proto::ListRemoteDirectory,
2326    response: Response<proto::ListRemoteDirectory>,
2327    session: UserSession,
2328) -> Result<()> {
2329    let dev_server_id = DevServerId(request.dev_server_id as i32);
2330    let dev_server_connection_id = session
2331        .connection_pool()
2332        .await
2333        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2334
2335    session
2336        .db()
2337        .await
2338        .get_dev_server_for_user(dev_server_id, session.user_id())
2339        .await?;
2340
2341    response.send(
2342        session
2343            .peer
2344            .forward_request(session.connection_id, dev_server_connection_id, request)
2345            .await?,
2346    )?;
2347    Ok(())
2348}
2349
2350async fn update_dev_server_project(
2351    request: proto::UpdateDevServerProject,
2352    response: Response<proto::UpdateDevServerProject>,
2353    session: UserSession,
2354) -> Result<()> {
2355    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2356
2357    let (dev_server_project, update) = session
2358        .db()
2359        .await
2360        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2361        .await?;
2362
2363    let projects = session
2364        .db()
2365        .await
2366        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2367        .await?;
2368
2369    let dev_server_connection_id = session
2370        .connection_pool()
2371        .await
2372        .dev_server_connection_id_supporting(
2373            dev_server_project.dev_server_id,
2374            ZedVersion::with_list_directory(),
2375        )?;
2376
2377    session.peer.send(
2378        dev_server_connection_id,
2379        proto::DevServerInstructions { projects },
2380    )?;
2381
2382    send_dev_server_projects_update(session.user_id(), update, &session).await;
2383
2384    response.send(proto::Ack {})
2385}
2386
2387async fn create_dev_server_project(
2388    request: proto::CreateDevServerProject,
2389    response: Response<proto::CreateDevServerProject>,
2390    session: UserSession,
2391) -> Result<()> {
2392    let dev_server_id = DevServerId(request.dev_server_id as i32);
2393    let dev_server_connection_id = session
2394        .connection_pool()
2395        .await
2396        .dev_server_connection_id(dev_server_id);
2397    let Some(dev_server_connection_id) = dev_server_connection_id else {
2398        Err(ErrorCode::DevServerOffline
2399            .message("Cannot create a remote project when the dev server is offline".to_string())
2400            .anyhow())?
2401    };
2402
2403    let path = request.path.clone();
2404    //Check that the path exists on the dev server
2405    session
2406        .peer
2407        .forward_request(
2408            session.connection_id,
2409            dev_server_connection_id,
2410            proto::ValidateDevServerProjectRequest { path: path.clone() },
2411        )
2412        .await?;
2413
2414    let (dev_server_project, update) = session
2415        .db()
2416        .await
2417        .create_dev_server_project(
2418            DevServerId(request.dev_server_id as i32),
2419            &request.path,
2420            session.user_id(),
2421        )
2422        .await?;
2423
2424    let projects = session
2425        .db()
2426        .await
2427        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2428        .await?;
2429
2430    session.peer.send(
2431        dev_server_connection_id,
2432        proto::DevServerInstructions { projects },
2433    )?;
2434
2435    send_dev_server_projects_update(session.user_id(), update, &session).await;
2436
2437    response.send(proto::CreateDevServerProjectResponse {
2438        dev_server_project: Some(dev_server_project.to_proto(None)),
2439    })?;
2440    Ok(())
2441}
2442
2443async fn create_dev_server(
2444    request: proto::CreateDevServer,
2445    response: Response<proto::CreateDevServer>,
2446    session: UserSession,
2447) -> Result<()> {
2448    let access_token = auth::random_token();
2449    let hashed_access_token = auth::hash_access_token(&access_token);
2450
2451    if request.name.is_empty() {
2452        return Err(proto::ErrorCode::Forbidden
2453            .message("Dev server name cannot be empty".to_string())
2454            .anyhow())?;
2455    }
2456
2457    let (dev_server, status) = session
2458        .db()
2459        .await
2460        .create_dev_server(
2461            &request.name,
2462            request.ssh_connection_string.as_deref(),
2463            &hashed_access_token,
2464            session.user_id(),
2465        )
2466        .await?;
2467
2468    send_dev_server_projects_update(session.user_id(), status, &session).await;
2469
2470    response.send(proto::CreateDevServerResponse {
2471        dev_server_id: dev_server.id.0 as u64,
2472        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2473        name: request.name,
2474    })?;
2475    Ok(())
2476}
2477
2478async fn regenerate_dev_server_token(
2479    request: proto::RegenerateDevServerToken,
2480    response: Response<proto::RegenerateDevServerToken>,
2481    session: UserSession,
2482) -> Result<()> {
2483    let dev_server_id = DevServerId(request.dev_server_id as i32);
2484    let access_token = auth::random_token();
2485    let hashed_access_token = auth::hash_access_token(&access_token);
2486
2487    let connection_id = session
2488        .connection_pool()
2489        .await
2490        .dev_server_connection_id(dev_server_id);
2491    if let Some(connection_id) = connection_id {
2492        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2493        session.peer.send(
2494            connection_id,
2495            proto::ShutdownDevServer {
2496                reason: Some("dev server token was regenerated".to_string()),
2497            },
2498        )?;
2499        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2500    }
2501
2502    let status = session
2503        .db()
2504        .await
2505        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2506        .await?;
2507
2508    send_dev_server_projects_update(session.user_id(), status, &session).await;
2509
2510    response.send(proto::RegenerateDevServerTokenResponse {
2511        dev_server_id: dev_server_id.to_proto(),
2512        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2513    })?;
2514    Ok(())
2515}
2516
2517async fn rename_dev_server(
2518    request: proto::RenameDevServer,
2519    response: Response<proto::RenameDevServer>,
2520    session: UserSession,
2521) -> Result<()> {
2522    if request.name.trim().is_empty() {
2523        return Err(proto::ErrorCode::Forbidden
2524            .message("Dev server name cannot be empty".to_string())
2525            .anyhow())?;
2526    }
2527
2528    let dev_server_id = DevServerId(request.dev_server_id as i32);
2529    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2530    if dev_server.user_id != session.user_id() {
2531        return Err(anyhow!(ErrorCode::Forbidden))?;
2532    }
2533
2534    let status = session
2535        .db()
2536        .await
2537        .rename_dev_server(
2538            dev_server_id,
2539            &request.name,
2540            request.ssh_connection_string.as_deref(),
2541            session.user_id(),
2542        )
2543        .await?;
2544
2545    send_dev_server_projects_update(session.user_id(), status, &session).await;
2546
2547    response.send(proto::Ack {})?;
2548    Ok(())
2549}
2550
2551async fn delete_dev_server(
2552    request: proto::DeleteDevServer,
2553    response: Response<proto::DeleteDevServer>,
2554    session: UserSession,
2555) -> Result<()> {
2556    let dev_server_id = DevServerId(request.dev_server_id as i32);
2557    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2558    if dev_server.user_id != session.user_id() {
2559        return Err(anyhow!(ErrorCode::Forbidden))?;
2560    }
2561
2562    let connection_id = session
2563        .connection_pool()
2564        .await
2565        .dev_server_connection_id(dev_server_id);
2566    if let Some(connection_id) = connection_id {
2567        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2568        session.peer.send(
2569            connection_id,
2570            proto::ShutdownDevServer {
2571                reason: Some("dev server was deleted".to_string()),
2572            },
2573        )?;
2574        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2575    }
2576
2577    let status = session
2578        .db()
2579        .await
2580        .delete_dev_server(dev_server_id, session.user_id())
2581        .await?;
2582
2583    send_dev_server_projects_update(session.user_id(), status, &session).await;
2584
2585    response.send(proto::Ack {})?;
2586    Ok(())
2587}
2588
2589async fn delete_dev_server_project(
2590    request: proto::DeleteDevServerProject,
2591    response: Response<proto::DeleteDevServerProject>,
2592    session: UserSession,
2593) -> Result<()> {
2594    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2595    let dev_server_project = session
2596        .db()
2597        .await
2598        .get_dev_server_project(dev_server_project_id)
2599        .await?;
2600
2601    let dev_server = session
2602        .db()
2603        .await
2604        .get_dev_server(dev_server_project.dev_server_id)
2605        .await?;
2606    if dev_server.user_id != session.user_id() {
2607        return Err(anyhow!(ErrorCode::Forbidden))?;
2608    }
2609
2610    let dev_server_connection_id = session
2611        .connection_pool()
2612        .await
2613        .dev_server_connection_id(dev_server.id);
2614
2615    if let Some(dev_server_connection_id) = dev_server_connection_id {
2616        let project = session
2617            .db()
2618            .await
2619            .find_dev_server_project(dev_server_project_id)
2620            .await;
2621        if let Ok(project) = project {
2622            unshare_project_internal(
2623                project.id,
2624                dev_server_connection_id,
2625                Some(session.user_id()),
2626                &session,
2627            )
2628            .await?;
2629        }
2630    }
2631
2632    let (projects, status) = session
2633        .db()
2634        .await
2635        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2636        .await?;
2637
2638    if let Some(dev_server_connection_id) = dev_server_connection_id {
2639        session.peer.send(
2640            dev_server_connection_id,
2641            proto::DevServerInstructions { projects },
2642        )?;
2643    }
2644
2645    send_dev_server_projects_update(session.user_id(), status, &session).await;
2646
2647    response.send(proto::Ack {})?;
2648    Ok(())
2649}
2650
2651async fn rejoin_dev_server_projects(
2652    request: proto::RejoinRemoteProjects,
2653    response: Response<proto::RejoinRemoteProjects>,
2654    session: UserSession,
2655) -> Result<()> {
2656    let mut rejoined_projects = {
2657        let db = session.db().await;
2658        db.rejoin_dev_server_projects(
2659            &request.rejoined_projects,
2660            session.user_id(),
2661            session.0.connection_id,
2662        )
2663        .await?
2664    };
2665    response.send(proto::RejoinRemoteProjectsResponse {
2666        rejoined_projects: rejoined_projects
2667            .iter()
2668            .map(|project| project.to_proto())
2669            .collect(),
2670    })?;
2671    notify_rejoined_projects(&mut rejoined_projects, &session)
2672}
2673
2674async fn reconnect_dev_server(
2675    request: proto::ReconnectDevServer,
2676    response: Response<proto::ReconnectDevServer>,
2677    session: DevServerSession,
2678) -> Result<()> {
2679    let reshared_projects = {
2680        let db = session.db().await;
2681        db.reshare_dev_server_projects(
2682            &request.reshared_projects,
2683            session.dev_server_id(),
2684            session.0.connection_id,
2685        )
2686        .await?
2687    };
2688
2689    for project in &reshared_projects {
2690        for collaborator in &project.collaborators {
2691            session
2692                .peer
2693                .send(
2694                    collaborator.connection_id,
2695                    proto::UpdateProjectCollaborator {
2696                        project_id: project.id.to_proto(),
2697                        old_peer_id: Some(project.old_connection_id.into()),
2698                        new_peer_id: Some(session.connection_id.into()),
2699                    },
2700                )
2701                .trace_err();
2702        }
2703
2704        broadcast(
2705            Some(session.connection_id),
2706            project
2707                .collaborators
2708                .iter()
2709                .map(|collaborator| collaborator.connection_id),
2710            |connection_id| {
2711                session.peer.forward_send(
2712                    session.connection_id,
2713                    connection_id,
2714                    proto::UpdateProject {
2715                        project_id: project.id.to_proto(),
2716                        worktrees: project.worktrees.clone(),
2717                    },
2718                )
2719            },
2720        );
2721    }
2722
2723    response.send(proto::ReconnectDevServerResponse {
2724        reshared_projects: reshared_projects
2725            .iter()
2726            .map(|project| proto::ResharedProject {
2727                id: project.id.to_proto(),
2728                collaborators: project
2729                    .collaborators
2730                    .iter()
2731                    .map(|collaborator| collaborator.to_proto())
2732                    .collect(),
2733            })
2734            .collect(),
2735    })?;
2736
2737    Ok(())
2738}
2739
2740async fn shutdown_dev_server(
2741    _: proto::ShutdownDevServer,
2742    response: Response<proto::ShutdownDevServer>,
2743    session: DevServerSession,
2744) -> Result<()> {
2745    response.send(proto::Ack {})?;
2746    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2747    remove_dev_server_connection(session.dev_server_id(), &session).await
2748}
2749
2750async fn shutdown_dev_server_internal(
2751    dev_server_id: DevServerId,
2752    connection_id: ConnectionId,
2753    session: &Session,
2754) -> Result<()> {
2755    let (dev_server_projects, dev_server) = {
2756        let db = session.db().await;
2757        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2758        let dev_server = db.get_dev_server(dev_server_id).await?;
2759        (dev_server_projects, dev_server)
2760    };
2761
2762    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2763        unshare_project_internal(
2764            ProjectId::from_proto(project_id),
2765            connection_id,
2766            None,
2767            session,
2768        )
2769        .await?;
2770    }
2771
2772    session
2773        .connection_pool()
2774        .await
2775        .set_dev_server_offline(dev_server_id);
2776
2777    let status = session
2778        .db()
2779        .await
2780        .dev_server_projects_update(dev_server.user_id)
2781        .await?;
2782    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2783
2784    Ok(())
2785}
2786
2787async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2788    let dev_server_connection = session
2789        .connection_pool()
2790        .await
2791        .dev_server_connection_id(dev_server_id);
2792
2793    if let Some(dev_server_connection) = dev_server_connection {
2794        session
2795            .connection_pool()
2796            .await
2797            .remove_connection(dev_server_connection)?;
2798    }
2799    Ok(())
2800}
2801
2802/// Updates other participants with changes to the project
2803async fn update_project(
2804    request: proto::UpdateProject,
2805    response: Response<proto::UpdateProject>,
2806    session: Session,
2807) -> Result<()> {
2808    let project_id = ProjectId::from_proto(request.project_id);
2809    let (room, guest_connection_ids) = &*session
2810        .db()
2811        .await
2812        .update_project(project_id, session.connection_id, &request.worktrees)
2813        .await?;
2814    broadcast(
2815        Some(session.connection_id),
2816        guest_connection_ids.iter().copied(),
2817        |connection_id| {
2818            session
2819                .peer
2820                .forward_send(session.connection_id, connection_id, request.clone())
2821        },
2822    );
2823    if let Some(room) = room {
2824        room_updated(&room, &session.peer);
2825    }
2826    response.send(proto::Ack {})?;
2827
2828    Ok(())
2829}
2830
2831/// Updates other participants with changes to the worktree
2832async fn update_worktree(
2833    request: proto::UpdateWorktree,
2834    response: Response<proto::UpdateWorktree>,
2835    session: Session,
2836) -> Result<()> {
2837    let guest_connection_ids = session
2838        .db()
2839        .await
2840        .update_worktree(&request, session.connection_id)
2841        .await?;
2842
2843    broadcast(
2844        Some(session.connection_id),
2845        guest_connection_ids.iter().copied(),
2846        |connection_id| {
2847            session
2848                .peer
2849                .forward_send(session.connection_id, connection_id, request.clone())
2850        },
2851    );
2852    response.send(proto::Ack {})?;
2853    Ok(())
2854}
2855
2856/// Updates other participants with changes to the diagnostics
2857async fn update_diagnostic_summary(
2858    message: proto::UpdateDiagnosticSummary,
2859    session: Session,
2860) -> Result<()> {
2861    let guest_connection_ids = session
2862        .db()
2863        .await
2864        .update_diagnostic_summary(&message, session.connection_id)
2865        .await?;
2866
2867    broadcast(
2868        Some(session.connection_id),
2869        guest_connection_ids.iter().copied(),
2870        |connection_id| {
2871            session
2872                .peer
2873                .forward_send(session.connection_id, connection_id, message.clone())
2874        },
2875    );
2876
2877    Ok(())
2878}
2879
2880/// Updates other participants with changes to the worktree settings
2881async fn update_worktree_settings(
2882    message: proto::UpdateWorktreeSettings,
2883    session: Session,
2884) -> Result<()> {
2885    let guest_connection_ids = session
2886        .db()
2887        .await
2888        .update_worktree_settings(&message, session.connection_id)
2889        .await?;
2890
2891    broadcast(
2892        Some(session.connection_id),
2893        guest_connection_ids.iter().copied(),
2894        |connection_id| {
2895            session
2896                .peer
2897                .forward_send(session.connection_id, connection_id, message.clone())
2898        },
2899    );
2900
2901    Ok(())
2902}
2903
2904/// Notify other participants that a  language server has started.
2905async fn start_language_server(
2906    request: proto::StartLanguageServer,
2907    session: Session,
2908) -> Result<()> {
2909    let guest_connection_ids = session
2910        .db()
2911        .await
2912        .start_language_server(&request, session.connection_id)
2913        .await?;
2914
2915    broadcast(
2916        Some(session.connection_id),
2917        guest_connection_ids.iter().copied(),
2918        |connection_id| {
2919            session
2920                .peer
2921                .forward_send(session.connection_id, connection_id, request.clone())
2922        },
2923    );
2924    Ok(())
2925}
2926
2927/// Notify other participants that a language server has changed.
2928async fn update_language_server(
2929    request: proto::UpdateLanguageServer,
2930    session: Session,
2931) -> Result<()> {
2932    let project_id = ProjectId::from_proto(request.project_id);
2933    let project_connection_ids = session
2934        .db()
2935        .await
2936        .project_connection_ids(project_id, session.connection_id, true)
2937        .await?;
2938    broadcast(
2939        Some(session.connection_id),
2940        project_connection_ids.iter().copied(),
2941        |connection_id| {
2942            session
2943                .peer
2944                .forward_send(session.connection_id, connection_id, request.clone())
2945        },
2946    );
2947    Ok(())
2948}
2949
2950/// forward a project request to the host. These requests should be read only
2951/// as guests are allowed to send them.
2952async fn forward_read_only_project_request<T>(
2953    request: T,
2954    response: Response<T>,
2955    session: UserSession,
2956) -> Result<()>
2957where
2958    T: EntityMessage + RequestMessage,
2959{
2960    let project_id = ProjectId::from_proto(request.remote_entity_id());
2961    let host_connection_id = session
2962        .db()
2963        .await
2964        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2965        .await?;
2966    let payload = session
2967        .peer
2968        .forward_request(session.connection_id, host_connection_id, request)
2969        .await?;
2970    response.send(payload)?;
2971    Ok(())
2972}
2973
2974/// forward a project request to the dev server. Only allowed
2975/// if it's your dev server.
2976async fn forward_project_request_for_owner<T>(
2977    request: T,
2978    response: Response<T>,
2979    session: UserSession,
2980) -> Result<()>
2981where
2982    T: EntityMessage + RequestMessage,
2983{
2984    let project_id = ProjectId::from_proto(request.remote_entity_id());
2985
2986    let host_connection_id = session
2987        .db()
2988        .await
2989        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2990        .await?;
2991    let payload = session
2992        .peer
2993        .forward_request(session.connection_id, host_connection_id, request)
2994        .await?;
2995    response.send(payload)?;
2996    Ok(())
2997}
2998
2999/// forward a project request to the host. These requests are disallowed
3000/// for guests.
3001async fn forward_mutating_project_request<T>(
3002    request: T,
3003    response: Response<T>,
3004    session: UserSession,
3005) -> Result<()>
3006where
3007    T: EntityMessage + RequestMessage,
3008{
3009    let project_id = ProjectId::from_proto(request.remote_entity_id());
3010
3011    let host_connection_id = session
3012        .db()
3013        .await
3014        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3015        .await?;
3016    let payload = session
3017        .peer
3018        .forward_request(session.connection_id, host_connection_id, request)
3019        .await?;
3020    response.send(payload)?;
3021    Ok(())
3022}
3023
3024/// forward a project request to the host. These requests are disallowed
3025/// for guests.
3026async fn forward_versioned_mutating_project_request<T>(
3027    request: T,
3028    response: Response<T>,
3029    session: UserSession,
3030) -> Result<()>
3031where
3032    T: EntityMessage + RequestMessage + VersionedMessage,
3033{
3034    let project_id = ProjectId::from_proto(request.remote_entity_id());
3035
3036    let host_connection_id = session
3037        .db()
3038        .await
3039        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3040        .await?;
3041    if let Some(host_version) = session
3042        .connection_pool()
3043        .await
3044        .connection(host_connection_id)
3045        .map(|c| c.zed_version)
3046    {
3047        if let Some(min_required_version) = request.required_host_version() {
3048            if min_required_version > host_version {
3049                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3050                    .with_tag("required", &min_required_version.to_string())))?;
3051            }
3052        }
3053    }
3054
3055    let payload = session
3056        .peer
3057        .forward_request(session.connection_id, host_connection_id, request)
3058        .await?;
3059    response.send(payload)?;
3060    Ok(())
3061}
3062
3063/// Notify other participants that a new buffer has been created
3064async fn create_buffer_for_peer(
3065    request: proto::CreateBufferForPeer,
3066    session: Session,
3067) -> Result<()> {
3068    session
3069        .db()
3070        .await
3071        .check_user_is_project_host(
3072            ProjectId::from_proto(request.project_id),
3073            session.connection_id,
3074        )
3075        .await?;
3076    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3077    session
3078        .peer
3079        .forward_send(session.connection_id, peer_id.into(), request)?;
3080    Ok(())
3081}
3082
3083/// Notify other participants that a buffer has been updated. This is
3084/// allowed for guests as long as the update is limited to selections.
3085async fn update_buffer(
3086    request: proto::UpdateBuffer,
3087    response: Response<proto::UpdateBuffer>,
3088    session: Session,
3089) -> Result<()> {
3090    let project_id = ProjectId::from_proto(request.project_id);
3091    let mut capability = Capability::ReadOnly;
3092
3093    for op in request.operations.iter() {
3094        match op.variant {
3095            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3096            Some(_) => capability = Capability::ReadWrite,
3097        }
3098    }
3099
3100    let host = {
3101        let guard = session
3102            .db()
3103            .await
3104            .connections_for_buffer_update(
3105                project_id,
3106                session.principal_id(),
3107                session.connection_id,
3108                capability,
3109            )
3110            .await?;
3111
3112        let (host, guests) = &*guard;
3113
3114        broadcast(
3115            Some(session.connection_id),
3116            guests.clone(),
3117            |connection_id| {
3118                session
3119                    .peer
3120                    .forward_send(session.connection_id, connection_id, request.clone())
3121            },
3122        );
3123
3124        *host
3125    };
3126
3127    if host != session.connection_id {
3128        session
3129            .peer
3130            .forward_request(session.connection_id, host, request.clone())
3131            .await?;
3132    }
3133
3134    response.send(proto::Ack {})?;
3135    Ok(())
3136}
3137
3138async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3139    let project_id = ProjectId::from_proto(message.project_id);
3140
3141    let operation = message.operation.as_ref().context("invalid operation")?;
3142    let capability = match operation.variant.as_ref() {
3143        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3144            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3145                match buffer_op.variant {
3146                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3147                        Capability::ReadOnly
3148                    }
3149                    _ => Capability::ReadWrite,
3150                }
3151            } else {
3152                Capability::ReadWrite
3153            }
3154        }
3155        Some(_) => Capability::ReadWrite,
3156        None => Capability::ReadOnly,
3157    };
3158
3159    let guard = session
3160        .db()
3161        .await
3162        .connections_for_buffer_update(
3163            project_id,
3164            session.principal_id(),
3165            session.connection_id,
3166            capability,
3167        )
3168        .await?;
3169
3170    let (host, guests) = &*guard;
3171
3172    broadcast(
3173        Some(session.connection_id),
3174        guests.iter().chain([host]).copied(),
3175        |connection_id| {
3176            session
3177                .peer
3178                .forward_send(session.connection_id, connection_id, message.clone())
3179        },
3180    );
3181
3182    Ok(())
3183}
3184
3185/// Notify other participants that a project has been updated.
3186async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3187    request: T,
3188    session: Session,
3189) -> Result<()> {
3190    let project_id = ProjectId::from_proto(request.remote_entity_id());
3191    let project_connection_ids = session
3192        .db()
3193        .await
3194        .project_connection_ids(project_id, session.connection_id, false)
3195        .await?;
3196
3197    broadcast(
3198        Some(session.connection_id),
3199        project_connection_ids.iter().copied(),
3200        |connection_id| {
3201            session
3202                .peer
3203                .forward_send(session.connection_id, connection_id, request.clone())
3204        },
3205    );
3206    Ok(())
3207}
3208
3209/// Start following another user in a call.
3210async fn follow(
3211    request: proto::Follow,
3212    response: Response<proto::Follow>,
3213    session: UserSession,
3214) -> Result<()> {
3215    let room_id = RoomId::from_proto(request.room_id);
3216    let project_id = request.project_id.map(ProjectId::from_proto);
3217    let leader_id = request
3218        .leader_id
3219        .ok_or_else(|| anyhow!("invalid leader id"))?
3220        .into();
3221    let follower_id = session.connection_id;
3222
3223    session
3224        .db()
3225        .await
3226        .check_room_participants(room_id, leader_id, session.connection_id)
3227        .await?;
3228
3229    let response_payload = session
3230        .peer
3231        .forward_request(session.connection_id, leader_id, request)
3232        .await?;
3233    response.send(response_payload)?;
3234
3235    if let Some(project_id) = project_id {
3236        let room = session
3237            .db()
3238            .await
3239            .follow(room_id, project_id, leader_id, follower_id)
3240            .await?;
3241        room_updated(&room, &session.peer);
3242    }
3243
3244    Ok(())
3245}
3246
3247/// Stop following another user in a call.
3248async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3249    let room_id = RoomId::from_proto(request.room_id);
3250    let project_id = request.project_id.map(ProjectId::from_proto);
3251    let leader_id = request
3252        .leader_id
3253        .ok_or_else(|| anyhow!("invalid leader id"))?
3254        .into();
3255    let follower_id = session.connection_id;
3256
3257    session
3258        .db()
3259        .await
3260        .check_room_participants(room_id, leader_id, session.connection_id)
3261        .await?;
3262
3263    session
3264        .peer
3265        .forward_send(session.connection_id, leader_id, request)?;
3266
3267    if let Some(project_id) = project_id {
3268        let room = session
3269            .db()
3270            .await
3271            .unfollow(room_id, project_id, leader_id, follower_id)
3272            .await?;
3273        room_updated(&room, &session.peer);
3274    }
3275
3276    Ok(())
3277}
3278
3279/// Notify everyone following you of your current location.
3280async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3281    let room_id = RoomId::from_proto(request.room_id);
3282    let database = session.db.lock().await;
3283
3284    let connection_ids = if let Some(project_id) = request.project_id {
3285        let project_id = ProjectId::from_proto(project_id);
3286        database
3287            .project_connection_ids(project_id, session.connection_id, true)
3288            .await?
3289    } else {
3290        database
3291            .room_connection_ids(room_id, session.connection_id)
3292            .await?
3293    };
3294
3295    // For now, don't send view update messages back to that view's current leader.
3296    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3297        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3298        _ => None,
3299    });
3300
3301    for connection_id in connection_ids.iter().cloned() {
3302        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3303            session
3304                .peer
3305                .forward_send(session.connection_id, connection_id, request.clone())?;
3306        }
3307    }
3308    Ok(())
3309}
3310
3311/// Get public data about users.
3312async fn get_users(
3313    request: proto::GetUsers,
3314    response: Response<proto::GetUsers>,
3315    session: Session,
3316) -> Result<()> {
3317    let user_ids = request
3318        .user_ids
3319        .into_iter()
3320        .map(UserId::from_proto)
3321        .collect();
3322    let users = session
3323        .db()
3324        .await
3325        .get_users_by_ids(user_ids)
3326        .await?
3327        .into_iter()
3328        .map(|user| proto::User {
3329            id: user.id.to_proto(),
3330            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3331            github_login: user.github_login,
3332        })
3333        .collect();
3334    response.send(proto::UsersResponse { users })?;
3335    Ok(())
3336}
3337
3338/// Search for users (to invite) buy Github login
3339async fn fuzzy_search_users(
3340    request: proto::FuzzySearchUsers,
3341    response: Response<proto::FuzzySearchUsers>,
3342    session: UserSession,
3343) -> Result<()> {
3344    let query = request.query;
3345    let users = match query.len() {
3346        0 => vec![],
3347        1 | 2 => session
3348            .db()
3349            .await
3350            .get_user_by_github_login(&query)
3351            .await?
3352            .into_iter()
3353            .collect(),
3354        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3355    };
3356    let users = users
3357        .into_iter()
3358        .filter(|user| user.id != session.user_id())
3359        .map(|user| proto::User {
3360            id: user.id.to_proto(),
3361            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3362            github_login: user.github_login,
3363        })
3364        .collect();
3365    response.send(proto::UsersResponse { users })?;
3366    Ok(())
3367}
3368
3369/// Send a contact request to another user.
3370async fn request_contact(
3371    request: proto::RequestContact,
3372    response: Response<proto::RequestContact>,
3373    session: UserSession,
3374) -> Result<()> {
3375    let requester_id = session.user_id();
3376    let responder_id = UserId::from_proto(request.responder_id);
3377    if requester_id == responder_id {
3378        return Err(anyhow!("cannot add yourself as a contact"))?;
3379    }
3380
3381    let notifications = session
3382        .db()
3383        .await
3384        .send_contact_request(requester_id, responder_id)
3385        .await?;
3386
3387    // Update outgoing contact requests of requester
3388    let mut update = proto::UpdateContacts::default();
3389    update.outgoing_requests.push(responder_id.to_proto());
3390    for connection_id in session
3391        .connection_pool()
3392        .await
3393        .user_connection_ids(requester_id)
3394    {
3395        session.peer.send(connection_id, update.clone())?;
3396    }
3397
3398    // Update incoming contact requests of responder
3399    let mut update = proto::UpdateContacts::default();
3400    update
3401        .incoming_requests
3402        .push(proto::IncomingContactRequest {
3403            requester_id: requester_id.to_proto(),
3404        });
3405    let connection_pool = session.connection_pool().await;
3406    for connection_id in connection_pool.user_connection_ids(responder_id) {
3407        session.peer.send(connection_id, update.clone())?;
3408    }
3409
3410    send_notifications(&connection_pool, &session.peer, notifications);
3411
3412    response.send(proto::Ack {})?;
3413    Ok(())
3414}
3415
3416/// Accept or decline a contact request
3417async fn respond_to_contact_request(
3418    request: proto::RespondToContactRequest,
3419    response: Response<proto::RespondToContactRequest>,
3420    session: UserSession,
3421) -> Result<()> {
3422    let responder_id = session.user_id();
3423    let requester_id = UserId::from_proto(request.requester_id);
3424    let db = session.db().await;
3425    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3426        db.dismiss_contact_notification(responder_id, requester_id)
3427            .await?;
3428    } else {
3429        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3430
3431        let notifications = db
3432            .respond_to_contact_request(responder_id, requester_id, accept)
3433            .await?;
3434        let requester_busy = db.is_user_busy(requester_id).await?;
3435        let responder_busy = db.is_user_busy(responder_id).await?;
3436
3437        let pool = session.connection_pool().await;
3438        // Update responder with new contact
3439        let mut update = proto::UpdateContacts::default();
3440        if accept {
3441            update
3442                .contacts
3443                .push(contact_for_user(requester_id, requester_busy, &pool));
3444        }
3445        update
3446            .remove_incoming_requests
3447            .push(requester_id.to_proto());
3448        for connection_id in pool.user_connection_ids(responder_id) {
3449            session.peer.send(connection_id, update.clone())?;
3450        }
3451
3452        // Update requester with new contact
3453        let mut update = proto::UpdateContacts::default();
3454        if accept {
3455            update
3456                .contacts
3457                .push(contact_for_user(responder_id, responder_busy, &pool));
3458        }
3459        update
3460            .remove_outgoing_requests
3461            .push(responder_id.to_proto());
3462
3463        for connection_id in pool.user_connection_ids(requester_id) {
3464            session.peer.send(connection_id, update.clone())?;
3465        }
3466
3467        send_notifications(&pool, &session.peer, notifications);
3468    }
3469
3470    response.send(proto::Ack {})?;
3471    Ok(())
3472}
3473
3474/// Remove a contact.
3475async fn remove_contact(
3476    request: proto::RemoveContact,
3477    response: Response<proto::RemoveContact>,
3478    session: UserSession,
3479) -> Result<()> {
3480    let requester_id = session.user_id();
3481    let responder_id = UserId::from_proto(request.user_id);
3482    let db = session.db().await;
3483    let (contact_accepted, deleted_notification_id) =
3484        db.remove_contact(requester_id, responder_id).await?;
3485
3486    let pool = session.connection_pool().await;
3487    // Update outgoing contact requests of requester
3488    let mut update = proto::UpdateContacts::default();
3489    if contact_accepted {
3490        update.remove_contacts.push(responder_id.to_proto());
3491    } else {
3492        update
3493            .remove_outgoing_requests
3494            .push(responder_id.to_proto());
3495    }
3496    for connection_id in pool.user_connection_ids(requester_id) {
3497        session.peer.send(connection_id, update.clone())?;
3498    }
3499
3500    // Update incoming contact requests of responder
3501    let mut update = proto::UpdateContacts::default();
3502    if contact_accepted {
3503        update.remove_contacts.push(requester_id.to_proto());
3504    } else {
3505        update
3506            .remove_incoming_requests
3507            .push(requester_id.to_proto());
3508    }
3509    for connection_id in pool.user_connection_ids(responder_id) {
3510        session.peer.send(connection_id, update.clone())?;
3511        if let Some(notification_id) = deleted_notification_id {
3512            session.peer.send(
3513                connection_id,
3514                proto::DeleteNotification {
3515                    notification_id: notification_id.to_proto(),
3516                },
3517            )?;
3518        }
3519    }
3520
3521    response.send(proto::Ack {})?;
3522    Ok(())
3523}
3524
3525fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3526    version.0.minor() < 139
3527}
3528
3529async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3530    subscribe_user_to_channels(
3531        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3532        &session,
3533    )
3534    .await?;
3535    Ok(())
3536}
3537
3538async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3539    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3540    let mut pool = session.connection_pool().await;
3541    for membership in &channels_for_user.channel_memberships {
3542        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3543    }
3544    session.peer.send(
3545        session.connection_id,
3546        build_update_user_channels(&channels_for_user),
3547    )?;
3548    session.peer.send(
3549        session.connection_id,
3550        build_channels_update(channels_for_user),
3551    )?;
3552    Ok(())
3553}
3554
3555/// Creates a new channel.
3556async fn create_channel(
3557    request: proto::CreateChannel,
3558    response: Response<proto::CreateChannel>,
3559    session: UserSession,
3560) -> Result<()> {
3561    let db = session.db().await;
3562
3563    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3564    let (channel, membership) = db
3565        .create_channel(&request.name, parent_id, session.user_id())
3566        .await?;
3567
3568    let root_id = channel.root_id();
3569    let channel = Channel::from_model(channel);
3570
3571    response.send(proto::CreateChannelResponse {
3572        channel: Some(channel.to_proto()),
3573        parent_id: request.parent_id,
3574    })?;
3575
3576    let mut connection_pool = session.connection_pool().await;
3577    if let Some(membership) = membership {
3578        connection_pool.subscribe_to_channel(
3579            membership.user_id,
3580            membership.channel_id,
3581            membership.role,
3582        );
3583        let update = proto::UpdateUserChannels {
3584            channel_memberships: vec![proto::ChannelMembership {
3585                channel_id: membership.channel_id.to_proto(),
3586                role: membership.role.into(),
3587            }],
3588            ..Default::default()
3589        };
3590        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3591            session.peer.send(connection_id, update.clone())?;
3592        }
3593    }
3594
3595    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3596        if !role.can_see_channel(channel.visibility) {
3597            continue;
3598        }
3599
3600        let update = proto::UpdateChannels {
3601            channels: vec![channel.to_proto()],
3602            ..Default::default()
3603        };
3604        session.peer.send(connection_id, update.clone())?;
3605    }
3606
3607    Ok(())
3608}
3609
3610/// Delete a channel
3611async fn delete_channel(
3612    request: proto::DeleteChannel,
3613    response: Response<proto::DeleteChannel>,
3614    session: UserSession,
3615) -> Result<()> {
3616    let db = session.db().await;
3617
3618    let channel_id = request.channel_id;
3619    let (root_channel, removed_channels) = db
3620        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3621        .await?;
3622    response.send(proto::Ack {})?;
3623
3624    // Notify members of removed channels
3625    let mut update = proto::UpdateChannels::default();
3626    update
3627        .delete_channels
3628        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3629
3630    let connection_pool = session.connection_pool().await;
3631    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3632        session.peer.send(connection_id, update.clone())?;
3633    }
3634
3635    Ok(())
3636}
3637
3638/// Invite someone to join a channel.
3639async fn invite_channel_member(
3640    request: proto::InviteChannelMember,
3641    response: Response<proto::InviteChannelMember>,
3642    session: UserSession,
3643) -> Result<()> {
3644    let db = session.db().await;
3645    let channel_id = ChannelId::from_proto(request.channel_id);
3646    let invitee_id = UserId::from_proto(request.user_id);
3647    let InviteMemberResult {
3648        channel,
3649        notifications,
3650    } = db
3651        .invite_channel_member(
3652            channel_id,
3653            invitee_id,
3654            session.user_id(),
3655            request.role().into(),
3656        )
3657        .await?;
3658
3659    let update = proto::UpdateChannels {
3660        channel_invitations: vec![channel.to_proto()],
3661        ..Default::default()
3662    };
3663
3664    let connection_pool = session.connection_pool().await;
3665    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3666        session.peer.send(connection_id, update.clone())?;
3667    }
3668
3669    send_notifications(&connection_pool, &session.peer, notifications);
3670
3671    response.send(proto::Ack {})?;
3672    Ok(())
3673}
3674
3675/// remove someone from a channel
3676async fn remove_channel_member(
3677    request: proto::RemoveChannelMember,
3678    response: Response<proto::RemoveChannelMember>,
3679    session: UserSession,
3680) -> Result<()> {
3681    let db = session.db().await;
3682    let channel_id = ChannelId::from_proto(request.channel_id);
3683    let member_id = UserId::from_proto(request.user_id);
3684
3685    let RemoveChannelMemberResult {
3686        membership_update,
3687        notification_id,
3688    } = db
3689        .remove_channel_member(channel_id, member_id, session.user_id())
3690        .await?;
3691
3692    let mut connection_pool = session.connection_pool().await;
3693    notify_membership_updated(
3694        &mut connection_pool,
3695        membership_update,
3696        member_id,
3697        &session.peer,
3698    );
3699    for connection_id in connection_pool.user_connection_ids(member_id) {
3700        if let Some(notification_id) = notification_id {
3701            session
3702                .peer
3703                .send(
3704                    connection_id,
3705                    proto::DeleteNotification {
3706                        notification_id: notification_id.to_proto(),
3707                    },
3708                )
3709                .trace_err();
3710        }
3711    }
3712
3713    response.send(proto::Ack {})?;
3714    Ok(())
3715}
3716
3717/// Toggle the channel between public and private.
3718/// Care is taken to maintain the invariant that public channels only descend from public channels,
3719/// (though members-only channels can appear at any point in the hierarchy).
3720async fn set_channel_visibility(
3721    request: proto::SetChannelVisibility,
3722    response: Response<proto::SetChannelVisibility>,
3723    session: UserSession,
3724) -> Result<()> {
3725    let db = session.db().await;
3726    let channel_id = ChannelId::from_proto(request.channel_id);
3727    let visibility = request.visibility().into();
3728
3729    let channel_model = db
3730        .set_channel_visibility(channel_id, visibility, session.user_id())
3731        .await?;
3732    let root_id = channel_model.root_id();
3733    let channel = Channel::from_model(channel_model);
3734
3735    let mut connection_pool = session.connection_pool().await;
3736    for (user_id, role) in connection_pool
3737        .channel_user_ids(root_id)
3738        .collect::<Vec<_>>()
3739        .into_iter()
3740    {
3741        let update = if role.can_see_channel(channel.visibility) {
3742            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3743            proto::UpdateChannels {
3744                channels: vec![channel.to_proto()],
3745                ..Default::default()
3746            }
3747        } else {
3748            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3749            proto::UpdateChannels {
3750                delete_channels: vec![channel.id.to_proto()],
3751                ..Default::default()
3752            }
3753        };
3754
3755        for connection_id in connection_pool.user_connection_ids(user_id) {
3756            session.peer.send(connection_id, update.clone())?;
3757        }
3758    }
3759
3760    response.send(proto::Ack {})?;
3761    Ok(())
3762}
3763
3764/// Alter the role for a user in the channel.
3765async fn set_channel_member_role(
3766    request: proto::SetChannelMemberRole,
3767    response: Response<proto::SetChannelMemberRole>,
3768    session: UserSession,
3769) -> Result<()> {
3770    let db = session.db().await;
3771    let channel_id = ChannelId::from_proto(request.channel_id);
3772    let member_id = UserId::from_proto(request.user_id);
3773    let result = db
3774        .set_channel_member_role(
3775            channel_id,
3776            session.user_id(),
3777            member_id,
3778            request.role().into(),
3779        )
3780        .await?;
3781
3782    match result {
3783        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3784            let mut connection_pool = session.connection_pool().await;
3785            notify_membership_updated(
3786                &mut connection_pool,
3787                membership_update,
3788                member_id,
3789                &session.peer,
3790            )
3791        }
3792        db::SetMemberRoleResult::InviteUpdated(channel) => {
3793            let update = proto::UpdateChannels {
3794                channel_invitations: vec![channel.to_proto()],
3795                ..Default::default()
3796            };
3797
3798            for connection_id in session
3799                .connection_pool()
3800                .await
3801                .user_connection_ids(member_id)
3802            {
3803                session.peer.send(connection_id, update.clone())?;
3804            }
3805        }
3806    }
3807
3808    response.send(proto::Ack {})?;
3809    Ok(())
3810}
3811
3812/// Change the name of a channel
3813async fn rename_channel(
3814    request: proto::RenameChannel,
3815    response: Response<proto::RenameChannel>,
3816    session: UserSession,
3817) -> Result<()> {
3818    let db = session.db().await;
3819    let channel_id = ChannelId::from_proto(request.channel_id);
3820    let channel_model = db
3821        .rename_channel(channel_id, session.user_id(), &request.name)
3822        .await?;
3823    let root_id = channel_model.root_id();
3824    let channel = Channel::from_model(channel_model);
3825
3826    response.send(proto::RenameChannelResponse {
3827        channel: Some(channel.to_proto()),
3828    })?;
3829
3830    let connection_pool = session.connection_pool().await;
3831    let update = proto::UpdateChannels {
3832        channels: vec![channel.to_proto()],
3833        ..Default::default()
3834    };
3835    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3836        if role.can_see_channel(channel.visibility) {
3837            session.peer.send(connection_id, update.clone())?;
3838        }
3839    }
3840
3841    Ok(())
3842}
3843
3844/// Move a channel to a new parent.
3845async fn move_channel(
3846    request: proto::MoveChannel,
3847    response: Response<proto::MoveChannel>,
3848    session: UserSession,
3849) -> Result<()> {
3850    let channel_id = ChannelId::from_proto(request.channel_id);
3851    let to = ChannelId::from_proto(request.to);
3852
3853    let (root_id, channels) = session
3854        .db()
3855        .await
3856        .move_channel(channel_id, to, session.user_id())
3857        .await?;
3858
3859    let connection_pool = session.connection_pool().await;
3860    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3861        let channels = channels
3862            .iter()
3863            .filter_map(|channel| {
3864                if role.can_see_channel(channel.visibility) {
3865                    Some(channel.to_proto())
3866                } else {
3867                    None
3868                }
3869            })
3870            .collect::<Vec<_>>();
3871        if channels.is_empty() {
3872            continue;
3873        }
3874
3875        let update = proto::UpdateChannels {
3876            channels,
3877            ..Default::default()
3878        };
3879
3880        session.peer.send(connection_id, update.clone())?;
3881    }
3882
3883    response.send(Ack {})?;
3884    Ok(())
3885}
3886
3887/// Get the list of channel members
3888async fn get_channel_members(
3889    request: proto::GetChannelMembers,
3890    response: Response<proto::GetChannelMembers>,
3891    session: UserSession,
3892) -> Result<()> {
3893    let db = session.db().await;
3894    let channel_id = ChannelId::from_proto(request.channel_id);
3895    let limit = if request.limit == 0 {
3896        u16::MAX as u64
3897    } else {
3898        request.limit
3899    };
3900    let (members, users) = db
3901        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3902        .await?;
3903    response.send(proto::GetChannelMembersResponse { members, users })?;
3904    Ok(())
3905}
3906
3907/// Accept or decline a channel invitation.
3908async fn respond_to_channel_invite(
3909    request: proto::RespondToChannelInvite,
3910    response: Response<proto::RespondToChannelInvite>,
3911    session: UserSession,
3912) -> Result<()> {
3913    let db = session.db().await;
3914    let channel_id = ChannelId::from_proto(request.channel_id);
3915    let RespondToChannelInvite {
3916        membership_update,
3917        notifications,
3918    } = db
3919        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3920        .await?;
3921
3922    let mut connection_pool = session.connection_pool().await;
3923    if let Some(membership_update) = membership_update {
3924        notify_membership_updated(
3925            &mut connection_pool,
3926            membership_update,
3927            session.user_id(),
3928            &session.peer,
3929        );
3930    } else {
3931        let update = proto::UpdateChannels {
3932            remove_channel_invitations: vec![channel_id.to_proto()],
3933            ..Default::default()
3934        };
3935
3936        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3937            session.peer.send(connection_id, update.clone())?;
3938        }
3939    };
3940
3941    send_notifications(&connection_pool, &session.peer, notifications);
3942
3943    response.send(proto::Ack {})?;
3944
3945    Ok(())
3946}
3947
3948/// Join the channels' room
3949async fn join_channel(
3950    request: proto::JoinChannel,
3951    response: Response<proto::JoinChannel>,
3952    session: UserSession,
3953) -> Result<()> {
3954    let channel_id = ChannelId::from_proto(request.channel_id);
3955    join_channel_internal(channel_id, Box::new(response), session).await
3956}
3957
3958trait JoinChannelInternalResponse {
3959    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3960}
3961impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3962    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3963        Response::<proto::JoinChannel>::send(self, result)
3964    }
3965}
3966impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3967    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3968        Response::<proto::JoinRoom>::send(self, result)
3969    }
3970}
3971
3972async fn join_channel_internal(
3973    channel_id: ChannelId,
3974    response: Box<impl JoinChannelInternalResponse>,
3975    session: UserSession,
3976) -> Result<()> {
3977    let joined_room = {
3978        let mut db = session.db().await;
3979        // If zed quits without leaving the room, and the user re-opens zed before the
3980        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3981        // room they were in.
3982        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3983            tracing::info!(
3984                stale_connection_id = %connection,
3985                "cleaning up stale connection",
3986            );
3987            drop(db);
3988            leave_room_for_session(&session, connection).await?;
3989            db = session.db().await;
3990        }
3991
3992        let (joined_room, membership_updated, role) = db
3993            .join_channel(channel_id, session.user_id(), session.connection_id)
3994            .await?;
3995
3996        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3997            let (can_publish, token) = if role == ChannelRole::Guest {
3998                (
3999                    false,
4000                    live_kit
4001                        .guest_token(
4002                            &joined_room.room.live_kit_room,
4003                            &session.user_id().to_string(),
4004                        )
4005                        .trace_err()?,
4006                )
4007            } else {
4008                (
4009                    true,
4010                    live_kit
4011                        .room_token(
4012                            &joined_room.room.live_kit_room,
4013                            &session.user_id().to_string(),
4014                        )
4015                        .trace_err()?,
4016                )
4017            };
4018
4019            Some(LiveKitConnectionInfo {
4020                server_url: live_kit.url().into(),
4021                token,
4022                can_publish,
4023            })
4024        });
4025
4026        response.send(proto::JoinRoomResponse {
4027            room: Some(joined_room.room.clone()),
4028            channel_id: joined_room
4029                .channel
4030                .as_ref()
4031                .map(|channel| channel.id.to_proto()),
4032            live_kit_connection_info,
4033        })?;
4034
4035        let mut connection_pool = session.connection_pool().await;
4036        if let Some(membership_updated) = membership_updated {
4037            notify_membership_updated(
4038                &mut connection_pool,
4039                membership_updated,
4040                session.user_id(),
4041                &session.peer,
4042            );
4043        }
4044
4045        room_updated(&joined_room.room, &session.peer);
4046
4047        joined_room
4048    };
4049
4050    channel_updated(
4051        &joined_room
4052            .channel
4053            .ok_or_else(|| anyhow!("channel not returned"))?,
4054        &joined_room.room,
4055        &session.peer,
4056        &*session.connection_pool().await,
4057    );
4058
4059    update_user_contacts(session.user_id(), &session).await?;
4060    Ok(())
4061}
4062
4063/// Start editing the channel notes
4064async fn join_channel_buffer(
4065    request: proto::JoinChannelBuffer,
4066    response: Response<proto::JoinChannelBuffer>,
4067    session: UserSession,
4068) -> Result<()> {
4069    let db = session.db().await;
4070    let channel_id = ChannelId::from_proto(request.channel_id);
4071
4072    let open_response = db
4073        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4074        .await?;
4075
4076    let collaborators = open_response.collaborators.clone();
4077    response.send(open_response)?;
4078
4079    let update = UpdateChannelBufferCollaborators {
4080        channel_id: channel_id.to_proto(),
4081        collaborators: collaborators.clone(),
4082    };
4083    channel_buffer_updated(
4084        session.connection_id,
4085        collaborators
4086            .iter()
4087            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4088        &update,
4089        &session.peer,
4090    );
4091
4092    Ok(())
4093}
4094
4095/// Edit the channel notes
4096async fn update_channel_buffer(
4097    request: proto::UpdateChannelBuffer,
4098    session: UserSession,
4099) -> Result<()> {
4100    let db = session.db().await;
4101    let channel_id = ChannelId::from_proto(request.channel_id);
4102
4103    let (collaborators, epoch, version) = db
4104        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4105        .await?;
4106
4107    channel_buffer_updated(
4108        session.connection_id,
4109        collaborators.clone(),
4110        &proto::UpdateChannelBuffer {
4111            channel_id: channel_id.to_proto(),
4112            operations: request.operations,
4113        },
4114        &session.peer,
4115    );
4116
4117    let pool = &*session.connection_pool().await;
4118
4119    let non_collaborators =
4120        pool.channel_connection_ids(channel_id)
4121            .filter_map(|(connection_id, _)| {
4122                if collaborators.contains(&connection_id) {
4123                    None
4124                } else {
4125                    Some(connection_id)
4126                }
4127            });
4128
4129    broadcast(None, non_collaborators, |peer_id| {
4130        session.peer.send(
4131            peer_id,
4132            proto::UpdateChannels {
4133                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4134                    channel_id: channel_id.to_proto(),
4135                    epoch: epoch as u64,
4136                    version: version.clone(),
4137                }],
4138                ..Default::default()
4139            },
4140        )
4141    });
4142
4143    Ok(())
4144}
4145
4146/// Rejoin the channel notes after a connection blip
4147async fn rejoin_channel_buffers(
4148    request: proto::RejoinChannelBuffers,
4149    response: Response<proto::RejoinChannelBuffers>,
4150    session: UserSession,
4151) -> Result<()> {
4152    let db = session.db().await;
4153    let buffers = db
4154        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4155        .await?;
4156
4157    for rejoined_buffer in &buffers {
4158        let collaborators_to_notify = rejoined_buffer
4159            .buffer
4160            .collaborators
4161            .iter()
4162            .filter_map(|c| Some(c.peer_id?.into()));
4163        channel_buffer_updated(
4164            session.connection_id,
4165            collaborators_to_notify,
4166            &proto::UpdateChannelBufferCollaborators {
4167                channel_id: rejoined_buffer.buffer.channel_id,
4168                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4169            },
4170            &session.peer,
4171        );
4172    }
4173
4174    response.send(proto::RejoinChannelBuffersResponse {
4175        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4176    })?;
4177
4178    Ok(())
4179}
4180
4181/// Stop editing the channel notes
4182async fn leave_channel_buffer(
4183    request: proto::LeaveChannelBuffer,
4184    response: Response<proto::LeaveChannelBuffer>,
4185    session: UserSession,
4186) -> Result<()> {
4187    let db = session.db().await;
4188    let channel_id = ChannelId::from_proto(request.channel_id);
4189
4190    let left_buffer = db
4191        .leave_channel_buffer(channel_id, session.connection_id)
4192        .await?;
4193
4194    response.send(Ack {})?;
4195
4196    channel_buffer_updated(
4197        session.connection_id,
4198        left_buffer.connections,
4199        &proto::UpdateChannelBufferCollaborators {
4200            channel_id: channel_id.to_proto(),
4201            collaborators: left_buffer.collaborators,
4202        },
4203        &session.peer,
4204    );
4205
4206    Ok(())
4207}
4208
4209fn channel_buffer_updated<T: EnvelopedMessage>(
4210    sender_id: ConnectionId,
4211    collaborators: impl IntoIterator<Item = ConnectionId>,
4212    message: &T,
4213    peer: &Peer,
4214) {
4215    broadcast(Some(sender_id), collaborators, |peer_id| {
4216        peer.send(peer_id, message.clone())
4217    });
4218}
4219
4220fn send_notifications(
4221    connection_pool: &ConnectionPool,
4222    peer: &Peer,
4223    notifications: db::NotificationBatch,
4224) {
4225    for (user_id, notification) in notifications {
4226        for connection_id in connection_pool.user_connection_ids(user_id) {
4227            if let Err(error) = peer.send(
4228                connection_id,
4229                proto::AddNotification {
4230                    notification: Some(notification.clone()),
4231                },
4232            ) {
4233                tracing::error!(
4234                    "failed to send notification to {:?} {}",
4235                    connection_id,
4236                    error
4237                );
4238            }
4239        }
4240    }
4241}
4242
4243/// Send a message to the channel
4244async fn send_channel_message(
4245    request: proto::SendChannelMessage,
4246    response: Response<proto::SendChannelMessage>,
4247    session: UserSession,
4248) -> Result<()> {
4249    // Validate the message body.
4250    let body = request.body.trim().to_string();
4251    if body.len() > MAX_MESSAGE_LEN {
4252        return Err(anyhow!("message is too long"))?;
4253    }
4254    if body.is_empty() {
4255        return Err(anyhow!("message can't be blank"))?;
4256    }
4257
4258    // TODO: adjust mentions if body is trimmed
4259
4260    let timestamp = OffsetDateTime::now_utc();
4261    let nonce = request
4262        .nonce
4263        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4264
4265    let channel_id = ChannelId::from_proto(request.channel_id);
4266    let CreatedChannelMessage {
4267        message_id,
4268        participant_connection_ids,
4269        notifications,
4270    } = session
4271        .db()
4272        .await
4273        .create_channel_message(
4274            channel_id,
4275            session.user_id(),
4276            &body,
4277            &request.mentions,
4278            timestamp,
4279            nonce.clone().into(),
4280            match request.reply_to_message_id {
4281                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4282                None => None,
4283            },
4284        )
4285        .await?;
4286
4287    let message = proto::ChannelMessage {
4288        sender_id: session.user_id().to_proto(),
4289        id: message_id.to_proto(),
4290        body,
4291        mentions: request.mentions,
4292        timestamp: timestamp.unix_timestamp() as u64,
4293        nonce: Some(nonce),
4294        reply_to_message_id: request.reply_to_message_id,
4295        edited_at: None,
4296    };
4297    broadcast(
4298        Some(session.connection_id),
4299        participant_connection_ids.clone(),
4300        |connection| {
4301            session.peer.send(
4302                connection,
4303                proto::ChannelMessageSent {
4304                    channel_id: channel_id.to_proto(),
4305                    message: Some(message.clone()),
4306                },
4307            )
4308        },
4309    );
4310    response.send(proto::SendChannelMessageResponse {
4311        message: Some(message),
4312    })?;
4313
4314    let pool = &*session.connection_pool().await;
4315    let non_participants =
4316        pool.channel_connection_ids(channel_id)
4317            .filter_map(|(connection_id, _)| {
4318                if participant_connection_ids.contains(&connection_id) {
4319                    None
4320                } else {
4321                    Some(connection_id)
4322                }
4323            });
4324    broadcast(None, non_participants, |peer_id| {
4325        session.peer.send(
4326            peer_id,
4327            proto::UpdateChannels {
4328                latest_channel_message_ids: vec![proto::ChannelMessageId {
4329                    channel_id: channel_id.to_proto(),
4330                    message_id: message_id.to_proto(),
4331                }],
4332                ..Default::default()
4333            },
4334        )
4335    });
4336    send_notifications(pool, &session.peer, notifications);
4337
4338    Ok(())
4339}
4340
4341/// Delete a channel message
4342async fn remove_channel_message(
4343    request: proto::RemoveChannelMessage,
4344    response: Response<proto::RemoveChannelMessage>,
4345    session: UserSession,
4346) -> Result<()> {
4347    let channel_id = ChannelId::from_proto(request.channel_id);
4348    let message_id = MessageId::from_proto(request.message_id);
4349    let (connection_ids, existing_notification_ids) = session
4350        .db()
4351        .await
4352        .remove_channel_message(channel_id, message_id, session.user_id())
4353        .await?;
4354
4355    broadcast(
4356        Some(session.connection_id),
4357        connection_ids,
4358        move |connection| {
4359            session.peer.send(connection, request.clone())?;
4360
4361            for notification_id in &existing_notification_ids {
4362                session.peer.send(
4363                    connection,
4364                    proto::DeleteNotification {
4365                        notification_id: (*notification_id).to_proto(),
4366                    },
4367                )?;
4368            }
4369
4370            Ok(())
4371        },
4372    );
4373    response.send(proto::Ack {})?;
4374    Ok(())
4375}
4376
4377async fn update_channel_message(
4378    request: proto::UpdateChannelMessage,
4379    response: Response<proto::UpdateChannelMessage>,
4380    session: UserSession,
4381) -> Result<()> {
4382    let channel_id = ChannelId::from_proto(request.channel_id);
4383    let message_id = MessageId::from_proto(request.message_id);
4384    let updated_at = OffsetDateTime::now_utc();
4385    let UpdatedChannelMessage {
4386        message_id,
4387        participant_connection_ids,
4388        notifications,
4389        reply_to_message_id,
4390        timestamp,
4391        deleted_mention_notification_ids,
4392        updated_mention_notifications,
4393    } = session
4394        .db()
4395        .await
4396        .update_channel_message(
4397            channel_id,
4398            message_id,
4399            session.user_id(),
4400            request.body.as_str(),
4401            &request.mentions,
4402            updated_at,
4403        )
4404        .await?;
4405
4406    let nonce = request
4407        .nonce
4408        .clone()
4409        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4410
4411    let message = proto::ChannelMessage {
4412        sender_id: session.user_id().to_proto(),
4413        id: message_id.to_proto(),
4414        body: request.body.clone(),
4415        mentions: request.mentions.clone(),
4416        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4417        nonce: Some(nonce),
4418        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4419        edited_at: Some(updated_at.unix_timestamp() as u64),
4420    };
4421
4422    response.send(proto::Ack {})?;
4423
4424    let pool = &*session.connection_pool().await;
4425    broadcast(
4426        Some(session.connection_id),
4427        participant_connection_ids,
4428        |connection| {
4429            session.peer.send(
4430                connection,
4431                proto::ChannelMessageUpdate {
4432                    channel_id: channel_id.to_proto(),
4433                    message: Some(message.clone()),
4434                },
4435            )?;
4436
4437            for notification_id in &deleted_mention_notification_ids {
4438                session.peer.send(
4439                    connection,
4440                    proto::DeleteNotification {
4441                        notification_id: (*notification_id).to_proto(),
4442                    },
4443                )?;
4444            }
4445
4446            for notification in &updated_mention_notifications {
4447                session.peer.send(
4448                    connection,
4449                    proto::UpdateNotification {
4450                        notification: Some(notification.clone()),
4451                    },
4452                )?;
4453            }
4454
4455            Ok(())
4456        },
4457    );
4458
4459    send_notifications(pool, &session.peer, notifications);
4460
4461    Ok(())
4462}
4463
4464/// Mark a channel message as read
4465async fn acknowledge_channel_message(
4466    request: proto::AckChannelMessage,
4467    session: UserSession,
4468) -> Result<()> {
4469    let channel_id = ChannelId::from_proto(request.channel_id);
4470    let message_id = MessageId::from_proto(request.message_id);
4471    let notifications = session
4472        .db()
4473        .await
4474        .observe_channel_message(channel_id, session.user_id(), message_id)
4475        .await?;
4476    send_notifications(
4477        &*session.connection_pool().await,
4478        &session.peer,
4479        notifications,
4480    );
4481    Ok(())
4482}
4483
4484/// Mark a buffer version as synced
4485async fn acknowledge_buffer_version(
4486    request: proto::AckBufferOperation,
4487    session: UserSession,
4488) -> Result<()> {
4489    let buffer_id = BufferId::from_proto(request.buffer_id);
4490    session
4491        .db()
4492        .await
4493        .observe_buffer_version(
4494            buffer_id,
4495            session.user_id(),
4496            request.epoch as i32,
4497            &request.version,
4498        )
4499        .await?;
4500    Ok(())
4501}
4502
4503struct CompleteWithLanguageModelRateLimit;
4504
4505impl RateLimit for CompleteWithLanguageModelRateLimit {
4506    fn capacity() -> usize {
4507        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4508            .ok()
4509            .and_then(|v| v.parse().ok())
4510            .unwrap_or(120) // Picked arbitrarily
4511    }
4512
4513    fn refill_duration() -> chrono::Duration {
4514        chrono::Duration::hours(1)
4515    }
4516
4517    fn db_name() -> &'static str {
4518        "complete-with-language-model"
4519    }
4520}
4521
4522async fn complete_with_language_model(
4523    mut request: proto::CompleteWithLanguageModel,
4524    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4525    session: Session,
4526    open_ai_api_key: Option<Arc<str>>,
4527    google_ai_api_key: Option<Arc<str>>,
4528    anthropic_api_key: Option<Arc<str>>,
4529) -> Result<()> {
4530    let Some(session) = session.for_user() else {
4531        return Err(anyhow!("user not found"))?;
4532    };
4533    authorize_access_to_language_models(&session).await?;
4534    session
4535        .rate_limiter
4536        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4537        .await?;
4538
4539    let mut provider_and_model = request.model.split('/');
4540    let (provider, model) = match (
4541        provider_and_model.next().unwrap(),
4542        provider_and_model.next(),
4543    ) {
4544        (provider, Some(model)) => (provider, model),
4545        (model, None) => {
4546            if model.starts_with("gpt") {
4547                ("openai", model)
4548            } else if model.starts_with("gemini") {
4549                ("google", model)
4550            } else if model.starts_with("claude") {
4551                ("anthropic", model)
4552            } else {
4553                ("unknown", model)
4554            }
4555        }
4556    };
4557    let provider = provider.to_string();
4558    request.model = model.to_string();
4559
4560    match provider.as_str() {
4561        "openai" => {
4562            let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
4563            complete_with_open_ai(request, response, session, api_key).await?;
4564        }
4565        "anthropic" => {
4566            let api_key =
4567                anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
4568            complete_with_anthropic(request, response, session, api_key).await?;
4569        }
4570        "google" => {
4571            let api_key =
4572                google_ai_api_key.context("no Google AI API key configured on the server")?;
4573            complete_with_google_ai(request, response, session, api_key).await?;
4574        }
4575        provider => return Err(anyhow!("unknown provider {:?}", provider))?,
4576    }
4577
4578    Ok(())
4579}
4580
4581async fn complete_with_open_ai(
4582    request: proto::CompleteWithLanguageModel,
4583    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4584    session: UserSession,
4585    api_key: Arc<str>,
4586) -> Result<()> {
4587    let mut completion_stream = open_ai::stream_completion(
4588        session.http_client.as_ref(),
4589        OPEN_AI_API_URL,
4590        &api_key,
4591        crate::ai::language_model_request_to_open_ai(request)?,
4592        None,
4593    )
4594    .await
4595    .context("open_ai::stream_completion request failed within collab")?;
4596
4597    while let Some(event) = completion_stream.next().await {
4598        let event = event?;
4599        response.send(proto::LanguageModelResponse {
4600            choices: event
4601                .choices
4602                .into_iter()
4603                .map(|choice| proto::LanguageModelChoiceDelta {
4604                    index: choice.index,
4605                    delta: Some(proto::LanguageModelResponseMessage {
4606                        role: choice.delta.role.map(|role| match role {
4607                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4608                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4609                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4610                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4611                        } as i32),
4612                        content: choice.delta.content,
4613                        tool_calls: choice
4614                            .delta
4615                            .tool_calls
4616                            .unwrap_or_default()
4617                            .into_iter()
4618                            .map(|delta| proto::ToolCallDelta {
4619                                index: delta.index as u32,
4620                                id: delta.id,
4621                                variant: match delta.function {
4622                                    Some(function) => {
4623                                        let name = function.name;
4624                                        let arguments = function.arguments;
4625
4626                                        Some(proto::tool_call_delta::Variant::Function(
4627                                            proto::tool_call_delta::FunctionCallDelta {
4628                                                name,
4629                                                arguments,
4630                                            },
4631                                        ))
4632                                    }
4633                                    None => None,
4634                                },
4635                            })
4636                            .collect(),
4637                    }),
4638                    finish_reason: choice.finish_reason,
4639                })
4640                .collect(),
4641        })?;
4642    }
4643
4644    Ok(())
4645}
4646
4647async fn complete_with_google_ai(
4648    request: proto::CompleteWithLanguageModel,
4649    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4650    session: UserSession,
4651    api_key: Arc<str>,
4652) -> Result<()> {
4653    let mut stream = google_ai::stream_generate_content(
4654        session.http_client.clone(),
4655        google_ai::API_URL,
4656        api_key.as_ref(),
4657        &request.model.clone(),
4658        crate::ai::language_model_request_to_google_ai(request)?,
4659    )
4660    .await
4661    .context("google_ai::stream_generate_content request failed")?;
4662
4663    while let Some(event) = stream.next().await {
4664        let event = event?;
4665        response.send(proto::LanguageModelResponse {
4666            choices: event
4667                .candidates
4668                .unwrap_or_default()
4669                .into_iter()
4670                .map(|candidate| proto::LanguageModelChoiceDelta {
4671                    index: candidate.index as u32,
4672                    delta: Some(proto::LanguageModelResponseMessage {
4673                        role: Some(match candidate.content.role {
4674                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4675                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4676                        } as i32),
4677                        content: Some(
4678                            candidate
4679                                .content
4680                                .parts
4681                                .into_iter()
4682                                .filter_map(|part| match part {
4683                                    google_ai::Part::TextPart(part) => Some(part.text),
4684                                    google_ai::Part::InlineDataPart(_) => None,
4685                                })
4686                                .collect(),
4687                        ),
4688                        // Tool calls are not supported for Google
4689                        tool_calls: Vec::new(),
4690                    }),
4691                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4692                })
4693                .collect(),
4694        })?;
4695    }
4696
4697    Ok(())
4698}
4699
4700async fn complete_with_anthropic(
4701    request: proto::CompleteWithLanguageModel,
4702    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4703    session: UserSession,
4704    api_key: Arc<str>,
4705) -> Result<()> {
4706    let model = anthropic::Model::from_id(&request.model)?;
4707
4708    let mut system_message = String::new();
4709    let messages = request
4710        .messages
4711        .into_iter()
4712        .filter_map(|message| {
4713            match message.role() {
4714                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4715                    role: anthropic::Role::User,
4716                    content: message.content,
4717                }),
4718                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4719                    role: anthropic::Role::Assistant,
4720                    content: message.content,
4721                }),
4722                // Anthropic's API breaks system instructions out as a separate field rather
4723                // than having a system message role.
4724                LanguageModelRole::LanguageModelSystem => {
4725                    if !system_message.is_empty() {
4726                        system_message.push_str("\n\n");
4727                    }
4728                    system_message.push_str(&message.content);
4729
4730                    None
4731                }
4732                // We don't yet support tool calls for Anthropic
4733                LanguageModelRole::LanguageModelTool => None,
4734            }
4735        })
4736        .collect();
4737
4738    let mut stream = anthropic::stream_completion(
4739        session.http_client.as_ref(),
4740        anthropic::ANTHROPIC_API_URL,
4741        &api_key,
4742        anthropic::Request {
4743            model,
4744            messages,
4745            stream: true,
4746            system: system_message,
4747            max_tokens: 4092,
4748        },
4749        None,
4750    )
4751    .await?;
4752
4753    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4754
4755    while let Some(event) = stream.next().await {
4756        let event = event?;
4757
4758        match event {
4759            anthropic::ResponseEvent::MessageStart { message } => {
4760                if let Some(role) = message.role {
4761                    if role == "assistant" {
4762                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4763                    } else if role == "user" {
4764                        current_role = proto::LanguageModelRole::LanguageModelUser;
4765                    }
4766                }
4767            }
4768            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4769                match content_block {
4770                    anthropic::ContentBlock::Text { text } => {
4771                        if !text.is_empty() {
4772                            response.send(proto::LanguageModelResponse {
4773                                choices: vec![proto::LanguageModelChoiceDelta {
4774                                    index: 0,
4775                                    delta: Some(proto::LanguageModelResponseMessage {
4776                                        role: Some(current_role as i32),
4777                                        content: Some(text),
4778                                        tool_calls: Vec::new(),
4779                                    }),
4780                                    finish_reason: None,
4781                                }],
4782                            })?;
4783                        }
4784                    }
4785                }
4786            }
4787            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4788                anthropic::TextDelta::TextDelta { text } => {
4789                    response.send(proto::LanguageModelResponse {
4790                        choices: vec![proto::LanguageModelChoiceDelta {
4791                            index: 0,
4792                            delta: Some(proto::LanguageModelResponseMessage {
4793                                role: Some(current_role as i32),
4794                                content: Some(text),
4795                                tool_calls: Vec::new(),
4796                            }),
4797                            finish_reason: None,
4798                        }],
4799                    })?;
4800                }
4801            },
4802            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4803                if let Some(stop_reason) = delta.stop_reason {
4804                    response.send(proto::LanguageModelResponse {
4805                        choices: vec![proto::LanguageModelChoiceDelta {
4806                            index: 0,
4807                            delta: None,
4808                            finish_reason: Some(stop_reason),
4809                        }],
4810                    })?;
4811                }
4812            }
4813            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4814            anthropic::ResponseEvent::MessageStop {} => {}
4815            anthropic::ResponseEvent::Ping {} => {}
4816        }
4817    }
4818
4819    Ok(())
4820}
4821
4822struct CountTokensWithLanguageModelRateLimit;
4823
4824impl RateLimit for CountTokensWithLanguageModelRateLimit {
4825    fn capacity() -> usize {
4826        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4827            .ok()
4828            .and_then(|v| v.parse().ok())
4829            .unwrap_or(600) // Picked arbitrarily
4830    }
4831
4832    fn refill_duration() -> chrono::Duration {
4833        chrono::Duration::hours(1)
4834    }
4835
4836    fn db_name() -> &'static str {
4837        "count-tokens-with-language-model"
4838    }
4839}
4840
4841async fn count_tokens_with_language_model(
4842    request: proto::CountTokensWithLanguageModel,
4843    response: Response<proto::CountTokensWithLanguageModel>,
4844    session: UserSession,
4845    google_ai_api_key: Option<Arc<str>>,
4846) -> Result<()> {
4847    authorize_access_to_language_models(&session).await?;
4848
4849    if !request.model.starts_with("gemini") {
4850        return Err(anyhow!(
4851            "counting tokens for model: {:?} is not supported",
4852            request.model
4853        ))?;
4854    }
4855
4856    session
4857        .rate_limiter
4858        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4859        .await?;
4860
4861    let api_key = google_ai_api_key
4862        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4863    let tokens_response = google_ai::count_tokens(
4864        session.http_client.as_ref(),
4865        google_ai::API_URL,
4866        &api_key,
4867        crate::ai::count_tokens_request_to_google_ai(request)?,
4868    )
4869    .await?;
4870    response.send(proto::CountTokensResponse {
4871        token_count: tokens_response.total_tokens as u32,
4872    })?;
4873    Ok(())
4874}
4875
4876struct ComputeEmbeddingsRateLimit;
4877
4878impl RateLimit for ComputeEmbeddingsRateLimit {
4879    fn capacity() -> usize {
4880        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4881            .ok()
4882            .and_then(|v| v.parse().ok())
4883            .unwrap_or(5000) // Picked arbitrarily
4884    }
4885
4886    fn refill_duration() -> chrono::Duration {
4887        chrono::Duration::hours(1)
4888    }
4889
4890    fn db_name() -> &'static str {
4891        "compute-embeddings"
4892    }
4893}
4894
4895async fn compute_embeddings(
4896    request: proto::ComputeEmbeddings,
4897    response: Response<proto::ComputeEmbeddings>,
4898    session: UserSession,
4899    api_key: Option<Arc<str>>,
4900) -> Result<()> {
4901    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4902    authorize_access_to_language_models(&session).await?;
4903
4904    session
4905        .rate_limiter
4906        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4907        .await?;
4908
4909    let embeddings = match request.model.as_str() {
4910        "openai/text-embedding-3-small" => {
4911            open_ai::embed(
4912                session.http_client.as_ref(),
4913                OPEN_AI_API_URL,
4914                &api_key,
4915                OpenAiEmbeddingModel::TextEmbedding3Small,
4916                request.texts.iter().map(|text| text.as_str()),
4917            )
4918            .await?
4919        }
4920        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4921    };
4922
4923    let embeddings = request
4924        .texts
4925        .iter()
4926        .map(|text| {
4927            let mut hasher = sha2::Sha256::new();
4928            hasher.update(text.as_bytes());
4929            let result = hasher.finalize();
4930            result.to_vec()
4931        })
4932        .zip(
4933            embeddings
4934                .data
4935                .into_iter()
4936                .map(|embedding| embedding.embedding),
4937        )
4938        .collect::<HashMap<_, _>>();
4939
4940    let db = session.db().await;
4941    db.save_embeddings(&request.model, &embeddings)
4942        .await
4943        .context("failed to save embeddings")
4944        .trace_err();
4945
4946    response.send(proto::ComputeEmbeddingsResponse {
4947        embeddings: embeddings
4948            .into_iter()
4949            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4950            .collect(),
4951    })?;
4952    Ok(())
4953}
4954
4955async fn get_cached_embeddings(
4956    request: proto::GetCachedEmbeddings,
4957    response: Response<proto::GetCachedEmbeddings>,
4958    session: UserSession,
4959) -> Result<()> {
4960    authorize_access_to_language_models(&session).await?;
4961
4962    let db = session.db().await;
4963    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4964
4965    response.send(proto::GetCachedEmbeddingsResponse {
4966        embeddings: embeddings
4967            .into_iter()
4968            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4969            .collect(),
4970    })?;
4971    Ok(())
4972}
4973
4974async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4975    let db = session.db().await;
4976    let flags = db.get_user_flags(session.user_id()).await?;
4977    if flags.iter().any(|flag| flag == "language-models") {
4978        Ok(())
4979    } else {
4980        Err(anyhow!("permission denied"))?
4981    }
4982}
4983
4984/// Get a Supermaven API key for the user
4985async fn get_supermaven_api_key(
4986    _request: proto::GetSupermavenApiKey,
4987    response: Response<proto::GetSupermavenApiKey>,
4988    session: UserSession,
4989) -> Result<()> {
4990    let user_id: String = session.user_id().to_string();
4991    if !session.is_staff() {
4992        return Err(anyhow!("supermaven not enabled for this account"))?;
4993    }
4994
4995    let email = session
4996        .email()
4997        .ok_or_else(|| anyhow!("user must have an email"))?;
4998
4999    let supermaven_admin_api = session
5000        .supermaven_client
5001        .as_ref()
5002        .ok_or_else(|| anyhow!("supermaven not configured"))?;
5003
5004    let result = supermaven_admin_api
5005        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
5006        .await?;
5007
5008    response.send(proto::GetSupermavenApiKeyResponse {
5009        api_key: result.api_key,
5010    })?;
5011
5012    Ok(())
5013}
5014
5015/// Start receiving chat updates for a channel
5016async fn join_channel_chat(
5017    request: proto::JoinChannelChat,
5018    response: Response<proto::JoinChannelChat>,
5019    session: UserSession,
5020) -> Result<()> {
5021    let channel_id = ChannelId::from_proto(request.channel_id);
5022
5023    let db = session.db().await;
5024    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
5025        .await?;
5026    let messages = db
5027        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
5028        .await?;
5029    response.send(proto::JoinChannelChatResponse {
5030        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5031        messages,
5032    })?;
5033    Ok(())
5034}
5035
5036/// Stop receiving chat updates for a channel
5037async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5038    let channel_id = ChannelId::from_proto(request.channel_id);
5039    session
5040        .db()
5041        .await
5042        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5043        .await?;
5044    Ok(())
5045}
5046
5047/// Retrieve the chat history for a channel
5048async fn get_channel_messages(
5049    request: proto::GetChannelMessages,
5050    response: Response<proto::GetChannelMessages>,
5051    session: UserSession,
5052) -> Result<()> {
5053    let channel_id = ChannelId::from_proto(request.channel_id);
5054    let messages = session
5055        .db()
5056        .await
5057        .get_channel_messages(
5058            channel_id,
5059            session.user_id(),
5060            MESSAGE_COUNT_PER_PAGE,
5061            Some(MessageId::from_proto(request.before_message_id)),
5062        )
5063        .await?;
5064    response.send(proto::GetChannelMessagesResponse {
5065        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5066        messages,
5067    })?;
5068    Ok(())
5069}
5070
5071/// Retrieve specific chat messages
5072async fn get_channel_messages_by_id(
5073    request: proto::GetChannelMessagesById,
5074    response: Response<proto::GetChannelMessagesById>,
5075    session: UserSession,
5076) -> Result<()> {
5077    let message_ids = request
5078        .message_ids
5079        .iter()
5080        .map(|id| MessageId::from_proto(*id))
5081        .collect::<Vec<_>>();
5082    let messages = session
5083        .db()
5084        .await
5085        .get_channel_messages_by_id(session.user_id(), &message_ids)
5086        .await?;
5087    response.send(proto::GetChannelMessagesResponse {
5088        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5089        messages,
5090    })?;
5091    Ok(())
5092}
5093
5094/// Retrieve the current users notifications
5095async fn get_notifications(
5096    request: proto::GetNotifications,
5097    response: Response<proto::GetNotifications>,
5098    session: UserSession,
5099) -> Result<()> {
5100    let notifications = session
5101        .db()
5102        .await
5103        .get_notifications(
5104            session.user_id(),
5105            NOTIFICATION_COUNT_PER_PAGE,
5106            request
5107                .before_id
5108                .map(|id| db::NotificationId::from_proto(id)),
5109        )
5110        .await?;
5111    response.send(proto::GetNotificationsResponse {
5112        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5113        notifications,
5114    })?;
5115    Ok(())
5116}
5117
5118/// Mark notifications as read
5119async fn mark_notification_as_read(
5120    request: proto::MarkNotificationRead,
5121    response: Response<proto::MarkNotificationRead>,
5122    session: UserSession,
5123) -> Result<()> {
5124    let database = &session.db().await;
5125    let notifications = database
5126        .mark_notification_as_read_by_id(
5127            session.user_id(),
5128            NotificationId::from_proto(request.notification_id),
5129        )
5130        .await?;
5131    send_notifications(
5132        &*session.connection_pool().await,
5133        &session.peer,
5134        notifications,
5135    );
5136    response.send(proto::Ack {})?;
5137    Ok(())
5138}
5139
5140/// Get the current users information
5141async fn get_private_user_info(
5142    _request: proto::GetPrivateUserInfo,
5143    response: Response<proto::GetPrivateUserInfo>,
5144    session: UserSession,
5145) -> Result<()> {
5146    let db = session.db().await;
5147
5148    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5149    let user = db
5150        .get_user_by_id(session.user_id())
5151        .await?
5152        .ok_or_else(|| anyhow!("user not found"))?;
5153    let flags = db.get_user_flags(session.user_id()).await?;
5154
5155    response.send(proto::GetPrivateUserInfoResponse {
5156        metrics_id,
5157        staff: user.admin,
5158        flags,
5159    })?;
5160    Ok(())
5161}
5162
5163fn to_axum_message(message: TungsteniteMessage) -> Option<AxumMessage> {
5164    match message {
5165        TungsteniteMessage::Text(payload) => Some(AxumMessage::Text(payload)),
5166        TungsteniteMessage::Binary(payload) => Some(AxumMessage::Binary(payload)),
5167        TungsteniteMessage::Ping(payload) => Some(AxumMessage::Ping(payload)),
5168        TungsteniteMessage::Pong(payload) => Some(AxumMessage::Pong(payload)),
5169        TungsteniteMessage::Close(frame) => {
5170            Some(AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5171                code: frame.code.into(),
5172                reason: frame.reason,
5173            })))
5174        }
5175        // we can ignore `Frame` frames as recommended by the tungstenite maintainers
5176        // https://github.com/snapview/tungstenite-rs/issues/268
5177        TungsteniteMessage::Frame(_) => None,
5178    }
5179}
5180
5181fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5182    match message {
5183        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5184        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5185        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5186        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5187        AxumMessage::Close(frame) => {
5188            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5189                code: frame.code.into(),
5190                reason: frame.reason,
5191            }))
5192        }
5193    }
5194}
5195
5196fn notify_membership_updated(
5197    connection_pool: &mut ConnectionPool,
5198    result: MembershipUpdated,
5199    user_id: UserId,
5200    peer: &Peer,
5201) {
5202    for membership in &result.new_channels.channel_memberships {
5203        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5204    }
5205    for channel_id in &result.removed_channels {
5206        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5207    }
5208
5209    let user_channels_update = proto::UpdateUserChannels {
5210        channel_memberships: result
5211            .new_channels
5212            .channel_memberships
5213            .iter()
5214            .map(|cm| proto::ChannelMembership {
5215                channel_id: cm.channel_id.to_proto(),
5216                role: cm.role.into(),
5217            })
5218            .collect(),
5219        ..Default::default()
5220    };
5221
5222    let mut update = build_channels_update(result.new_channels);
5223    update.delete_channels = result
5224        .removed_channels
5225        .into_iter()
5226        .map(|id| id.to_proto())
5227        .collect();
5228    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5229
5230    for connection_id in connection_pool.user_connection_ids(user_id) {
5231        peer.send(connection_id, user_channels_update.clone())
5232            .trace_err();
5233        peer.send(connection_id, update.clone()).trace_err();
5234    }
5235}
5236
5237fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5238    proto::UpdateUserChannels {
5239        channel_memberships: channels
5240            .channel_memberships
5241            .iter()
5242            .map(|m| proto::ChannelMembership {
5243                channel_id: m.channel_id.to_proto(),
5244                role: m.role.into(),
5245            })
5246            .collect(),
5247        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5248        observed_channel_message_id: channels.observed_channel_messages.clone(),
5249    }
5250}
5251
5252fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5253    let mut update = proto::UpdateChannels::default();
5254
5255    for channel in channels.channels {
5256        update.channels.push(channel.to_proto());
5257    }
5258
5259    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5260    update.latest_channel_message_ids = channels.latest_channel_messages;
5261
5262    for (channel_id, participants) in channels.channel_participants {
5263        update
5264            .channel_participants
5265            .push(proto::ChannelParticipants {
5266                channel_id: channel_id.to_proto(),
5267                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5268            });
5269    }
5270
5271    for channel in channels.invited_channels {
5272        update.channel_invitations.push(channel.to_proto());
5273    }
5274
5275    update.hosted_projects = channels.hosted_projects;
5276    update
5277}
5278
5279fn build_initial_contacts_update(
5280    contacts: Vec<db::Contact>,
5281    pool: &ConnectionPool,
5282) -> proto::UpdateContacts {
5283    let mut update = proto::UpdateContacts::default();
5284
5285    for contact in contacts {
5286        match contact {
5287            db::Contact::Accepted { user_id, busy } => {
5288                update.contacts.push(contact_for_user(user_id, busy, &pool));
5289            }
5290            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5291            db::Contact::Incoming { user_id } => {
5292                update
5293                    .incoming_requests
5294                    .push(proto::IncomingContactRequest {
5295                        requester_id: user_id.to_proto(),
5296                    })
5297            }
5298        }
5299    }
5300
5301    update
5302}
5303
5304fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5305    proto::Contact {
5306        user_id: user_id.to_proto(),
5307        online: pool.is_user_online(user_id),
5308        busy,
5309    }
5310}
5311
5312fn room_updated(room: &proto::Room, peer: &Peer) {
5313    broadcast(
5314        None,
5315        room.participants
5316            .iter()
5317            .filter_map(|participant| Some(participant.peer_id?.into())),
5318        |peer_id| {
5319            peer.send(
5320                peer_id,
5321                proto::RoomUpdated {
5322                    room: Some(room.clone()),
5323                },
5324            )
5325        },
5326    );
5327}
5328
5329fn channel_updated(
5330    channel: &db::channel::Model,
5331    room: &proto::Room,
5332    peer: &Peer,
5333    pool: &ConnectionPool,
5334) {
5335    let participants = room
5336        .participants
5337        .iter()
5338        .map(|p| p.user_id)
5339        .collect::<Vec<_>>();
5340
5341    broadcast(
5342        None,
5343        pool.channel_connection_ids(channel.root_id())
5344            .filter_map(|(channel_id, role)| {
5345                role.can_see_channel(channel.visibility).then(|| channel_id)
5346            }),
5347        |peer_id| {
5348            peer.send(
5349                peer_id,
5350                proto::UpdateChannels {
5351                    channel_participants: vec![proto::ChannelParticipants {
5352                        channel_id: channel.id.to_proto(),
5353                        participant_user_ids: participants.clone(),
5354                    }],
5355                    ..Default::default()
5356                },
5357            )
5358        },
5359    );
5360}
5361
5362async fn send_dev_server_projects_update(
5363    user_id: UserId,
5364    mut status: proto::DevServerProjectsUpdate,
5365    session: &Session,
5366) {
5367    let pool = session.connection_pool().await;
5368    for dev_server in &mut status.dev_servers {
5369        dev_server.status =
5370            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5371    }
5372    let connections = pool.user_connection_ids(user_id);
5373    for connection_id in connections {
5374        session.peer.send(connection_id, status.clone()).trace_err();
5375    }
5376}
5377
5378async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5379    let db = session.db().await;
5380
5381    let contacts = db.get_contacts(user_id).await?;
5382    let busy = db.is_user_busy(user_id).await?;
5383
5384    let pool = session.connection_pool().await;
5385    let updated_contact = contact_for_user(user_id, busy, &pool);
5386    for contact in contacts {
5387        if let db::Contact::Accepted {
5388            user_id: contact_user_id,
5389            ..
5390        } = contact
5391        {
5392            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5393                session
5394                    .peer
5395                    .send(
5396                        contact_conn_id,
5397                        proto::UpdateContacts {
5398                            contacts: vec![updated_contact.clone()],
5399                            remove_contacts: Default::default(),
5400                            incoming_requests: Default::default(),
5401                            remove_incoming_requests: Default::default(),
5402                            outgoing_requests: Default::default(),
5403                            remove_outgoing_requests: Default::default(),
5404                        },
5405                    )
5406                    .trace_err();
5407            }
5408        }
5409    }
5410    Ok(())
5411}
5412
5413async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5414    log::info!("lost dev server connection, unsharing projects");
5415    let project_ids = session
5416        .db()
5417        .await
5418        .get_stale_dev_server_projects(session.connection_id)
5419        .await?;
5420
5421    for project_id in project_ids {
5422        // not unshare re-checks the connection ids match, so we get away with no transaction
5423        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5424    }
5425
5426    let user_id = session.dev_server().user_id;
5427    let update = session
5428        .db()
5429        .await
5430        .dev_server_projects_update(user_id)
5431        .await?;
5432
5433    send_dev_server_projects_update(user_id, update, session).await;
5434
5435    Ok(())
5436}
5437
5438async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5439    let mut contacts_to_update = HashSet::default();
5440
5441    let room_id;
5442    let canceled_calls_to_user_ids;
5443    let live_kit_room;
5444    let delete_live_kit_room;
5445    let room;
5446    let channel;
5447
5448    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5449        contacts_to_update.insert(session.user_id());
5450
5451        for project in left_room.left_projects.values() {
5452            project_left(project, session);
5453        }
5454
5455        room_id = RoomId::from_proto(left_room.room.id);
5456        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5457        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5458        delete_live_kit_room = left_room.deleted;
5459        room = mem::take(&mut left_room.room);
5460        channel = mem::take(&mut left_room.channel);
5461
5462        room_updated(&room, &session.peer);
5463    } else {
5464        return Ok(());
5465    }
5466
5467    if let Some(channel) = channel {
5468        channel_updated(
5469            &channel,
5470            &room,
5471            &session.peer,
5472            &*session.connection_pool().await,
5473        );
5474    }
5475
5476    {
5477        let pool = session.connection_pool().await;
5478        for canceled_user_id in canceled_calls_to_user_ids {
5479            for connection_id in pool.user_connection_ids(canceled_user_id) {
5480                session
5481                    .peer
5482                    .send(
5483                        connection_id,
5484                        proto::CallCanceled {
5485                            room_id: room_id.to_proto(),
5486                        },
5487                    )
5488                    .trace_err();
5489            }
5490            contacts_to_update.insert(canceled_user_id);
5491        }
5492    }
5493
5494    for contact_user_id in contacts_to_update {
5495        update_user_contacts(contact_user_id, &session).await?;
5496    }
5497
5498    if let Some(live_kit) = session.live_kit_client.as_ref() {
5499        live_kit
5500            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5501            .await
5502            .trace_err();
5503
5504        if delete_live_kit_room {
5505            live_kit.delete_room(live_kit_room).await.trace_err();
5506        }
5507    }
5508
5509    Ok(())
5510}
5511
5512async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5513    let left_channel_buffers = session
5514        .db()
5515        .await
5516        .leave_channel_buffers(session.connection_id)
5517        .await?;
5518
5519    for left_buffer in left_channel_buffers {
5520        channel_buffer_updated(
5521            session.connection_id,
5522            left_buffer.connections,
5523            &proto::UpdateChannelBufferCollaborators {
5524                channel_id: left_buffer.channel_id.to_proto(),
5525                collaborators: left_buffer.collaborators,
5526            },
5527            &session.peer,
5528        );
5529    }
5530
5531    Ok(())
5532}
5533
5534fn project_left(project: &db::LeftProject, session: &UserSession) {
5535    for connection_id in &project.connection_ids {
5536        if project.should_unshare {
5537            session
5538                .peer
5539                .send(
5540                    *connection_id,
5541                    proto::UnshareProject {
5542                        project_id: project.id.to_proto(),
5543                    },
5544                )
5545                .trace_err();
5546        } else {
5547            session
5548                .peer
5549                .send(
5550                    *connection_id,
5551                    proto::RemoveProjectCollaborator {
5552                        project_id: project.id.to_proto(),
5553                        peer_id: Some(session.connection_id.into()),
5554                    },
5555                )
5556                .trace_err();
5557        }
5558    }
5559}
5560
5561pub trait ResultExt {
5562    type Ok;
5563
5564    fn trace_err(self) -> Option<Self::Ok>;
5565}
5566
5567impl<T, E> ResultExt for Result<T, E>
5568where
5569    E: std::fmt::Debug,
5570{
5571    type Ok = T;
5572
5573    #[track_caller]
5574    fn trace_err(self) -> Option<T> {
5575        match self {
5576            Ok(value) => Some(value),
5577            Err(error) => {
5578                tracing::error!("{:?}", error);
5579                None
5580            }
5581        }
5582    }
5583}