rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  50        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(delete_dev_server_project))
 435            .add_request_handler(user_handler(create_dev_server))
 436            .add_request_handler(user_handler(regenerate_dev_server_token))
 437            .add_request_handler(user_handler(rename_dev_server))
 438            .add_request_handler(user_handler(delete_dev_server))
 439            .add_request_handler(dev_server_handler(share_dev_server_project))
 440            .add_request_handler(dev_server_handler(shutdown_dev_server))
 441            .add_request_handler(dev_server_handler(reconnect_dev_server))
 442            .add_message_handler(user_message_handler(leave_project))
 443            .add_request_handler(update_project)
 444            .add_request_handler(update_worktree)
 445            .add_message_handler(start_language_server)
 446            .add_message_handler(update_language_server)
 447            .add_message_handler(update_diagnostic_summary)
 448            .add_message_handler(update_worktree_settings)
 449            .add_request_handler(user_handler(
 450                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 451            ))
 452            .add_request_handler(user_handler(
 453                forward_project_request_for_owner::<proto::TaskTemplates>,
 454            ))
 455            .add_request_handler(user_handler(
 456                forward_read_only_project_request::<proto::GetHover>,
 457            ))
 458            .add_request_handler(user_handler(
 459                forward_read_only_project_request::<proto::GetDefinition>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_read_only_project_request::<proto::GetTypeDefinition>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetReferences>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::SearchProject>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetProjectSymbols>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_read_only_project_request::<proto::OpenBufferById>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::InlayHints>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::OpenBufferByPath>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_mutating_project_request::<proto::GetCompletions>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::GetCodeActions>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::ApplyCodeAction>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::PrepareRename>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::PerformRename>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::ReloadBuffers>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::FormatBuffers>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::CreateProjectEntry>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::RenameProjectEntry>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::CopyProjectEntry>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::OnTypeFormatting>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::BlameBuffer>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::MultiLspQuery>,
 547            ))
 548            .add_request_handler(user_handler(
 549                forward_mutating_project_request::<proto::RestartLanguageServers>,
 550            ))
 551            .add_request_handler(user_handler(
 552                forward_mutating_project_request::<proto::LinkedEditingRange>,
 553            ))
 554            .add_message_handler(create_buffer_for_peer)
 555            .add_request_handler(update_buffer)
 556            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 557            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 558            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 559            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 560            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 561            .add_request_handler(get_users)
 562            .add_request_handler(user_handler(fuzzy_search_users))
 563            .add_request_handler(user_handler(request_contact))
 564            .add_request_handler(user_handler(remove_contact))
 565            .add_request_handler(user_handler(respond_to_contact_request))
 566            .add_message_handler(subscribe_to_channels)
 567            .add_request_handler(user_handler(create_channel))
 568            .add_request_handler(user_handler(delete_channel))
 569            .add_request_handler(user_handler(invite_channel_member))
 570            .add_request_handler(user_handler(remove_channel_member))
 571            .add_request_handler(user_handler(set_channel_member_role))
 572            .add_request_handler(user_handler(set_channel_visibility))
 573            .add_request_handler(user_handler(rename_channel))
 574            .add_request_handler(user_handler(join_channel_buffer))
 575            .add_request_handler(user_handler(leave_channel_buffer))
 576            .add_message_handler(user_message_handler(update_channel_buffer))
 577            .add_request_handler(user_handler(rejoin_channel_buffers))
 578            .add_request_handler(user_handler(get_channel_members))
 579            .add_request_handler(user_handler(respond_to_channel_invite))
 580            .add_request_handler(user_handler(join_channel))
 581            .add_request_handler(user_handler(join_channel_chat))
 582            .add_message_handler(user_message_handler(leave_channel_chat))
 583            .add_request_handler(user_handler(send_channel_message))
 584            .add_request_handler(user_handler(remove_channel_message))
 585            .add_request_handler(user_handler(update_channel_message))
 586            .add_request_handler(user_handler(get_channel_messages))
 587            .add_request_handler(user_handler(get_channel_messages_by_id))
 588            .add_request_handler(user_handler(get_notifications))
 589            .add_request_handler(user_handler(mark_notification_as_read))
 590            .add_request_handler(user_handler(move_channel))
 591            .add_request_handler(user_handler(follow))
 592            .add_message_handler(user_message_handler(unfollow))
 593            .add_message_handler(user_message_handler(update_followers))
 594            .add_request_handler(user_handler(get_private_user_info))
 595            .add_message_handler(user_message_handler(acknowledge_channel_message))
 596            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 597            .add_request_handler(user_handler(get_supermaven_api_key))
 598            .add_request_handler(user_handler(
 599                forward_mutating_project_request::<proto::OpenContext>,
 600            ))
 601            .add_request_handler(user_handler(
 602                forward_mutating_project_request::<proto::SynchronizeContexts>,
 603            ))
 604            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 605            .add_message_handler(update_context)
 606            .add_streaming_request_handler({
 607                let app_state = app_state.clone();
 608                move |request, response, session| {
 609                    complete_with_language_model(
 610                        request,
 611                        response,
 612                        session,
 613                        app_state.config.openai_api_key.clone(),
 614                        app_state.config.google_ai_api_key.clone(),
 615                        app_state.config.anthropic_api_key.clone(),
 616                    )
 617                }
 618            })
 619            .add_request_handler({
 620                let app_state = app_state.clone();
 621                user_handler(move |request, response, session| {
 622                    count_tokens_with_language_model(
 623                        request,
 624                        response,
 625                        session,
 626                        app_state.config.google_ai_api_key.clone(),
 627                    )
 628                })
 629            })
 630            .add_request_handler({
 631                user_handler(move |request, response, session| {
 632                    get_cached_embeddings(request, response, session)
 633                })
 634            })
 635            .add_request_handler({
 636                let app_state = app_state.clone();
 637                user_handler(move |request, response, session| {
 638                    compute_embeddings(
 639                        request,
 640                        response,
 641                        session,
 642                        app_state.config.openai_api_key.clone(),
 643                    )
 644                })
 645            })
 646            .add_request_handler(user_handler(
 647                forward_read_only_project_request::<proto::GetSignatureHelp>,
 648            ));
 649
 650        Arc::new(server)
 651    }
 652
 653    pub async fn start(&self) -> Result<()> {
 654        let server_id = *self.id.lock();
 655        let app_state = self.app_state.clone();
 656        let peer = self.peer.clone();
 657        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 658        let pool = self.connection_pool.clone();
 659        let live_kit_client = self.app_state.live_kit_client.clone();
 660
 661        let span = info_span!("start server");
 662        self.app_state.executor.spawn_detached(
 663            async move {
 664                tracing::info!("waiting for cleanup timeout");
 665                timeout.await;
 666                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 667                if let Some((room_ids, channel_ids)) = app_state
 668                    .db
 669                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 670                    .await
 671                    .trace_err()
 672                {
 673                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 674                    tracing::info!(
 675                        stale_channel_buffer_count = channel_ids.len(),
 676                        "retrieved stale channel buffers"
 677                    );
 678
 679                    for channel_id in channel_ids {
 680                        if let Some(refreshed_channel_buffer) = app_state
 681                            .db
 682                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 683                            .await
 684                            .trace_err()
 685                        {
 686                            for connection_id in refreshed_channel_buffer.connection_ids {
 687                                peer.send(
 688                                    connection_id,
 689                                    proto::UpdateChannelBufferCollaborators {
 690                                        channel_id: channel_id.to_proto(),
 691                                        collaborators: refreshed_channel_buffer
 692                                            .collaborators
 693                                            .clone(),
 694                                    },
 695                                )
 696                                .trace_err();
 697                            }
 698                        }
 699                    }
 700
 701                    for room_id in room_ids {
 702                        let mut contacts_to_update = HashSet::default();
 703                        let mut canceled_calls_to_user_ids = Vec::new();
 704                        let mut live_kit_room = String::new();
 705                        let mut delete_live_kit_room = false;
 706
 707                        if let Some(mut refreshed_room) = app_state
 708                            .db
 709                            .clear_stale_room_participants(room_id, server_id)
 710                            .await
 711                            .trace_err()
 712                        {
 713                            tracing::info!(
 714                                room_id = room_id.0,
 715                                new_participant_count = refreshed_room.room.participants.len(),
 716                                "refreshed room"
 717                            );
 718                            room_updated(&refreshed_room.room, &peer);
 719                            if let Some(channel) = refreshed_room.channel.as_ref() {
 720                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 721                            }
 722                            contacts_to_update
 723                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 724                            contacts_to_update
 725                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 726                            canceled_calls_to_user_ids =
 727                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 728                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 729                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 730                        }
 731
 732                        {
 733                            let pool = pool.lock();
 734                            for canceled_user_id in canceled_calls_to_user_ids {
 735                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 736                                    peer.send(
 737                                        connection_id,
 738                                        proto::CallCanceled {
 739                                            room_id: room_id.to_proto(),
 740                                        },
 741                                    )
 742                                    .trace_err();
 743                                }
 744                            }
 745                        }
 746
 747                        for user_id in contacts_to_update {
 748                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 749                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 750                            if let Some((busy, contacts)) = busy.zip(contacts) {
 751                                let pool = pool.lock();
 752                                let updated_contact = contact_for_user(user_id, busy, &pool);
 753                                for contact in contacts {
 754                                    if let db::Contact::Accepted {
 755                                        user_id: contact_user_id,
 756                                        ..
 757                                    } = contact
 758                                    {
 759                                        for contact_conn_id in
 760                                            pool.user_connection_ids(contact_user_id)
 761                                        {
 762                                            peer.send(
 763                                                contact_conn_id,
 764                                                proto::UpdateContacts {
 765                                                    contacts: vec![updated_contact.clone()],
 766                                                    remove_contacts: Default::default(),
 767                                                    incoming_requests: Default::default(),
 768                                                    remove_incoming_requests: Default::default(),
 769                                                    outgoing_requests: Default::default(),
 770                                                    remove_outgoing_requests: Default::default(),
 771                                                },
 772                                            )
 773                                            .trace_err();
 774                                        }
 775                                    }
 776                                }
 777                            }
 778                        }
 779
 780                        if let Some(live_kit) = live_kit_client.as_ref() {
 781                            if delete_live_kit_room {
 782                                live_kit.delete_room(live_kit_room).await.trace_err();
 783                            }
 784                        }
 785                    }
 786                }
 787
 788                app_state
 789                    .db
 790                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 791                    .await
 792                    .trace_err();
 793            }
 794            .instrument(span),
 795        );
 796        Ok(())
 797    }
 798
 799    pub fn teardown(&self) {
 800        self.peer.teardown();
 801        self.connection_pool.lock().reset();
 802        let _ = self.teardown.send(true);
 803    }
 804
 805    #[cfg(test)]
 806    pub fn reset(&self, id: ServerId) {
 807        self.teardown();
 808        *self.id.lock() = id;
 809        self.peer.reset(id.0 as u32);
 810        let _ = self.teardown.send(false);
 811    }
 812
 813    #[cfg(test)]
 814    pub fn id(&self) -> ServerId {
 815        *self.id.lock()
 816    }
 817
 818    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 819    where
 820        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 821        Fut: 'static + Send + Future<Output = Result<()>>,
 822        M: EnvelopedMessage,
 823    {
 824        let prev_handler = self.handlers.insert(
 825            TypeId::of::<M>(),
 826            Box::new(move |envelope, session| {
 827                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 828                let received_at = envelope.received_at;
 829                tracing::info!("message received");
 830                let start_time = Instant::now();
 831                let future = (handler)(*envelope, session);
 832                async move {
 833                    let result = future.await;
 834                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 835                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 836                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 837                    let payload_type = M::NAME;
 838
 839                    match result {
 840                        Err(error) => {
 841                            tracing::error!(
 842                                ?error,
 843                                total_duration_ms,
 844                                processing_duration_ms,
 845                                queue_duration_ms,
 846                                payload_type,
 847                                "error handling message"
 848                            )
 849                        }
 850                        Ok(()) => tracing::info!(
 851                            total_duration_ms,
 852                            processing_duration_ms,
 853                            queue_duration_ms,
 854                            "finished handling message"
 855                        ),
 856                    }
 857                }
 858                .boxed()
 859            }),
 860        );
 861        if prev_handler.is_some() {
 862            panic!("registered a handler for the same message twice");
 863        }
 864        self
 865    }
 866
 867    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 868    where
 869        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 870        Fut: 'static + Send + Future<Output = Result<()>>,
 871        M: EnvelopedMessage,
 872    {
 873        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 874        self
 875    }
 876
 877    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 878    where
 879        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 880        Fut: Send + Future<Output = Result<()>>,
 881        M: RequestMessage,
 882    {
 883        let handler = Arc::new(handler);
 884        self.add_handler(move |envelope, session| {
 885            let receipt = envelope.receipt();
 886            let handler = handler.clone();
 887            async move {
 888                let peer = session.peer.clone();
 889                let responded = Arc::new(AtomicBool::default());
 890                let response = Response {
 891                    peer: peer.clone(),
 892                    responded: responded.clone(),
 893                    receipt,
 894                };
 895                match (handler)(envelope.payload, response, session).await {
 896                    Ok(()) => {
 897                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 898                            Ok(())
 899                        } else {
 900                            Err(anyhow!("handler did not send a response"))?
 901                        }
 902                    }
 903                    Err(error) => {
 904                        let proto_err = match &error {
 905                            Error::Internal(err) => err.to_proto(),
 906                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 907                        };
 908                        peer.respond_with_error(receipt, proto_err)?;
 909                        Err(error)
 910                    }
 911                }
 912            }
 913        })
 914    }
 915
 916    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 917    where
 918        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 919        Fut: Send + Future<Output = Result<()>>,
 920        M: RequestMessage,
 921    {
 922        let handler = Arc::new(handler);
 923        self.add_handler(move |envelope, session| {
 924            let receipt = envelope.receipt();
 925            let handler = handler.clone();
 926            async move {
 927                let peer = session.peer.clone();
 928                let response = StreamingResponse {
 929                    peer: peer.clone(),
 930                    receipt,
 931                };
 932                match (handler)(envelope.payload, response, session).await {
 933                    Ok(()) => {
 934                        peer.end_stream(receipt)?;
 935                        Ok(())
 936                    }
 937                    Err(error) => {
 938                        let proto_err = match &error {
 939                            Error::Internal(err) => err.to_proto(),
 940                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 941                        };
 942                        peer.respond_with_error(receipt, proto_err)?;
 943                        Err(error)
 944                    }
 945                }
 946            }
 947        })
 948    }
 949
 950    #[allow(clippy::too_many_arguments)]
 951    pub fn handle_connection(
 952        self: &Arc<Self>,
 953        connection: Connection,
 954        address: String,
 955        principal: Principal,
 956        zed_version: ZedVersion,
 957        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 958        executor: Executor,
 959    ) -> impl Future<Output = ()> {
 960        let this = self.clone();
 961        let span = info_span!("handle connection", %address,
 962            connection_id=field::Empty,
 963            user_id=field::Empty,
 964            login=field::Empty,
 965            impersonator=field::Empty,
 966            dev_server_id=field::Empty
 967        );
 968        principal.update_span(&span);
 969
 970        let mut teardown = self.teardown.subscribe();
 971        async move {
 972            if *teardown.borrow() {
 973                tracing::error!("server is tearing down");
 974                return
 975            }
 976            let (connection_id, handle_io, mut incoming_rx) = this
 977                .peer
 978                .add_connection(connection, {
 979                    let executor = executor.clone();
 980                    move |duration| executor.sleep(duration)
 981                });
 982            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 983            tracing::info!("connection opened");
 984
 985            let http_client = match IsahcHttpClient::new() {
 986                Ok(http_client) => Arc::new(http_client),
 987                Err(error) => {
 988                    tracing::error!(?error, "failed to create HTTP client");
 989                    return;
 990                }
 991            };
 992
 993            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 994                Some(Arc::new(SupermavenAdminApi::new(
 995                    supermaven_admin_api_key.to_string(),
 996                    http_client.clone(),
 997                )))
 998            } else {
 999                None
1000            };
1001
1002            let session = Session {
1003                principal: principal.clone(),
1004                connection_id,
1005                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1006                peer: this.peer.clone(),
1007                connection_pool: this.connection_pool.clone(),
1008                live_kit_client: this.app_state.live_kit_client.clone(),
1009                http_client,
1010                rate_limiter: this.app_state.rate_limiter.clone(),
1011                _executor: executor.clone(),
1012                supermaven_client,
1013            };
1014
1015            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1016                tracing::error!(?error, "failed to send initial client update");
1017                return;
1018            }
1019
1020            let handle_io = handle_io.fuse();
1021            futures::pin_mut!(handle_io);
1022
1023            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1024            // This prevents deadlocks when e.g., client A performs a request to client B and
1025            // client B performs a request to client A. If both clients stop processing further
1026            // messages until their respective request completes, they won't have a chance to
1027            // respond to the other client's request and cause a deadlock.
1028            //
1029            // This arrangement ensures we will attempt to process earlier messages first, but fall
1030            // back to processing messages arrived later in the spirit of making progress.
1031            let mut foreground_message_handlers = FuturesUnordered::new();
1032            let concurrent_handlers = Arc::new(Semaphore::new(256));
1033            loop {
1034                let next_message = async {
1035                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1036                    let message = incoming_rx.next().await;
1037                    (permit, message)
1038                }.fuse();
1039                futures::pin_mut!(next_message);
1040                futures::select_biased! {
1041                    _ = teardown.changed().fuse() => return,
1042                    result = handle_io => {
1043                        if let Err(error) = result {
1044                            tracing::error!(?error, "error handling I/O");
1045                        }
1046                        break;
1047                    }
1048                    _ = foreground_message_handlers.next() => {}
1049                    next_message = next_message => {
1050                        let (permit, message) = next_message;
1051                        if let Some(message) = message {
1052                            let type_name = message.payload_type_name();
1053                            // note: we copy all the fields from the parent span so we can query them in the logs.
1054                            // (https://github.com/tokio-rs/tracing/issues/2670).
1055                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1056                                user_id=field::Empty,
1057                                login=field::Empty,
1058                                impersonator=field::Empty,
1059                                dev_server_id=field::Empty
1060                            );
1061                            principal.update_span(&span);
1062                            let span_enter = span.enter();
1063                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1064                                let is_background = message.is_background();
1065                                let handle_message = (handler)(message, session.clone());
1066                                drop(span_enter);
1067
1068                                let handle_message = async move {
1069                                    handle_message.await;
1070                                    drop(permit);
1071                                }.instrument(span);
1072                                if is_background {
1073                                    executor.spawn_detached(handle_message);
1074                                } else {
1075                                    foreground_message_handlers.push(handle_message);
1076                                }
1077                            } else {
1078                                tracing::error!("no message handler");
1079                            }
1080                        } else {
1081                            tracing::info!("connection closed");
1082                            break;
1083                        }
1084                    }
1085                }
1086            }
1087
1088            drop(foreground_message_handlers);
1089            tracing::info!("signing out");
1090            if let Err(error) = connection_lost(session, teardown, executor).await {
1091                tracing::error!(?error, "error signing out");
1092            }
1093
1094        }.instrument(span)
1095    }
1096
1097    async fn send_initial_client_update(
1098        &self,
1099        connection_id: ConnectionId,
1100        principal: &Principal,
1101        zed_version: ZedVersion,
1102        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1103        session: &Session,
1104    ) -> Result<()> {
1105        self.peer.send(
1106            connection_id,
1107            proto::Hello {
1108                peer_id: Some(connection_id.into()),
1109            },
1110        )?;
1111        tracing::info!("sent hello message");
1112        if let Some(send_connection_id) = send_connection_id.take() {
1113            let _ = send_connection_id.send(connection_id);
1114        }
1115
1116        match principal {
1117            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1118                if !user.connected_once {
1119                    self.peer.send(connection_id, proto::ShowContacts {})?;
1120                    self.app_state
1121                        .db
1122                        .set_user_connected_once(user.id, true)
1123                        .await?;
1124                }
1125
1126                let (contacts, dev_server_projects) = future::try_join(
1127                    self.app_state.db.get_contacts(user.id),
1128                    self.app_state.db.dev_server_projects_update(user.id),
1129                )
1130                .await?;
1131
1132                {
1133                    let mut pool = self.connection_pool.lock();
1134                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1135                    self.peer.send(
1136                        connection_id,
1137                        build_initial_contacts_update(contacts, &pool),
1138                    )?;
1139                }
1140
1141                if should_auto_subscribe_to_channels(zed_version) {
1142                    subscribe_user_to_channels(user.id, session).await?;
1143                }
1144
1145                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1146
1147                if let Some(incoming_call) =
1148                    self.app_state.db.incoming_call_for_user(user.id).await?
1149                {
1150                    self.peer.send(connection_id, incoming_call)?;
1151                }
1152
1153                update_user_contacts(user.id, &session).await?;
1154            }
1155            Principal::DevServer(dev_server) => {
1156                {
1157                    let mut pool = self.connection_pool.lock();
1158                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1159                    {
1160                        self.peer.send(
1161                            stale_connection_id,
1162                            proto::ShutdownDevServer {
1163                                reason: Some(
1164                                    "another dev server connected with the same token".to_string(),
1165                                ),
1166                            },
1167                        )?;
1168                        pool.remove_connection(stale_connection_id)?;
1169                    };
1170                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1171                }
1172
1173                let projects = self
1174                    .app_state
1175                    .db
1176                    .get_projects_for_dev_server(dev_server.id)
1177                    .await?;
1178                self.peer
1179                    .send(connection_id, proto::DevServerInstructions { projects })?;
1180
1181                let status = self
1182                    .app_state
1183                    .db
1184                    .dev_server_projects_update(dev_server.user_id)
1185                    .await?;
1186                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1187            }
1188        }
1189
1190        Ok(())
1191    }
1192
1193    pub async fn invite_code_redeemed(
1194        self: &Arc<Self>,
1195        inviter_id: UserId,
1196        invitee_id: UserId,
1197    ) -> Result<()> {
1198        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1199            if let Some(code) = &user.invite_code {
1200                let pool = self.connection_pool.lock();
1201                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1202                for connection_id in pool.user_connection_ids(inviter_id) {
1203                    self.peer.send(
1204                        connection_id,
1205                        proto::UpdateContacts {
1206                            contacts: vec![invitee_contact.clone()],
1207                            ..Default::default()
1208                        },
1209                    )?;
1210                    self.peer.send(
1211                        connection_id,
1212                        proto::UpdateInviteInfo {
1213                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1214                            count: user.invite_count as u32,
1215                        },
1216                    )?;
1217                }
1218            }
1219        }
1220        Ok(())
1221    }
1222
1223    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1224        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1225            if let Some(invite_code) = &user.invite_code {
1226                let pool = self.connection_pool.lock();
1227                for connection_id in pool.user_connection_ids(user_id) {
1228                    self.peer.send(
1229                        connection_id,
1230                        proto::UpdateInviteInfo {
1231                            url: format!(
1232                                "{}{}",
1233                                self.app_state.config.invite_link_prefix, invite_code
1234                            ),
1235                            count: user.invite_count as u32,
1236                        },
1237                    )?;
1238                }
1239            }
1240        }
1241        Ok(())
1242    }
1243
1244    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1245        ServerSnapshot {
1246            connection_pool: ConnectionPoolGuard {
1247                guard: self.connection_pool.lock(),
1248                _not_send: PhantomData,
1249            },
1250            peer: &self.peer,
1251        }
1252    }
1253}
1254
1255impl<'a> Deref for ConnectionPoolGuard<'a> {
1256    type Target = ConnectionPool;
1257
1258    fn deref(&self) -> &Self::Target {
1259        &self.guard
1260    }
1261}
1262
1263impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1264    fn deref_mut(&mut self) -> &mut Self::Target {
1265        &mut self.guard
1266    }
1267}
1268
1269impl<'a> Drop for ConnectionPoolGuard<'a> {
1270    fn drop(&mut self) {
1271        #[cfg(test)]
1272        self.check_invariants();
1273    }
1274}
1275
1276fn broadcast<F>(
1277    sender_id: Option<ConnectionId>,
1278    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1279    mut f: F,
1280) where
1281    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1282{
1283    for receiver_id in receiver_ids {
1284        if Some(receiver_id) != sender_id {
1285            if let Err(error) = f(receiver_id) {
1286                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1287            }
1288        }
1289    }
1290}
1291
1292pub struct ProtocolVersion(u32);
1293
1294impl Header for ProtocolVersion {
1295    fn name() -> &'static HeaderName {
1296        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1297        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1298    }
1299
1300    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1301    where
1302        Self: Sized,
1303        I: Iterator<Item = &'i axum::http::HeaderValue>,
1304    {
1305        let version = values
1306            .next()
1307            .ok_or_else(axum::headers::Error::invalid)?
1308            .to_str()
1309            .map_err(|_| axum::headers::Error::invalid())?
1310            .parse()
1311            .map_err(|_| axum::headers::Error::invalid())?;
1312        Ok(Self(version))
1313    }
1314
1315    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1316        values.extend([self.0.to_string().parse().unwrap()]);
1317    }
1318}
1319
1320pub struct AppVersionHeader(SemanticVersion);
1321impl Header for AppVersionHeader {
1322    fn name() -> &'static HeaderName {
1323        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1324        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1325    }
1326
1327    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1328    where
1329        Self: Sized,
1330        I: Iterator<Item = &'i axum::http::HeaderValue>,
1331    {
1332        let version = values
1333            .next()
1334            .ok_or_else(axum::headers::Error::invalid)?
1335            .to_str()
1336            .map_err(|_| axum::headers::Error::invalid())?
1337            .parse()
1338            .map_err(|_| axum::headers::Error::invalid())?;
1339        Ok(Self(version))
1340    }
1341
1342    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1343        values.extend([self.0.to_string().parse().unwrap()]);
1344    }
1345}
1346
1347pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1348    Router::new()
1349        .route("/rpc", get(handle_websocket_request))
1350        .layer(
1351            ServiceBuilder::new()
1352                .layer(Extension(server.app_state.clone()))
1353                .layer(middleware::from_fn(auth::validate_header)),
1354        )
1355        .route("/metrics", get(handle_metrics))
1356        .layer(Extension(server))
1357}
1358
1359pub async fn handle_websocket_request(
1360    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1361    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1362    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1363    Extension(server): Extension<Arc<Server>>,
1364    Extension(principal): Extension<Principal>,
1365    ws: WebSocketUpgrade,
1366) -> axum::response::Response {
1367    if protocol_version != rpc::PROTOCOL_VERSION {
1368        return (
1369            StatusCode::UPGRADE_REQUIRED,
1370            "client must be upgraded".to_string(),
1371        )
1372            .into_response();
1373    }
1374
1375    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1376        return (
1377            StatusCode::UPGRADE_REQUIRED,
1378            "no version header found".to_string(),
1379        )
1380            .into_response();
1381    };
1382
1383    if !version.can_collaborate() {
1384        return (
1385            StatusCode::UPGRADE_REQUIRED,
1386            "client must be upgraded".to_string(),
1387        )
1388            .into_response();
1389    }
1390
1391    let socket_address = socket_address.to_string();
1392    ws.on_upgrade(move |socket| {
1393        let socket = socket
1394            .map_ok(to_tungstenite_message)
1395            .err_into()
1396            .with(|message| async move { Ok(to_axum_message(message)) });
1397        let connection = Connection::new(Box::pin(socket));
1398        async move {
1399            server
1400                .handle_connection(
1401                    connection,
1402                    socket_address,
1403                    principal,
1404                    version,
1405                    None,
1406                    Executor::Production,
1407                )
1408                .await;
1409        }
1410    })
1411}
1412
1413pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1414    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1415    let connections_metric = CONNECTIONS_METRIC
1416        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1417
1418    let connections = server
1419        .connection_pool
1420        .lock()
1421        .connections()
1422        .filter(|connection| !connection.admin)
1423        .count();
1424    connections_metric.set(connections as _);
1425
1426    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1427    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1428        register_int_gauge!(
1429            "shared_projects",
1430            "number of open projects with one or more guests"
1431        )
1432        .unwrap()
1433    });
1434
1435    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1436    shared_projects_metric.set(shared_projects as _);
1437
1438    let encoder = prometheus::TextEncoder::new();
1439    let metric_families = prometheus::gather();
1440    let encoded_metrics = encoder
1441        .encode_to_string(&metric_families)
1442        .map_err(|err| anyhow!("{}", err))?;
1443    Ok(encoded_metrics)
1444}
1445
1446#[instrument(err, skip(executor))]
1447async fn connection_lost(
1448    session: Session,
1449    mut teardown: watch::Receiver<bool>,
1450    executor: Executor,
1451) -> Result<()> {
1452    session.peer.disconnect(session.connection_id);
1453    session
1454        .connection_pool()
1455        .await
1456        .remove_connection(session.connection_id)?;
1457
1458    session
1459        .db()
1460        .await
1461        .connection_lost(session.connection_id)
1462        .await
1463        .trace_err();
1464
1465    futures::select_biased! {
1466        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1467            match &session.principal {
1468                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1469                    let session = session.for_user().unwrap();
1470
1471                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1472                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1473                    leave_channel_buffers_for_session(&session)
1474                        .await
1475                        .trace_err();
1476
1477                    if !session
1478                        .connection_pool()
1479                        .await
1480                        .is_user_online(session.user_id())
1481                    {
1482                        let db = session.db().await;
1483                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1484                            room_updated(&room, &session.peer);
1485                        }
1486                    }
1487
1488                    update_user_contacts(session.user_id(), &session).await?;
1489                },
1490            Principal::DevServer(_) => {
1491                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1492            },
1493        }
1494        },
1495        _ = teardown.changed().fuse() => {}
1496    }
1497
1498    Ok(())
1499}
1500
1501/// Acknowledges a ping from a client, used to keep the connection alive.
1502async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1503    response.send(proto::Ack {})?;
1504    Ok(())
1505}
1506
1507/// Creates a new room for calling (outside of channels)
1508async fn create_room(
1509    _request: proto::CreateRoom,
1510    response: Response<proto::CreateRoom>,
1511    session: UserSession,
1512) -> Result<()> {
1513    let live_kit_room = nanoid::nanoid!(30);
1514
1515    let live_kit_connection_info = util::maybe!(async {
1516        let live_kit = session.live_kit_client.as_ref();
1517        let live_kit = live_kit?;
1518        let user_id = session.user_id().to_string();
1519
1520        let token = live_kit
1521            .room_token(&live_kit_room, &user_id.to_string())
1522            .trace_err()?;
1523
1524        Some(proto::LiveKitConnectionInfo {
1525            server_url: live_kit.url().into(),
1526            token,
1527            can_publish: true,
1528        })
1529    })
1530    .await;
1531
1532    let room = session
1533        .db()
1534        .await
1535        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1536        .await?;
1537
1538    response.send(proto::CreateRoomResponse {
1539        room: Some(room.clone()),
1540        live_kit_connection_info,
1541    })?;
1542
1543    update_user_contacts(session.user_id(), &session).await?;
1544    Ok(())
1545}
1546
1547/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1548async fn join_room(
1549    request: proto::JoinRoom,
1550    response: Response<proto::JoinRoom>,
1551    session: UserSession,
1552) -> Result<()> {
1553    let room_id = RoomId::from_proto(request.id);
1554
1555    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1556
1557    if let Some(channel_id) = channel_id {
1558        return join_channel_internal(channel_id, Box::new(response), session).await;
1559    }
1560
1561    let joined_room = {
1562        let room = session
1563            .db()
1564            .await
1565            .join_room(room_id, session.user_id(), session.connection_id)
1566            .await?;
1567        room_updated(&room.room, &session.peer);
1568        room.into_inner()
1569    };
1570
1571    for connection_id in session
1572        .connection_pool()
1573        .await
1574        .user_connection_ids(session.user_id())
1575    {
1576        session
1577            .peer
1578            .send(
1579                connection_id,
1580                proto::CallCanceled {
1581                    room_id: room_id.to_proto(),
1582                },
1583            )
1584            .trace_err();
1585    }
1586
1587    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1588        if let Some(token) = live_kit
1589            .room_token(
1590                &joined_room.room.live_kit_room,
1591                &session.user_id().to_string(),
1592            )
1593            .trace_err()
1594        {
1595            Some(proto::LiveKitConnectionInfo {
1596                server_url: live_kit.url().into(),
1597                token,
1598                can_publish: true,
1599            })
1600        } else {
1601            None
1602        }
1603    } else {
1604        None
1605    };
1606
1607    response.send(proto::JoinRoomResponse {
1608        room: Some(joined_room.room),
1609        channel_id: None,
1610        live_kit_connection_info,
1611    })?;
1612
1613    update_user_contacts(session.user_id(), &session).await?;
1614    Ok(())
1615}
1616
1617/// Rejoin room is used to reconnect to a room after connection errors.
1618async fn rejoin_room(
1619    request: proto::RejoinRoom,
1620    response: Response<proto::RejoinRoom>,
1621    session: UserSession,
1622) -> Result<()> {
1623    let room;
1624    let channel;
1625    {
1626        let mut rejoined_room = session
1627            .db()
1628            .await
1629            .rejoin_room(request, session.user_id(), session.connection_id)
1630            .await?;
1631
1632        response.send(proto::RejoinRoomResponse {
1633            room: Some(rejoined_room.room.clone()),
1634            reshared_projects: rejoined_room
1635                .reshared_projects
1636                .iter()
1637                .map(|project| proto::ResharedProject {
1638                    id: project.id.to_proto(),
1639                    collaborators: project
1640                        .collaborators
1641                        .iter()
1642                        .map(|collaborator| collaborator.to_proto())
1643                        .collect(),
1644                })
1645                .collect(),
1646            rejoined_projects: rejoined_room
1647                .rejoined_projects
1648                .iter()
1649                .map(|rejoined_project| rejoined_project.to_proto())
1650                .collect(),
1651        })?;
1652        room_updated(&rejoined_room.room, &session.peer);
1653
1654        for project in &rejoined_room.reshared_projects {
1655            for collaborator in &project.collaborators {
1656                session
1657                    .peer
1658                    .send(
1659                        collaborator.connection_id,
1660                        proto::UpdateProjectCollaborator {
1661                            project_id: project.id.to_proto(),
1662                            old_peer_id: Some(project.old_connection_id.into()),
1663                            new_peer_id: Some(session.connection_id.into()),
1664                        },
1665                    )
1666                    .trace_err();
1667            }
1668
1669            broadcast(
1670                Some(session.connection_id),
1671                project
1672                    .collaborators
1673                    .iter()
1674                    .map(|collaborator| collaborator.connection_id),
1675                |connection_id| {
1676                    session.peer.forward_send(
1677                        session.connection_id,
1678                        connection_id,
1679                        proto::UpdateProject {
1680                            project_id: project.id.to_proto(),
1681                            worktrees: project.worktrees.clone(),
1682                        },
1683                    )
1684                },
1685            );
1686        }
1687
1688        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1689
1690        let rejoined_room = rejoined_room.into_inner();
1691
1692        room = rejoined_room.room;
1693        channel = rejoined_room.channel;
1694    }
1695
1696    if let Some(channel) = channel {
1697        channel_updated(
1698            &channel,
1699            &room,
1700            &session.peer,
1701            &*session.connection_pool().await,
1702        );
1703    }
1704
1705    update_user_contacts(session.user_id(), &session).await?;
1706    Ok(())
1707}
1708
1709fn notify_rejoined_projects(
1710    rejoined_projects: &mut Vec<RejoinedProject>,
1711    session: &UserSession,
1712) -> Result<()> {
1713    for project in rejoined_projects.iter() {
1714        for collaborator in &project.collaborators {
1715            session
1716                .peer
1717                .send(
1718                    collaborator.connection_id,
1719                    proto::UpdateProjectCollaborator {
1720                        project_id: project.id.to_proto(),
1721                        old_peer_id: Some(project.old_connection_id.into()),
1722                        new_peer_id: Some(session.connection_id.into()),
1723                    },
1724                )
1725                .trace_err();
1726        }
1727    }
1728
1729    for project in rejoined_projects {
1730        for worktree in mem::take(&mut project.worktrees) {
1731            #[cfg(any(test, feature = "test-support"))]
1732            const MAX_CHUNK_SIZE: usize = 2;
1733            #[cfg(not(any(test, feature = "test-support")))]
1734            const MAX_CHUNK_SIZE: usize = 256;
1735
1736            // Stream this worktree's entries.
1737            let message = proto::UpdateWorktree {
1738                project_id: project.id.to_proto(),
1739                worktree_id: worktree.id,
1740                abs_path: worktree.abs_path.clone(),
1741                root_name: worktree.root_name,
1742                updated_entries: worktree.updated_entries,
1743                removed_entries: worktree.removed_entries,
1744                scan_id: worktree.scan_id,
1745                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1746                updated_repositories: worktree.updated_repositories,
1747                removed_repositories: worktree.removed_repositories,
1748            };
1749            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1750                session.peer.send(session.connection_id, update.clone())?;
1751            }
1752
1753            // Stream this worktree's diagnostics.
1754            for summary in worktree.diagnostic_summaries {
1755                session.peer.send(
1756                    session.connection_id,
1757                    proto::UpdateDiagnosticSummary {
1758                        project_id: project.id.to_proto(),
1759                        worktree_id: worktree.id,
1760                        summary: Some(summary),
1761                    },
1762                )?;
1763            }
1764
1765            for settings_file in worktree.settings_files {
1766                session.peer.send(
1767                    session.connection_id,
1768                    proto::UpdateWorktreeSettings {
1769                        project_id: project.id.to_proto(),
1770                        worktree_id: worktree.id,
1771                        path: settings_file.path,
1772                        content: Some(settings_file.content),
1773                    },
1774                )?;
1775            }
1776        }
1777
1778        for language_server in &project.language_servers {
1779            session.peer.send(
1780                session.connection_id,
1781                proto::UpdateLanguageServer {
1782                    project_id: project.id.to_proto(),
1783                    language_server_id: language_server.id,
1784                    variant: Some(
1785                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1786                            proto::LspDiskBasedDiagnosticsUpdated {},
1787                        ),
1788                    ),
1789                },
1790            )?;
1791        }
1792    }
1793    Ok(())
1794}
1795
1796/// leave room disconnects from the room.
1797async fn leave_room(
1798    _: proto::LeaveRoom,
1799    response: Response<proto::LeaveRoom>,
1800    session: UserSession,
1801) -> Result<()> {
1802    leave_room_for_session(&session, session.connection_id).await?;
1803    response.send(proto::Ack {})?;
1804    Ok(())
1805}
1806
1807/// Updates the permissions of someone else in the room.
1808async fn set_room_participant_role(
1809    request: proto::SetRoomParticipantRole,
1810    response: Response<proto::SetRoomParticipantRole>,
1811    session: UserSession,
1812) -> Result<()> {
1813    let user_id = UserId::from_proto(request.user_id);
1814    let role = ChannelRole::from(request.role());
1815
1816    let (live_kit_room, can_publish) = {
1817        let room = session
1818            .db()
1819            .await
1820            .set_room_participant_role(
1821                session.user_id(),
1822                RoomId::from_proto(request.room_id),
1823                user_id,
1824                role,
1825            )
1826            .await?;
1827
1828        let live_kit_room = room.live_kit_room.clone();
1829        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1830        room_updated(&room, &session.peer);
1831        (live_kit_room, can_publish)
1832    };
1833
1834    if let Some(live_kit) = session.live_kit_client.as_ref() {
1835        live_kit
1836            .update_participant(
1837                live_kit_room.clone(),
1838                request.user_id.to_string(),
1839                live_kit_server::proto::ParticipantPermission {
1840                    can_subscribe: true,
1841                    can_publish,
1842                    can_publish_data: can_publish,
1843                    hidden: false,
1844                    recorder: false,
1845                },
1846            )
1847            .await
1848            .trace_err();
1849    }
1850
1851    response.send(proto::Ack {})?;
1852    Ok(())
1853}
1854
1855/// Call someone else into the current room
1856async fn call(
1857    request: proto::Call,
1858    response: Response<proto::Call>,
1859    session: UserSession,
1860) -> Result<()> {
1861    let room_id = RoomId::from_proto(request.room_id);
1862    let calling_user_id = session.user_id();
1863    let calling_connection_id = session.connection_id;
1864    let called_user_id = UserId::from_proto(request.called_user_id);
1865    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1866    if !session
1867        .db()
1868        .await
1869        .has_contact(calling_user_id, called_user_id)
1870        .await?
1871    {
1872        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1873    }
1874
1875    let incoming_call = {
1876        let (room, incoming_call) = &mut *session
1877            .db()
1878            .await
1879            .call(
1880                room_id,
1881                calling_user_id,
1882                calling_connection_id,
1883                called_user_id,
1884                initial_project_id,
1885            )
1886            .await?;
1887        room_updated(&room, &session.peer);
1888        mem::take(incoming_call)
1889    };
1890    update_user_contacts(called_user_id, &session).await?;
1891
1892    let mut calls = session
1893        .connection_pool()
1894        .await
1895        .user_connection_ids(called_user_id)
1896        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1897        .collect::<FuturesUnordered<_>>();
1898
1899    while let Some(call_response) = calls.next().await {
1900        match call_response.as_ref() {
1901            Ok(_) => {
1902                response.send(proto::Ack {})?;
1903                return Ok(());
1904            }
1905            Err(_) => {
1906                call_response.trace_err();
1907            }
1908        }
1909    }
1910
1911    {
1912        let room = session
1913            .db()
1914            .await
1915            .call_failed(room_id, called_user_id)
1916            .await?;
1917        room_updated(&room, &session.peer);
1918    }
1919    update_user_contacts(called_user_id, &session).await?;
1920
1921    Err(anyhow!("failed to ring user"))?
1922}
1923
1924/// Cancel an outgoing call.
1925async fn cancel_call(
1926    request: proto::CancelCall,
1927    response: Response<proto::CancelCall>,
1928    session: UserSession,
1929) -> Result<()> {
1930    let called_user_id = UserId::from_proto(request.called_user_id);
1931    let room_id = RoomId::from_proto(request.room_id);
1932    {
1933        let room = session
1934            .db()
1935            .await
1936            .cancel_call(room_id, session.connection_id, called_user_id)
1937            .await?;
1938        room_updated(&room, &session.peer);
1939    }
1940
1941    for connection_id in session
1942        .connection_pool()
1943        .await
1944        .user_connection_ids(called_user_id)
1945    {
1946        session
1947            .peer
1948            .send(
1949                connection_id,
1950                proto::CallCanceled {
1951                    room_id: room_id.to_proto(),
1952                },
1953            )
1954            .trace_err();
1955    }
1956    response.send(proto::Ack {})?;
1957
1958    update_user_contacts(called_user_id, &session).await?;
1959    Ok(())
1960}
1961
1962/// Decline an incoming call.
1963async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1964    let room_id = RoomId::from_proto(message.room_id);
1965    {
1966        let room = session
1967            .db()
1968            .await
1969            .decline_call(Some(room_id), session.user_id())
1970            .await?
1971            .ok_or_else(|| anyhow!("failed to decline call"))?;
1972        room_updated(&room, &session.peer);
1973    }
1974
1975    for connection_id in session
1976        .connection_pool()
1977        .await
1978        .user_connection_ids(session.user_id())
1979    {
1980        session
1981            .peer
1982            .send(
1983                connection_id,
1984                proto::CallCanceled {
1985                    room_id: room_id.to_proto(),
1986                },
1987            )
1988            .trace_err();
1989    }
1990    update_user_contacts(session.user_id(), &session).await?;
1991    Ok(())
1992}
1993
1994/// Updates other participants in the room with your current location.
1995async fn update_participant_location(
1996    request: proto::UpdateParticipantLocation,
1997    response: Response<proto::UpdateParticipantLocation>,
1998    session: UserSession,
1999) -> Result<()> {
2000    let room_id = RoomId::from_proto(request.room_id);
2001    let location = request
2002        .location
2003        .ok_or_else(|| anyhow!("invalid location"))?;
2004
2005    let db = session.db().await;
2006    let room = db
2007        .update_room_participant_location(room_id, session.connection_id, location)
2008        .await?;
2009
2010    room_updated(&room, &session.peer);
2011    response.send(proto::Ack {})?;
2012    Ok(())
2013}
2014
2015/// Share a project into the room.
2016async fn share_project(
2017    request: proto::ShareProject,
2018    response: Response<proto::ShareProject>,
2019    session: UserSession,
2020) -> Result<()> {
2021    let (project_id, room) = &*session
2022        .db()
2023        .await
2024        .share_project(
2025            RoomId::from_proto(request.room_id),
2026            session.connection_id,
2027            &request.worktrees,
2028            request
2029                .dev_server_project_id
2030                .map(|id| DevServerProjectId::from_proto(id)),
2031        )
2032        .await?;
2033    response.send(proto::ShareProjectResponse {
2034        project_id: project_id.to_proto(),
2035    })?;
2036    room_updated(&room, &session.peer);
2037
2038    Ok(())
2039}
2040
2041/// Unshare a project from the room.
2042async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2043    let project_id = ProjectId::from_proto(message.project_id);
2044    unshare_project_internal(
2045        project_id,
2046        session.connection_id,
2047        session.user_id(),
2048        &session,
2049    )
2050    .await
2051}
2052
2053async fn unshare_project_internal(
2054    project_id: ProjectId,
2055    connection_id: ConnectionId,
2056    user_id: Option<UserId>,
2057    session: &Session,
2058) -> Result<()> {
2059    let delete = {
2060        let room_guard = session
2061            .db()
2062            .await
2063            .unshare_project(project_id, connection_id, user_id)
2064            .await?;
2065
2066        let (delete, room, guest_connection_ids) = &*room_guard;
2067
2068        let message = proto::UnshareProject {
2069            project_id: project_id.to_proto(),
2070        };
2071
2072        broadcast(
2073            Some(connection_id),
2074            guest_connection_ids.iter().copied(),
2075            |conn_id| session.peer.send(conn_id, message.clone()),
2076        );
2077        if let Some(room) = room {
2078            room_updated(room, &session.peer);
2079        }
2080
2081        *delete
2082    };
2083
2084    if delete {
2085        let db = session.db().await;
2086        db.delete_project(project_id).await?;
2087    }
2088
2089    Ok(())
2090}
2091
2092/// DevServer makes a project available online
2093async fn share_dev_server_project(
2094    request: proto::ShareDevServerProject,
2095    response: Response<proto::ShareDevServerProject>,
2096    session: DevServerSession,
2097) -> Result<()> {
2098    let (dev_server_project, user_id, status) = session
2099        .db()
2100        .await
2101        .share_dev_server_project(
2102            DevServerProjectId::from_proto(request.dev_server_project_id),
2103            session.dev_server_id(),
2104            session.connection_id,
2105            &request.worktrees,
2106        )
2107        .await?;
2108    let Some(project_id) = dev_server_project.project_id else {
2109        return Err(anyhow!("failed to share remote project"))?;
2110    };
2111
2112    send_dev_server_projects_update(user_id, status, &session).await;
2113
2114    response.send(proto::ShareProjectResponse { project_id })?;
2115
2116    Ok(())
2117}
2118
2119/// Join someone elses shared project.
2120async fn join_project(
2121    request: proto::JoinProject,
2122    response: Response<proto::JoinProject>,
2123    session: UserSession,
2124) -> Result<()> {
2125    let project_id = ProjectId::from_proto(request.project_id);
2126
2127    tracing::info!(%project_id, "join project");
2128
2129    let db = session.db().await;
2130    let (project, replica_id) = &mut *db
2131        .join_project(project_id, session.connection_id, session.user_id())
2132        .await?;
2133    drop(db);
2134    tracing::info!(%project_id, "join remote project");
2135    join_project_internal(response, session, project, replica_id)
2136}
2137
2138trait JoinProjectInternalResponse {
2139    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2140}
2141impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2142    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2143        Response::<proto::JoinProject>::send(self, result)
2144    }
2145}
2146impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2147    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2148        Response::<proto::JoinHostedProject>::send(self, result)
2149    }
2150}
2151
2152fn join_project_internal(
2153    response: impl JoinProjectInternalResponse,
2154    session: UserSession,
2155    project: &mut Project,
2156    replica_id: &ReplicaId,
2157) -> Result<()> {
2158    let collaborators = project
2159        .collaborators
2160        .iter()
2161        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2162        .map(|collaborator| collaborator.to_proto())
2163        .collect::<Vec<_>>();
2164    let project_id = project.id;
2165    let guest_user_id = session.user_id();
2166
2167    let worktrees = project
2168        .worktrees
2169        .iter()
2170        .map(|(id, worktree)| proto::WorktreeMetadata {
2171            id: *id,
2172            root_name: worktree.root_name.clone(),
2173            visible: worktree.visible,
2174            abs_path: worktree.abs_path.clone(),
2175        })
2176        .collect::<Vec<_>>();
2177
2178    let add_project_collaborator = proto::AddProjectCollaborator {
2179        project_id: project_id.to_proto(),
2180        collaborator: Some(proto::Collaborator {
2181            peer_id: Some(session.connection_id.into()),
2182            replica_id: replica_id.0 as u32,
2183            user_id: guest_user_id.to_proto(),
2184        }),
2185    };
2186
2187    for collaborator in &collaborators {
2188        session
2189            .peer
2190            .send(
2191                collaborator.peer_id.unwrap().into(),
2192                add_project_collaborator.clone(),
2193            )
2194            .trace_err();
2195    }
2196
2197    // First, we send the metadata associated with each worktree.
2198    response.send(proto::JoinProjectResponse {
2199        project_id: project.id.0 as u64,
2200        worktrees: worktrees.clone(),
2201        replica_id: replica_id.0 as u32,
2202        collaborators: collaborators.clone(),
2203        language_servers: project.language_servers.clone(),
2204        role: project.role.into(),
2205        dev_server_project_id: project
2206            .dev_server_project_id
2207            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2208    })?;
2209
2210    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2211        #[cfg(any(test, feature = "test-support"))]
2212        const MAX_CHUNK_SIZE: usize = 2;
2213        #[cfg(not(any(test, feature = "test-support")))]
2214        const MAX_CHUNK_SIZE: usize = 256;
2215
2216        // Stream this worktree's entries.
2217        let message = proto::UpdateWorktree {
2218            project_id: project_id.to_proto(),
2219            worktree_id,
2220            abs_path: worktree.abs_path.clone(),
2221            root_name: worktree.root_name,
2222            updated_entries: worktree.entries,
2223            removed_entries: Default::default(),
2224            scan_id: worktree.scan_id,
2225            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2226            updated_repositories: worktree.repository_entries.into_values().collect(),
2227            removed_repositories: Default::default(),
2228        };
2229        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2230            session.peer.send(session.connection_id, update.clone())?;
2231        }
2232
2233        // Stream this worktree's diagnostics.
2234        for summary in worktree.diagnostic_summaries {
2235            session.peer.send(
2236                session.connection_id,
2237                proto::UpdateDiagnosticSummary {
2238                    project_id: project_id.to_proto(),
2239                    worktree_id: worktree.id,
2240                    summary: Some(summary),
2241                },
2242            )?;
2243        }
2244
2245        for settings_file in worktree.settings_files {
2246            session.peer.send(
2247                session.connection_id,
2248                proto::UpdateWorktreeSettings {
2249                    project_id: project_id.to_proto(),
2250                    worktree_id: worktree.id,
2251                    path: settings_file.path,
2252                    content: Some(settings_file.content),
2253                },
2254            )?;
2255        }
2256    }
2257
2258    for language_server in &project.language_servers {
2259        session.peer.send(
2260            session.connection_id,
2261            proto::UpdateLanguageServer {
2262                project_id: project_id.to_proto(),
2263                language_server_id: language_server.id,
2264                variant: Some(
2265                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2266                        proto::LspDiskBasedDiagnosticsUpdated {},
2267                    ),
2268                ),
2269            },
2270        )?;
2271    }
2272
2273    Ok(())
2274}
2275
2276/// Leave someone elses shared project.
2277async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2278    let sender_id = session.connection_id;
2279    let project_id = ProjectId::from_proto(request.project_id);
2280    let db = session.db().await;
2281    if db.is_hosted_project(project_id).await? {
2282        let project = db.leave_hosted_project(project_id, sender_id).await?;
2283        project_left(&project, &session);
2284        return Ok(());
2285    }
2286
2287    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2288    tracing::info!(
2289        %project_id,
2290        "leave project"
2291    );
2292
2293    project_left(&project, &session);
2294    if let Some(room) = room {
2295        room_updated(&room, &session.peer);
2296    }
2297
2298    Ok(())
2299}
2300
2301async fn join_hosted_project(
2302    request: proto::JoinHostedProject,
2303    response: Response<proto::JoinHostedProject>,
2304    session: UserSession,
2305) -> Result<()> {
2306    let (mut project, replica_id) = session
2307        .db()
2308        .await
2309        .join_hosted_project(
2310            ProjectId(request.project_id as i32),
2311            session.user_id(),
2312            session.connection_id,
2313        )
2314        .await?;
2315
2316    join_project_internal(response, session, &mut project, &replica_id)
2317}
2318
2319async fn create_dev_server_project(
2320    request: proto::CreateDevServerProject,
2321    response: Response<proto::CreateDevServerProject>,
2322    session: UserSession,
2323) -> Result<()> {
2324    let dev_server_id = DevServerId(request.dev_server_id as i32);
2325    let dev_server_connection_id = session
2326        .connection_pool()
2327        .await
2328        .dev_server_connection_id(dev_server_id);
2329    let Some(dev_server_connection_id) = dev_server_connection_id else {
2330        Err(ErrorCode::DevServerOffline
2331            .message("Cannot create a remote project when the dev server is offline".to_string())
2332            .anyhow())?
2333    };
2334
2335    let path = request.path.clone();
2336    //Check that the path exists on the dev server
2337    session
2338        .peer
2339        .forward_request(
2340            session.connection_id,
2341            dev_server_connection_id,
2342            proto::ValidateDevServerProjectRequest { path: path.clone() },
2343        )
2344        .await?;
2345
2346    let (dev_server_project, update) = session
2347        .db()
2348        .await
2349        .create_dev_server_project(
2350            DevServerId(request.dev_server_id as i32),
2351            &request.path,
2352            session.user_id(),
2353        )
2354        .await?;
2355
2356    let projects = session
2357        .db()
2358        .await
2359        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2360        .await?;
2361
2362    session.peer.send(
2363        dev_server_connection_id,
2364        proto::DevServerInstructions { projects },
2365    )?;
2366
2367    send_dev_server_projects_update(session.user_id(), update, &session).await;
2368
2369    response.send(proto::CreateDevServerProjectResponse {
2370        dev_server_project: Some(dev_server_project.to_proto(None)),
2371    })?;
2372    Ok(())
2373}
2374
2375async fn create_dev_server(
2376    request: proto::CreateDevServer,
2377    response: Response<proto::CreateDevServer>,
2378    session: UserSession,
2379) -> Result<()> {
2380    let access_token = auth::random_token();
2381    let hashed_access_token = auth::hash_access_token(&access_token);
2382
2383    if request.name.is_empty() {
2384        return Err(proto::ErrorCode::Forbidden
2385            .message("Dev server name cannot be empty".to_string())
2386            .anyhow())?;
2387    }
2388
2389    let (dev_server, status) = session
2390        .db()
2391        .await
2392        .create_dev_server(
2393            &request.name,
2394            request.ssh_connection_string.as_deref(),
2395            &hashed_access_token,
2396            session.user_id(),
2397        )
2398        .await?;
2399
2400    send_dev_server_projects_update(session.user_id(), status, &session).await;
2401
2402    response.send(proto::CreateDevServerResponse {
2403        dev_server_id: dev_server.id.0 as u64,
2404        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2405        name: request.name,
2406    })?;
2407    Ok(())
2408}
2409
2410async fn regenerate_dev_server_token(
2411    request: proto::RegenerateDevServerToken,
2412    response: Response<proto::RegenerateDevServerToken>,
2413    session: UserSession,
2414) -> Result<()> {
2415    let dev_server_id = DevServerId(request.dev_server_id as i32);
2416    let access_token = auth::random_token();
2417    let hashed_access_token = auth::hash_access_token(&access_token);
2418
2419    let connection_id = session
2420        .connection_pool()
2421        .await
2422        .dev_server_connection_id(dev_server_id);
2423    if let Some(connection_id) = connection_id {
2424        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2425        session.peer.send(
2426            connection_id,
2427            proto::ShutdownDevServer {
2428                reason: Some("dev server token was regenerated".to_string()),
2429            },
2430        )?;
2431        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2432    }
2433
2434    let status = session
2435        .db()
2436        .await
2437        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2438        .await?;
2439
2440    send_dev_server_projects_update(session.user_id(), status, &session).await;
2441
2442    response.send(proto::RegenerateDevServerTokenResponse {
2443        dev_server_id: dev_server_id.to_proto(),
2444        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2445    })?;
2446    Ok(())
2447}
2448
2449async fn rename_dev_server(
2450    request: proto::RenameDevServer,
2451    response: Response<proto::RenameDevServer>,
2452    session: UserSession,
2453) -> Result<()> {
2454    if request.name.trim().is_empty() {
2455        return Err(proto::ErrorCode::Forbidden
2456            .message("Dev server name cannot be empty".to_string())
2457            .anyhow())?;
2458    }
2459
2460    let dev_server_id = DevServerId(request.dev_server_id as i32);
2461    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2462    if dev_server.user_id != session.user_id() {
2463        return Err(anyhow!(ErrorCode::Forbidden))?;
2464    }
2465
2466    let status = session
2467        .db()
2468        .await
2469        .rename_dev_server(
2470            dev_server_id,
2471            &request.name,
2472            request.ssh_connection_string.as_deref(),
2473            session.user_id(),
2474        )
2475        .await?;
2476
2477    send_dev_server_projects_update(session.user_id(), status, &session).await;
2478
2479    response.send(proto::Ack {})?;
2480    Ok(())
2481}
2482
2483async fn delete_dev_server(
2484    request: proto::DeleteDevServer,
2485    response: Response<proto::DeleteDevServer>,
2486    session: UserSession,
2487) -> Result<()> {
2488    let dev_server_id = DevServerId(request.dev_server_id as i32);
2489    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2490    if dev_server.user_id != session.user_id() {
2491        return Err(anyhow!(ErrorCode::Forbidden))?;
2492    }
2493
2494    let connection_id = session
2495        .connection_pool()
2496        .await
2497        .dev_server_connection_id(dev_server_id);
2498    if let Some(connection_id) = connection_id {
2499        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2500        session.peer.send(
2501            connection_id,
2502            proto::ShutdownDevServer {
2503                reason: Some("dev server was deleted".to_string()),
2504            },
2505        )?;
2506        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2507    }
2508
2509    let status = session
2510        .db()
2511        .await
2512        .delete_dev_server(dev_server_id, session.user_id())
2513        .await?;
2514
2515    send_dev_server_projects_update(session.user_id(), status, &session).await;
2516
2517    response.send(proto::Ack {})?;
2518    Ok(())
2519}
2520
2521async fn delete_dev_server_project(
2522    request: proto::DeleteDevServerProject,
2523    response: Response<proto::DeleteDevServerProject>,
2524    session: UserSession,
2525) -> Result<()> {
2526    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2527    let dev_server_project = session
2528        .db()
2529        .await
2530        .get_dev_server_project(dev_server_project_id)
2531        .await?;
2532
2533    let dev_server = session
2534        .db()
2535        .await
2536        .get_dev_server(dev_server_project.dev_server_id)
2537        .await?;
2538    if dev_server.user_id != session.user_id() {
2539        return Err(anyhow!(ErrorCode::Forbidden))?;
2540    }
2541
2542    let dev_server_connection_id = session
2543        .connection_pool()
2544        .await
2545        .dev_server_connection_id(dev_server.id);
2546
2547    if let Some(dev_server_connection_id) = dev_server_connection_id {
2548        let project = session
2549            .db()
2550            .await
2551            .find_dev_server_project(dev_server_project_id)
2552            .await;
2553        if let Ok(project) = project {
2554            unshare_project_internal(
2555                project.id,
2556                dev_server_connection_id,
2557                Some(session.user_id()),
2558                &session,
2559            )
2560            .await?;
2561        }
2562    }
2563
2564    let (projects, status) = session
2565        .db()
2566        .await
2567        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2568        .await?;
2569
2570    if let Some(dev_server_connection_id) = dev_server_connection_id {
2571        session.peer.send(
2572            dev_server_connection_id,
2573            proto::DevServerInstructions { projects },
2574        )?;
2575    }
2576
2577    send_dev_server_projects_update(session.user_id(), status, &session).await;
2578
2579    response.send(proto::Ack {})?;
2580    Ok(())
2581}
2582
2583async fn rejoin_dev_server_projects(
2584    request: proto::RejoinRemoteProjects,
2585    response: Response<proto::RejoinRemoteProjects>,
2586    session: UserSession,
2587) -> Result<()> {
2588    let mut rejoined_projects = {
2589        let db = session.db().await;
2590        db.rejoin_dev_server_projects(
2591            &request.rejoined_projects,
2592            session.user_id(),
2593            session.0.connection_id,
2594        )
2595        .await?
2596    };
2597    response.send(proto::RejoinRemoteProjectsResponse {
2598        rejoined_projects: rejoined_projects
2599            .iter()
2600            .map(|project| project.to_proto())
2601            .collect(),
2602    })?;
2603    notify_rejoined_projects(&mut rejoined_projects, &session)
2604}
2605
2606async fn reconnect_dev_server(
2607    request: proto::ReconnectDevServer,
2608    response: Response<proto::ReconnectDevServer>,
2609    session: DevServerSession,
2610) -> Result<()> {
2611    let reshared_projects = {
2612        let db = session.db().await;
2613        db.reshare_dev_server_projects(
2614            &request.reshared_projects,
2615            session.dev_server_id(),
2616            session.0.connection_id,
2617        )
2618        .await?
2619    };
2620
2621    for project in &reshared_projects {
2622        for collaborator in &project.collaborators {
2623            session
2624                .peer
2625                .send(
2626                    collaborator.connection_id,
2627                    proto::UpdateProjectCollaborator {
2628                        project_id: project.id.to_proto(),
2629                        old_peer_id: Some(project.old_connection_id.into()),
2630                        new_peer_id: Some(session.connection_id.into()),
2631                    },
2632                )
2633                .trace_err();
2634        }
2635
2636        broadcast(
2637            Some(session.connection_id),
2638            project
2639                .collaborators
2640                .iter()
2641                .map(|collaborator| collaborator.connection_id),
2642            |connection_id| {
2643                session.peer.forward_send(
2644                    session.connection_id,
2645                    connection_id,
2646                    proto::UpdateProject {
2647                        project_id: project.id.to_proto(),
2648                        worktrees: project.worktrees.clone(),
2649                    },
2650                )
2651            },
2652        );
2653    }
2654
2655    response.send(proto::ReconnectDevServerResponse {
2656        reshared_projects: reshared_projects
2657            .iter()
2658            .map(|project| proto::ResharedProject {
2659                id: project.id.to_proto(),
2660                collaborators: project
2661                    .collaborators
2662                    .iter()
2663                    .map(|collaborator| collaborator.to_proto())
2664                    .collect(),
2665            })
2666            .collect(),
2667    })?;
2668
2669    Ok(())
2670}
2671
2672async fn shutdown_dev_server(
2673    _: proto::ShutdownDevServer,
2674    response: Response<proto::ShutdownDevServer>,
2675    session: DevServerSession,
2676) -> Result<()> {
2677    response.send(proto::Ack {})?;
2678    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2679    remove_dev_server_connection(session.dev_server_id(), &session).await
2680}
2681
2682async fn shutdown_dev_server_internal(
2683    dev_server_id: DevServerId,
2684    connection_id: ConnectionId,
2685    session: &Session,
2686) -> Result<()> {
2687    let (dev_server_projects, dev_server) = {
2688        let db = session.db().await;
2689        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2690        let dev_server = db.get_dev_server(dev_server_id).await?;
2691        (dev_server_projects, dev_server)
2692    };
2693
2694    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2695        unshare_project_internal(
2696            ProjectId::from_proto(project_id),
2697            connection_id,
2698            None,
2699            session,
2700        )
2701        .await?;
2702    }
2703
2704    session
2705        .connection_pool()
2706        .await
2707        .set_dev_server_offline(dev_server_id);
2708
2709    let status = session
2710        .db()
2711        .await
2712        .dev_server_projects_update(dev_server.user_id)
2713        .await?;
2714    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2715
2716    Ok(())
2717}
2718
2719async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2720    let dev_server_connection = session
2721        .connection_pool()
2722        .await
2723        .dev_server_connection_id(dev_server_id);
2724
2725    if let Some(dev_server_connection) = dev_server_connection {
2726        session
2727            .connection_pool()
2728            .await
2729            .remove_connection(dev_server_connection)?;
2730    }
2731    Ok(())
2732}
2733
2734/// Updates other participants with changes to the project
2735async fn update_project(
2736    request: proto::UpdateProject,
2737    response: Response<proto::UpdateProject>,
2738    session: Session,
2739) -> Result<()> {
2740    let project_id = ProjectId::from_proto(request.project_id);
2741    let (room, guest_connection_ids) = &*session
2742        .db()
2743        .await
2744        .update_project(project_id, session.connection_id, &request.worktrees)
2745        .await?;
2746    broadcast(
2747        Some(session.connection_id),
2748        guest_connection_ids.iter().copied(),
2749        |connection_id| {
2750            session
2751                .peer
2752                .forward_send(session.connection_id, connection_id, request.clone())
2753        },
2754    );
2755    if let Some(room) = room {
2756        room_updated(&room, &session.peer);
2757    }
2758    response.send(proto::Ack {})?;
2759
2760    Ok(())
2761}
2762
2763/// Updates other participants with changes to the worktree
2764async fn update_worktree(
2765    request: proto::UpdateWorktree,
2766    response: Response<proto::UpdateWorktree>,
2767    session: Session,
2768) -> Result<()> {
2769    let guest_connection_ids = session
2770        .db()
2771        .await
2772        .update_worktree(&request, session.connection_id)
2773        .await?;
2774
2775    broadcast(
2776        Some(session.connection_id),
2777        guest_connection_ids.iter().copied(),
2778        |connection_id| {
2779            session
2780                .peer
2781                .forward_send(session.connection_id, connection_id, request.clone())
2782        },
2783    );
2784    response.send(proto::Ack {})?;
2785    Ok(())
2786}
2787
2788/// Updates other participants with changes to the diagnostics
2789async fn update_diagnostic_summary(
2790    message: proto::UpdateDiagnosticSummary,
2791    session: Session,
2792) -> Result<()> {
2793    let guest_connection_ids = session
2794        .db()
2795        .await
2796        .update_diagnostic_summary(&message, session.connection_id)
2797        .await?;
2798
2799    broadcast(
2800        Some(session.connection_id),
2801        guest_connection_ids.iter().copied(),
2802        |connection_id| {
2803            session
2804                .peer
2805                .forward_send(session.connection_id, connection_id, message.clone())
2806        },
2807    );
2808
2809    Ok(())
2810}
2811
2812/// Updates other participants with changes to the worktree settings
2813async fn update_worktree_settings(
2814    message: proto::UpdateWorktreeSettings,
2815    session: Session,
2816) -> Result<()> {
2817    let guest_connection_ids = session
2818        .db()
2819        .await
2820        .update_worktree_settings(&message, session.connection_id)
2821        .await?;
2822
2823    broadcast(
2824        Some(session.connection_id),
2825        guest_connection_ids.iter().copied(),
2826        |connection_id| {
2827            session
2828                .peer
2829                .forward_send(session.connection_id, connection_id, message.clone())
2830        },
2831    );
2832
2833    Ok(())
2834}
2835
2836/// Notify other participants that a  language server has started.
2837async fn start_language_server(
2838    request: proto::StartLanguageServer,
2839    session: Session,
2840) -> Result<()> {
2841    let guest_connection_ids = session
2842        .db()
2843        .await
2844        .start_language_server(&request, session.connection_id)
2845        .await?;
2846
2847    broadcast(
2848        Some(session.connection_id),
2849        guest_connection_ids.iter().copied(),
2850        |connection_id| {
2851            session
2852                .peer
2853                .forward_send(session.connection_id, connection_id, request.clone())
2854        },
2855    );
2856    Ok(())
2857}
2858
2859/// Notify other participants that a language server has changed.
2860async fn update_language_server(
2861    request: proto::UpdateLanguageServer,
2862    session: Session,
2863) -> Result<()> {
2864    let project_id = ProjectId::from_proto(request.project_id);
2865    let project_connection_ids = session
2866        .db()
2867        .await
2868        .project_connection_ids(project_id, session.connection_id, true)
2869        .await?;
2870    broadcast(
2871        Some(session.connection_id),
2872        project_connection_ids.iter().copied(),
2873        |connection_id| {
2874            session
2875                .peer
2876                .forward_send(session.connection_id, connection_id, request.clone())
2877        },
2878    );
2879    Ok(())
2880}
2881
2882/// forward a project request to the host. These requests should be read only
2883/// as guests are allowed to send them.
2884async fn forward_read_only_project_request<T>(
2885    request: T,
2886    response: Response<T>,
2887    session: UserSession,
2888) -> Result<()>
2889where
2890    T: EntityMessage + RequestMessage,
2891{
2892    let project_id = ProjectId::from_proto(request.remote_entity_id());
2893    let host_connection_id = session
2894        .db()
2895        .await
2896        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2897        .await?;
2898    let payload = session
2899        .peer
2900        .forward_request(session.connection_id, host_connection_id, request)
2901        .await?;
2902    response.send(payload)?;
2903    Ok(())
2904}
2905
2906/// forward a project request to the dev server. Only allowed
2907/// if it's your dev server.
2908async fn forward_project_request_for_owner<T>(
2909    request: T,
2910    response: Response<T>,
2911    session: UserSession,
2912) -> Result<()>
2913where
2914    T: EntityMessage + RequestMessage,
2915{
2916    let project_id = ProjectId::from_proto(request.remote_entity_id());
2917
2918    let host_connection_id = session
2919        .db()
2920        .await
2921        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2922        .await?;
2923    let payload = session
2924        .peer
2925        .forward_request(session.connection_id, host_connection_id, request)
2926        .await?;
2927    response.send(payload)?;
2928    Ok(())
2929}
2930
2931/// forward a project request to the host. These requests are disallowed
2932/// for guests.
2933async fn forward_mutating_project_request<T>(
2934    request: T,
2935    response: Response<T>,
2936    session: UserSession,
2937) -> Result<()>
2938where
2939    T: EntityMessage + RequestMessage,
2940{
2941    let project_id = ProjectId::from_proto(request.remote_entity_id());
2942
2943    let host_connection_id = session
2944        .db()
2945        .await
2946        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2947        .await?;
2948    let payload = session
2949        .peer
2950        .forward_request(session.connection_id, host_connection_id, request)
2951        .await?;
2952    response.send(payload)?;
2953    Ok(())
2954}
2955
2956/// forward a project request to the host. These requests are disallowed
2957/// for guests.
2958async fn forward_versioned_mutating_project_request<T>(
2959    request: T,
2960    response: Response<T>,
2961    session: UserSession,
2962) -> Result<()>
2963where
2964    T: EntityMessage + RequestMessage + VersionedMessage,
2965{
2966    let project_id = ProjectId::from_proto(request.remote_entity_id());
2967
2968    let host_connection_id = session
2969        .db()
2970        .await
2971        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2972        .await?;
2973    if let Some(host_version) = session
2974        .connection_pool()
2975        .await
2976        .connection(host_connection_id)
2977        .map(|c| c.zed_version)
2978    {
2979        if let Some(min_required_version) = request.required_host_version() {
2980            if min_required_version > host_version {
2981                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2982                    .with_tag("required", &min_required_version.to_string())))?;
2983            }
2984        }
2985    }
2986
2987    let payload = session
2988        .peer
2989        .forward_request(session.connection_id, host_connection_id, request)
2990        .await?;
2991    response.send(payload)?;
2992    Ok(())
2993}
2994
2995/// Notify other participants that a new buffer has been created
2996async fn create_buffer_for_peer(
2997    request: proto::CreateBufferForPeer,
2998    session: Session,
2999) -> Result<()> {
3000    session
3001        .db()
3002        .await
3003        .check_user_is_project_host(
3004            ProjectId::from_proto(request.project_id),
3005            session.connection_id,
3006        )
3007        .await?;
3008    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3009    session
3010        .peer
3011        .forward_send(session.connection_id, peer_id.into(), request)?;
3012    Ok(())
3013}
3014
3015/// Notify other participants that a buffer has been updated. This is
3016/// allowed for guests as long as the update is limited to selections.
3017async fn update_buffer(
3018    request: proto::UpdateBuffer,
3019    response: Response<proto::UpdateBuffer>,
3020    session: Session,
3021) -> Result<()> {
3022    let project_id = ProjectId::from_proto(request.project_id);
3023    let mut capability = Capability::ReadOnly;
3024
3025    for op in request.operations.iter() {
3026        match op.variant {
3027            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3028            Some(_) => capability = Capability::ReadWrite,
3029        }
3030    }
3031
3032    let host = {
3033        let guard = session
3034            .db()
3035            .await
3036            .connections_for_buffer_update(
3037                project_id,
3038                session.principal_id(),
3039                session.connection_id,
3040                capability,
3041            )
3042            .await?;
3043
3044        let (host, guests) = &*guard;
3045
3046        broadcast(
3047            Some(session.connection_id),
3048            guests.clone(),
3049            |connection_id| {
3050                session
3051                    .peer
3052                    .forward_send(session.connection_id, connection_id, request.clone())
3053            },
3054        );
3055
3056        *host
3057    };
3058
3059    if host != session.connection_id {
3060        session
3061            .peer
3062            .forward_request(session.connection_id, host, request.clone())
3063            .await?;
3064    }
3065
3066    response.send(proto::Ack {})?;
3067    Ok(())
3068}
3069
3070async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3071    let project_id = ProjectId::from_proto(message.project_id);
3072
3073    let operation = message.operation.as_ref().context("invalid operation")?;
3074    let capability = match operation.variant.as_ref() {
3075        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3076            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3077                match buffer_op.variant {
3078                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3079                        Capability::ReadOnly
3080                    }
3081                    _ => Capability::ReadWrite,
3082                }
3083            } else {
3084                Capability::ReadWrite
3085            }
3086        }
3087        Some(_) => Capability::ReadWrite,
3088        None => Capability::ReadOnly,
3089    };
3090
3091    let guard = session
3092        .db()
3093        .await
3094        .connections_for_buffer_update(
3095            project_id,
3096            session.principal_id(),
3097            session.connection_id,
3098            capability,
3099        )
3100        .await?;
3101
3102    let (host, guests) = &*guard;
3103
3104    broadcast(
3105        Some(session.connection_id),
3106        guests.iter().chain([host]).copied(),
3107        |connection_id| {
3108            session
3109                .peer
3110                .forward_send(session.connection_id, connection_id, message.clone())
3111        },
3112    );
3113
3114    Ok(())
3115}
3116
3117/// Notify other participants that a project has been updated.
3118async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3119    request: T,
3120    session: Session,
3121) -> Result<()> {
3122    let project_id = ProjectId::from_proto(request.remote_entity_id());
3123    let project_connection_ids = session
3124        .db()
3125        .await
3126        .project_connection_ids(project_id, session.connection_id, false)
3127        .await?;
3128
3129    broadcast(
3130        Some(session.connection_id),
3131        project_connection_ids.iter().copied(),
3132        |connection_id| {
3133            session
3134                .peer
3135                .forward_send(session.connection_id, connection_id, request.clone())
3136        },
3137    );
3138    Ok(())
3139}
3140
3141/// Start following another user in a call.
3142async fn follow(
3143    request: proto::Follow,
3144    response: Response<proto::Follow>,
3145    session: UserSession,
3146) -> Result<()> {
3147    let room_id = RoomId::from_proto(request.room_id);
3148    let project_id = request.project_id.map(ProjectId::from_proto);
3149    let leader_id = request
3150        .leader_id
3151        .ok_or_else(|| anyhow!("invalid leader id"))?
3152        .into();
3153    let follower_id = session.connection_id;
3154
3155    session
3156        .db()
3157        .await
3158        .check_room_participants(room_id, leader_id, session.connection_id)
3159        .await?;
3160
3161    let response_payload = session
3162        .peer
3163        .forward_request(session.connection_id, leader_id, request)
3164        .await?;
3165    response.send(response_payload)?;
3166
3167    if let Some(project_id) = project_id {
3168        let room = session
3169            .db()
3170            .await
3171            .follow(room_id, project_id, leader_id, follower_id)
3172            .await?;
3173        room_updated(&room, &session.peer);
3174    }
3175
3176    Ok(())
3177}
3178
3179/// Stop following another user in a call.
3180async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3181    let room_id = RoomId::from_proto(request.room_id);
3182    let project_id = request.project_id.map(ProjectId::from_proto);
3183    let leader_id = request
3184        .leader_id
3185        .ok_or_else(|| anyhow!("invalid leader id"))?
3186        .into();
3187    let follower_id = session.connection_id;
3188
3189    session
3190        .db()
3191        .await
3192        .check_room_participants(room_id, leader_id, session.connection_id)
3193        .await?;
3194
3195    session
3196        .peer
3197        .forward_send(session.connection_id, leader_id, request)?;
3198
3199    if let Some(project_id) = project_id {
3200        let room = session
3201            .db()
3202            .await
3203            .unfollow(room_id, project_id, leader_id, follower_id)
3204            .await?;
3205        room_updated(&room, &session.peer);
3206    }
3207
3208    Ok(())
3209}
3210
3211/// Notify everyone following you of your current location.
3212async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3213    let room_id = RoomId::from_proto(request.room_id);
3214    let database = session.db.lock().await;
3215
3216    let connection_ids = if let Some(project_id) = request.project_id {
3217        let project_id = ProjectId::from_proto(project_id);
3218        database
3219            .project_connection_ids(project_id, session.connection_id, true)
3220            .await?
3221    } else {
3222        database
3223            .room_connection_ids(room_id, session.connection_id)
3224            .await?
3225    };
3226
3227    // For now, don't send view update messages back to that view's current leader.
3228    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3229        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3230        _ => None,
3231    });
3232
3233    for connection_id in connection_ids.iter().cloned() {
3234        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3235            session
3236                .peer
3237                .forward_send(session.connection_id, connection_id, request.clone())?;
3238        }
3239    }
3240    Ok(())
3241}
3242
3243/// Get public data about users.
3244async fn get_users(
3245    request: proto::GetUsers,
3246    response: Response<proto::GetUsers>,
3247    session: Session,
3248) -> Result<()> {
3249    let user_ids = request
3250        .user_ids
3251        .into_iter()
3252        .map(UserId::from_proto)
3253        .collect();
3254    let users = session
3255        .db()
3256        .await
3257        .get_users_by_ids(user_ids)
3258        .await?
3259        .into_iter()
3260        .map(|user| proto::User {
3261            id: user.id.to_proto(),
3262            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3263            github_login: user.github_login,
3264        })
3265        .collect();
3266    response.send(proto::UsersResponse { users })?;
3267    Ok(())
3268}
3269
3270/// Search for users (to invite) buy Github login
3271async fn fuzzy_search_users(
3272    request: proto::FuzzySearchUsers,
3273    response: Response<proto::FuzzySearchUsers>,
3274    session: UserSession,
3275) -> Result<()> {
3276    let query = request.query;
3277    let users = match query.len() {
3278        0 => vec![],
3279        1 | 2 => session
3280            .db()
3281            .await
3282            .get_user_by_github_login(&query)
3283            .await?
3284            .into_iter()
3285            .collect(),
3286        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3287    };
3288    let users = users
3289        .into_iter()
3290        .filter(|user| user.id != session.user_id())
3291        .map(|user| proto::User {
3292            id: user.id.to_proto(),
3293            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3294            github_login: user.github_login,
3295        })
3296        .collect();
3297    response.send(proto::UsersResponse { users })?;
3298    Ok(())
3299}
3300
3301/// Send a contact request to another user.
3302async fn request_contact(
3303    request: proto::RequestContact,
3304    response: Response<proto::RequestContact>,
3305    session: UserSession,
3306) -> Result<()> {
3307    let requester_id = session.user_id();
3308    let responder_id = UserId::from_proto(request.responder_id);
3309    if requester_id == responder_id {
3310        return Err(anyhow!("cannot add yourself as a contact"))?;
3311    }
3312
3313    let notifications = session
3314        .db()
3315        .await
3316        .send_contact_request(requester_id, responder_id)
3317        .await?;
3318
3319    // Update outgoing contact requests of requester
3320    let mut update = proto::UpdateContacts::default();
3321    update.outgoing_requests.push(responder_id.to_proto());
3322    for connection_id in session
3323        .connection_pool()
3324        .await
3325        .user_connection_ids(requester_id)
3326    {
3327        session.peer.send(connection_id, update.clone())?;
3328    }
3329
3330    // Update incoming contact requests of responder
3331    let mut update = proto::UpdateContacts::default();
3332    update
3333        .incoming_requests
3334        .push(proto::IncomingContactRequest {
3335            requester_id: requester_id.to_proto(),
3336        });
3337    let connection_pool = session.connection_pool().await;
3338    for connection_id in connection_pool.user_connection_ids(responder_id) {
3339        session.peer.send(connection_id, update.clone())?;
3340    }
3341
3342    send_notifications(&connection_pool, &session.peer, notifications);
3343
3344    response.send(proto::Ack {})?;
3345    Ok(())
3346}
3347
3348/// Accept or decline a contact request
3349async fn respond_to_contact_request(
3350    request: proto::RespondToContactRequest,
3351    response: Response<proto::RespondToContactRequest>,
3352    session: UserSession,
3353) -> Result<()> {
3354    let responder_id = session.user_id();
3355    let requester_id = UserId::from_proto(request.requester_id);
3356    let db = session.db().await;
3357    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3358        db.dismiss_contact_notification(responder_id, requester_id)
3359            .await?;
3360    } else {
3361        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3362
3363        let notifications = db
3364            .respond_to_contact_request(responder_id, requester_id, accept)
3365            .await?;
3366        let requester_busy = db.is_user_busy(requester_id).await?;
3367        let responder_busy = db.is_user_busy(responder_id).await?;
3368
3369        let pool = session.connection_pool().await;
3370        // Update responder with new contact
3371        let mut update = proto::UpdateContacts::default();
3372        if accept {
3373            update
3374                .contacts
3375                .push(contact_for_user(requester_id, requester_busy, &pool));
3376        }
3377        update
3378            .remove_incoming_requests
3379            .push(requester_id.to_proto());
3380        for connection_id in pool.user_connection_ids(responder_id) {
3381            session.peer.send(connection_id, update.clone())?;
3382        }
3383
3384        // Update requester with new contact
3385        let mut update = proto::UpdateContacts::default();
3386        if accept {
3387            update
3388                .contacts
3389                .push(contact_for_user(responder_id, responder_busy, &pool));
3390        }
3391        update
3392            .remove_outgoing_requests
3393            .push(responder_id.to_proto());
3394
3395        for connection_id in pool.user_connection_ids(requester_id) {
3396            session.peer.send(connection_id, update.clone())?;
3397        }
3398
3399        send_notifications(&pool, &session.peer, notifications);
3400    }
3401
3402    response.send(proto::Ack {})?;
3403    Ok(())
3404}
3405
3406/// Remove a contact.
3407async fn remove_contact(
3408    request: proto::RemoveContact,
3409    response: Response<proto::RemoveContact>,
3410    session: UserSession,
3411) -> Result<()> {
3412    let requester_id = session.user_id();
3413    let responder_id = UserId::from_proto(request.user_id);
3414    let db = session.db().await;
3415    let (contact_accepted, deleted_notification_id) =
3416        db.remove_contact(requester_id, responder_id).await?;
3417
3418    let pool = session.connection_pool().await;
3419    // Update outgoing contact requests of requester
3420    let mut update = proto::UpdateContacts::default();
3421    if contact_accepted {
3422        update.remove_contacts.push(responder_id.to_proto());
3423    } else {
3424        update
3425            .remove_outgoing_requests
3426            .push(responder_id.to_proto());
3427    }
3428    for connection_id in pool.user_connection_ids(requester_id) {
3429        session.peer.send(connection_id, update.clone())?;
3430    }
3431
3432    // Update incoming contact requests of responder
3433    let mut update = proto::UpdateContacts::default();
3434    if contact_accepted {
3435        update.remove_contacts.push(requester_id.to_proto());
3436    } else {
3437        update
3438            .remove_incoming_requests
3439            .push(requester_id.to_proto());
3440    }
3441    for connection_id in pool.user_connection_ids(responder_id) {
3442        session.peer.send(connection_id, update.clone())?;
3443        if let Some(notification_id) = deleted_notification_id {
3444            session.peer.send(
3445                connection_id,
3446                proto::DeleteNotification {
3447                    notification_id: notification_id.to_proto(),
3448                },
3449            )?;
3450        }
3451    }
3452
3453    response.send(proto::Ack {})?;
3454    Ok(())
3455}
3456
3457fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3458    version.0.minor() < 139
3459}
3460
3461async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3462    subscribe_user_to_channels(
3463        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3464        &session,
3465    )
3466    .await?;
3467    Ok(())
3468}
3469
3470async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3471    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3472    let mut pool = session.connection_pool().await;
3473    for membership in &channels_for_user.channel_memberships {
3474        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3475    }
3476    session.peer.send(
3477        session.connection_id,
3478        build_update_user_channels(&channels_for_user),
3479    )?;
3480    session.peer.send(
3481        session.connection_id,
3482        build_channels_update(channels_for_user),
3483    )?;
3484    Ok(())
3485}
3486
3487/// Creates a new channel.
3488async fn create_channel(
3489    request: proto::CreateChannel,
3490    response: Response<proto::CreateChannel>,
3491    session: UserSession,
3492) -> Result<()> {
3493    let db = session.db().await;
3494
3495    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3496    let (channel, membership) = db
3497        .create_channel(&request.name, parent_id, session.user_id())
3498        .await?;
3499
3500    let root_id = channel.root_id();
3501    let channel = Channel::from_model(channel);
3502
3503    response.send(proto::CreateChannelResponse {
3504        channel: Some(channel.to_proto()),
3505        parent_id: request.parent_id,
3506    })?;
3507
3508    let mut connection_pool = session.connection_pool().await;
3509    if let Some(membership) = membership {
3510        connection_pool.subscribe_to_channel(
3511            membership.user_id,
3512            membership.channel_id,
3513            membership.role,
3514        );
3515        let update = proto::UpdateUserChannels {
3516            channel_memberships: vec![proto::ChannelMembership {
3517                channel_id: membership.channel_id.to_proto(),
3518                role: membership.role.into(),
3519            }],
3520            ..Default::default()
3521        };
3522        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3523            session.peer.send(connection_id, update.clone())?;
3524        }
3525    }
3526
3527    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3528        if !role.can_see_channel(channel.visibility) {
3529            continue;
3530        }
3531
3532        let update = proto::UpdateChannels {
3533            channels: vec![channel.to_proto()],
3534            ..Default::default()
3535        };
3536        session.peer.send(connection_id, update.clone())?;
3537    }
3538
3539    Ok(())
3540}
3541
3542/// Delete a channel
3543async fn delete_channel(
3544    request: proto::DeleteChannel,
3545    response: Response<proto::DeleteChannel>,
3546    session: UserSession,
3547) -> Result<()> {
3548    let db = session.db().await;
3549
3550    let channel_id = request.channel_id;
3551    let (root_channel, removed_channels) = db
3552        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3553        .await?;
3554    response.send(proto::Ack {})?;
3555
3556    // Notify members of removed channels
3557    let mut update = proto::UpdateChannels::default();
3558    update
3559        .delete_channels
3560        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3561
3562    let connection_pool = session.connection_pool().await;
3563    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3564        session.peer.send(connection_id, update.clone())?;
3565    }
3566
3567    Ok(())
3568}
3569
3570/// Invite someone to join a channel.
3571async fn invite_channel_member(
3572    request: proto::InviteChannelMember,
3573    response: Response<proto::InviteChannelMember>,
3574    session: UserSession,
3575) -> Result<()> {
3576    let db = session.db().await;
3577    let channel_id = ChannelId::from_proto(request.channel_id);
3578    let invitee_id = UserId::from_proto(request.user_id);
3579    let InviteMemberResult {
3580        channel,
3581        notifications,
3582    } = db
3583        .invite_channel_member(
3584            channel_id,
3585            invitee_id,
3586            session.user_id(),
3587            request.role().into(),
3588        )
3589        .await?;
3590
3591    let update = proto::UpdateChannels {
3592        channel_invitations: vec![channel.to_proto()],
3593        ..Default::default()
3594    };
3595
3596    let connection_pool = session.connection_pool().await;
3597    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3598        session.peer.send(connection_id, update.clone())?;
3599    }
3600
3601    send_notifications(&connection_pool, &session.peer, notifications);
3602
3603    response.send(proto::Ack {})?;
3604    Ok(())
3605}
3606
3607/// remove someone from a channel
3608async fn remove_channel_member(
3609    request: proto::RemoveChannelMember,
3610    response: Response<proto::RemoveChannelMember>,
3611    session: UserSession,
3612) -> Result<()> {
3613    let db = session.db().await;
3614    let channel_id = ChannelId::from_proto(request.channel_id);
3615    let member_id = UserId::from_proto(request.user_id);
3616
3617    let RemoveChannelMemberResult {
3618        membership_update,
3619        notification_id,
3620    } = db
3621        .remove_channel_member(channel_id, member_id, session.user_id())
3622        .await?;
3623
3624    let mut connection_pool = session.connection_pool().await;
3625    notify_membership_updated(
3626        &mut connection_pool,
3627        membership_update,
3628        member_id,
3629        &session.peer,
3630    );
3631    for connection_id in connection_pool.user_connection_ids(member_id) {
3632        if let Some(notification_id) = notification_id {
3633            session
3634                .peer
3635                .send(
3636                    connection_id,
3637                    proto::DeleteNotification {
3638                        notification_id: notification_id.to_proto(),
3639                    },
3640                )
3641                .trace_err();
3642        }
3643    }
3644
3645    response.send(proto::Ack {})?;
3646    Ok(())
3647}
3648
3649/// Toggle the channel between public and private.
3650/// Care is taken to maintain the invariant that public channels only descend from public channels,
3651/// (though members-only channels can appear at any point in the hierarchy).
3652async fn set_channel_visibility(
3653    request: proto::SetChannelVisibility,
3654    response: Response<proto::SetChannelVisibility>,
3655    session: UserSession,
3656) -> Result<()> {
3657    let db = session.db().await;
3658    let channel_id = ChannelId::from_proto(request.channel_id);
3659    let visibility = request.visibility().into();
3660
3661    let channel_model = db
3662        .set_channel_visibility(channel_id, visibility, session.user_id())
3663        .await?;
3664    let root_id = channel_model.root_id();
3665    let channel = Channel::from_model(channel_model);
3666
3667    let mut connection_pool = session.connection_pool().await;
3668    for (user_id, role) in connection_pool
3669        .channel_user_ids(root_id)
3670        .collect::<Vec<_>>()
3671        .into_iter()
3672    {
3673        let update = if role.can_see_channel(channel.visibility) {
3674            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3675            proto::UpdateChannels {
3676                channels: vec![channel.to_proto()],
3677                ..Default::default()
3678            }
3679        } else {
3680            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3681            proto::UpdateChannels {
3682                delete_channels: vec![channel.id.to_proto()],
3683                ..Default::default()
3684            }
3685        };
3686
3687        for connection_id in connection_pool.user_connection_ids(user_id) {
3688            session.peer.send(connection_id, update.clone())?;
3689        }
3690    }
3691
3692    response.send(proto::Ack {})?;
3693    Ok(())
3694}
3695
3696/// Alter the role for a user in the channel.
3697async fn set_channel_member_role(
3698    request: proto::SetChannelMemberRole,
3699    response: Response<proto::SetChannelMemberRole>,
3700    session: UserSession,
3701) -> Result<()> {
3702    let db = session.db().await;
3703    let channel_id = ChannelId::from_proto(request.channel_id);
3704    let member_id = UserId::from_proto(request.user_id);
3705    let result = db
3706        .set_channel_member_role(
3707            channel_id,
3708            session.user_id(),
3709            member_id,
3710            request.role().into(),
3711        )
3712        .await?;
3713
3714    match result {
3715        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3716            let mut connection_pool = session.connection_pool().await;
3717            notify_membership_updated(
3718                &mut connection_pool,
3719                membership_update,
3720                member_id,
3721                &session.peer,
3722            )
3723        }
3724        db::SetMemberRoleResult::InviteUpdated(channel) => {
3725            let update = proto::UpdateChannels {
3726                channel_invitations: vec![channel.to_proto()],
3727                ..Default::default()
3728            };
3729
3730            for connection_id in session
3731                .connection_pool()
3732                .await
3733                .user_connection_ids(member_id)
3734            {
3735                session.peer.send(connection_id, update.clone())?;
3736            }
3737        }
3738    }
3739
3740    response.send(proto::Ack {})?;
3741    Ok(())
3742}
3743
3744/// Change the name of a channel
3745async fn rename_channel(
3746    request: proto::RenameChannel,
3747    response: Response<proto::RenameChannel>,
3748    session: UserSession,
3749) -> Result<()> {
3750    let db = session.db().await;
3751    let channel_id = ChannelId::from_proto(request.channel_id);
3752    let channel_model = db
3753        .rename_channel(channel_id, session.user_id(), &request.name)
3754        .await?;
3755    let root_id = channel_model.root_id();
3756    let channel = Channel::from_model(channel_model);
3757
3758    response.send(proto::RenameChannelResponse {
3759        channel: Some(channel.to_proto()),
3760    })?;
3761
3762    let connection_pool = session.connection_pool().await;
3763    let update = proto::UpdateChannels {
3764        channels: vec![channel.to_proto()],
3765        ..Default::default()
3766    };
3767    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3768        if role.can_see_channel(channel.visibility) {
3769            session.peer.send(connection_id, update.clone())?;
3770        }
3771    }
3772
3773    Ok(())
3774}
3775
3776/// Move a channel to a new parent.
3777async fn move_channel(
3778    request: proto::MoveChannel,
3779    response: Response<proto::MoveChannel>,
3780    session: UserSession,
3781) -> Result<()> {
3782    let channel_id = ChannelId::from_proto(request.channel_id);
3783    let to = ChannelId::from_proto(request.to);
3784
3785    let (root_id, channels) = session
3786        .db()
3787        .await
3788        .move_channel(channel_id, to, session.user_id())
3789        .await?;
3790
3791    let connection_pool = session.connection_pool().await;
3792    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3793        let channels = channels
3794            .iter()
3795            .filter_map(|channel| {
3796                if role.can_see_channel(channel.visibility) {
3797                    Some(channel.to_proto())
3798                } else {
3799                    None
3800                }
3801            })
3802            .collect::<Vec<_>>();
3803        if channels.is_empty() {
3804            continue;
3805        }
3806
3807        let update = proto::UpdateChannels {
3808            channels,
3809            ..Default::default()
3810        };
3811
3812        session.peer.send(connection_id, update.clone())?;
3813    }
3814
3815    response.send(Ack {})?;
3816    Ok(())
3817}
3818
3819/// Get the list of channel members
3820async fn get_channel_members(
3821    request: proto::GetChannelMembers,
3822    response: Response<proto::GetChannelMembers>,
3823    session: UserSession,
3824) -> Result<()> {
3825    let db = session.db().await;
3826    let channel_id = ChannelId::from_proto(request.channel_id);
3827    let limit = if request.limit == 0 {
3828        u16::MAX as u64
3829    } else {
3830        request.limit
3831    };
3832    let (members, users) = db
3833        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3834        .await?;
3835    response.send(proto::GetChannelMembersResponse { members, users })?;
3836    Ok(())
3837}
3838
3839/// Accept or decline a channel invitation.
3840async fn respond_to_channel_invite(
3841    request: proto::RespondToChannelInvite,
3842    response: Response<proto::RespondToChannelInvite>,
3843    session: UserSession,
3844) -> Result<()> {
3845    let db = session.db().await;
3846    let channel_id = ChannelId::from_proto(request.channel_id);
3847    let RespondToChannelInvite {
3848        membership_update,
3849        notifications,
3850    } = db
3851        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3852        .await?;
3853
3854    let mut connection_pool = session.connection_pool().await;
3855    if let Some(membership_update) = membership_update {
3856        notify_membership_updated(
3857            &mut connection_pool,
3858            membership_update,
3859            session.user_id(),
3860            &session.peer,
3861        );
3862    } else {
3863        let update = proto::UpdateChannels {
3864            remove_channel_invitations: vec![channel_id.to_proto()],
3865            ..Default::default()
3866        };
3867
3868        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3869            session.peer.send(connection_id, update.clone())?;
3870        }
3871    };
3872
3873    send_notifications(&connection_pool, &session.peer, notifications);
3874
3875    response.send(proto::Ack {})?;
3876
3877    Ok(())
3878}
3879
3880/// Join the channels' room
3881async fn join_channel(
3882    request: proto::JoinChannel,
3883    response: Response<proto::JoinChannel>,
3884    session: UserSession,
3885) -> Result<()> {
3886    let channel_id = ChannelId::from_proto(request.channel_id);
3887    join_channel_internal(channel_id, Box::new(response), session).await
3888}
3889
3890trait JoinChannelInternalResponse {
3891    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3892}
3893impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3894    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3895        Response::<proto::JoinChannel>::send(self, result)
3896    }
3897}
3898impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3899    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3900        Response::<proto::JoinRoom>::send(self, result)
3901    }
3902}
3903
3904async fn join_channel_internal(
3905    channel_id: ChannelId,
3906    response: Box<impl JoinChannelInternalResponse>,
3907    session: UserSession,
3908) -> Result<()> {
3909    let joined_room = {
3910        let mut db = session.db().await;
3911        // If zed quits without leaving the room, and the user re-opens zed before the
3912        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3913        // room they were in.
3914        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3915            tracing::info!(
3916                stale_connection_id = %connection,
3917                "cleaning up stale connection",
3918            );
3919            drop(db);
3920            leave_room_for_session(&session, connection).await?;
3921            db = session.db().await;
3922        }
3923
3924        let (joined_room, membership_updated, role) = db
3925            .join_channel(channel_id, session.user_id(), session.connection_id)
3926            .await?;
3927
3928        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3929            let (can_publish, token) = if role == ChannelRole::Guest {
3930                (
3931                    false,
3932                    live_kit
3933                        .guest_token(
3934                            &joined_room.room.live_kit_room,
3935                            &session.user_id().to_string(),
3936                        )
3937                        .trace_err()?,
3938                )
3939            } else {
3940                (
3941                    true,
3942                    live_kit
3943                        .room_token(
3944                            &joined_room.room.live_kit_room,
3945                            &session.user_id().to_string(),
3946                        )
3947                        .trace_err()?,
3948                )
3949            };
3950
3951            Some(LiveKitConnectionInfo {
3952                server_url: live_kit.url().into(),
3953                token,
3954                can_publish,
3955            })
3956        });
3957
3958        response.send(proto::JoinRoomResponse {
3959            room: Some(joined_room.room.clone()),
3960            channel_id: joined_room
3961                .channel
3962                .as_ref()
3963                .map(|channel| channel.id.to_proto()),
3964            live_kit_connection_info,
3965        })?;
3966
3967        let mut connection_pool = session.connection_pool().await;
3968        if let Some(membership_updated) = membership_updated {
3969            notify_membership_updated(
3970                &mut connection_pool,
3971                membership_updated,
3972                session.user_id(),
3973                &session.peer,
3974            );
3975        }
3976
3977        room_updated(&joined_room.room, &session.peer);
3978
3979        joined_room
3980    };
3981
3982    channel_updated(
3983        &joined_room
3984            .channel
3985            .ok_or_else(|| anyhow!("channel not returned"))?,
3986        &joined_room.room,
3987        &session.peer,
3988        &*session.connection_pool().await,
3989    );
3990
3991    update_user_contacts(session.user_id(), &session).await?;
3992    Ok(())
3993}
3994
3995/// Start editing the channel notes
3996async fn join_channel_buffer(
3997    request: proto::JoinChannelBuffer,
3998    response: Response<proto::JoinChannelBuffer>,
3999    session: UserSession,
4000) -> Result<()> {
4001    let db = session.db().await;
4002    let channel_id = ChannelId::from_proto(request.channel_id);
4003
4004    let open_response = db
4005        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4006        .await?;
4007
4008    let collaborators = open_response.collaborators.clone();
4009    response.send(open_response)?;
4010
4011    let update = UpdateChannelBufferCollaborators {
4012        channel_id: channel_id.to_proto(),
4013        collaborators: collaborators.clone(),
4014    };
4015    channel_buffer_updated(
4016        session.connection_id,
4017        collaborators
4018            .iter()
4019            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4020        &update,
4021        &session.peer,
4022    );
4023
4024    Ok(())
4025}
4026
4027/// Edit the channel notes
4028async fn update_channel_buffer(
4029    request: proto::UpdateChannelBuffer,
4030    session: UserSession,
4031) -> Result<()> {
4032    let db = session.db().await;
4033    let channel_id = ChannelId::from_proto(request.channel_id);
4034
4035    let (collaborators, epoch, version) = db
4036        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4037        .await?;
4038
4039    channel_buffer_updated(
4040        session.connection_id,
4041        collaborators.clone(),
4042        &proto::UpdateChannelBuffer {
4043            channel_id: channel_id.to_proto(),
4044            operations: request.operations,
4045        },
4046        &session.peer,
4047    );
4048
4049    let pool = &*session.connection_pool().await;
4050
4051    let non_collaborators =
4052        pool.channel_connection_ids(channel_id)
4053            .filter_map(|(connection_id, _)| {
4054                if collaborators.contains(&connection_id) {
4055                    None
4056                } else {
4057                    Some(connection_id)
4058                }
4059            });
4060
4061    broadcast(None, non_collaborators, |peer_id| {
4062        session.peer.send(
4063            peer_id,
4064            proto::UpdateChannels {
4065                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4066                    channel_id: channel_id.to_proto(),
4067                    epoch: epoch as u64,
4068                    version: version.clone(),
4069                }],
4070                ..Default::default()
4071            },
4072        )
4073    });
4074
4075    Ok(())
4076}
4077
4078/// Rejoin the channel notes after a connection blip
4079async fn rejoin_channel_buffers(
4080    request: proto::RejoinChannelBuffers,
4081    response: Response<proto::RejoinChannelBuffers>,
4082    session: UserSession,
4083) -> Result<()> {
4084    let db = session.db().await;
4085    let buffers = db
4086        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4087        .await?;
4088
4089    for rejoined_buffer in &buffers {
4090        let collaborators_to_notify = rejoined_buffer
4091            .buffer
4092            .collaborators
4093            .iter()
4094            .filter_map(|c| Some(c.peer_id?.into()));
4095        channel_buffer_updated(
4096            session.connection_id,
4097            collaborators_to_notify,
4098            &proto::UpdateChannelBufferCollaborators {
4099                channel_id: rejoined_buffer.buffer.channel_id,
4100                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4101            },
4102            &session.peer,
4103        );
4104    }
4105
4106    response.send(proto::RejoinChannelBuffersResponse {
4107        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4108    })?;
4109
4110    Ok(())
4111}
4112
4113/// Stop editing the channel notes
4114async fn leave_channel_buffer(
4115    request: proto::LeaveChannelBuffer,
4116    response: Response<proto::LeaveChannelBuffer>,
4117    session: UserSession,
4118) -> Result<()> {
4119    let db = session.db().await;
4120    let channel_id = ChannelId::from_proto(request.channel_id);
4121
4122    let left_buffer = db
4123        .leave_channel_buffer(channel_id, session.connection_id)
4124        .await?;
4125
4126    response.send(Ack {})?;
4127
4128    channel_buffer_updated(
4129        session.connection_id,
4130        left_buffer.connections,
4131        &proto::UpdateChannelBufferCollaborators {
4132            channel_id: channel_id.to_proto(),
4133            collaborators: left_buffer.collaborators,
4134        },
4135        &session.peer,
4136    );
4137
4138    Ok(())
4139}
4140
4141fn channel_buffer_updated<T: EnvelopedMessage>(
4142    sender_id: ConnectionId,
4143    collaborators: impl IntoIterator<Item = ConnectionId>,
4144    message: &T,
4145    peer: &Peer,
4146) {
4147    broadcast(Some(sender_id), collaborators, |peer_id| {
4148        peer.send(peer_id, message.clone())
4149    });
4150}
4151
4152fn send_notifications(
4153    connection_pool: &ConnectionPool,
4154    peer: &Peer,
4155    notifications: db::NotificationBatch,
4156) {
4157    for (user_id, notification) in notifications {
4158        for connection_id in connection_pool.user_connection_ids(user_id) {
4159            if let Err(error) = peer.send(
4160                connection_id,
4161                proto::AddNotification {
4162                    notification: Some(notification.clone()),
4163                },
4164            ) {
4165                tracing::error!(
4166                    "failed to send notification to {:?} {}",
4167                    connection_id,
4168                    error
4169                );
4170            }
4171        }
4172    }
4173}
4174
4175/// Send a message to the channel
4176async fn send_channel_message(
4177    request: proto::SendChannelMessage,
4178    response: Response<proto::SendChannelMessage>,
4179    session: UserSession,
4180) -> Result<()> {
4181    // Validate the message body.
4182    let body = request.body.trim().to_string();
4183    if body.len() > MAX_MESSAGE_LEN {
4184        return Err(anyhow!("message is too long"))?;
4185    }
4186    if body.is_empty() {
4187        return Err(anyhow!("message can't be blank"))?;
4188    }
4189
4190    // TODO: adjust mentions if body is trimmed
4191
4192    let timestamp = OffsetDateTime::now_utc();
4193    let nonce = request
4194        .nonce
4195        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4196
4197    let channel_id = ChannelId::from_proto(request.channel_id);
4198    let CreatedChannelMessage {
4199        message_id,
4200        participant_connection_ids,
4201        notifications,
4202    } = session
4203        .db()
4204        .await
4205        .create_channel_message(
4206            channel_id,
4207            session.user_id(),
4208            &body,
4209            &request.mentions,
4210            timestamp,
4211            nonce.clone().into(),
4212            match request.reply_to_message_id {
4213                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4214                None => None,
4215            },
4216        )
4217        .await?;
4218
4219    let message = proto::ChannelMessage {
4220        sender_id: session.user_id().to_proto(),
4221        id: message_id.to_proto(),
4222        body,
4223        mentions: request.mentions,
4224        timestamp: timestamp.unix_timestamp() as u64,
4225        nonce: Some(nonce),
4226        reply_to_message_id: request.reply_to_message_id,
4227        edited_at: None,
4228    };
4229    broadcast(
4230        Some(session.connection_id),
4231        participant_connection_ids.clone(),
4232        |connection| {
4233            session.peer.send(
4234                connection,
4235                proto::ChannelMessageSent {
4236                    channel_id: channel_id.to_proto(),
4237                    message: Some(message.clone()),
4238                },
4239            )
4240        },
4241    );
4242    response.send(proto::SendChannelMessageResponse {
4243        message: Some(message),
4244    })?;
4245
4246    let pool = &*session.connection_pool().await;
4247    let non_participants =
4248        pool.channel_connection_ids(channel_id)
4249            .filter_map(|(connection_id, _)| {
4250                if participant_connection_ids.contains(&connection_id) {
4251                    None
4252                } else {
4253                    Some(connection_id)
4254                }
4255            });
4256    broadcast(None, non_participants, |peer_id| {
4257        session.peer.send(
4258            peer_id,
4259            proto::UpdateChannels {
4260                latest_channel_message_ids: vec![proto::ChannelMessageId {
4261                    channel_id: channel_id.to_proto(),
4262                    message_id: message_id.to_proto(),
4263                }],
4264                ..Default::default()
4265            },
4266        )
4267    });
4268    send_notifications(pool, &session.peer, notifications);
4269
4270    Ok(())
4271}
4272
4273/// Delete a channel message
4274async fn remove_channel_message(
4275    request: proto::RemoveChannelMessage,
4276    response: Response<proto::RemoveChannelMessage>,
4277    session: UserSession,
4278) -> Result<()> {
4279    let channel_id = ChannelId::from_proto(request.channel_id);
4280    let message_id = MessageId::from_proto(request.message_id);
4281    let (connection_ids, existing_notification_ids) = session
4282        .db()
4283        .await
4284        .remove_channel_message(channel_id, message_id, session.user_id())
4285        .await?;
4286
4287    broadcast(
4288        Some(session.connection_id),
4289        connection_ids,
4290        move |connection| {
4291            session.peer.send(connection, request.clone())?;
4292
4293            for notification_id in &existing_notification_ids {
4294                session.peer.send(
4295                    connection,
4296                    proto::DeleteNotification {
4297                        notification_id: (*notification_id).to_proto(),
4298                    },
4299                )?;
4300            }
4301
4302            Ok(())
4303        },
4304    );
4305    response.send(proto::Ack {})?;
4306    Ok(())
4307}
4308
4309async fn update_channel_message(
4310    request: proto::UpdateChannelMessage,
4311    response: Response<proto::UpdateChannelMessage>,
4312    session: UserSession,
4313) -> Result<()> {
4314    let channel_id = ChannelId::from_proto(request.channel_id);
4315    let message_id = MessageId::from_proto(request.message_id);
4316    let updated_at = OffsetDateTime::now_utc();
4317    let UpdatedChannelMessage {
4318        message_id,
4319        participant_connection_ids,
4320        notifications,
4321        reply_to_message_id,
4322        timestamp,
4323        deleted_mention_notification_ids,
4324        updated_mention_notifications,
4325    } = session
4326        .db()
4327        .await
4328        .update_channel_message(
4329            channel_id,
4330            message_id,
4331            session.user_id(),
4332            request.body.as_str(),
4333            &request.mentions,
4334            updated_at,
4335        )
4336        .await?;
4337
4338    let nonce = request
4339        .nonce
4340        .clone()
4341        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4342
4343    let message = proto::ChannelMessage {
4344        sender_id: session.user_id().to_proto(),
4345        id: message_id.to_proto(),
4346        body: request.body.clone(),
4347        mentions: request.mentions.clone(),
4348        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4349        nonce: Some(nonce),
4350        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4351        edited_at: Some(updated_at.unix_timestamp() as u64),
4352    };
4353
4354    response.send(proto::Ack {})?;
4355
4356    let pool = &*session.connection_pool().await;
4357    broadcast(
4358        Some(session.connection_id),
4359        participant_connection_ids,
4360        |connection| {
4361            session.peer.send(
4362                connection,
4363                proto::ChannelMessageUpdate {
4364                    channel_id: channel_id.to_proto(),
4365                    message: Some(message.clone()),
4366                },
4367            )?;
4368
4369            for notification_id in &deleted_mention_notification_ids {
4370                session.peer.send(
4371                    connection,
4372                    proto::DeleteNotification {
4373                        notification_id: (*notification_id).to_proto(),
4374                    },
4375                )?;
4376            }
4377
4378            for notification in &updated_mention_notifications {
4379                session.peer.send(
4380                    connection,
4381                    proto::UpdateNotification {
4382                        notification: Some(notification.clone()),
4383                    },
4384                )?;
4385            }
4386
4387            Ok(())
4388        },
4389    );
4390
4391    send_notifications(pool, &session.peer, notifications);
4392
4393    Ok(())
4394}
4395
4396/// Mark a channel message as read
4397async fn acknowledge_channel_message(
4398    request: proto::AckChannelMessage,
4399    session: UserSession,
4400) -> Result<()> {
4401    let channel_id = ChannelId::from_proto(request.channel_id);
4402    let message_id = MessageId::from_proto(request.message_id);
4403    let notifications = session
4404        .db()
4405        .await
4406        .observe_channel_message(channel_id, session.user_id(), message_id)
4407        .await?;
4408    send_notifications(
4409        &*session.connection_pool().await,
4410        &session.peer,
4411        notifications,
4412    );
4413    Ok(())
4414}
4415
4416/// Mark a buffer version as synced
4417async fn acknowledge_buffer_version(
4418    request: proto::AckBufferOperation,
4419    session: UserSession,
4420) -> Result<()> {
4421    let buffer_id = BufferId::from_proto(request.buffer_id);
4422    session
4423        .db()
4424        .await
4425        .observe_buffer_version(
4426            buffer_id,
4427            session.user_id(),
4428            request.epoch as i32,
4429            &request.version,
4430        )
4431        .await?;
4432    Ok(())
4433}
4434
4435struct CompleteWithLanguageModelRateLimit;
4436
4437impl RateLimit for CompleteWithLanguageModelRateLimit {
4438    fn capacity() -> usize {
4439        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4440            .ok()
4441            .and_then(|v| v.parse().ok())
4442            .unwrap_or(120) // Picked arbitrarily
4443    }
4444
4445    fn refill_duration() -> chrono::Duration {
4446        chrono::Duration::hours(1)
4447    }
4448
4449    fn db_name() -> &'static str {
4450        "complete-with-language-model"
4451    }
4452}
4453
4454async fn complete_with_language_model(
4455    request: proto::CompleteWithLanguageModel,
4456    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4457    session: Session,
4458    open_ai_api_key: Option<Arc<str>>,
4459    google_ai_api_key: Option<Arc<str>>,
4460    anthropic_api_key: Option<Arc<str>>,
4461) -> Result<()> {
4462    let Some(session) = session.for_user() else {
4463        return Err(anyhow!("user not found"))?;
4464    };
4465    authorize_access_to_language_models(&session).await?;
4466    session
4467        .rate_limiter
4468        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4469        .await?;
4470
4471    if request.model.starts_with("gpt") {
4472        let api_key =
4473            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4474        complete_with_open_ai(request, response, session, api_key).await?;
4475    } else if request.model.starts_with("gemini") {
4476        let api_key = google_ai_api_key
4477            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4478        complete_with_google_ai(request, response, session, api_key).await?;
4479    } else if request.model.starts_with("claude") {
4480        let api_key = anthropic_api_key
4481            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4482        complete_with_anthropic(request, response, session, api_key).await?;
4483    }
4484
4485    Ok(())
4486}
4487
4488async fn complete_with_open_ai(
4489    request: proto::CompleteWithLanguageModel,
4490    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4491    session: UserSession,
4492    api_key: Arc<str>,
4493) -> Result<()> {
4494    let mut completion_stream = open_ai::stream_completion(
4495        session.http_client.as_ref(),
4496        OPEN_AI_API_URL,
4497        &api_key,
4498        crate::ai::language_model_request_to_open_ai(request)?,
4499        None,
4500    )
4501    .await
4502    .context("open_ai::stream_completion request failed within collab")?;
4503
4504    while let Some(event) = completion_stream.next().await {
4505        let event = event?;
4506        response.send(proto::LanguageModelResponse {
4507            choices: event
4508                .choices
4509                .into_iter()
4510                .map(|choice| proto::LanguageModelChoiceDelta {
4511                    index: choice.index,
4512                    delta: Some(proto::LanguageModelResponseMessage {
4513                        role: choice.delta.role.map(|role| match role {
4514                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4515                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4516                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4517                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4518                        } as i32),
4519                        content: choice.delta.content,
4520                        tool_calls: choice
4521                            .delta
4522                            .tool_calls
4523                            .unwrap_or_default()
4524                            .into_iter()
4525                            .map(|delta| proto::ToolCallDelta {
4526                                index: delta.index as u32,
4527                                id: delta.id,
4528                                variant: match delta.function {
4529                                    Some(function) => {
4530                                        let name = function.name;
4531                                        let arguments = function.arguments;
4532
4533                                        Some(proto::tool_call_delta::Variant::Function(
4534                                            proto::tool_call_delta::FunctionCallDelta {
4535                                                name,
4536                                                arguments,
4537                                            },
4538                                        ))
4539                                    }
4540                                    None => None,
4541                                },
4542                            })
4543                            .collect(),
4544                    }),
4545                    finish_reason: choice.finish_reason,
4546                })
4547                .collect(),
4548        })?;
4549    }
4550
4551    Ok(())
4552}
4553
4554async fn complete_with_google_ai(
4555    request: proto::CompleteWithLanguageModel,
4556    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4557    session: UserSession,
4558    api_key: Arc<str>,
4559) -> Result<()> {
4560    let mut stream = google_ai::stream_generate_content(
4561        session.http_client.clone(),
4562        google_ai::API_URL,
4563        api_key.as_ref(),
4564        &request.model.clone(),
4565        crate::ai::language_model_request_to_google_ai(request)?,
4566    )
4567    .await
4568    .context("google_ai::stream_generate_content request failed")?;
4569
4570    while let Some(event) = stream.next().await {
4571        let event = event?;
4572        response.send(proto::LanguageModelResponse {
4573            choices: event
4574                .candidates
4575                .unwrap_or_default()
4576                .into_iter()
4577                .map(|candidate| proto::LanguageModelChoiceDelta {
4578                    index: candidate.index as u32,
4579                    delta: Some(proto::LanguageModelResponseMessage {
4580                        role: Some(match candidate.content.role {
4581                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4582                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4583                        } as i32),
4584                        content: Some(
4585                            candidate
4586                                .content
4587                                .parts
4588                                .into_iter()
4589                                .filter_map(|part| match part {
4590                                    google_ai::Part::TextPart(part) => Some(part.text),
4591                                    google_ai::Part::InlineDataPart(_) => None,
4592                                })
4593                                .collect(),
4594                        ),
4595                        // Tool calls are not supported for Google
4596                        tool_calls: Vec::new(),
4597                    }),
4598                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4599                })
4600                .collect(),
4601        })?;
4602    }
4603
4604    Ok(())
4605}
4606
4607async fn complete_with_anthropic(
4608    request: proto::CompleteWithLanguageModel,
4609    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4610    session: UserSession,
4611    api_key: Arc<str>,
4612) -> Result<()> {
4613    let model = anthropic::Model::from_id(&request.model)?;
4614
4615    let mut system_message = String::new();
4616    let messages = request
4617        .messages
4618        .into_iter()
4619        .filter_map(|message| {
4620            match message.role() {
4621                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4622                    role: anthropic::Role::User,
4623                    content: message.content,
4624                }),
4625                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4626                    role: anthropic::Role::Assistant,
4627                    content: message.content,
4628                }),
4629                // Anthropic's API breaks system instructions out as a separate field rather
4630                // than having a system message role.
4631                LanguageModelRole::LanguageModelSystem => {
4632                    if !system_message.is_empty() {
4633                        system_message.push_str("\n\n");
4634                    }
4635                    system_message.push_str(&message.content);
4636
4637                    None
4638                }
4639                // We don't yet support tool calls for Anthropic
4640                LanguageModelRole::LanguageModelTool => None,
4641            }
4642        })
4643        .collect();
4644
4645    let mut stream = anthropic::stream_completion(
4646        session.http_client.as_ref(),
4647        anthropic::ANTHROPIC_API_URL,
4648        &api_key,
4649        anthropic::Request {
4650            model,
4651            messages,
4652            stream: true,
4653            system: system_message,
4654            max_tokens: 4092,
4655        },
4656        None,
4657    )
4658    .await?;
4659
4660    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4661
4662    while let Some(event) = stream.next().await {
4663        let event = event?;
4664
4665        match event {
4666            anthropic::ResponseEvent::MessageStart { message } => {
4667                if let Some(role) = message.role {
4668                    if role == "assistant" {
4669                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4670                    } else if role == "user" {
4671                        current_role = proto::LanguageModelRole::LanguageModelUser;
4672                    }
4673                }
4674            }
4675            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4676                match content_block {
4677                    anthropic::ContentBlock::Text { text } => {
4678                        if !text.is_empty() {
4679                            response.send(proto::LanguageModelResponse {
4680                                choices: vec![proto::LanguageModelChoiceDelta {
4681                                    index: 0,
4682                                    delta: Some(proto::LanguageModelResponseMessage {
4683                                        role: Some(current_role as i32),
4684                                        content: Some(text),
4685                                        tool_calls: Vec::new(),
4686                                    }),
4687                                    finish_reason: None,
4688                                }],
4689                            })?;
4690                        }
4691                    }
4692                }
4693            }
4694            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4695                anthropic::TextDelta::TextDelta { text } => {
4696                    response.send(proto::LanguageModelResponse {
4697                        choices: vec![proto::LanguageModelChoiceDelta {
4698                            index: 0,
4699                            delta: Some(proto::LanguageModelResponseMessage {
4700                                role: Some(current_role as i32),
4701                                content: Some(text),
4702                                tool_calls: Vec::new(),
4703                            }),
4704                            finish_reason: None,
4705                        }],
4706                    })?;
4707                }
4708            },
4709            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4710                if let Some(stop_reason) = delta.stop_reason {
4711                    response.send(proto::LanguageModelResponse {
4712                        choices: vec![proto::LanguageModelChoiceDelta {
4713                            index: 0,
4714                            delta: None,
4715                            finish_reason: Some(stop_reason),
4716                        }],
4717                    })?;
4718                }
4719            }
4720            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4721            anthropic::ResponseEvent::MessageStop {} => {}
4722            anthropic::ResponseEvent::Ping {} => {}
4723        }
4724    }
4725
4726    Ok(())
4727}
4728
4729struct CountTokensWithLanguageModelRateLimit;
4730
4731impl RateLimit for CountTokensWithLanguageModelRateLimit {
4732    fn capacity() -> usize {
4733        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4734            .ok()
4735            .and_then(|v| v.parse().ok())
4736            .unwrap_or(600) // Picked arbitrarily
4737    }
4738
4739    fn refill_duration() -> chrono::Duration {
4740        chrono::Duration::hours(1)
4741    }
4742
4743    fn db_name() -> &'static str {
4744        "count-tokens-with-language-model"
4745    }
4746}
4747
4748async fn count_tokens_with_language_model(
4749    request: proto::CountTokensWithLanguageModel,
4750    response: Response<proto::CountTokensWithLanguageModel>,
4751    session: UserSession,
4752    google_ai_api_key: Option<Arc<str>>,
4753) -> Result<()> {
4754    authorize_access_to_language_models(&session).await?;
4755
4756    if !request.model.starts_with("gemini") {
4757        return Err(anyhow!(
4758            "counting tokens for model: {:?} is not supported",
4759            request.model
4760        ))?;
4761    }
4762
4763    session
4764        .rate_limiter
4765        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4766        .await?;
4767
4768    let api_key = google_ai_api_key
4769        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4770    let tokens_response = google_ai::count_tokens(
4771        session.http_client.as_ref(),
4772        google_ai::API_URL,
4773        &api_key,
4774        crate::ai::count_tokens_request_to_google_ai(request)?,
4775    )
4776    .await?;
4777    response.send(proto::CountTokensResponse {
4778        token_count: tokens_response.total_tokens as u32,
4779    })?;
4780    Ok(())
4781}
4782
4783struct ComputeEmbeddingsRateLimit;
4784
4785impl RateLimit for ComputeEmbeddingsRateLimit {
4786    fn capacity() -> usize {
4787        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4788            .ok()
4789            .and_then(|v| v.parse().ok())
4790            .unwrap_or(5000) // Picked arbitrarily
4791    }
4792
4793    fn refill_duration() -> chrono::Duration {
4794        chrono::Duration::hours(1)
4795    }
4796
4797    fn db_name() -> &'static str {
4798        "compute-embeddings"
4799    }
4800}
4801
4802async fn compute_embeddings(
4803    request: proto::ComputeEmbeddings,
4804    response: Response<proto::ComputeEmbeddings>,
4805    session: UserSession,
4806    api_key: Option<Arc<str>>,
4807) -> Result<()> {
4808    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4809    authorize_access_to_language_models(&session).await?;
4810
4811    session
4812        .rate_limiter
4813        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4814        .await?;
4815
4816    let embeddings = match request.model.as_str() {
4817        "openai/text-embedding-3-small" => {
4818            open_ai::embed(
4819                session.http_client.as_ref(),
4820                OPEN_AI_API_URL,
4821                &api_key,
4822                OpenAiEmbeddingModel::TextEmbedding3Small,
4823                request.texts.iter().map(|text| text.as_str()),
4824            )
4825            .await?
4826        }
4827        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4828    };
4829
4830    let embeddings = request
4831        .texts
4832        .iter()
4833        .map(|text| {
4834            let mut hasher = sha2::Sha256::new();
4835            hasher.update(text.as_bytes());
4836            let result = hasher.finalize();
4837            result.to_vec()
4838        })
4839        .zip(
4840            embeddings
4841                .data
4842                .into_iter()
4843                .map(|embedding| embedding.embedding),
4844        )
4845        .collect::<HashMap<_, _>>();
4846
4847    let db = session.db().await;
4848    db.save_embeddings(&request.model, &embeddings)
4849        .await
4850        .context("failed to save embeddings")
4851        .trace_err();
4852
4853    response.send(proto::ComputeEmbeddingsResponse {
4854        embeddings: embeddings
4855            .into_iter()
4856            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4857            .collect(),
4858    })?;
4859    Ok(())
4860}
4861
4862async fn get_cached_embeddings(
4863    request: proto::GetCachedEmbeddings,
4864    response: Response<proto::GetCachedEmbeddings>,
4865    session: UserSession,
4866) -> Result<()> {
4867    authorize_access_to_language_models(&session).await?;
4868
4869    let db = session.db().await;
4870    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4871
4872    response.send(proto::GetCachedEmbeddingsResponse {
4873        embeddings: embeddings
4874            .into_iter()
4875            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4876            .collect(),
4877    })?;
4878    Ok(())
4879}
4880
4881async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4882    let db = session.db().await;
4883    let flags = db.get_user_flags(session.user_id()).await?;
4884    if flags.iter().any(|flag| flag == "language-models") {
4885        Ok(())
4886    } else {
4887        Err(anyhow!("permission denied"))?
4888    }
4889}
4890
4891/// Get a Supermaven API key for the user
4892async fn get_supermaven_api_key(
4893    _request: proto::GetSupermavenApiKey,
4894    response: Response<proto::GetSupermavenApiKey>,
4895    session: UserSession,
4896) -> Result<()> {
4897    let user_id: String = session.user_id().to_string();
4898    if !session.is_staff() {
4899        return Err(anyhow!("supermaven not enabled for this account"))?;
4900    }
4901
4902    let email = session
4903        .email()
4904        .ok_or_else(|| anyhow!("user must have an email"))?;
4905
4906    let supermaven_admin_api = session
4907        .supermaven_client
4908        .as_ref()
4909        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4910
4911    let result = supermaven_admin_api
4912        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4913        .await?;
4914
4915    response.send(proto::GetSupermavenApiKeyResponse {
4916        api_key: result.api_key,
4917    })?;
4918
4919    Ok(())
4920}
4921
4922/// Start receiving chat updates for a channel
4923async fn join_channel_chat(
4924    request: proto::JoinChannelChat,
4925    response: Response<proto::JoinChannelChat>,
4926    session: UserSession,
4927) -> Result<()> {
4928    let channel_id = ChannelId::from_proto(request.channel_id);
4929
4930    let db = session.db().await;
4931    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4932        .await?;
4933    let messages = db
4934        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4935        .await?;
4936    response.send(proto::JoinChannelChatResponse {
4937        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4938        messages,
4939    })?;
4940    Ok(())
4941}
4942
4943/// Stop receiving chat updates for a channel
4944async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4945    let channel_id = ChannelId::from_proto(request.channel_id);
4946    session
4947        .db()
4948        .await
4949        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4950        .await?;
4951    Ok(())
4952}
4953
4954/// Retrieve the chat history for a channel
4955async fn get_channel_messages(
4956    request: proto::GetChannelMessages,
4957    response: Response<proto::GetChannelMessages>,
4958    session: UserSession,
4959) -> Result<()> {
4960    let channel_id = ChannelId::from_proto(request.channel_id);
4961    let messages = session
4962        .db()
4963        .await
4964        .get_channel_messages(
4965            channel_id,
4966            session.user_id(),
4967            MESSAGE_COUNT_PER_PAGE,
4968            Some(MessageId::from_proto(request.before_message_id)),
4969        )
4970        .await?;
4971    response.send(proto::GetChannelMessagesResponse {
4972        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4973        messages,
4974    })?;
4975    Ok(())
4976}
4977
4978/// Retrieve specific chat messages
4979async fn get_channel_messages_by_id(
4980    request: proto::GetChannelMessagesById,
4981    response: Response<proto::GetChannelMessagesById>,
4982    session: UserSession,
4983) -> Result<()> {
4984    let message_ids = request
4985        .message_ids
4986        .iter()
4987        .map(|id| MessageId::from_proto(*id))
4988        .collect::<Vec<_>>();
4989    let messages = session
4990        .db()
4991        .await
4992        .get_channel_messages_by_id(session.user_id(), &message_ids)
4993        .await?;
4994    response.send(proto::GetChannelMessagesResponse {
4995        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4996        messages,
4997    })?;
4998    Ok(())
4999}
5000
5001/// Retrieve the current users notifications
5002async fn get_notifications(
5003    request: proto::GetNotifications,
5004    response: Response<proto::GetNotifications>,
5005    session: UserSession,
5006) -> Result<()> {
5007    let notifications = session
5008        .db()
5009        .await
5010        .get_notifications(
5011            session.user_id(),
5012            NOTIFICATION_COUNT_PER_PAGE,
5013            request
5014                .before_id
5015                .map(|id| db::NotificationId::from_proto(id)),
5016        )
5017        .await?;
5018    response.send(proto::GetNotificationsResponse {
5019        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5020        notifications,
5021    })?;
5022    Ok(())
5023}
5024
5025/// Mark notifications as read
5026async fn mark_notification_as_read(
5027    request: proto::MarkNotificationRead,
5028    response: Response<proto::MarkNotificationRead>,
5029    session: UserSession,
5030) -> Result<()> {
5031    let database = &session.db().await;
5032    let notifications = database
5033        .mark_notification_as_read_by_id(
5034            session.user_id(),
5035            NotificationId::from_proto(request.notification_id),
5036        )
5037        .await?;
5038    send_notifications(
5039        &*session.connection_pool().await,
5040        &session.peer,
5041        notifications,
5042    );
5043    response.send(proto::Ack {})?;
5044    Ok(())
5045}
5046
5047/// Get the current users information
5048async fn get_private_user_info(
5049    _request: proto::GetPrivateUserInfo,
5050    response: Response<proto::GetPrivateUserInfo>,
5051    session: UserSession,
5052) -> Result<()> {
5053    let db = session.db().await;
5054
5055    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5056    let user = db
5057        .get_user_by_id(session.user_id())
5058        .await?
5059        .ok_or_else(|| anyhow!("user not found"))?;
5060    let flags = db.get_user_flags(session.user_id()).await?;
5061
5062    response.send(proto::GetPrivateUserInfoResponse {
5063        metrics_id,
5064        staff: user.admin,
5065        flags,
5066    })?;
5067    Ok(())
5068}
5069
5070fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5071    match message {
5072        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5073        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5074        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5075        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5076        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5077            code: frame.code.into(),
5078            reason: frame.reason,
5079        })),
5080    }
5081}
5082
5083fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5084    match message {
5085        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5086        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5087        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5088        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5089        AxumMessage::Close(frame) => {
5090            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5091                code: frame.code.into(),
5092                reason: frame.reason,
5093            }))
5094        }
5095    }
5096}
5097
5098fn notify_membership_updated(
5099    connection_pool: &mut ConnectionPool,
5100    result: MembershipUpdated,
5101    user_id: UserId,
5102    peer: &Peer,
5103) {
5104    for membership in &result.new_channels.channel_memberships {
5105        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5106    }
5107    for channel_id in &result.removed_channels {
5108        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5109    }
5110
5111    let user_channels_update = proto::UpdateUserChannels {
5112        channel_memberships: result
5113            .new_channels
5114            .channel_memberships
5115            .iter()
5116            .map(|cm| proto::ChannelMembership {
5117                channel_id: cm.channel_id.to_proto(),
5118                role: cm.role.into(),
5119            })
5120            .collect(),
5121        ..Default::default()
5122    };
5123
5124    let mut update = build_channels_update(result.new_channels);
5125    update.delete_channels = result
5126        .removed_channels
5127        .into_iter()
5128        .map(|id| id.to_proto())
5129        .collect();
5130    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5131
5132    for connection_id in connection_pool.user_connection_ids(user_id) {
5133        peer.send(connection_id, user_channels_update.clone())
5134            .trace_err();
5135        peer.send(connection_id, update.clone()).trace_err();
5136    }
5137}
5138
5139fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5140    proto::UpdateUserChannels {
5141        channel_memberships: channels
5142            .channel_memberships
5143            .iter()
5144            .map(|m| proto::ChannelMembership {
5145                channel_id: m.channel_id.to_proto(),
5146                role: m.role.into(),
5147            })
5148            .collect(),
5149        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5150        observed_channel_message_id: channels.observed_channel_messages.clone(),
5151    }
5152}
5153
5154fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5155    let mut update = proto::UpdateChannels::default();
5156
5157    for channel in channels.channels {
5158        update.channels.push(channel.to_proto());
5159    }
5160
5161    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5162    update.latest_channel_message_ids = channels.latest_channel_messages;
5163
5164    for (channel_id, participants) in channels.channel_participants {
5165        update
5166            .channel_participants
5167            .push(proto::ChannelParticipants {
5168                channel_id: channel_id.to_proto(),
5169                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5170            });
5171    }
5172
5173    for channel in channels.invited_channels {
5174        update.channel_invitations.push(channel.to_proto());
5175    }
5176
5177    update.hosted_projects = channels.hosted_projects;
5178    update
5179}
5180
5181fn build_initial_contacts_update(
5182    contacts: Vec<db::Contact>,
5183    pool: &ConnectionPool,
5184) -> proto::UpdateContacts {
5185    let mut update = proto::UpdateContacts::default();
5186
5187    for contact in contacts {
5188        match contact {
5189            db::Contact::Accepted { user_id, busy } => {
5190                update.contacts.push(contact_for_user(user_id, busy, &pool));
5191            }
5192            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5193            db::Contact::Incoming { user_id } => {
5194                update
5195                    .incoming_requests
5196                    .push(proto::IncomingContactRequest {
5197                        requester_id: user_id.to_proto(),
5198                    })
5199            }
5200        }
5201    }
5202
5203    update
5204}
5205
5206fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5207    proto::Contact {
5208        user_id: user_id.to_proto(),
5209        online: pool.is_user_online(user_id),
5210        busy,
5211    }
5212}
5213
5214fn room_updated(room: &proto::Room, peer: &Peer) {
5215    broadcast(
5216        None,
5217        room.participants
5218            .iter()
5219            .filter_map(|participant| Some(participant.peer_id?.into())),
5220        |peer_id| {
5221            peer.send(
5222                peer_id,
5223                proto::RoomUpdated {
5224                    room: Some(room.clone()),
5225                },
5226            )
5227        },
5228    );
5229}
5230
5231fn channel_updated(
5232    channel: &db::channel::Model,
5233    room: &proto::Room,
5234    peer: &Peer,
5235    pool: &ConnectionPool,
5236) {
5237    let participants = room
5238        .participants
5239        .iter()
5240        .map(|p| p.user_id)
5241        .collect::<Vec<_>>();
5242
5243    broadcast(
5244        None,
5245        pool.channel_connection_ids(channel.root_id())
5246            .filter_map(|(channel_id, role)| {
5247                role.can_see_channel(channel.visibility).then(|| channel_id)
5248            }),
5249        |peer_id| {
5250            peer.send(
5251                peer_id,
5252                proto::UpdateChannels {
5253                    channel_participants: vec![proto::ChannelParticipants {
5254                        channel_id: channel.id.to_proto(),
5255                        participant_user_ids: participants.clone(),
5256                    }],
5257                    ..Default::default()
5258                },
5259            )
5260        },
5261    );
5262}
5263
5264async fn send_dev_server_projects_update(
5265    user_id: UserId,
5266    mut status: proto::DevServerProjectsUpdate,
5267    session: &Session,
5268) {
5269    let pool = session.connection_pool().await;
5270    for dev_server in &mut status.dev_servers {
5271        dev_server.status =
5272            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5273    }
5274    let connections = pool.user_connection_ids(user_id);
5275    for connection_id in connections {
5276        session.peer.send(connection_id, status.clone()).trace_err();
5277    }
5278}
5279
5280async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5281    let db = session.db().await;
5282
5283    let contacts = db.get_contacts(user_id).await?;
5284    let busy = db.is_user_busy(user_id).await?;
5285
5286    let pool = session.connection_pool().await;
5287    let updated_contact = contact_for_user(user_id, busy, &pool);
5288    for contact in contacts {
5289        if let db::Contact::Accepted {
5290            user_id: contact_user_id,
5291            ..
5292        } = contact
5293        {
5294            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5295                session
5296                    .peer
5297                    .send(
5298                        contact_conn_id,
5299                        proto::UpdateContacts {
5300                            contacts: vec![updated_contact.clone()],
5301                            remove_contacts: Default::default(),
5302                            incoming_requests: Default::default(),
5303                            remove_incoming_requests: Default::default(),
5304                            outgoing_requests: Default::default(),
5305                            remove_outgoing_requests: Default::default(),
5306                        },
5307                    )
5308                    .trace_err();
5309            }
5310        }
5311    }
5312    Ok(())
5313}
5314
5315async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5316    log::info!("lost dev server connection, unsharing projects");
5317    let project_ids = session
5318        .db()
5319        .await
5320        .get_stale_dev_server_projects(session.connection_id)
5321        .await?;
5322
5323    for project_id in project_ids {
5324        // not unshare re-checks the connection ids match, so we get away with no transaction
5325        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5326    }
5327
5328    let user_id = session.dev_server().user_id;
5329    let update = session
5330        .db()
5331        .await
5332        .dev_server_projects_update(user_id)
5333        .await?;
5334
5335    send_dev_server_projects_update(user_id, update, session).await;
5336
5337    Ok(())
5338}
5339
5340async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5341    let mut contacts_to_update = HashSet::default();
5342
5343    let room_id;
5344    let canceled_calls_to_user_ids;
5345    let live_kit_room;
5346    let delete_live_kit_room;
5347    let room;
5348    let channel;
5349
5350    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5351        contacts_to_update.insert(session.user_id());
5352
5353        for project in left_room.left_projects.values() {
5354            project_left(project, session);
5355        }
5356
5357        room_id = RoomId::from_proto(left_room.room.id);
5358        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5359        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5360        delete_live_kit_room = left_room.deleted;
5361        room = mem::take(&mut left_room.room);
5362        channel = mem::take(&mut left_room.channel);
5363
5364        room_updated(&room, &session.peer);
5365    } else {
5366        return Ok(());
5367    }
5368
5369    if let Some(channel) = channel {
5370        channel_updated(
5371            &channel,
5372            &room,
5373            &session.peer,
5374            &*session.connection_pool().await,
5375        );
5376    }
5377
5378    {
5379        let pool = session.connection_pool().await;
5380        for canceled_user_id in canceled_calls_to_user_ids {
5381            for connection_id in pool.user_connection_ids(canceled_user_id) {
5382                session
5383                    .peer
5384                    .send(
5385                        connection_id,
5386                        proto::CallCanceled {
5387                            room_id: room_id.to_proto(),
5388                        },
5389                    )
5390                    .trace_err();
5391            }
5392            contacts_to_update.insert(canceled_user_id);
5393        }
5394    }
5395
5396    for contact_user_id in contacts_to_update {
5397        update_user_contacts(contact_user_id, &session).await?;
5398    }
5399
5400    if let Some(live_kit) = session.live_kit_client.as_ref() {
5401        live_kit
5402            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5403            .await
5404            .trace_err();
5405
5406        if delete_live_kit_room {
5407            live_kit.delete_room(live_kit_room).await.trace_err();
5408        }
5409    }
5410
5411    Ok(())
5412}
5413
5414async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5415    let left_channel_buffers = session
5416        .db()
5417        .await
5418        .leave_channel_buffers(session.connection_id)
5419        .await?;
5420
5421    for left_buffer in left_channel_buffers {
5422        channel_buffer_updated(
5423            session.connection_id,
5424            left_buffer.connections,
5425            &proto::UpdateChannelBufferCollaborators {
5426                channel_id: left_buffer.channel_id.to_proto(),
5427                collaborators: left_buffer.collaborators,
5428            },
5429            &session.peer,
5430        );
5431    }
5432
5433    Ok(())
5434}
5435
5436fn project_left(project: &db::LeftProject, session: &UserSession) {
5437    for connection_id in &project.connection_ids {
5438        if project.should_unshare {
5439            session
5440                .peer
5441                .send(
5442                    *connection_id,
5443                    proto::UnshareProject {
5444                        project_id: project.id.to_proto(),
5445                    },
5446                )
5447                .trace_err();
5448        } else {
5449            session
5450                .peer
5451                .send(
5452                    *connection_id,
5453                    proto::RemoveProjectCollaborator {
5454                        project_id: project.id.to_proto(),
5455                        peer_id: Some(session.connection_id.into()),
5456                    },
5457                )
5458                .trace_err();
5459        }
5460    }
5461}
5462
5463pub trait ResultExt {
5464    type Ok;
5465
5466    fn trace_err(self) -> Option<Self::Ok>;
5467}
5468
5469impl<T, E> ResultExt for Result<T, E>
5470where
5471    E: std::fmt::Debug,
5472{
5473    type Ok = T;
5474
5475    #[track_caller]
5476    fn trace_err(self) -> Option<T> {
5477        match self {
5478            Ok(value) => Some(value),
5479            Err(error) => {
5480                tracing::error!("{:?}", error);
5481                None
5482            }
5483        }
5484    }
5485}