rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  38
  39use futures::{
  40    channel::oneshot,
  41    future::{self, BoxFuture},
  42    stream::FuturesUnordered,
  43    FutureExt, SinkExt, StreamExt, TryStreamExt,
  44};
  45use http::IsahcHttpClient;
  46use prometheus::{register_int_gauge, IntGauge};
  47use rpc::{
  48    proto::{
  49        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  50        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  51    },
  52    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  53};
  54use semantic_version::SemanticVersion;
  55use serde::{Serialize, Serializer};
  56use std::{
  57    any::TypeId,
  58    future::Future,
  59    marker::PhantomData,
  60    mem,
  61    net::SocketAddr,
  62    ops::{Deref, DerefMut},
  63    rc::Rc,
  64    sync::{
  65        atomic::{AtomicBool, Ordering::SeqCst},
  66        Arc, OnceLock,
  67    },
  68    time::{Duration, Instant},
  69};
  70use time::OffsetDateTime;
  71use tokio::sync::{watch, Semaphore};
  72use tower::ServiceBuilder;
  73use tracing::{
  74    field::{self},
  75    info_span, instrument, Instrument,
  76};
  77
  78use self::connection_pool::VersionedMessage;
  79
  80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  81
  82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  84
  85const MESSAGE_COUNT_PER_PAGE: usize = 100;
  86const MAX_MESSAGE_LEN: usize = 1024;
  87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  88
  89type MessageHandler =
  90    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  91
  92struct Response<R> {
  93    peer: Arc<Peer>,
  94    receipt: Receipt<R>,
  95    responded: Arc<AtomicBool>,
  96}
  97
  98impl<R: RequestMessage> Response<R> {
  99    fn send(self, payload: R::Response) -> Result<()> {
 100        self.responded.store(true, SeqCst);
 101        self.peer.respond(self.receipt, payload)?;
 102        Ok(())
 103    }
 104}
 105
 106struct StreamingResponse<R: RequestMessage> {
 107    peer: Arc<Peer>,
 108    receipt: Receipt<R>,
 109}
 110
 111impl<R: RequestMessage> StreamingResponse<R> {
 112    fn send(&self, payload: R::Response) -> Result<()> {
 113        self.peer.respond(self.receipt, payload)?;
 114        Ok(())
 115    }
 116}
 117
 118#[derive(Clone, Debug)]
 119pub enum Principal {
 120    User(User),
 121    Impersonated { user: User, admin: User },
 122    DevServer(dev_server::Model),
 123}
 124
 125impl Principal {
 126    fn update_span(&self, span: &tracing::Span) {
 127        match &self {
 128            Principal::User(user) => {
 129                span.record("user_id", &user.id.0);
 130                span.record("login", &user.github_login);
 131            }
 132            Principal::Impersonated { user, admin } => {
 133                span.record("user_id", &user.id.0);
 134                span.record("login", &user.github_login);
 135                span.record("impersonator", &admin.github_login);
 136            }
 137            Principal::DevServer(dev_server) => {
 138                span.record("dev_server_id", &dev_server.id.0);
 139            }
 140        }
 141    }
 142}
 143
 144#[derive(Clone)]
 145struct Session {
 146    principal: Principal,
 147    connection_id: ConnectionId,
 148    db: Arc<tokio::sync::Mutex<DbHandle>>,
 149    peer: Arc<Peer>,
 150    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 151    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 152    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 153    http_client: Arc<IsahcHttpClient>,
 154    rate_limiter: Arc<RateLimiter>,
 155    _executor: Executor,
 156}
 157
 158impl Session {
 159    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.db.lock().await;
 163        #[cfg(test)]
 164        tokio::task::yield_now().await;
 165        guard
 166    }
 167
 168    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 169        #[cfg(test)]
 170        tokio::task::yield_now().await;
 171        let guard = self.connection_pool.lock();
 172        ConnectionPoolGuard {
 173            guard,
 174            _not_send: PhantomData,
 175        }
 176    }
 177
 178    fn for_user(self) -> Option<UserSession> {
 179        UserSession::new(self)
 180    }
 181
 182    fn for_dev_server(self) -> Option<DevServerSession> {
 183        DevServerSession::new(self)
 184    }
 185
 186    fn user_id(&self) -> Option<UserId> {
 187        match &self.principal {
 188            Principal::User(user) => Some(user.id),
 189            Principal::Impersonated { user, .. } => Some(user.id),
 190            Principal::DevServer(_) => None,
 191        }
 192    }
 193
 194    fn is_staff(&self) -> bool {
 195        match &self.principal {
 196            Principal::User(user) => user.admin,
 197            Principal::Impersonated { .. } => true,
 198            Principal::DevServer(_) => false,
 199        }
 200    }
 201
 202    fn dev_server_id(&self) -> Option<DevServerId> {
 203        match &self.principal {
 204            Principal::User(_) | Principal::Impersonated { .. } => None,
 205            Principal::DevServer(dev_server) => Some(dev_server.id),
 206        }
 207    }
 208
 209    fn principal_id(&self) -> PrincipalId {
 210        match &self.principal {
 211            Principal::User(user) => PrincipalId::UserId(user.id),
 212            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 213            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 214        }
 215    }
 216}
 217
 218impl Debug for Session {
 219    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 220        let mut result = f.debug_struct("Session");
 221        match &self.principal {
 222            Principal::User(user) => {
 223                result.field("user", &user.github_login);
 224            }
 225            Principal::Impersonated { user, admin } => {
 226                result.field("user", &user.github_login);
 227                result.field("impersonator", &admin.github_login);
 228            }
 229            Principal::DevServer(dev_server) => {
 230                result.field("dev_server", &dev_server.id);
 231            }
 232        }
 233        result.field("connection_id", &self.connection_id).finish()
 234    }
 235}
 236
 237struct UserSession(Session);
 238
 239impl UserSession {
 240    pub fn new(s: Session) -> Option<Self> {
 241        s.user_id().map(|_| UserSession(s))
 242    }
 243    pub fn user_id(&self) -> UserId {
 244        self.0.user_id().unwrap()
 245    }
 246
 247    pub fn email(&self) -> Option<String> {
 248        match &self.0.principal {
 249            Principal::User(user) => user.email_address.clone(),
 250            Principal::Impersonated { user, .. } => user.email_address.clone(),
 251            Principal::DevServer(..) => None,
 252        }
 253    }
 254}
 255
 256impl Deref for UserSession {
 257    type Target = Session;
 258
 259    fn deref(&self) -> &Self::Target {
 260        &self.0
 261    }
 262}
 263impl DerefMut for UserSession {
 264    fn deref_mut(&mut self) -> &mut Self::Target {
 265        &mut self.0
 266    }
 267}
 268
 269struct DevServerSession(Session);
 270
 271impl DevServerSession {
 272    pub fn new(s: Session) -> Option<Self> {
 273        s.dev_server_id().map(|_| DevServerSession(s))
 274    }
 275    pub fn dev_server_id(&self) -> DevServerId {
 276        self.0.dev_server_id().unwrap()
 277    }
 278
 279    fn dev_server(&self) -> &dev_server::Model {
 280        match &self.0.principal {
 281            Principal::DevServer(dev_server) => dev_server,
 282            _ => unreachable!(),
 283        }
 284    }
 285}
 286
 287impl Deref for DevServerSession {
 288    type Target = Session;
 289
 290    fn deref(&self) -> &Self::Target {
 291        &self.0
 292    }
 293}
 294impl DerefMut for DevServerSession {
 295    fn deref_mut(&mut self) -> &mut Self::Target {
 296        &mut self.0
 297    }
 298}
 299
 300fn user_handler<M: RequestMessage, Fut>(
 301    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 303where
 304    Fut: Send + Future<Output = Result<()>>,
 305{
 306    let handler = Arc::new(handler);
 307    move |message, response, session| {
 308        let handler = handler.clone();
 309        Box::pin(async move {
 310            if let Some(user_session) = session.for_user() {
 311                Ok(handler(message, response, user_session).await?)
 312            } else {
 313                Err(Error::Internal(anyhow!(
 314                    "must be a user to call {}",
 315                    M::NAME
 316                )))
 317            }
 318        })
 319    }
 320}
 321
 322fn dev_server_handler<M: RequestMessage, Fut>(
 323    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 325where
 326    Fut: Send + Future<Output = Result<()>>,
 327{
 328    let handler = Arc::new(handler);
 329    move |message, response, session| {
 330        let handler = handler.clone();
 331        Box::pin(async move {
 332            if let Some(dev_server_session) = session.for_dev_server() {
 333                Ok(handler(message, response, dev_server_session).await?)
 334            } else {
 335                Err(Error::Internal(anyhow!(
 336                    "must be a dev server to call {}",
 337                    M::NAME
 338                )))
 339            }
 340        })
 341    }
 342}
 343
 344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 345    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 347where
 348    InnertRetFut: Send + Future<Output = Result<()>>,
 349{
 350    let handler = Arc::new(handler);
 351    move |message, session| {
 352        let handler = handler.clone();
 353        Box::pin(async move {
 354            if let Some(user_session) = session.for_user() {
 355                Ok(handler(message, user_session).await?)
 356            } else {
 357                Err(Error::Internal(anyhow!(
 358                    "must be a user to call {}",
 359                    M::NAME
 360                )))
 361            }
 362        })
 363    }
 364}
 365
 366struct DbHandle(Arc<Database>);
 367
 368impl Deref for DbHandle {
 369    type Target = Database;
 370
 371    fn deref(&self) -> &Self::Target {
 372        self.0.as_ref()
 373    }
 374}
 375
 376pub struct Server {
 377    id: parking_lot::Mutex<ServerId>,
 378    peer: Arc<Peer>,
 379    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 380    app_state: Arc<AppState>,
 381    handlers: HashMap<TypeId, MessageHandler>,
 382    teardown: watch::Sender<bool>,
 383}
 384
 385pub(crate) struct ConnectionPoolGuard<'a> {
 386    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 387    _not_send: PhantomData<Rc<()>>,
 388}
 389
 390#[derive(Serialize)]
 391pub struct ServerSnapshot<'a> {
 392    peer: &'a Peer,
 393    #[serde(serialize_with = "serialize_deref")]
 394    connection_pool: ConnectionPoolGuard<'a>,
 395}
 396
 397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 398where
 399    S: Serializer,
 400    T: Deref<Target = U>,
 401    U: Serialize,
 402{
 403    Serialize::serialize(value.deref(), serializer)
 404}
 405
 406impl Server {
 407    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 408        let mut server = Self {
 409            id: parking_lot::Mutex::new(id),
 410            peer: Peer::new(id.0 as u32),
 411            app_state: app_state.clone(),
 412            connection_pool: Default::default(),
 413            handlers: Default::default(),
 414            teardown: watch::channel(false).0,
 415        };
 416
 417        server
 418            .add_request_handler(ping)
 419            .add_request_handler(user_handler(create_room))
 420            .add_request_handler(user_handler(join_room))
 421            .add_request_handler(user_handler(rejoin_room))
 422            .add_request_handler(user_handler(leave_room))
 423            .add_request_handler(user_handler(set_room_participant_role))
 424            .add_request_handler(user_handler(call))
 425            .add_request_handler(user_handler(cancel_call))
 426            .add_message_handler(user_message_handler(decline_call))
 427            .add_request_handler(user_handler(update_participant_location))
 428            .add_request_handler(user_handler(share_project))
 429            .add_message_handler(unshare_project)
 430            .add_request_handler(user_handler(join_project))
 431            .add_request_handler(user_handler(join_hosted_project))
 432            .add_request_handler(user_handler(rejoin_dev_server_projects))
 433            .add_request_handler(user_handler(create_dev_server_project))
 434            .add_request_handler(user_handler(delete_dev_server_project))
 435            .add_request_handler(user_handler(create_dev_server))
 436            .add_request_handler(user_handler(regenerate_dev_server_token))
 437            .add_request_handler(user_handler(rename_dev_server))
 438            .add_request_handler(user_handler(delete_dev_server))
 439            .add_request_handler(dev_server_handler(share_dev_server_project))
 440            .add_request_handler(dev_server_handler(shutdown_dev_server))
 441            .add_request_handler(dev_server_handler(reconnect_dev_server))
 442            .add_message_handler(user_message_handler(leave_project))
 443            .add_request_handler(update_project)
 444            .add_request_handler(update_worktree)
 445            .add_message_handler(start_language_server)
 446            .add_message_handler(update_language_server)
 447            .add_message_handler(update_diagnostic_summary)
 448            .add_message_handler(update_worktree_settings)
 449            .add_request_handler(user_handler(
 450                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 451            ))
 452            .add_request_handler(user_handler(
 453                forward_project_request_for_owner::<proto::TaskTemplates>,
 454            ))
 455            .add_request_handler(user_handler(
 456                forward_read_only_project_request::<proto::GetHover>,
 457            ))
 458            .add_request_handler(user_handler(
 459                forward_read_only_project_request::<proto::GetDefinition>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_read_only_project_request::<proto::GetTypeDefinition>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetReferences>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::SearchProject>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetProjectSymbols>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_read_only_project_request::<proto::OpenBufferById>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_read_only_project_request::<proto::InlayHints>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_read_only_project_request::<proto::OpenBufferByPath>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_mutating_project_request::<proto::GetCompletions>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::GetCodeActions>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::ApplyCodeAction>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::PrepareRename>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_mutating_project_request::<proto::PerformRename>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::ReloadBuffers>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::FormatBuffers>,
 520            ))
 521            .add_request_handler(user_handler(
 522                forward_mutating_project_request::<proto::CreateProjectEntry>,
 523            ))
 524            .add_request_handler(user_handler(
 525                forward_mutating_project_request::<proto::RenameProjectEntry>,
 526            ))
 527            .add_request_handler(user_handler(
 528                forward_mutating_project_request::<proto::CopyProjectEntry>,
 529            ))
 530            .add_request_handler(user_handler(
 531                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 532            ))
 533            .add_request_handler(user_handler(
 534                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 535            ))
 536            .add_request_handler(user_handler(
 537                forward_mutating_project_request::<proto::OnTypeFormatting>,
 538            ))
 539            .add_request_handler(user_handler(
 540                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 541            ))
 542            .add_request_handler(user_handler(
 543                forward_mutating_project_request::<proto::BlameBuffer>,
 544            ))
 545            .add_request_handler(user_handler(
 546                forward_mutating_project_request::<proto::MultiLspQuery>,
 547            ))
 548            .add_request_handler(user_handler(
 549                forward_mutating_project_request::<proto::RestartLanguageServers>,
 550            ))
 551            .add_request_handler(user_handler(
 552                forward_mutating_project_request::<proto::LinkedEditingRange>,
 553            ))
 554            .add_message_handler(create_buffer_for_peer)
 555            .add_request_handler(update_buffer)
 556            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 557            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 558            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 559            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 560            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 561            .add_request_handler(get_users)
 562            .add_request_handler(user_handler(fuzzy_search_users))
 563            .add_request_handler(user_handler(request_contact))
 564            .add_request_handler(user_handler(remove_contact))
 565            .add_request_handler(user_handler(respond_to_contact_request))
 566            .add_message_handler(subscribe_to_channels)
 567            .add_request_handler(user_handler(create_channel))
 568            .add_request_handler(user_handler(delete_channel))
 569            .add_request_handler(user_handler(invite_channel_member))
 570            .add_request_handler(user_handler(remove_channel_member))
 571            .add_request_handler(user_handler(set_channel_member_role))
 572            .add_request_handler(user_handler(set_channel_visibility))
 573            .add_request_handler(user_handler(rename_channel))
 574            .add_request_handler(user_handler(join_channel_buffer))
 575            .add_request_handler(user_handler(leave_channel_buffer))
 576            .add_message_handler(user_message_handler(update_channel_buffer))
 577            .add_request_handler(user_handler(rejoin_channel_buffers))
 578            .add_request_handler(user_handler(get_channel_members))
 579            .add_request_handler(user_handler(respond_to_channel_invite))
 580            .add_request_handler(user_handler(join_channel))
 581            .add_request_handler(user_handler(join_channel_chat))
 582            .add_message_handler(user_message_handler(leave_channel_chat))
 583            .add_request_handler(user_handler(send_channel_message))
 584            .add_request_handler(user_handler(remove_channel_message))
 585            .add_request_handler(user_handler(update_channel_message))
 586            .add_request_handler(user_handler(get_channel_messages))
 587            .add_request_handler(user_handler(get_channel_messages_by_id))
 588            .add_request_handler(user_handler(get_notifications))
 589            .add_request_handler(user_handler(mark_notification_as_read))
 590            .add_request_handler(user_handler(move_channel))
 591            .add_request_handler(user_handler(follow))
 592            .add_message_handler(user_message_handler(unfollow))
 593            .add_message_handler(user_message_handler(update_followers))
 594            .add_request_handler(user_handler(get_private_user_info))
 595            .add_message_handler(user_message_handler(acknowledge_channel_message))
 596            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 597            .add_request_handler(user_handler(get_supermaven_api_key))
 598            .add_streaming_request_handler({
 599                let app_state = app_state.clone();
 600                move |request, response, session| {
 601                    complete_with_language_model(
 602                        request,
 603                        response,
 604                        session,
 605                        app_state.config.openai_api_key.clone(),
 606                        app_state.config.google_ai_api_key.clone(),
 607                        app_state.config.anthropic_api_key.clone(),
 608                    )
 609                }
 610            })
 611            .add_request_handler({
 612                let app_state = app_state.clone();
 613                user_handler(move |request, response, session| {
 614                    count_tokens_with_language_model(
 615                        request,
 616                        response,
 617                        session,
 618                        app_state.config.google_ai_api_key.clone(),
 619                    )
 620                })
 621            })
 622            .add_request_handler({
 623                user_handler(move |request, response, session| {
 624                    get_cached_embeddings(request, response, session)
 625                })
 626            })
 627            .add_request_handler({
 628                let app_state = app_state.clone();
 629                user_handler(move |request, response, session| {
 630                    compute_embeddings(
 631                        request,
 632                        response,
 633                        session,
 634                        app_state.config.openai_api_key.clone(),
 635                    )
 636                })
 637            });
 638
 639        Arc::new(server)
 640    }
 641
 642    pub async fn start(&self) -> Result<()> {
 643        let server_id = *self.id.lock();
 644        let app_state = self.app_state.clone();
 645        let peer = self.peer.clone();
 646        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 647        let pool = self.connection_pool.clone();
 648        let live_kit_client = self.app_state.live_kit_client.clone();
 649
 650        let span = info_span!("start server");
 651        self.app_state.executor.spawn_detached(
 652            async move {
 653                tracing::info!("waiting for cleanup timeout");
 654                timeout.await;
 655                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 656                if let Some((room_ids, channel_ids)) = app_state
 657                    .db
 658                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 659                    .await
 660                    .trace_err()
 661                {
 662                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 663                    tracing::info!(
 664                        stale_channel_buffer_count = channel_ids.len(),
 665                        "retrieved stale channel buffers"
 666                    );
 667
 668                    for channel_id in channel_ids {
 669                        if let Some(refreshed_channel_buffer) = app_state
 670                            .db
 671                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 672                            .await
 673                            .trace_err()
 674                        {
 675                            for connection_id in refreshed_channel_buffer.connection_ids {
 676                                peer.send(
 677                                    connection_id,
 678                                    proto::UpdateChannelBufferCollaborators {
 679                                        channel_id: channel_id.to_proto(),
 680                                        collaborators: refreshed_channel_buffer
 681                                            .collaborators
 682                                            .clone(),
 683                                    },
 684                                )
 685                                .trace_err();
 686                            }
 687                        }
 688                    }
 689
 690                    for room_id in room_ids {
 691                        let mut contacts_to_update = HashSet::default();
 692                        let mut canceled_calls_to_user_ids = Vec::new();
 693                        let mut live_kit_room = String::new();
 694                        let mut delete_live_kit_room = false;
 695
 696                        if let Some(mut refreshed_room) = app_state
 697                            .db
 698                            .clear_stale_room_participants(room_id, server_id)
 699                            .await
 700                            .trace_err()
 701                        {
 702                            tracing::info!(
 703                                room_id = room_id.0,
 704                                new_participant_count = refreshed_room.room.participants.len(),
 705                                "refreshed room"
 706                            );
 707                            room_updated(&refreshed_room.room, &peer);
 708                            if let Some(channel) = refreshed_room.channel.as_ref() {
 709                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 710                            }
 711                            contacts_to_update
 712                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 713                            contacts_to_update
 714                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 715                            canceled_calls_to_user_ids =
 716                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 717                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 718                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 719                        }
 720
 721                        {
 722                            let pool = pool.lock();
 723                            for canceled_user_id in canceled_calls_to_user_ids {
 724                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 725                                    peer.send(
 726                                        connection_id,
 727                                        proto::CallCanceled {
 728                                            room_id: room_id.to_proto(),
 729                                        },
 730                                    )
 731                                    .trace_err();
 732                                }
 733                            }
 734                        }
 735
 736                        for user_id in contacts_to_update {
 737                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 738                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 739                            if let Some((busy, contacts)) = busy.zip(contacts) {
 740                                let pool = pool.lock();
 741                                let updated_contact = contact_for_user(user_id, busy, &pool);
 742                                for contact in contacts {
 743                                    if let db::Contact::Accepted {
 744                                        user_id: contact_user_id,
 745                                        ..
 746                                    } = contact
 747                                    {
 748                                        for contact_conn_id in
 749                                            pool.user_connection_ids(contact_user_id)
 750                                        {
 751                                            peer.send(
 752                                                contact_conn_id,
 753                                                proto::UpdateContacts {
 754                                                    contacts: vec![updated_contact.clone()],
 755                                                    remove_contacts: Default::default(),
 756                                                    incoming_requests: Default::default(),
 757                                                    remove_incoming_requests: Default::default(),
 758                                                    outgoing_requests: Default::default(),
 759                                                    remove_outgoing_requests: Default::default(),
 760                                                },
 761                                            )
 762                                            .trace_err();
 763                                        }
 764                                    }
 765                                }
 766                            }
 767                        }
 768
 769                        if let Some(live_kit) = live_kit_client.as_ref() {
 770                            if delete_live_kit_room {
 771                                live_kit.delete_room(live_kit_room).await.trace_err();
 772                            }
 773                        }
 774                    }
 775                }
 776
 777                app_state
 778                    .db
 779                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 780                    .await
 781                    .trace_err();
 782            }
 783            .instrument(span),
 784        );
 785        Ok(())
 786    }
 787
 788    pub fn teardown(&self) {
 789        self.peer.teardown();
 790        self.connection_pool.lock().reset();
 791        let _ = self.teardown.send(true);
 792    }
 793
 794    #[cfg(test)]
 795    pub fn reset(&self, id: ServerId) {
 796        self.teardown();
 797        *self.id.lock() = id;
 798        self.peer.reset(id.0 as u32);
 799        let _ = self.teardown.send(false);
 800    }
 801
 802    #[cfg(test)]
 803    pub fn id(&self) -> ServerId {
 804        *self.id.lock()
 805    }
 806
 807    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 808    where
 809        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 810        Fut: 'static + Send + Future<Output = Result<()>>,
 811        M: EnvelopedMessage,
 812    {
 813        let prev_handler = self.handlers.insert(
 814            TypeId::of::<M>(),
 815            Box::new(move |envelope, session| {
 816                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 817                let received_at = envelope.received_at;
 818                tracing::info!("message received");
 819                let start_time = Instant::now();
 820                let future = (handler)(*envelope, session);
 821                async move {
 822                    let result = future.await;
 823                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 824                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 825                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 826                    let payload_type = M::NAME;
 827
 828                    match result {
 829                        Err(error) => {
 830                            tracing::error!(
 831                                ?error,
 832                                total_duration_ms,
 833                                processing_duration_ms,
 834                                queue_duration_ms,
 835                                payload_type,
 836                                "error handling message"
 837                            )
 838                        }
 839                        Ok(()) => tracing::info!(
 840                            total_duration_ms,
 841                            processing_duration_ms,
 842                            queue_duration_ms,
 843                            "finished handling message"
 844                        ),
 845                    }
 846                }
 847                .boxed()
 848            }),
 849        );
 850        if prev_handler.is_some() {
 851            panic!("registered a handler for the same message twice");
 852        }
 853        self
 854    }
 855
 856    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 857    where
 858        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 859        Fut: 'static + Send + Future<Output = Result<()>>,
 860        M: EnvelopedMessage,
 861    {
 862        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 863        self
 864    }
 865
 866    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 867    where
 868        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 869        Fut: Send + Future<Output = Result<()>>,
 870        M: RequestMessage,
 871    {
 872        let handler = Arc::new(handler);
 873        self.add_handler(move |envelope, session| {
 874            let receipt = envelope.receipt();
 875            let handler = handler.clone();
 876            async move {
 877                let peer = session.peer.clone();
 878                let responded = Arc::new(AtomicBool::default());
 879                let response = Response {
 880                    peer: peer.clone(),
 881                    responded: responded.clone(),
 882                    receipt,
 883                };
 884                match (handler)(envelope.payload, response, session).await {
 885                    Ok(()) => {
 886                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 887                            Ok(())
 888                        } else {
 889                            Err(anyhow!("handler did not send a response"))?
 890                        }
 891                    }
 892                    Err(error) => {
 893                        let proto_err = match &error {
 894                            Error::Internal(err) => err.to_proto(),
 895                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 896                        };
 897                        peer.respond_with_error(receipt, proto_err)?;
 898                        Err(error)
 899                    }
 900                }
 901            }
 902        })
 903    }
 904
 905    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 906    where
 907        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 908        Fut: Send + Future<Output = Result<()>>,
 909        M: RequestMessage,
 910    {
 911        let handler = Arc::new(handler);
 912        self.add_handler(move |envelope, session| {
 913            let receipt = envelope.receipt();
 914            let handler = handler.clone();
 915            async move {
 916                let peer = session.peer.clone();
 917                let response = StreamingResponse {
 918                    peer: peer.clone(),
 919                    receipt,
 920                };
 921                match (handler)(envelope.payload, response, session).await {
 922                    Ok(()) => {
 923                        peer.end_stream(receipt)?;
 924                        Ok(())
 925                    }
 926                    Err(error) => {
 927                        let proto_err = match &error {
 928                            Error::Internal(err) => err.to_proto(),
 929                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 930                        };
 931                        peer.respond_with_error(receipt, proto_err)?;
 932                        Err(error)
 933                    }
 934                }
 935            }
 936        })
 937    }
 938
 939    #[allow(clippy::too_many_arguments)]
 940    pub fn handle_connection(
 941        self: &Arc<Self>,
 942        connection: Connection,
 943        address: String,
 944        principal: Principal,
 945        zed_version: ZedVersion,
 946        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 947        executor: Executor,
 948    ) -> impl Future<Output = ()> {
 949        let this = self.clone();
 950        let span = info_span!("handle connection", %address,
 951            connection_id=field::Empty,
 952            user_id=field::Empty,
 953            login=field::Empty,
 954            impersonator=field::Empty,
 955            dev_server_id=field::Empty
 956        );
 957        principal.update_span(&span);
 958
 959        let mut teardown = self.teardown.subscribe();
 960        async move {
 961            if *teardown.borrow() {
 962                tracing::error!("server is tearing down");
 963                return
 964            }
 965            let (connection_id, handle_io, mut incoming_rx) = this
 966                .peer
 967                .add_connection(connection, {
 968                    let executor = executor.clone();
 969                    move |duration| executor.sleep(duration)
 970                });
 971            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 972            tracing::info!("connection opened");
 973
 974            let http_client = match IsahcHttpClient::new() {
 975                Ok(http_client) => Arc::new(http_client),
 976                Err(error) => {
 977                    tracing::error!(?error, "failed to create HTTP client");
 978                    return;
 979                }
 980            };
 981
 982            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 983                Some(Arc::new(SupermavenAdminApi::new(
 984                    supermaven_admin_api_key.to_string(),
 985                    http_client.clone(),
 986                )))
 987            } else {
 988                None
 989            };
 990
 991            let session = Session {
 992                principal: principal.clone(),
 993                connection_id,
 994                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 995                peer: this.peer.clone(),
 996                connection_pool: this.connection_pool.clone(),
 997                live_kit_client: this.app_state.live_kit_client.clone(),
 998                http_client,
 999                rate_limiter: this.app_state.rate_limiter.clone(),
1000                _executor: executor.clone(),
1001                supermaven_client,
1002            };
1003
1004            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1005                tracing::error!(?error, "failed to send initial client update");
1006                return;
1007            }
1008
1009            let handle_io = handle_io.fuse();
1010            futures::pin_mut!(handle_io);
1011
1012            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1013            // This prevents deadlocks when e.g., client A performs a request to client B and
1014            // client B performs a request to client A. If both clients stop processing further
1015            // messages until their respective request completes, they won't have a chance to
1016            // respond to the other client's request and cause a deadlock.
1017            //
1018            // This arrangement ensures we will attempt to process earlier messages first, but fall
1019            // back to processing messages arrived later in the spirit of making progress.
1020            let mut foreground_message_handlers = FuturesUnordered::new();
1021            let concurrent_handlers = Arc::new(Semaphore::new(256));
1022            loop {
1023                let next_message = async {
1024                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1025                    let message = incoming_rx.next().await;
1026                    (permit, message)
1027                }.fuse();
1028                futures::pin_mut!(next_message);
1029                futures::select_biased! {
1030                    _ = teardown.changed().fuse() => return,
1031                    result = handle_io => {
1032                        if let Err(error) = result {
1033                            tracing::error!(?error, "error handling I/O");
1034                        }
1035                        break;
1036                    }
1037                    _ = foreground_message_handlers.next() => {}
1038                    next_message = next_message => {
1039                        let (permit, message) = next_message;
1040                        if let Some(message) = message {
1041                            let type_name = message.payload_type_name();
1042                            // note: we copy all the fields from the parent span so we can query them in the logs.
1043                            // (https://github.com/tokio-rs/tracing/issues/2670).
1044                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1045                                user_id=field::Empty,
1046                                login=field::Empty,
1047                                impersonator=field::Empty,
1048                                dev_server_id=field::Empty
1049                            );
1050                            principal.update_span(&span);
1051                            let span_enter = span.enter();
1052                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1053                                let is_background = message.is_background();
1054                                let handle_message = (handler)(message, session.clone());
1055                                drop(span_enter);
1056
1057                                let handle_message = async move {
1058                                    handle_message.await;
1059                                    drop(permit);
1060                                }.instrument(span);
1061                                if is_background {
1062                                    executor.spawn_detached(handle_message);
1063                                } else {
1064                                    foreground_message_handlers.push(handle_message);
1065                                }
1066                            } else {
1067                                tracing::error!("no message handler");
1068                            }
1069                        } else {
1070                            tracing::info!("connection closed");
1071                            break;
1072                        }
1073                    }
1074                }
1075            }
1076
1077            drop(foreground_message_handlers);
1078            tracing::info!("signing out");
1079            if let Err(error) = connection_lost(session, teardown, executor).await {
1080                tracing::error!(?error, "error signing out");
1081            }
1082
1083        }.instrument(span)
1084    }
1085
1086    async fn send_initial_client_update(
1087        &self,
1088        connection_id: ConnectionId,
1089        principal: &Principal,
1090        zed_version: ZedVersion,
1091        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1092        session: &Session,
1093    ) -> Result<()> {
1094        self.peer.send(
1095            connection_id,
1096            proto::Hello {
1097                peer_id: Some(connection_id.into()),
1098            },
1099        )?;
1100        tracing::info!("sent hello message");
1101        if let Some(send_connection_id) = send_connection_id.take() {
1102            let _ = send_connection_id.send(connection_id);
1103        }
1104
1105        match principal {
1106            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1107                if !user.connected_once {
1108                    self.peer.send(connection_id, proto::ShowContacts {})?;
1109                    self.app_state
1110                        .db
1111                        .set_user_connected_once(user.id, true)
1112                        .await?;
1113                }
1114
1115                let (contacts, dev_server_projects) = future::try_join(
1116                    self.app_state.db.get_contacts(user.id),
1117                    self.app_state.db.dev_server_projects_update(user.id),
1118                )
1119                .await?;
1120
1121                {
1122                    let mut pool = self.connection_pool.lock();
1123                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1124                    self.peer.send(
1125                        connection_id,
1126                        build_initial_contacts_update(contacts, &pool),
1127                    )?;
1128                }
1129
1130                if should_auto_subscribe_to_channels(zed_version) {
1131                    subscribe_user_to_channels(user.id, session).await?;
1132                }
1133
1134                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1135
1136                if let Some(incoming_call) =
1137                    self.app_state.db.incoming_call_for_user(user.id).await?
1138                {
1139                    self.peer.send(connection_id, incoming_call)?;
1140                }
1141
1142                update_user_contacts(user.id, &session).await?;
1143            }
1144            Principal::DevServer(dev_server) => {
1145                {
1146                    let mut pool = self.connection_pool.lock();
1147                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1148                    {
1149                        self.peer.send(
1150                            stale_connection_id,
1151                            proto::ShutdownDevServer {
1152                                reason: Some(
1153                                    "another dev server connected with the same token".to_string(),
1154                                ),
1155                            },
1156                        )?;
1157                        pool.remove_connection(stale_connection_id)?;
1158                    };
1159                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1160                }
1161
1162                let projects = self
1163                    .app_state
1164                    .db
1165                    .get_projects_for_dev_server(dev_server.id)
1166                    .await?;
1167                self.peer
1168                    .send(connection_id, proto::DevServerInstructions { projects })?;
1169
1170                let status = self
1171                    .app_state
1172                    .db
1173                    .dev_server_projects_update(dev_server.user_id)
1174                    .await?;
1175                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1176            }
1177        }
1178
1179        Ok(())
1180    }
1181
1182    pub async fn invite_code_redeemed(
1183        self: &Arc<Self>,
1184        inviter_id: UserId,
1185        invitee_id: UserId,
1186    ) -> Result<()> {
1187        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1188            if let Some(code) = &user.invite_code {
1189                let pool = self.connection_pool.lock();
1190                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1191                for connection_id in pool.user_connection_ids(inviter_id) {
1192                    self.peer.send(
1193                        connection_id,
1194                        proto::UpdateContacts {
1195                            contacts: vec![invitee_contact.clone()],
1196                            ..Default::default()
1197                        },
1198                    )?;
1199                    self.peer.send(
1200                        connection_id,
1201                        proto::UpdateInviteInfo {
1202                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1203                            count: user.invite_count as u32,
1204                        },
1205                    )?;
1206                }
1207            }
1208        }
1209        Ok(())
1210    }
1211
1212    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1213        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1214            if let Some(invite_code) = &user.invite_code {
1215                let pool = self.connection_pool.lock();
1216                for connection_id in pool.user_connection_ids(user_id) {
1217                    self.peer.send(
1218                        connection_id,
1219                        proto::UpdateInviteInfo {
1220                            url: format!(
1221                                "{}{}",
1222                                self.app_state.config.invite_link_prefix, invite_code
1223                            ),
1224                            count: user.invite_count as u32,
1225                        },
1226                    )?;
1227                }
1228            }
1229        }
1230        Ok(())
1231    }
1232
1233    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1234        ServerSnapshot {
1235            connection_pool: ConnectionPoolGuard {
1236                guard: self.connection_pool.lock(),
1237                _not_send: PhantomData,
1238            },
1239            peer: &self.peer,
1240        }
1241    }
1242}
1243
1244impl<'a> Deref for ConnectionPoolGuard<'a> {
1245    type Target = ConnectionPool;
1246
1247    fn deref(&self) -> &Self::Target {
1248        &self.guard
1249    }
1250}
1251
1252impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1253    fn deref_mut(&mut self) -> &mut Self::Target {
1254        &mut self.guard
1255    }
1256}
1257
1258impl<'a> Drop for ConnectionPoolGuard<'a> {
1259    fn drop(&mut self) {
1260        #[cfg(test)]
1261        self.check_invariants();
1262    }
1263}
1264
1265fn broadcast<F>(
1266    sender_id: Option<ConnectionId>,
1267    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1268    mut f: F,
1269) where
1270    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1271{
1272    for receiver_id in receiver_ids {
1273        if Some(receiver_id) != sender_id {
1274            if let Err(error) = f(receiver_id) {
1275                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1276            }
1277        }
1278    }
1279}
1280
1281pub struct ProtocolVersion(u32);
1282
1283impl Header for ProtocolVersion {
1284    fn name() -> &'static HeaderName {
1285        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1286        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1287    }
1288
1289    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1290    where
1291        Self: Sized,
1292        I: Iterator<Item = &'i axum::http::HeaderValue>,
1293    {
1294        let version = values
1295            .next()
1296            .ok_or_else(axum::headers::Error::invalid)?
1297            .to_str()
1298            .map_err(|_| axum::headers::Error::invalid())?
1299            .parse()
1300            .map_err(|_| axum::headers::Error::invalid())?;
1301        Ok(Self(version))
1302    }
1303
1304    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1305        values.extend([self.0.to_string().parse().unwrap()]);
1306    }
1307}
1308
1309pub struct AppVersionHeader(SemanticVersion);
1310impl Header for AppVersionHeader {
1311    fn name() -> &'static HeaderName {
1312        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1313        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1314    }
1315
1316    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1317    where
1318        Self: Sized,
1319        I: Iterator<Item = &'i axum::http::HeaderValue>,
1320    {
1321        let version = values
1322            .next()
1323            .ok_or_else(axum::headers::Error::invalid)?
1324            .to_str()
1325            .map_err(|_| axum::headers::Error::invalid())?
1326            .parse()
1327            .map_err(|_| axum::headers::Error::invalid())?;
1328        Ok(Self(version))
1329    }
1330
1331    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1332        values.extend([self.0.to_string().parse().unwrap()]);
1333    }
1334}
1335
1336pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1337    Router::new()
1338        .route("/rpc", get(handle_websocket_request))
1339        .layer(
1340            ServiceBuilder::new()
1341                .layer(Extension(server.app_state.clone()))
1342                .layer(middleware::from_fn(auth::validate_header)),
1343        )
1344        .route("/metrics", get(handle_metrics))
1345        .layer(Extension(server))
1346}
1347
1348pub async fn handle_websocket_request(
1349    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1350    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1351    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1352    Extension(server): Extension<Arc<Server>>,
1353    Extension(principal): Extension<Principal>,
1354    ws: WebSocketUpgrade,
1355) -> axum::response::Response {
1356    if protocol_version != rpc::PROTOCOL_VERSION {
1357        return (
1358            StatusCode::UPGRADE_REQUIRED,
1359            "client must be upgraded".to_string(),
1360        )
1361            .into_response();
1362    }
1363
1364    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1365        return (
1366            StatusCode::UPGRADE_REQUIRED,
1367            "no version header found".to_string(),
1368        )
1369            .into_response();
1370    };
1371
1372    if !version.can_collaborate() {
1373        return (
1374            StatusCode::UPGRADE_REQUIRED,
1375            "client must be upgraded".to_string(),
1376        )
1377            .into_response();
1378    }
1379
1380    let socket_address = socket_address.to_string();
1381    ws.on_upgrade(move |socket| {
1382        let socket = socket
1383            .map_ok(to_tungstenite_message)
1384            .err_into()
1385            .with(|message| async move { Ok(to_axum_message(message)) });
1386        let connection = Connection::new(Box::pin(socket));
1387        async move {
1388            server
1389                .handle_connection(
1390                    connection,
1391                    socket_address,
1392                    principal,
1393                    version,
1394                    None,
1395                    Executor::Production,
1396                )
1397                .await;
1398        }
1399    })
1400}
1401
1402pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1403    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1404    let connections_metric = CONNECTIONS_METRIC
1405        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1406
1407    let connections = server
1408        .connection_pool
1409        .lock()
1410        .connections()
1411        .filter(|connection| !connection.admin)
1412        .count();
1413    connections_metric.set(connections as _);
1414
1415    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1416    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1417        register_int_gauge!(
1418            "shared_projects",
1419            "number of open projects with one or more guests"
1420        )
1421        .unwrap()
1422    });
1423
1424    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1425    shared_projects_metric.set(shared_projects as _);
1426
1427    let encoder = prometheus::TextEncoder::new();
1428    let metric_families = prometheus::gather();
1429    let encoded_metrics = encoder
1430        .encode_to_string(&metric_families)
1431        .map_err(|err| anyhow!("{}", err))?;
1432    Ok(encoded_metrics)
1433}
1434
1435#[instrument(err, skip(executor))]
1436async fn connection_lost(
1437    session: Session,
1438    mut teardown: watch::Receiver<bool>,
1439    executor: Executor,
1440) -> Result<()> {
1441    session.peer.disconnect(session.connection_id);
1442    session
1443        .connection_pool()
1444        .await
1445        .remove_connection(session.connection_id)?;
1446
1447    session
1448        .db()
1449        .await
1450        .connection_lost(session.connection_id)
1451        .await
1452        .trace_err();
1453
1454    futures::select_biased! {
1455        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1456            match &session.principal {
1457                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1458                    let session = session.for_user().unwrap();
1459
1460                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1461                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1462                    leave_channel_buffers_for_session(&session)
1463                        .await
1464                        .trace_err();
1465
1466                    if !session
1467                        .connection_pool()
1468                        .await
1469                        .is_user_online(session.user_id())
1470                    {
1471                        let db = session.db().await;
1472                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1473                            room_updated(&room, &session.peer);
1474                        }
1475                    }
1476
1477                    update_user_contacts(session.user_id(), &session).await?;
1478                },
1479            Principal::DevServer(_) => {
1480                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1481            },
1482        }
1483        },
1484        _ = teardown.changed().fuse() => {}
1485    }
1486
1487    Ok(())
1488}
1489
1490/// Acknowledges a ping from a client, used to keep the connection alive.
1491async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1492    response.send(proto::Ack {})?;
1493    Ok(())
1494}
1495
1496/// Creates a new room for calling (outside of channels)
1497async fn create_room(
1498    _request: proto::CreateRoom,
1499    response: Response<proto::CreateRoom>,
1500    session: UserSession,
1501) -> Result<()> {
1502    let live_kit_room = nanoid::nanoid!(30);
1503
1504    let live_kit_connection_info = util::maybe!(async {
1505        let live_kit = session.live_kit_client.as_ref();
1506        let live_kit = live_kit?;
1507        let user_id = session.user_id().to_string();
1508
1509        let token = live_kit
1510            .room_token(&live_kit_room, &user_id.to_string())
1511            .trace_err()?;
1512
1513        Some(proto::LiveKitConnectionInfo {
1514            server_url: live_kit.url().into(),
1515            token,
1516            can_publish: true,
1517        })
1518    })
1519    .await;
1520
1521    let room = session
1522        .db()
1523        .await
1524        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1525        .await?;
1526
1527    response.send(proto::CreateRoomResponse {
1528        room: Some(room.clone()),
1529        live_kit_connection_info,
1530    })?;
1531
1532    update_user_contacts(session.user_id(), &session).await?;
1533    Ok(())
1534}
1535
1536/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1537async fn join_room(
1538    request: proto::JoinRoom,
1539    response: Response<proto::JoinRoom>,
1540    session: UserSession,
1541) -> Result<()> {
1542    let room_id = RoomId::from_proto(request.id);
1543
1544    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1545
1546    if let Some(channel_id) = channel_id {
1547        return join_channel_internal(channel_id, Box::new(response), session).await;
1548    }
1549
1550    let joined_room = {
1551        let room = session
1552            .db()
1553            .await
1554            .join_room(room_id, session.user_id(), session.connection_id)
1555            .await?;
1556        room_updated(&room.room, &session.peer);
1557        room.into_inner()
1558    };
1559
1560    for connection_id in session
1561        .connection_pool()
1562        .await
1563        .user_connection_ids(session.user_id())
1564    {
1565        session
1566            .peer
1567            .send(
1568                connection_id,
1569                proto::CallCanceled {
1570                    room_id: room_id.to_proto(),
1571                },
1572            )
1573            .trace_err();
1574    }
1575
1576    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1577        if let Some(token) = live_kit
1578            .room_token(
1579                &joined_room.room.live_kit_room,
1580                &session.user_id().to_string(),
1581            )
1582            .trace_err()
1583        {
1584            Some(proto::LiveKitConnectionInfo {
1585                server_url: live_kit.url().into(),
1586                token,
1587                can_publish: true,
1588            })
1589        } else {
1590            None
1591        }
1592    } else {
1593        None
1594    };
1595
1596    response.send(proto::JoinRoomResponse {
1597        room: Some(joined_room.room),
1598        channel_id: None,
1599        live_kit_connection_info,
1600    })?;
1601
1602    update_user_contacts(session.user_id(), &session).await?;
1603    Ok(())
1604}
1605
1606/// Rejoin room is used to reconnect to a room after connection errors.
1607async fn rejoin_room(
1608    request: proto::RejoinRoom,
1609    response: Response<proto::RejoinRoom>,
1610    session: UserSession,
1611) -> Result<()> {
1612    let room;
1613    let channel;
1614    {
1615        let mut rejoined_room = session
1616            .db()
1617            .await
1618            .rejoin_room(request, session.user_id(), session.connection_id)
1619            .await?;
1620
1621        response.send(proto::RejoinRoomResponse {
1622            room: Some(rejoined_room.room.clone()),
1623            reshared_projects: rejoined_room
1624                .reshared_projects
1625                .iter()
1626                .map(|project| proto::ResharedProject {
1627                    id: project.id.to_proto(),
1628                    collaborators: project
1629                        .collaborators
1630                        .iter()
1631                        .map(|collaborator| collaborator.to_proto())
1632                        .collect(),
1633                })
1634                .collect(),
1635            rejoined_projects: rejoined_room
1636                .rejoined_projects
1637                .iter()
1638                .map(|rejoined_project| rejoined_project.to_proto())
1639                .collect(),
1640        })?;
1641        room_updated(&rejoined_room.room, &session.peer);
1642
1643        for project in &rejoined_room.reshared_projects {
1644            for collaborator in &project.collaborators {
1645                session
1646                    .peer
1647                    .send(
1648                        collaborator.connection_id,
1649                        proto::UpdateProjectCollaborator {
1650                            project_id: project.id.to_proto(),
1651                            old_peer_id: Some(project.old_connection_id.into()),
1652                            new_peer_id: Some(session.connection_id.into()),
1653                        },
1654                    )
1655                    .trace_err();
1656            }
1657
1658            broadcast(
1659                Some(session.connection_id),
1660                project
1661                    .collaborators
1662                    .iter()
1663                    .map(|collaborator| collaborator.connection_id),
1664                |connection_id| {
1665                    session.peer.forward_send(
1666                        session.connection_id,
1667                        connection_id,
1668                        proto::UpdateProject {
1669                            project_id: project.id.to_proto(),
1670                            worktrees: project.worktrees.clone(),
1671                        },
1672                    )
1673                },
1674            );
1675        }
1676
1677        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1678
1679        let rejoined_room = rejoined_room.into_inner();
1680
1681        room = rejoined_room.room;
1682        channel = rejoined_room.channel;
1683    }
1684
1685    if let Some(channel) = channel {
1686        channel_updated(
1687            &channel,
1688            &room,
1689            &session.peer,
1690            &*session.connection_pool().await,
1691        );
1692    }
1693
1694    update_user_contacts(session.user_id(), &session).await?;
1695    Ok(())
1696}
1697
1698fn notify_rejoined_projects(
1699    rejoined_projects: &mut Vec<RejoinedProject>,
1700    session: &UserSession,
1701) -> Result<()> {
1702    for project in rejoined_projects.iter() {
1703        for collaborator in &project.collaborators {
1704            session
1705                .peer
1706                .send(
1707                    collaborator.connection_id,
1708                    proto::UpdateProjectCollaborator {
1709                        project_id: project.id.to_proto(),
1710                        old_peer_id: Some(project.old_connection_id.into()),
1711                        new_peer_id: Some(session.connection_id.into()),
1712                    },
1713                )
1714                .trace_err();
1715        }
1716    }
1717
1718    for project in rejoined_projects {
1719        for worktree in mem::take(&mut project.worktrees) {
1720            #[cfg(any(test, feature = "test-support"))]
1721            const MAX_CHUNK_SIZE: usize = 2;
1722            #[cfg(not(any(test, feature = "test-support")))]
1723            const MAX_CHUNK_SIZE: usize = 256;
1724
1725            // Stream this worktree's entries.
1726            let message = proto::UpdateWorktree {
1727                project_id: project.id.to_proto(),
1728                worktree_id: worktree.id,
1729                abs_path: worktree.abs_path.clone(),
1730                root_name: worktree.root_name,
1731                updated_entries: worktree.updated_entries,
1732                removed_entries: worktree.removed_entries,
1733                scan_id: worktree.scan_id,
1734                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1735                updated_repositories: worktree.updated_repositories,
1736                removed_repositories: worktree.removed_repositories,
1737            };
1738            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1739                session.peer.send(session.connection_id, update.clone())?;
1740            }
1741
1742            // Stream this worktree's diagnostics.
1743            for summary in worktree.diagnostic_summaries {
1744                session.peer.send(
1745                    session.connection_id,
1746                    proto::UpdateDiagnosticSummary {
1747                        project_id: project.id.to_proto(),
1748                        worktree_id: worktree.id,
1749                        summary: Some(summary),
1750                    },
1751                )?;
1752            }
1753
1754            for settings_file in worktree.settings_files {
1755                session.peer.send(
1756                    session.connection_id,
1757                    proto::UpdateWorktreeSettings {
1758                        project_id: project.id.to_proto(),
1759                        worktree_id: worktree.id,
1760                        path: settings_file.path,
1761                        content: Some(settings_file.content),
1762                    },
1763                )?;
1764            }
1765        }
1766
1767        for language_server in &project.language_servers {
1768            session.peer.send(
1769                session.connection_id,
1770                proto::UpdateLanguageServer {
1771                    project_id: project.id.to_proto(),
1772                    language_server_id: language_server.id,
1773                    variant: Some(
1774                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1775                            proto::LspDiskBasedDiagnosticsUpdated {},
1776                        ),
1777                    ),
1778                },
1779            )?;
1780        }
1781    }
1782    Ok(())
1783}
1784
1785/// leave room disconnects from the room.
1786async fn leave_room(
1787    _: proto::LeaveRoom,
1788    response: Response<proto::LeaveRoom>,
1789    session: UserSession,
1790) -> Result<()> {
1791    leave_room_for_session(&session, session.connection_id).await?;
1792    response.send(proto::Ack {})?;
1793    Ok(())
1794}
1795
1796/// Updates the permissions of someone else in the room.
1797async fn set_room_participant_role(
1798    request: proto::SetRoomParticipantRole,
1799    response: Response<proto::SetRoomParticipantRole>,
1800    session: UserSession,
1801) -> Result<()> {
1802    let user_id = UserId::from_proto(request.user_id);
1803    let role = ChannelRole::from(request.role());
1804
1805    let (live_kit_room, can_publish) = {
1806        let room = session
1807            .db()
1808            .await
1809            .set_room_participant_role(
1810                session.user_id(),
1811                RoomId::from_proto(request.room_id),
1812                user_id,
1813                role,
1814            )
1815            .await?;
1816
1817        let live_kit_room = room.live_kit_room.clone();
1818        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1819        room_updated(&room, &session.peer);
1820        (live_kit_room, can_publish)
1821    };
1822
1823    if let Some(live_kit) = session.live_kit_client.as_ref() {
1824        live_kit
1825            .update_participant(
1826                live_kit_room.clone(),
1827                request.user_id.to_string(),
1828                live_kit_server::proto::ParticipantPermission {
1829                    can_subscribe: true,
1830                    can_publish,
1831                    can_publish_data: can_publish,
1832                    hidden: false,
1833                    recorder: false,
1834                },
1835            )
1836            .await
1837            .trace_err();
1838    }
1839
1840    response.send(proto::Ack {})?;
1841    Ok(())
1842}
1843
1844/// Call someone else into the current room
1845async fn call(
1846    request: proto::Call,
1847    response: Response<proto::Call>,
1848    session: UserSession,
1849) -> Result<()> {
1850    let room_id = RoomId::from_proto(request.room_id);
1851    let calling_user_id = session.user_id();
1852    let calling_connection_id = session.connection_id;
1853    let called_user_id = UserId::from_proto(request.called_user_id);
1854    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1855    if !session
1856        .db()
1857        .await
1858        .has_contact(calling_user_id, called_user_id)
1859        .await?
1860    {
1861        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1862    }
1863
1864    let incoming_call = {
1865        let (room, incoming_call) = &mut *session
1866            .db()
1867            .await
1868            .call(
1869                room_id,
1870                calling_user_id,
1871                calling_connection_id,
1872                called_user_id,
1873                initial_project_id,
1874            )
1875            .await?;
1876        room_updated(&room, &session.peer);
1877        mem::take(incoming_call)
1878    };
1879    update_user_contacts(called_user_id, &session).await?;
1880
1881    let mut calls = session
1882        .connection_pool()
1883        .await
1884        .user_connection_ids(called_user_id)
1885        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1886        .collect::<FuturesUnordered<_>>();
1887
1888    while let Some(call_response) = calls.next().await {
1889        match call_response.as_ref() {
1890            Ok(_) => {
1891                response.send(proto::Ack {})?;
1892                return Ok(());
1893            }
1894            Err(_) => {
1895                call_response.trace_err();
1896            }
1897        }
1898    }
1899
1900    {
1901        let room = session
1902            .db()
1903            .await
1904            .call_failed(room_id, called_user_id)
1905            .await?;
1906        room_updated(&room, &session.peer);
1907    }
1908    update_user_contacts(called_user_id, &session).await?;
1909
1910    Err(anyhow!("failed to ring user"))?
1911}
1912
1913/// Cancel an outgoing call.
1914async fn cancel_call(
1915    request: proto::CancelCall,
1916    response: Response<proto::CancelCall>,
1917    session: UserSession,
1918) -> Result<()> {
1919    let called_user_id = UserId::from_proto(request.called_user_id);
1920    let room_id = RoomId::from_proto(request.room_id);
1921    {
1922        let room = session
1923            .db()
1924            .await
1925            .cancel_call(room_id, session.connection_id, called_user_id)
1926            .await?;
1927        room_updated(&room, &session.peer);
1928    }
1929
1930    for connection_id in session
1931        .connection_pool()
1932        .await
1933        .user_connection_ids(called_user_id)
1934    {
1935        session
1936            .peer
1937            .send(
1938                connection_id,
1939                proto::CallCanceled {
1940                    room_id: room_id.to_proto(),
1941                },
1942            )
1943            .trace_err();
1944    }
1945    response.send(proto::Ack {})?;
1946
1947    update_user_contacts(called_user_id, &session).await?;
1948    Ok(())
1949}
1950
1951/// Decline an incoming call.
1952async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1953    let room_id = RoomId::from_proto(message.room_id);
1954    {
1955        let room = session
1956            .db()
1957            .await
1958            .decline_call(Some(room_id), session.user_id())
1959            .await?
1960            .ok_or_else(|| anyhow!("failed to decline call"))?;
1961        room_updated(&room, &session.peer);
1962    }
1963
1964    for connection_id in session
1965        .connection_pool()
1966        .await
1967        .user_connection_ids(session.user_id())
1968    {
1969        session
1970            .peer
1971            .send(
1972                connection_id,
1973                proto::CallCanceled {
1974                    room_id: room_id.to_proto(),
1975                },
1976            )
1977            .trace_err();
1978    }
1979    update_user_contacts(session.user_id(), &session).await?;
1980    Ok(())
1981}
1982
1983/// Updates other participants in the room with your current location.
1984async fn update_participant_location(
1985    request: proto::UpdateParticipantLocation,
1986    response: Response<proto::UpdateParticipantLocation>,
1987    session: UserSession,
1988) -> Result<()> {
1989    let room_id = RoomId::from_proto(request.room_id);
1990    let location = request
1991        .location
1992        .ok_or_else(|| anyhow!("invalid location"))?;
1993
1994    let db = session.db().await;
1995    let room = db
1996        .update_room_participant_location(room_id, session.connection_id, location)
1997        .await?;
1998
1999    room_updated(&room, &session.peer);
2000    response.send(proto::Ack {})?;
2001    Ok(())
2002}
2003
2004/// Share a project into the room.
2005async fn share_project(
2006    request: proto::ShareProject,
2007    response: Response<proto::ShareProject>,
2008    session: UserSession,
2009) -> Result<()> {
2010    let (project_id, room) = &*session
2011        .db()
2012        .await
2013        .share_project(
2014            RoomId::from_proto(request.room_id),
2015            session.connection_id,
2016            &request.worktrees,
2017            request
2018                .dev_server_project_id
2019                .map(|id| DevServerProjectId::from_proto(id)),
2020        )
2021        .await?;
2022    response.send(proto::ShareProjectResponse {
2023        project_id: project_id.to_proto(),
2024    })?;
2025    room_updated(&room, &session.peer);
2026
2027    Ok(())
2028}
2029
2030/// Unshare a project from the room.
2031async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2032    let project_id = ProjectId::from_proto(message.project_id);
2033    unshare_project_internal(
2034        project_id,
2035        session.connection_id,
2036        session.user_id(),
2037        &session,
2038    )
2039    .await
2040}
2041
2042async fn unshare_project_internal(
2043    project_id: ProjectId,
2044    connection_id: ConnectionId,
2045    user_id: Option<UserId>,
2046    session: &Session,
2047) -> Result<()> {
2048    let delete = {
2049        let room_guard = session
2050            .db()
2051            .await
2052            .unshare_project(project_id, connection_id, user_id)
2053            .await?;
2054
2055        let (delete, room, guest_connection_ids) = &*room_guard;
2056
2057        let message = proto::UnshareProject {
2058            project_id: project_id.to_proto(),
2059        };
2060
2061        broadcast(
2062            Some(connection_id),
2063            guest_connection_ids.iter().copied(),
2064            |conn_id| session.peer.send(conn_id, message.clone()),
2065        );
2066        if let Some(room) = room {
2067            room_updated(room, &session.peer);
2068        }
2069
2070        *delete
2071    };
2072
2073    if delete {
2074        let db = session.db().await;
2075        db.delete_project(project_id).await?;
2076    }
2077
2078    Ok(())
2079}
2080
2081/// DevServer makes a project available online
2082async fn share_dev_server_project(
2083    request: proto::ShareDevServerProject,
2084    response: Response<proto::ShareDevServerProject>,
2085    session: DevServerSession,
2086) -> Result<()> {
2087    let (dev_server_project, user_id, status) = session
2088        .db()
2089        .await
2090        .share_dev_server_project(
2091            DevServerProjectId::from_proto(request.dev_server_project_id),
2092            session.dev_server_id(),
2093            session.connection_id,
2094            &request.worktrees,
2095        )
2096        .await?;
2097    let Some(project_id) = dev_server_project.project_id else {
2098        return Err(anyhow!("failed to share remote project"))?;
2099    };
2100
2101    send_dev_server_projects_update(user_id, status, &session).await;
2102
2103    response.send(proto::ShareProjectResponse { project_id })?;
2104
2105    Ok(())
2106}
2107
2108/// Join someone elses shared project.
2109async fn join_project(
2110    request: proto::JoinProject,
2111    response: Response<proto::JoinProject>,
2112    session: UserSession,
2113) -> Result<()> {
2114    let project_id = ProjectId::from_proto(request.project_id);
2115
2116    tracing::info!(%project_id, "join project");
2117
2118    let db = session.db().await;
2119    let (project, replica_id) = &mut *db
2120        .join_project(project_id, session.connection_id, session.user_id())
2121        .await?;
2122    drop(db);
2123    tracing::info!(%project_id, "join remote project");
2124    join_project_internal(response, session, project, replica_id)
2125}
2126
2127trait JoinProjectInternalResponse {
2128    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2129}
2130impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2131    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2132        Response::<proto::JoinProject>::send(self, result)
2133    }
2134}
2135impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2136    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2137        Response::<proto::JoinHostedProject>::send(self, result)
2138    }
2139}
2140
2141fn join_project_internal(
2142    response: impl JoinProjectInternalResponse,
2143    session: UserSession,
2144    project: &mut Project,
2145    replica_id: &ReplicaId,
2146) -> Result<()> {
2147    let collaborators = project
2148        .collaborators
2149        .iter()
2150        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2151        .map(|collaborator| collaborator.to_proto())
2152        .collect::<Vec<_>>();
2153    let project_id = project.id;
2154    let guest_user_id = session.user_id();
2155
2156    let worktrees = project
2157        .worktrees
2158        .iter()
2159        .map(|(id, worktree)| proto::WorktreeMetadata {
2160            id: *id,
2161            root_name: worktree.root_name.clone(),
2162            visible: worktree.visible,
2163            abs_path: worktree.abs_path.clone(),
2164        })
2165        .collect::<Vec<_>>();
2166
2167    let add_project_collaborator = proto::AddProjectCollaborator {
2168        project_id: project_id.to_proto(),
2169        collaborator: Some(proto::Collaborator {
2170            peer_id: Some(session.connection_id.into()),
2171            replica_id: replica_id.0 as u32,
2172            user_id: guest_user_id.to_proto(),
2173        }),
2174    };
2175
2176    for collaborator in &collaborators {
2177        session
2178            .peer
2179            .send(
2180                collaborator.peer_id.unwrap().into(),
2181                add_project_collaborator.clone(),
2182            )
2183            .trace_err();
2184    }
2185
2186    // First, we send the metadata associated with each worktree.
2187    response.send(proto::JoinProjectResponse {
2188        project_id: project.id.0 as u64,
2189        worktrees: worktrees.clone(),
2190        replica_id: replica_id.0 as u32,
2191        collaborators: collaborators.clone(),
2192        language_servers: project.language_servers.clone(),
2193        role: project.role.into(),
2194        dev_server_project_id: project
2195            .dev_server_project_id
2196            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2197    })?;
2198
2199    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2200        #[cfg(any(test, feature = "test-support"))]
2201        const MAX_CHUNK_SIZE: usize = 2;
2202        #[cfg(not(any(test, feature = "test-support")))]
2203        const MAX_CHUNK_SIZE: usize = 256;
2204
2205        // Stream this worktree's entries.
2206        let message = proto::UpdateWorktree {
2207            project_id: project_id.to_proto(),
2208            worktree_id,
2209            abs_path: worktree.abs_path.clone(),
2210            root_name: worktree.root_name,
2211            updated_entries: worktree.entries,
2212            removed_entries: Default::default(),
2213            scan_id: worktree.scan_id,
2214            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2215            updated_repositories: worktree.repository_entries.into_values().collect(),
2216            removed_repositories: Default::default(),
2217        };
2218        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2219            session.peer.send(session.connection_id, update.clone())?;
2220        }
2221
2222        // Stream this worktree's diagnostics.
2223        for summary in worktree.diagnostic_summaries {
2224            session.peer.send(
2225                session.connection_id,
2226                proto::UpdateDiagnosticSummary {
2227                    project_id: project_id.to_proto(),
2228                    worktree_id: worktree.id,
2229                    summary: Some(summary),
2230                },
2231            )?;
2232        }
2233
2234        for settings_file in worktree.settings_files {
2235            session.peer.send(
2236                session.connection_id,
2237                proto::UpdateWorktreeSettings {
2238                    project_id: project_id.to_proto(),
2239                    worktree_id: worktree.id,
2240                    path: settings_file.path,
2241                    content: Some(settings_file.content),
2242                },
2243            )?;
2244        }
2245    }
2246
2247    for language_server in &project.language_servers {
2248        session.peer.send(
2249            session.connection_id,
2250            proto::UpdateLanguageServer {
2251                project_id: project_id.to_proto(),
2252                language_server_id: language_server.id,
2253                variant: Some(
2254                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2255                        proto::LspDiskBasedDiagnosticsUpdated {},
2256                    ),
2257                ),
2258            },
2259        )?;
2260    }
2261
2262    Ok(())
2263}
2264
2265/// Leave someone elses shared project.
2266async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2267    let sender_id = session.connection_id;
2268    let project_id = ProjectId::from_proto(request.project_id);
2269    let db = session.db().await;
2270    if db.is_hosted_project(project_id).await? {
2271        let project = db.leave_hosted_project(project_id, sender_id).await?;
2272        project_left(&project, &session);
2273        return Ok(());
2274    }
2275
2276    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2277    tracing::info!(
2278        %project_id,
2279        "leave project"
2280    );
2281
2282    project_left(&project, &session);
2283    if let Some(room) = room {
2284        room_updated(&room, &session.peer);
2285    }
2286
2287    Ok(())
2288}
2289
2290async fn join_hosted_project(
2291    request: proto::JoinHostedProject,
2292    response: Response<proto::JoinHostedProject>,
2293    session: UserSession,
2294) -> Result<()> {
2295    let (mut project, replica_id) = session
2296        .db()
2297        .await
2298        .join_hosted_project(
2299            ProjectId(request.project_id as i32),
2300            session.user_id(),
2301            session.connection_id,
2302        )
2303        .await?;
2304
2305    join_project_internal(response, session, &mut project, &replica_id)
2306}
2307
2308async fn create_dev_server_project(
2309    request: proto::CreateDevServerProject,
2310    response: Response<proto::CreateDevServerProject>,
2311    session: UserSession,
2312) -> Result<()> {
2313    let dev_server_id = DevServerId(request.dev_server_id as i32);
2314    let dev_server_connection_id = session
2315        .connection_pool()
2316        .await
2317        .dev_server_connection_id(dev_server_id);
2318    let Some(dev_server_connection_id) = dev_server_connection_id else {
2319        Err(ErrorCode::DevServerOffline
2320            .message("Cannot create a remote project when the dev server is offline".to_string())
2321            .anyhow())?
2322    };
2323
2324    let path = request.path.clone();
2325    //Check that the path exists on the dev server
2326    session
2327        .peer
2328        .forward_request(
2329            session.connection_id,
2330            dev_server_connection_id,
2331            proto::ValidateDevServerProjectRequest { path: path.clone() },
2332        )
2333        .await?;
2334
2335    let (dev_server_project, update) = session
2336        .db()
2337        .await
2338        .create_dev_server_project(
2339            DevServerId(request.dev_server_id as i32),
2340            &request.path,
2341            session.user_id(),
2342        )
2343        .await?;
2344
2345    let projects = session
2346        .db()
2347        .await
2348        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2349        .await?;
2350
2351    session.peer.send(
2352        dev_server_connection_id,
2353        proto::DevServerInstructions { projects },
2354    )?;
2355
2356    send_dev_server_projects_update(session.user_id(), update, &session).await;
2357
2358    response.send(proto::CreateDevServerProjectResponse {
2359        dev_server_project: Some(dev_server_project.to_proto(None)),
2360    })?;
2361    Ok(())
2362}
2363
2364async fn create_dev_server(
2365    request: proto::CreateDevServer,
2366    response: Response<proto::CreateDevServer>,
2367    session: UserSession,
2368) -> Result<()> {
2369    let access_token = auth::random_token();
2370    let hashed_access_token = auth::hash_access_token(&access_token);
2371
2372    if request.name.is_empty() {
2373        return Err(proto::ErrorCode::Forbidden
2374            .message("Dev server name cannot be empty".to_string())
2375            .anyhow())?;
2376    }
2377
2378    let (dev_server, status) = session
2379        .db()
2380        .await
2381        .create_dev_server(
2382            &request.name,
2383            request.ssh_connection_string.as_deref(),
2384            &hashed_access_token,
2385            session.user_id(),
2386        )
2387        .await?;
2388
2389    send_dev_server_projects_update(session.user_id(), status, &session).await;
2390
2391    response.send(proto::CreateDevServerResponse {
2392        dev_server_id: dev_server.id.0 as u64,
2393        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2394        name: request.name,
2395    })?;
2396    Ok(())
2397}
2398
2399async fn regenerate_dev_server_token(
2400    request: proto::RegenerateDevServerToken,
2401    response: Response<proto::RegenerateDevServerToken>,
2402    session: UserSession,
2403) -> Result<()> {
2404    let dev_server_id = DevServerId(request.dev_server_id as i32);
2405    let access_token = auth::random_token();
2406    let hashed_access_token = auth::hash_access_token(&access_token);
2407
2408    let connection_id = session
2409        .connection_pool()
2410        .await
2411        .dev_server_connection_id(dev_server_id);
2412    if let Some(connection_id) = connection_id {
2413        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2414        session.peer.send(
2415            connection_id,
2416            proto::ShutdownDevServer {
2417                reason: Some("dev server token was regenerated".to_string()),
2418            },
2419        )?;
2420        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2421    }
2422
2423    let status = session
2424        .db()
2425        .await
2426        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2427        .await?;
2428
2429    send_dev_server_projects_update(session.user_id(), status, &session).await;
2430
2431    response.send(proto::RegenerateDevServerTokenResponse {
2432        dev_server_id: dev_server_id.to_proto(),
2433        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2434    })?;
2435    Ok(())
2436}
2437
2438async fn rename_dev_server(
2439    request: proto::RenameDevServer,
2440    response: Response<proto::RenameDevServer>,
2441    session: UserSession,
2442) -> Result<()> {
2443    if request.name.trim().is_empty() {
2444        return Err(proto::ErrorCode::Forbidden
2445            .message("Dev server name cannot be empty".to_string())
2446            .anyhow())?;
2447    }
2448
2449    let dev_server_id = DevServerId(request.dev_server_id as i32);
2450    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2451    if dev_server.user_id != session.user_id() {
2452        return Err(anyhow!(ErrorCode::Forbidden))?;
2453    }
2454
2455    let status = session
2456        .db()
2457        .await
2458        .rename_dev_server(
2459            dev_server_id,
2460            &request.name,
2461            request.ssh_connection_string.as_deref(),
2462            session.user_id(),
2463        )
2464        .await?;
2465
2466    send_dev_server_projects_update(session.user_id(), status, &session).await;
2467
2468    response.send(proto::Ack {})?;
2469    Ok(())
2470}
2471
2472async fn delete_dev_server(
2473    request: proto::DeleteDevServer,
2474    response: Response<proto::DeleteDevServer>,
2475    session: UserSession,
2476) -> Result<()> {
2477    let dev_server_id = DevServerId(request.dev_server_id as i32);
2478    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2479    if dev_server.user_id != session.user_id() {
2480        return Err(anyhow!(ErrorCode::Forbidden))?;
2481    }
2482
2483    let connection_id = session
2484        .connection_pool()
2485        .await
2486        .dev_server_connection_id(dev_server_id);
2487    if let Some(connection_id) = connection_id {
2488        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2489        session.peer.send(
2490            connection_id,
2491            proto::ShutdownDevServer {
2492                reason: Some("dev server was deleted".to_string()),
2493            },
2494        )?;
2495        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2496    }
2497
2498    let status = session
2499        .db()
2500        .await
2501        .delete_dev_server(dev_server_id, session.user_id())
2502        .await?;
2503
2504    send_dev_server_projects_update(session.user_id(), status, &session).await;
2505
2506    response.send(proto::Ack {})?;
2507    Ok(())
2508}
2509
2510async fn delete_dev_server_project(
2511    request: proto::DeleteDevServerProject,
2512    response: Response<proto::DeleteDevServerProject>,
2513    session: UserSession,
2514) -> Result<()> {
2515    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2516    let dev_server_project = session
2517        .db()
2518        .await
2519        .get_dev_server_project(dev_server_project_id)
2520        .await?;
2521
2522    let dev_server = session
2523        .db()
2524        .await
2525        .get_dev_server(dev_server_project.dev_server_id)
2526        .await?;
2527    if dev_server.user_id != session.user_id() {
2528        return Err(anyhow!(ErrorCode::Forbidden))?;
2529    }
2530
2531    let dev_server_connection_id = session
2532        .connection_pool()
2533        .await
2534        .dev_server_connection_id(dev_server.id);
2535
2536    if let Some(dev_server_connection_id) = dev_server_connection_id {
2537        let project = session
2538            .db()
2539            .await
2540            .find_dev_server_project(dev_server_project_id)
2541            .await;
2542        if let Ok(project) = project {
2543            unshare_project_internal(
2544                project.id,
2545                dev_server_connection_id,
2546                Some(session.user_id()),
2547                &session,
2548            )
2549            .await?;
2550        }
2551    }
2552
2553    let (projects, status) = session
2554        .db()
2555        .await
2556        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2557        .await?;
2558
2559    if let Some(dev_server_connection_id) = dev_server_connection_id {
2560        session.peer.send(
2561            dev_server_connection_id,
2562            proto::DevServerInstructions { projects },
2563        )?;
2564    }
2565
2566    send_dev_server_projects_update(session.user_id(), status, &session).await;
2567
2568    response.send(proto::Ack {})?;
2569    Ok(())
2570}
2571
2572async fn rejoin_dev_server_projects(
2573    request: proto::RejoinRemoteProjects,
2574    response: Response<proto::RejoinRemoteProjects>,
2575    session: UserSession,
2576) -> Result<()> {
2577    let mut rejoined_projects = {
2578        let db = session.db().await;
2579        db.rejoin_dev_server_projects(
2580            &request.rejoined_projects,
2581            session.user_id(),
2582            session.0.connection_id,
2583        )
2584        .await?
2585    };
2586    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2587
2588    response.send(proto::RejoinRemoteProjectsResponse {
2589        rejoined_projects: rejoined_projects
2590            .into_iter()
2591            .map(|project| project.to_proto())
2592            .collect(),
2593    })
2594}
2595
2596async fn reconnect_dev_server(
2597    request: proto::ReconnectDevServer,
2598    response: Response<proto::ReconnectDevServer>,
2599    session: DevServerSession,
2600) -> Result<()> {
2601    let reshared_projects = {
2602        let db = session.db().await;
2603        db.reshare_dev_server_projects(
2604            &request.reshared_projects,
2605            session.dev_server_id(),
2606            session.0.connection_id,
2607        )
2608        .await?
2609    };
2610
2611    for project in &reshared_projects {
2612        for collaborator in &project.collaborators {
2613            session
2614                .peer
2615                .send(
2616                    collaborator.connection_id,
2617                    proto::UpdateProjectCollaborator {
2618                        project_id: project.id.to_proto(),
2619                        old_peer_id: Some(project.old_connection_id.into()),
2620                        new_peer_id: Some(session.connection_id.into()),
2621                    },
2622                )
2623                .trace_err();
2624        }
2625
2626        broadcast(
2627            Some(session.connection_id),
2628            project
2629                .collaborators
2630                .iter()
2631                .map(|collaborator| collaborator.connection_id),
2632            |connection_id| {
2633                session.peer.forward_send(
2634                    session.connection_id,
2635                    connection_id,
2636                    proto::UpdateProject {
2637                        project_id: project.id.to_proto(),
2638                        worktrees: project.worktrees.clone(),
2639                    },
2640                )
2641            },
2642        );
2643    }
2644
2645    response.send(proto::ReconnectDevServerResponse {
2646        reshared_projects: reshared_projects
2647            .iter()
2648            .map(|project| proto::ResharedProject {
2649                id: project.id.to_proto(),
2650                collaborators: project
2651                    .collaborators
2652                    .iter()
2653                    .map(|collaborator| collaborator.to_proto())
2654                    .collect(),
2655            })
2656            .collect(),
2657    })?;
2658
2659    Ok(())
2660}
2661
2662async fn shutdown_dev_server(
2663    _: proto::ShutdownDevServer,
2664    response: Response<proto::ShutdownDevServer>,
2665    session: DevServerSession,
2666) -> Result<()> {
2667    response.send(proto::Ack {})?;
2668    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2669    remove_dev_server_connection(session.dev_server_id(), &session).await
2670}
2671
2672async fn shutdown_dev_server_internal(
2673    dev_server_id: DevServerId,
2674    connection_id: ConnectionId,
2675    session: &Session,
2676) -> Result<()> {
2677    let (dev_server_projects, dev_server) = {
2678        let db = session.db().await;
2679        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2680        let dev_server = db.get_dev_server(dev_server_id).await?;
2681        (dev_server_projects, dev_server)
2682    };
2683
2684    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2685        unshare_project_internal(
2686            ProjectId::from_proto(project_id),
2687            connection_id,
2688            None,
2689            session,
2690        )
2691        .await?;
2692    }
2693
2694    session
2695        .connection_pool()
2696        .await
2697        .set_dev_server_offline(dev_server_id);
2698
2699    let status = session
2700        .db()
2701        .await
2702        .dev_server_projects_update(dev_server.user_id)
2703        .await?;
2704    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2705
2706    Ok(())
2707}
2708
2709async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2710    let dev_server_connection = session
2711        .connection_pool()
2712        .await
2713        .dev_server_connection_id(dev_server_id);
2714
2715    if let Some(dev_server_connection) = dev_server_connection {
2716        session
2717            .connection_pool()
2718            .await
2719            .remove_connection(dev_server_connection)?;
2720    }
2721    Ok(())
2722}
2723
2724/// Updates other participants with changes to the project
2725async fn update_project(
2726    request: proto::UpdateProject,
2727    response: Response<proto::UpdateProject>,
2728    session: Session,
2729) -> Result<()> {
2730    let project_id = ProjectId::from_proto(request.project_id);
2731    let (room, guest_connection_ids) = &*session
2732        .db()
2733        .await
2734        .update_project(project_id, session.connection_id, &request.worktrees)
2735        .await?;
2736    broadcast(
2737        Some(session.connection_id),
2738        guest_connection_ids.iter().copied(),
2739        |connection_id| {
2740            session
2741                .peer
2742                .forward_send(session.connection_id, connection_id, request.clone())
2743        },
2744    );
2745    if let Some(room) = room {
2746        room_updated(&room, &session.peer);
2747    }
2748    response.send(proto::Ack {})?;
2749
2750    Ok(())
2751}
2752
2753/// Updates other participants with changes to the worktree
2754async fn update_worktree(
2755    request: proto::UpdateWorktree,
2756    response: Response<proto::UpdateWorktree>,
2757    session: Session,
2758) -> Result<()> {
2759    let guest_connection_ids = session
2760        .db()
2761        .await
2762        .update_worktree(&request, session.connection_id)
2763        .await?;
2764
2765    broadcast(
2766        Some(session.connection_id),
2767        guest_connection_ids.iter().copied(),
2768        |connection_id| {
2769            session
2770                .peer
2771                .forward_send(session.connection_id, connection_id, request.clone())
2772        },
2773    );
2774    response.send(proto::Ack {})?;
2775    Ok(())
2776}
2777
2778/// Updates other participants with changes to the diagnostics
2779async fn update_diagnostic_summary(
2780    message: proto::UpdateDiagnosticSummary,
2781    session: Session,
2782) -> Result<()> {
2783    let guest_connection_ids = session
2784        .db()
2785        .await
2786        .update_diagnostic_summary(&message, session.connection_id)
2787        .await?;
2788
2789    broadcast(
2790        Some(session.connection_id),
2791        guest_connection_ids.iter().copied(),
2792        |connection_id| {
2793            session
2794                .peer
2795                .forward_send(session.connection_id, connection_id, message.clone())
2796        },
2797    );
2798
2799    Ok(())
2800}
2801
2802/// Updates other participants with changes to the worktree settings
2803async fn update_worktree_settings(
2804    message: proto::UpdateWorktreeSettings,
2805    session: Session,
2806) -> Result<()> {
2807    let guest_connection_ids = session
2808        .db()
2809        .await
2810        .update_worktree_settings(&message, session.connection_id)
2811        .await?;
2812
2813    broadcast(
2814        Some(session.connection_id),
2815        guest_connection_ids.iter().copied(),
2816        |connection_id| {
2817            session
2818                .peer
2819                .forward_send(session.connection_id, connection_id, message.clone())
2820        },
2821    );
2822
2823    Ok(())
2824}
2825
2826/// Notify other participants that a  language server has started.
2827async fn start_language_server(
2828    request: proto::StartLanguageServer,
2829    session: Session,
2830) -> Result<()> {
2831    let guest_connection_ids = session
2832        .db()
2833        .await
2834        .start_language_server(&request, session.connection_id)
2835        .await?;
2836
2837    broadcast(
2838        Some(session.connection_id),
2839        guest_connection_ids.iter().copied(),
2840        |connection_id| {
2841            session
2842                .peer
2843                .forward_send(session.connection_id, connection_id, request.clone())
2844        },
2845    );
2846    Ok(())
2847}
2848
2849/// Notify other participants that a language server has changed.
2850async fn update_language_server(
2851    request: proto::UpdateLanguageServer,
2852    session: Session,
2853) -> Result<()> {
2854    let project_id = ProjectId::from_proto(request.project_id);
2855    let project_connection_ids = session
2856        .db()
2857        .await
2858        .project_connection_ids(project_id, session.connection_id, true)
2859        .await?;
2860    broadcast(
2861        Some(session.connection_id),
2862        project_connection_ids.iter().copied(),
2863        |connection_id| {
2864            session
2865                .peer
2866                .forward_send(session.connection_id, connection_id, request.clone())
2867        },
2868    );
2869    Ok(())
2870}
2871
2872/// forward a project request to the host. These requests should be read only
2873/// as guests are allowed to send them.
2874async fn forward_read_only_project_request<T>(
2875    request: T,
2876    response: Response<T>,
2877    session: UserSession,
2878) -> Result<()>
2879where
2880    T: EntityMessage + RequestMessage,
2881{
2882    let project_id = ProjectId::from_proto(request.remote_entity_id());
2883    let host_connection_id = session
2884        .db()
2885        .await
2886        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2887        .await?;
2888    let payload = session
2889        .peer
2890        .forward_request(session.connection_id, host_connection_id, request)
2891        .await?;
2892    response.send(payload)?;
2893    Ok(())
2894}
2895
2896/// forward a project request to the dev server. Only allowed
2897/// if it's your dev server.
2898async fn forward_project_request_for_owner<T>(
2899    request: T,
2900    response: Response<T>,
2901    session: UserSession,
2902) -> Result<()>
2903where
2904    T: EntityMessage + RequestMessage,
2905{
2906    let project_id = ProjectId::from_proto(request.remote_entity_id());
2907
2908    let host_connection_id = session
2909        .db()
2910        .await
2911        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2912        .await?;
2913    let payload = session
2914        .peer
2915        .forward_request(session.connection_id, host_connection_id, request)
2916        .await?;
2917    response.send(payload)?;
2918    Ok(())
2919}
2920
2921/// forward a project request to the host. These requests are disallowed
2922/// for guests.
2923async fn forward_mutating_project_request<T>(
2924    request: T,
2925    response: Response<T>,
2926    session: UserSession,
2927) -> Result<()>
2928where
2929    T: EntityMessage + RequestMessage,
2930{
2931    let project_id = ProjectId::from_proto(request.remote_entity_id());
2932
2933    let host_connection_id = session
2934        .db()
2935        .await
2936        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2937        .await?;
2938    let payload = session
2939        .peer
2940        .forward_request(session.connection_id, host_connection_id, request)
2941        .await?;
2942    response.send(payload)?;
2943    Ok(())
2944}
2945
2946/// forward a project request to the host. These requests are disallowed
2947/// for guests.
2948async fn forward_versioned_mutating_project_request<T>(
2949    request: T,
2950    response: Response<T>,
2951    session: UserSession,
2952) -> Result<()>
2953where
2954    T: EntityMessage + RequestMessage + VersionedMessage,
2955{
2956    let project_id = ProjectId::from_proto(request.remote_entity_id());
2957
2958    let host_connection_id = session
2959        .db()
2960        .await
2961        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2962        .await?;
2963    if let Some(host_version) = session
2964        .connection_pool()
2965        .await
2966        .connection(host_connection_id)
2967        .map(|c| c.zed_version)
2968    {
2969        if let Some(min_required_version) = request.required_host_version() {
2970            if min_required_version > host_version {
2971                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2972                    .with_tag("required", &min_required_version.to_string())))?;
2973            }
2974        }
2975    }
2976
2977    let payload = session
2978        .peer
2979        .forward_request(session.connection_id, host_connection_id, request)
2980        .await?;
2981    response.send(payload)?;
2982    Ok(())
2983}
2984
2985/// Notify other participants that a new buffer has been created
2986async fn create_buffer_for_peer(
2987    request: proto::CreateBufferForPeer,
2988    session: Session,
2989) -> Result<()> {
2990    session
2991        .db()
2992        .await
2993        .check_user_is_project_host(
2994            ProjectId::from_proto(request.project_id),
2995            session.connection_id,
2996        )
2997        .await?;
2998    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2999    session
3000        .peer
3001        .forward_send(session.connection_id, peer_id.into(), request)?;
3002    Ok(())
3003}
3004
3005/// Notify other participants that a buffer has been updated. This is
3006/// allowed for guests as long as the update is limited to selections.
3007async fn update_buffer(
3008    request: proto::UpdateBuffer,
3009    response: Response<proto::UpdateBuffer>,
3010    session: Session,
3011) -> Result<()> {
3012    let project_id = ProjectId::from_proto(request.project_id);
3013    let mut capability = Capability::ReadOnly;
3014
3015    for op in request.operations.iter() {
3016        match op.variant {
3017            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3018            Some(_) => capability = Capability::ReadWrite,
3019        }
3020    }
3021
3022    let host = {
3023        let guard = session
3024            .db()
3025            .await
3026            .connections_for_buffer_update(
3027                project_id,
3028                session.principal_id(),
3029                session.connection_id,
3030                capability,
3031            )
3032            .await?;
3033
3034        let (host, guests) = &*guard;
3035
3036        broadcast(
3037            Some(session.connection_id),
3038            guests.clone(),
3039            |connection_id| {
3040                session
3041                    .peer
3042                    .forward_send(session.connection_id, connection_id, request.clone())
3043            },
3044        );
3045
3046        *host
3047    };
3048
3049    if host != session.connection_id {
3050        session
3051            .peer
3052            .forward_request(session.connection_id, host, request.clone())
3053            .await?;
3054    }
3055
3056    response.send(proto::Ack {})?;
3057    Ok(())
3058}
3059
3060/// Notify other participants that a project has been updated.
3061async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3062    request: T,
3063    session: Session,
3064) -> Result<()> {
3065    let project_id = ProjectId::from_proto(request.remote_entity_id());
3066    let project_connection_ids = session
3067        .db()
3068        .await
3069        .project_connection_ids(project_id, session.connection_id, false)
3070        .await?;
3071
3072    broadcast(
3073        Some(session.connection_id),
3074        project_connection_ids.iter().copied(),
3075        |connection_id| {
3076            session
3077                .peer
3078                .forward_send(session.connection_id, connection_id, request.clone())
3079        },
3080    );
3081    Ok(())
3082}
3083
3084/// Start following another user in a call.
3085async fn follow(
3086    request: proto::Follow,
3087    response: Response<proto::Follow>,
3088    session: UserSession,
3089) -> Result<()> {
3090    let room_id = RoomId::from_proto(request.room_id);
3091    let project_id = request.project_id.map(ProjectId::from_proto);
3092    let leader_id = request
3093        .leader_id
3094        .ok_or_else(|| anyhow!("invalid leader id"))?
3095        .into();
3096    let follower_id = session.connection_id;
3097
3098    session
3099        .db()
3100        .await
3101        .check_room_participants(room_id, leader_id, session.connection_id)
3102        .await?;
3103
3104    let response_payload = session
3105        .peer
3106        .forward_request(session.connection_id, leader_id, request)
3107        .await?;
3108    response.send(response_payload)?;
3109
3110    if let Some(project_id) = project_id {
3111        let room = session
3112            .db()
3113            .await
3114            .follow(room_id, project_id, leader_id, follower_id)
3115            .await?;
3116        room_updated(&room, &session.peer);
3117    }
3118
3119    Ok(())
3120}
3121
3122/// Stop following another user in a call.
3123async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3124    let room_id = RoomId::from_proto(request.room_id);
3125    let project_id = request.project_id.map(ProjectId::from_proto);
3126    let leader_id = request
3127        .leader_id
3128        .ok_or_else(|| anyhow!("invalid leader id"))?
3129        .into();
3130    let follower_id = session.connection_id;
3131
3132    session
3133        .db()
3134        .await
3135        .check_room_participants(room_id, leader_id, session.connection_id)
3136        .await?;
3137
3138    session
3139        .peer
3140        .forward_send(session.connection_id, leader_id, request)?;
3141
3142    if let Some(project_id) = project_id {
3143        let room = session
3144            .db()
3145            .await
3146            .unfollow(room_id, project_id, leader_id, follower_id)
3147            .await?;
3148        room_updated(&room, &session.peer);
3149    }
3150
3151    Ok(())
3152}
3153
3154/// Notify everyone following you of your current location.
3155async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3156    let room_id = RoomId::from_proto(request.room_id);
3157    let database = session.db.lock().await;
3158
3159    let connection_ids = if let Some(project_id) = request.project_id {
3160        let project_id = ProjectId::from_proto(project_id);
3161        database
3162            .project_connection_ids(project_id, session.connection_id, true)
3163            .await?
3164    } else {
3165        database
3166            .room_connection_ids(room_id, session.connection_id)
3167            .await?
3168    };
3169
3170    // For now, don't send view update messages back to that view's current leader.
3171    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3172        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3173        _ => None,
3174    });
3175
3176    for connection_id in connection_ids.iter().cloned() {
3177        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3178            session
3179                .peer
3180                .forward_send(session.connection_id, connection_id, request.clone())?;
3181        }
3182    }
3183    Ok(())
3184}
3185
3186/// Get public data about users.
3187async fn get_users(
3188    request: proto::GetUsers,
3189    response: Response<proto::GetUsers>,
3190    session: Session,
3191) -> Result<()> {
3192    let user_ids = request
3193        .user_ids
3194        .into_iter()
3195        .map(UserId::from_proto)
3196        .collect();
3197    let users = session
3198        .db()
3199        .await
3200        .get_users_by_ids(user_ids)
3201        .await?
3202        .into_iter()
3203        .map(|user| proto::User {
3204            id: user.id.to_proto(),
3205            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3206            github_login: user.github_login,
3207        })
3208        .collect();
3209    response.send(proto::UsersResponse { users })?;
3210    Ok(())
3211}
3212
3213/// Search for users (to invite) buy Github login
3214async fn fuzzy_search_users(
3215    request: proto::FuzzySearchUsers,
3216    response: Response<proto::FuzzySearchUsers>,
3217    session: UserSession,
3218) -> Result<()> {
3219    let query = request.query;
3220    let users = match query.len() {
3221        0 => vec![],
3222        1 | 2 => session
3223            .db()
3224            .await
3225            .get_user_by_github_login(&query)
3226            .await?
3227            .into_iter()
3228            .collect(),
3229        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3230    };
3231    let users = users
3232        .into_iter()
3233        .filter(|user| user.id != session.user_id())
3234        .map(|user| proto::User {
3235            id: user.id.to_proto(),
3236            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3237            github_login: user.github_login,
3238        })
3239        .collect();
3240    response.send(proto::UsersResponse { users })?;
3241    Ok(())
3242}
3243
3244/// Send a contact request to another user.
3245async fn request_contact(
3246    request: proto::RequestContact,
3247    response: Response<proto::RequestContact>,
3248    session: UserSession,
3249) -> Result<()> {
3250    let requester_id = session.user_id();
3251    let responder_id = UserId::from_proto(request.responder_id);
3252    if requester_id == responder_id {
3253        return Err(anyhow!("cannot add yourself as a contact"))?;
3254    }
3255
3256    let notifications = session
3257        .db()
3258        .await
3259        .send_contact_request(requester_id, responder_id)
3260        .await?;
3261
3262    // Update outgoing contact requests of requester
3263    let mut update = proto::UpdateContacts::default();
3264    update.outgoing_requests.push(responder_id.to_proto());
3265    for connection_id in session
3266        .connection_pool()
3267        .await
3268        .user_connection_ids(requester_id)
3269    {
3270        session.peer.send(connection_id, update.clone())?;
3271    }
3272
3273    // Update incoming contact requests of responder
3274    let mut update = proto::UpdateContacts::default();
3275    update
3276        .incoming_requests
3277        .push(proto::IncomingContactRequest {
3278            requester_id: requester_id.to_proto(),
3279        });
3280    let connection_pool = session.connection_pool().await;
3281    for connection_id in connection_pool.user_connection_ids(responder_id) {
3282        session.peer.send(connection_id, update.clone())?;
3283    }
3284
3285    send_notifications(&connection_pool, &session.peer, notifications);
3286
3287    response.send(proto::Ack {})?;
3288    Ok(())
3289}
3290
3291/// Accept or decline a contact request
3292async fn respond_to_contact_request(
3293    request: proto::RespondToContactRequest,
3294    response: Response<proto::RespondToContactRequest>,
3295    session: UserSession,
3296) -> Result<()> {
3297    let responder_id = session.user_id();
3298    let requester_id = UserId::from_proto(request.requester_id);
3299    let db = session.db().await;
3300    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3301        db.dismiss_contact_notification(responder_id, requester_id)
3302            .await?;
3303    } else {
3304        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3305
3306        let notifications = db
3307            .respond_to_contact_request(responder_id, requester_id, accept)
3308            .await?;
3309        let requester_busy = db.is_user_busy(requester_id).await?;
3310        let responder_busy = db.is_user_busy(responder_id).await?;
3311
3312        let pool = session.connection_pool().await;
3313        // Update responder with new contact
3314        let mut update = proto::UpdateContacts::default();
3315        if accept {
3316            update
3317                .contacts
3318                .push(contact_for_user(requester_id, requester_busy, &pool));
3319        }
3320        update
3321            .remove_incoming_requests
3322            .push(requester_id.to_proto());
3323        for connection_id in pool.user_connection_ids(responder_id) {
3324            session.peer.send(connection_id, update.clone())?;
3325        }
3326
3327        // Update requester with new contact
3328        let mut update = proto::UpdateContacts::default();
3329        if accept {
3330            update
3331                .contacts
3332                .push(contact_for_user(responder_id, responder_busy, &pool));
3333        }
3334        update
3335            .remove_outgoing_requests
3336            .push(responder_id.to_proto());
3337
3338        for connection_id in pool.user_connection_ids(requester_id) {
3339            session.peer.send(connection_id, update.clone())?;
3340        }
3341
3342        send_notifications(&pool, &session.peer, notifications);
3343    }
3344
3345    response.send(proto::Ack {})?;
3346    Ok(())
3347}
3348
3349/// Remove a contact.
3350async fn remove_contact(
3351    request: proto::RemoveContact,
3352    response: Response<proto::RemoveContact>,
3353    session: UserSession,
3354) -> Result<()> {
3355    let requester_id = session.user_id();
3356    let responder_id = UserId::from_proto(request.user_id);
3357    let db = session.db().await;
3358    let (contact_accepted, deleted_notification_id) =
3359        db.remove_contact(requester_id, responder_id).await?;
3360
3361    let pool = session.connection_pool().await;
3362    // Update outgoing contact requests of requester
3363    let mut update = proto::UpdateContacts::default();
3364    if contact_accepted {
3365        update.remove_contacts.push(responder_id.to_proto());
3366    } else {
3367        update
3368            .remove_outgoing_requests
3369            .push(responder_id.to_proto());
3370    }
3371    for connection_id in pool.user_connection_ids(requester_id) {
3372        session.peer.send(connection_id, update.clone())?;
3373    }
3374
3375    // Update incoming contact requests of responder
3376    let mut update = proto::UpdateContacts::default();
3377    if contact_accepted {
3378        update.remove_contacts.push(requester_id.to_proto());
3379    } else {
3380        update
3381            .remove_incoming_requests
3382            .push(requester_id.to_proto());
3383    }
3384    for connection_id in pool.user_connection_ids(responder_id) {
3385        session.peer.send(connection_id, update.clone())?;
3386        if let Some(notification_id) = deleted_notification_id {
3387            session.peer.send(
3388                connection_id,
3389                proto::DeleteNotification {
3390                    notification_id: notification_id.to_proto(),
3391                },
3392            )?;
3393        }
3394    }
3395
3396    response.send(proto::Ack {})?;
3397    Ok(())
3398}
3399
3400fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3401    version.0.minor() < 139
3402}
3403
3404async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3405    subscribe_user_to_channels(
3406        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3407        &session,
3408    )
3409    .await?;
3410    Ok(())
3411}
3412
3413async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3414    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3415    let mut pool = session.connection_pool().await;
3416    for membership in &channels_for_user.channel_memberships {
3417        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3418    }
3419    session.peer.send(
3420        session.connection_id,
3421        build_update_user_channels(&channels_for_user),
3422    )?;
3423    session.peer.send(
3424        session.connection_id,
3425        build_channels_update(channels_for_user),
3426    )?;
3427    Ok(())
3428}
3429
3430/// Creates a new channel.
3431async fn create_channel(
3432    request: proto::CreateChannel,
3433    response: Response<proto::CreateChannel>,
3434    session: UserSession,
3435) -> Result<()> {
3436    let db = session.db().await;
3437
3438    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3439    let (channel, membership) = db
3440        .create_channel(&request.name, parent_id, session.user_id())
3441        .await?;
3442
3443    let root_id = channel.root_id();
3444    let channel = Channel::from_model(channel);
3445
3446    response.send(proto::CreateChannelResponse {
3447        channel: Some(channel.to_proto()),
3448        parent_id: request.parent_id,
3449    })?;
3450
3451    let mut connection_pool = session.connection_pool().await;
3452    if let Some(membership) = membership {
3453        connection_pool.subscribe_to_channel(
3454            membership.user_id,
3455            membership.channel_id,
3456            membership.role,
3457        );
3458        let update = proto::UpdateUserChannels {
3459            channel_memberships: vec![proto::ChannelMembership {
3460                channel_id: membership.channel_id.to_proto(),
3461                role: membership.role.into(),
3462            }],
3463            ..Default::default()
3464        };
3465        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3466            session.peer.send(connection_id, update.clone())?;
3467        }
3468    }
3469
3470    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3471        if !role.can_see_channel(channel.visibility) {
3472            continue;
3473        }
3474
3475        let update = proto::UpdateChannels {
3476            channels: vec![channel.to_proto()],
3477            ..Default::default()
3478        };
3479        session.peer.send(connection_id, update.clone())?;
3480    }
3481
3482    Ok(())
3483}
3484
3485/// Delete a channel
3486async fn delete_channel(
3487    request: proto::DeleteChannel,
3488    response: Response<proto::DeleteChannel>,
3489    session: UserSession,
3490) -> Result<()> {
3491    let db = session.db().await;
3492
3493    let channel_id = request.channel_id;
3494    let (root_channel, removed_channels) = db
3495        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3496        .await?;
3497    response.send(proto::Ack {})?;
3498
3499    // Notify members of removed channels
3500    let mut update = proto::UpdateChannels::default();
3501    update
3502        .delete_channels
3503        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3504
3505    let connection_pool = session.connection_pool().await;
3506    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3507        session.peer.send(connection_id, update.clone())?;
3508    }
3509
3510    Ok(())
3511}
3512
3513/// Invite someone to join a channel.
3514async fn invite_channel_member(
3515    request: proto::InviteChannelMember,
3516    response: Response<proto::InviteChannelMember>,
3517    session: UserSession,
3518) -> Result<()> {
3519    let db = session.db().await;
3520    let channel_id = ChannelId::from_proto(request.channel_id);
3521    let invitee_id = UserId::from_proto(request.user_id);
3522    let InviteMemberResult {
3523        channel,
3524        notifications,
3525    } = db
3526        .invite_channel_member(
3527            channel_id,
3528            invitee_id,
3529            session.user_id(),
3530            request.role().into(),
3531        )
3532        .await?;
3533
3534    let update = proto::UpdateChannels {
3535        channel_invitations: vec![channel.to_proto()],
3536        ..Default::default()
3537    };
3538
3539    let connection_pool = session.connection_pool().await;
3540    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3541        session.peer.send(connection_id, update.clone())?;
3542    }
3543
3544    send_notifications(&connection_pool, &session.peer, notifications);
3545
3546    response.send(proto::Ack {})?;
3547    Ok(())
3548}
3549
3550/// remove someone from a channel
3551async fn remove_channel_member(
3552    request: proto::RemoveChannelMember,
3553    response: Response<proto::RemoveChannelMember>,
3554    session: UserSession,
3555) -> Result<()> {
3556    let db = session.db().await;
3557    let channel_id = ChannelId::from_proto(request.channel_id);
3558    let member_id = UserId::from_proto(request.user_id);
3559
3560    let RemoveChannelMemberResult {
3561        membership_update,
3562        notification_id,
3563    } = db
3564        .remove_channel_member(channel_id, member_id, session.user_id())
3565        .await?;
3566
3567    let mut connection_pool = session.connection_pool().await;
3568    notify_membership_updated(
3569        &mut connection_pool,
3570        membership_update,
3571        member_id,
3572        &session.peer,
3573    );
3574    for connection_id in connection_pool.user_connection_ids(member_id) {
3575        if let Some(notification_id) = notification_id {
3576            session
3577                .peer
3578                .send(
3579                    connection_id,
3580                    proto::DeleteNotification {
3581                        notification_id: notification_id.to_proto(),
3582                    },
3583                )
3584                .trace_err();
3585        }
3586    }
3587
3588    response.send(proto::Ack {})?;
3589    Ok(())
3590}
3591
3592/// Toggle the channel between public and private.
3593/// Care is taken to maintain the invariant that public channels only descend from public channels,
3594/// (though members-only channels can appear at any point in the hierarchy).
3595async fn set_channel_visibility(
3596    request: proto::SetChannelVisibility,
3597    response: Response<proto::SetChannelVisibility>,
3598    session: UserSession,
3599) -> Result<()> {
3600    let db = session.db().await;
3601    let channel_id = ChannelId::from_proto(request.channel_id);
3602    let visibility = request.visibility().into();
3603
3604    let channel_model = db
3605        .set_channel_visibility(channel_id, visibility, session.user_id())
3606        .await?;
3607    let root_id = channel_model.root_id();
3608    let channel = Channel::from_model(channel_model);
3609
3610    let mut connection_pool = session.connection_pool().await;
3611    for (user_id, role) in connection_pool
3612        .channel_user_ids(root_id)
3613        .collect::<Vec<_>>()
3614        .into_iter()
3615    {
3616        let update = if role.can_see_channel(channel.visibility) {
3617            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3618            proto::UpdateChannels {
3619                channels: vec![channel.to_proto()],
3620                ..Default::default()
3621            }
3622        } else {
3623            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3624            proto::UpdateChannels {
3625                delete_channels: vec![channel.id.to_proto()],
3626                ..Default::default()
3627            }
3628        };
3629
3630        for connection_id in connection_pool.user_connection_ids(user_id) {
3631            session.peer.send(connection_id, update.clone())?;
3632        }
3633    }
3634
3635    response.send(proto::Ack {})?;
3636    Ok(())
3637}
3638
3639/// Alter the role for a user in the channel.
3640async fn set_channel_member_role(
3641    request: proto::SetChannelMemberRole,
3642    response: Response<proto::SetChannelMemberRole>,
3643    session: UserSession,
3644) -> Result<()> {
3645    let db = session.db().await;
3646    let channel_id = ChannelId::from_proto(request.channel_id);
3647    let member_id = UserId::from_proto(request.user_id);
3648    let result = db
3649        .set_channel_member_role(
3650            channel_id,
3651            session.user_id(),
3652            member_id,
3653            request.role().into(),
3654        )
3655        .await?;
3656
3657    match result {
3658        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3659            let mut connection_pool = session.connection_pool().await;
3660            notify_membership_updated(
3661                &mut connection_pool,
3662                membership_update,
3663                member_id,
3664                &session.peer,
3665            )
3666        }
3667        db::SetMemberRoleResult::InviteUpdated(channel) => {
3668            let update = proto::UpdateChannels {
3669                channel_invitations: vec![channel.to_proto()],
3670                ..Default::default()
3671            };
3672
3673            for connection_id in session
3674                .connection_pool()
3675                .await
3676                .user_connection_ids(member_id)
3677            {
3678                session.peer.send(connection_id, update.clone())?;
3679            }
3680        }
3681    }
3682
3683    response.send(proto::Ack {})?;
3684    Ok(())
3685}
3686
3687/// Change the name of a channel
3688async fn rename_channel(
3689    request: proto::RenameChannel,
3690    response: Response<proto::RenameChannel>,
3691    session: UserSession,
3692) -> Result<()> {
3693    let db = session.db().await;
3694    let channel_id = ChannelId::from_proto(request.channel_id);
3695    let channel_model = db
3696        .rename_channel(channel_id, session.user_id(), &request.name)
3697        .await?;
3698    let root_id = channel_model.root_id();
3699    let channel = Channel::from_model(channel_model);
3700
3701    response.send(proto::RenameChannelResponse {
3702        channel: Some(channel.to_proto()),
3703    })?;
3704
3705    let connection_pool = session.connection_pool().await;
3706    let update = proto::UpdateChannels {
3707        channels: vec![channel.to_proto()],
3708        ..Default::default()
3709    };
3710    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3711        if role.can_see_channel(channel.visibility) {
3712            session.peer.send(connection_id, update.clone())?;
3713        }
3714    }
3715
3716    Ok(())
3717}
3718
3719/// Move a channel to a new parent.
3720async fn move_channel(
3721    request: proto::MoveChannel,
3722    response: Response<proto::MoveChannel>,
3723    session: UserSession,
3724) -> Result<()> {
3725    let channel_id = ChannelId::from_proto(request.channel_id);
3726    let to = ChannelId::from_proto(request.to);
3727
3728    let (root_id, channels) = session
3729        .db()
3730        .await
3731        .move_channel(channel_id, to, session.user_id())
3732        .await?;
3733
3734    let connection_pool = session.connection_pool().await;
3735    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3736        let channels = channels
3737            .iter()
3738            .filter_map(|channel| {
3739                if role.can_see_channel(channel.visibility) {
3740                    Some(channel.to_proto())
3741                } else {
3742                    None
3743                }
3744            })
3745            .collect::<Vec<_>>();
3746        if channels.is_empty() {
3747            continue;
3748        }
3749
3750        let update = proto::UpdateChannels {
3751            channels,
3752            ..Default::default()
3753        };
3754
3755        session.peer.send(connection_id, update.clone())?;
3756    }
3757
3758    response.send(Ack {})?;
3759    Ok(())
3760}
3761
3762/// Get the list of channel members
3763async fn get_channel_members(
3764    request: proto::GetChannelMembers,
3765    response: Response<proto::GetChannelMembers>,
3766    session: UserSession,
3767) -> Result<()> {
3768    let db = session.db().await;
3769    let channel_id = ChannelId::from_proto(request.channel_id);
3770    let limit = if request.limit == 0 {
3771        u16::MAX as u64
3772    } else {
3773        request.limit
3774    };
3775    let (members, users) = db
3776        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3777        .await?;
3778    response.send(proto::GetChannelMembersResponse { members, users })?;
3779    Ok(())
3780}
3781
3782/// Accept or decline a channel invitation.
3783async fn respond_to_channel_invite(
3784    request: proto::RespondToChannelInvite,
3785    response: Response<proto::RespondToChannelInvite>,
3786    session: UserSession,
3787) -> Result<()> {
3788    let db = session.db().await;
3789    let channel_id = ChannelId::from_proto(request.channel_id);
3790    let RespondToChannelInvite {
3791        membership_update,
3792        notifications,
3793    } = db
3794        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3795        .await?;
3796
3797    let mut connection_pool = session.connection_pool().await;
3798    if let Some(membership_update) = membership_update {
3799        notify_membership_updated(
3800            &mut connection_pool,
3801            membership_update,
3802            session.user_id(),
3803            &session.peer,
3804        );
3805    } else {
3806        let update = proto::UpdateChannels {
3807            remove_channel_invitations: vec![channel_id.to_proto()],
3808            ..Default::default()
3809        };
3810
3811        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3812            session.peer.send(connection_id, update.clone())?;
3813        }
3814    };
3815
3816    send_notifications(&connection_pool, &session.peer, notifications);
3817
3818    response.send(proto::Ack {})?;
3819
3820    Ok(())
3821}
3822
3823/// Join the channels' room
3824async fn join_channel(
3825    request: proto::JoinChannel,
3826    response: Response<proto::JoinChannel>,
3827    session: UserSession,
3828) -> Result<()> {
3829    let channel_id = ChannelId::from_proto(request.channel_id);
3830    join_channel_internal(channel_id, Box::new(response), session).await
3831}
3832
3833trait JoinChannelInternalResponse {
3834    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3835}
3836impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3837    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3838        Response::<proto::JoinChannel>::send(self, result)
3839    }
3840}
3841impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3842    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3843        Response::<proto::JoinRoom>::send(self, result)
3844    }
3845}
3846
3847async fn join_channel_internal(
3848    channel_id: ChannelId,
3849    response: Box<impl JoinChannelInternalResponse>,
3850    session: UserSession,
3851) -> Result<()> {
3852    let joined_room = {
3853        let mut db = session.db().await;
3854        // If zed quits without leaving the room, and the user re-opens zed before the
3855        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3856        // room they were in.
3857        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3858            tracing::info!(
3859                stale_connection_id = %connection,
3860                "cleaning up stale connection",
3861            );
3862            drop(db);
3863            leave_room_for_session(&session, connection).await?;
3864            db = session.db().await;
3865        }
3866
3867        let (joined_room, membership_updated, role) = db
3868            .join_channel(channel_id, session.user_id(), session.connection_id)
3869            .await?;
3870
3871        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3872            let (can_publish, token) = if role == ChannelRole::Guest {
3873                (
3874                    false,
3875                    live_kit
3876                        .guest_token(
3877                            &joined_room.room.live_kit_room,
3878                            &session.user_id().to_string(),
3879                        )
3880                        .trace_err()?,
3881                )
3882            } else {
3883                (
3884                    true,
3885                    live_kit
3886                        .room_token(
3887                            &joined_room.room.live_kit_room,
3888                            &session.user_id().to_string(),
3889                        )
3890                        .trace_err()?,
3891                )
3892            };
3893
3894            Some(LiveKitConnectionInfo {
3895                server_url: live_kit.url().into(),
3896                token,
3897                can_publish,
3898            })
3899        });
3900
3901        response.send(proto::JoinRoomResponse {
3902            room: Some(joined_room.room.clone()),
3903            channel_id: joined_room
3904                .channel
3905                .as_ref()
3906                .map(|channel| channel.id.to_proto()),
3907            live_kit_connection_info,
3908        })?;
3909
3910        let mut connection_pool = session.connection_pool().await;
3911        if let Some(membership_updated) = membership_updated {
3912            notify_membership_updated(
3913                &mut connection_pool,
3914                membership_updated,
3915                session.user_id(),
3916                &session.peer,
3917            );
3918        }
3919
3920        room_updated(&joined_room.room, &session.peer);
3921
3922        joined_room
3923    };
3924
3925    channel_updated(
3926        &joined_room
3927            .channel
3928            .ok_or_else(|| anyhow!("channel not returned"))?,
3929        &joined_room.room,
3930        &session.peer,
3931        &*session.connection_pool().await,
3932    );
3933
3934    update_user_contacts(session.user_id(), &session).await?;
3935    Ok(())
3936}
3937
3938/// Start editing the channel notes
3939async fn join_channel_buffer(
3940    request: proto::JoinChannelBuffer,
3941    response: Response<proto::JoinChannelBuffer>,
3942    session: UserSession,
3943) -> Result<()> {
3944    let db = session.db().await;
3945    let channel_id = ChannelId::from_proto(request.channel_id);
3946
3947    let open_response = db
3948        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3949        .await?;
3950
3951    let collaborators = open_response.collaborators.clone();
3952    response.send(open_response)?;
3953
3954    let update = UpdateChannelBufferCollaborators {
3955        channel_id: channel_id.to_proto(),
3956        collaborators: collaborators.clone(),
3957    };
3958    channel_buffer_updated(
3959        session.connection_id,
3960        collaborators
3961            .iter()
3962            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3963        &update,
3964        &session.peer,
3965    );
3966
3967    Ok(())
3968}
3969
3970/// Edit the channel notes
3971async fn update_channel_buffer(
3972    request: proto::UpdateChannelBuffer,
3973    session: UserSession,
3974) -> Result<()> {
3975    let db = session.db().await;
3976    let channel_id = ChannelId::from_proto(request.channel_id);
3977
3978    let (collaborators, epoch, version) = db
3979        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3980        .await?;
3981
3982    channel_buffer_updated(
3983        session.connection_id,
3984        collaborators.clone(),
3985        &proto::UpdateChannelBuffer {
3986            channel_id: channel_id.to_proto(),
3987            operations: request.operations,
3988        },
3989        &session.peer,
3990    );
3991
3992    let pool = &*session.connection_pool().await;
3993
3994    let non_collaborators =
3995        pool.channel_connection_ids(channel_id)
3996            .filter_map(|(connection_id, _)| {
3997                if collaborators.contains(&connection_id) {
3998                    None
3999                } else {
4000                    Some(connection_id)
4001                }
4002            });
4003
4004    broadcast(None, non_collaborators, |peer_id| {
4005        session.peer.send(
4006            peer_id,
4007            proto::UpdateChannels {
4008                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4009                    channel_id: channel_id.to_proto(),
4010                    epoch: epoch as u64,
4011                    version: version.clone(),
4012                }],
4013                ..Default::default()
4014            },
4015        )
4016    });
4017
4018    Ok(())
4019}
4020
4021/// Rejoin the channel notes after a connection blip
4022async fn rejoin_channel_buffers(
4023    request: proto::RejoinChannelBuffers,
4024    response: Response<proto::RejoinChannelBuffers>,
4025    session: UserSession,
4026) -> Result<()> {
4027    let db = session.db().await;
4028    let buffers = db
4029        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4030        .await?;
4031
4032    for rejoined_buffer in &buffers {
4033        let collaborators_to_notify = rejoined_buffer
4034            .buffer
4035            .collaborators
4036            .iter()
4037            .filter_map(|c| Some(c.peer_id?.into()));
4038        channel_buffer_updated(
4039            session.connection_id,
4040            collaborators_to_notify,
4041            &proto::UpdateChannelBufferCollaborators {
4042                channel_id: rejoined_buffer.buffer.channel_id,
4043                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4044            },
4045            &session.peer,
4046        );
4047    }
4048
4049    response.send(proto::RejoinChannelBuffersResponse {
4050        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4051    })?;
4052
4053    Ok(())
4054}
4055
4056/// Stop editing the channel notes
4057async fn leave_channel_buffer(
4058    request: proto::LeaveChannelBuffer,
4059    response: Response<proto::LeaveChannelBuffer>,
4060    session: UserSession,
4061) -> Result<()> {
4062    let db = session.db().await;
4063    let channel_id = ChannelId::from_proto(request.channel_id);
4064
4065    let left_buffer = db
4066        .leave_channel_buffer(channel_id, session.connection_id)
4067        .await?;
4068
4069    response.send(Ack {})?;
4070
4071    channel_buffer_updated(
4072        session.connection_id,
4073        left_buffer.connections,
4074        &proto::UpdateChannelBufferCollaborators {
4075            channel_id: channel_id.to_proto(),
4076            collaborators: left_buffer.collaborators,
4077        },
4078        &session.peer,
4079    );
4080
4081    Ok(())
4082}
4083
4084fn channel_buffer_updated<T: EnvelopedMessage>(
4085    sender_id: ConnectionId,
4086    collaborators: impl IntoIterator<Item = ConnectionId>,
4087    message: &T,
4088    peer: &Peer,
4089) {
4090    broadcast(Some(sender_id), collaborators, |peer_id| {
4091        peer.send(peer_id, message.clone())
4092    });
4093}
4094
4095fn send_notifications(
4096    connection_pool: &ConnectionPool,
4097    peer: &Peer,
4098    notifications: db::NotificationBatch,
4099) {
4100    for (user_id, notification) in notifications {
4101        for connection_id in connection_pool.user_connection_ids(user_id) {
4102            if let Err(error) = peer.send(
4103                connection_id,
4104                proto::AddNotification {
4105                    notification: Some(notification.clone()),
4106                },
4107            ) {
4108                tracing::error!(
4109                    "failed to send notification to {:?} {}",
4110                    connection_id,
4111                    error
4112                );
4113            }
4114        }
4115    }
4116}
4117
4118/// Send a message to the channel
4119async fn send_channel_message(
4120    request: proto::SendChannelMessage,
4121    response: Response<proto::SendChannelMessage>,
4122    session: UserSession,
4123) -> Result<()> {
4124    // Validate the message body.
4125    let body = request.body.trim().to_string();
4126    if body.len() > MAX_MESSAGE_LEN {
4127        return Err(anyhow!("message is too long"))?;
4128    }
4129    if body.is_empty() {
4130        return Err(anyhow!("message can't be blank"))?;
4131    }
4132
4133    // TODO: adjust mentions if body is trimmed
4134
4135    let timestamp = OffsetDateTime::now_utc();
4136    let nonce = request
4137        .nonce
4138        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4139
4140    let channel_id = ChannelId::from_proto(request.channel_id);
4141    let CreatedChannelMessage {
4142        message_id,
4143        participant_connection_ids,
4144        notifications,
4145    } = session
4146        .db()
4147        .await
4148        .create_channel_message(
4149            channel_id,
4150            session.user_id(),
4151            &body,
4152            &request.mentions,
4153            timestamp,
4154            nonce.clone().into(),
4155            match request.reply_to_message_id {
4156                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4157                None => None,
4158            },
4159        )
4160        .await?;
4161
4162    let message = proto::ChannelMessage {
4163        sender_id: session.user_id().to_proto(),
4164        id: message_id.to_proto(),
4165        body,
4166        mentions: request.mentions,
4167        timestamp: timestamp.unix_timestamp() as u64,
4168        nonce: Some(nonce),
4169        reply_to_message_id: request.reply_to_message_id,
4170        edited_at: None,
4171    };
4172    broadcast(
4173        Some(session.connection_id),
4174        participant_connection_ids.clone(),
4175        |connection| {
4176            session.peer.send(
4177                connection,
4178                proto::ChannelMessageSent {
4179                    channel_id: channel_id.to_proto(),
4180                    message: Some(message.clone()),
4181                },
4182            )
4183        },
4184    );
4185    response.send(proto::SendChannelMessageResponse {
4186        message: Some(message),
4187    })?;
4188
4189    let pool = &*session.connection_pool().await;
4190    let non_participants =
4191        pool.channel_connection_ids(channel_id)
4192            .filter_map(|(connection_id, _)| {
4193                if participant_connection_ids.contains(&connection_id) {
4194                    None
4195                } else {
4196                    Some(connection_id)
4197                }
4198            });
4199    broadcast(None, non_participants, |peer_id| {
4200        session.peer.send(
4201            peer_id,
4202            proto::UpdateChannels {
4203                latest_channel_message_ids: vec![proto::ChannelMessageId {
4204                    channel_id: channel_id.to_proto(),
4205                    message_id: message_id.to_proto(),
4206                }],
4207                ..Default::default()
4208            },
4209        )
4210    });
4211    send_notifications(pool, &session.peer, notifications);
4212
4213    Ok(())
4214}
4215
4216/// Delete a channel message
4217async fn remove_channel_message(
4218    request: proto::RemoveChannelMessage,
4219    response: Response<proto::RemoveChannelMessage>,
4220    session: UserSession,
4221) -> Result<()> {
4222    let channel_id = ChannelId::from_proto(request.channel_id);
4223    let message_id = MessageId::from_proto(request.message_id);
4224    let (connection_ids, existing_notification_ids) = session
4225        .db()
4226        .await
4227        .remove_channel_message(channel_id, message_id, session.user_id())
4228        .await?;
4229
4230    broadcast(
4231        Some(session.connection_id),
4232        connection_ids,
4233        move |connection| {
4234            session.peer.send(connection, request.clone())?;
4235
4236            for notification_id in &existing_notification_ids {
4237                session.peer.send(
4238                    connection,
4239                    proto::DeleteNotification {
4240                        notification_id: (*notification_id).to_proto(),
4241                    },
4242                )?;
4243            }
4244
4245            Ok(())
4246        },
4247    );
4248    response.send(proto::Ack {})?;
4249    Ok(())
4250}
4251
4252async fn update_channel_message(
4253    request: proto::UpdateChannelMessage,
4254    response: Response<proto::UpdateChannelMessage>,
4255    session: UserSession,
4256) -> Result<()> {
4257    let channel_id = ChannelId::from_proto(request.channel_id);
4258    let message_id = MessageId::from_proto(request.message_id);
4259    let updated_at = OffsetDateTime::now_utc();
4260    let UpdatedChannelMessage {
4261        message_id,
4262        participant_connection_ids,
4263        notifications,
4264        reply_to_message_id,
4265        timestamp,
4266        deleted_mention_notification_ids,
4267        updated_mention_notifications,
4268    } = session
4269        .db()
4270        .await
4271        .update_channel_message(
4272            channel_id,
4273            message_id,
4274            session.user_id(),
4275            request.body.as_str(),
4276            &request.mentions,
4277            updated_at,
4278        )
4279        .await?;
4280
4281    let nonce = request
4282        .nonce
4283        .clone()
4284        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4285
4286    let message = proto::ChannelMessage {
4287        sender_id: session.user_id().to_proto(),
4288        id: message_id.to_proto(),
4289        body: request.body.clone(),
4290        mentions: request.mentions.clone(),
4291        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4292        nonce: Some(nonce),
4293        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4294        edited_at: Some(updated_at.unix_timestamp() as u64),
4295    };
4296
4297    response.send(proto::Ack {})?;
4298
4299    let pool = &*session.connection_pool().await;
4300    broadcast(
4301        Some(session.connection_id),
4302        participant_connection_ids,
4303        |connection| {
4304            session.peer.send(
4305                connection,
4306                proto::ChannelMessageUpdate {
4307                    channel_id: channel_id.to_proto(),
4308                    message: Some(message.clone()),
4309                },
4310            )?;
4311
4312            for notification_id in &deleted_mention_notification_ids {
4313                session.peer.send(
4314                    connection,
4315                    proto::DeleteNotification {
4316                        notification_id: (*notification_id).to_proto(),
4317                    },
4318                )?;
4319            }
4320
4321            for notification in &updated_mention_notifications {
4322                session.peer.send(
4323                    connection,
4324                    proto::UpdateNotification {
4325                        notification: Some(notification.clone()),
4326                    },
4327                )?;
4328            }
4329
4330            Ok(())
4331        },
4332    );
4333
4334    send_notifications(pool, &session.peer, notifications);
4335
4336    Ok(())
4337}
4338
4339/// Mark a channel message as read
4340async fn acknowledge_channel_message(
4341    request: proto::AckChannelMessage,
4342    session: UserSession,
4343) -> Result<()> {
4344    let channel_id = ChannelId::from_proto(request.channel_id);
4345    let message_id = MessageId::from_proto(request.message_id);
4346    let notifications = session
4347        .db()
4348        .await
4349        .observe_channel_message(channel_id, session.user_id(), message_id)
4350        .await?;
4351    send_notifications(
4352        &*session.connection_pool().await,
4353        &session.peer,
4354        notifications,
4355    );
4356    Ok(())
4357}
4358
4359/// Mark a buffer version as synced
4360async fn acknowledge_buffer_version(
4361    request: proto::AckBufferOperation,
4362    session: UserSession,
4363) -> Result<()> {
4364    let buffer_id = BufferId::from_proto(request.buffer_id);
4365    session
4366        .db()
4367        .await
4368        .observe_buffer_version(
4369            buffer_id,
4370            session.user_id(),
4371            request.epoch as i32,
4372            &request.version,
4373        )
4374        .await?;
4375    Ok(())
4376}
4377
4378struct CompleteWithLanguageModelRateLimit;
4379
4380impl RateLimit for CompleteWithLanguageModelRateLimit {
4381    fn capacity() -> usize {
4382        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4383            .ok()
4384            .and_then(|v| v.parse().ok())
4385            .unwrap_or(120) // Picked arbitrarily
4386    }
4387
4388    fn refill_duration() -> chrono::Duration {
4389        chrono::Duration::hours(1)
4390    }
4391
4392    fn db_name() -> &'static str {
4393        "complete-with-language-model"
4394    }
4395}
4396
4397async fn complete_with_language_model(
4398    request: proto::CompleteWithLanguageModel,
4399    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4400    session: Session,
4401    open_ai_api_key: Option<Arc<str>>,
4402    google_ai_api_key: Option<Arc<str>>,
4403    anthropic_api_key: Option<Arc<str>>,
4404) -> Result<()> {
4405    let Some(session) = session.for_user() else {
4406        return Err(anyhow!("user not found"))?;
4407    };
4408    authorize_access_to_language_models(&session).await?;
4409    session
4410        .rate_limiter
4411        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4412        .await?;
4413
4414    if request.model.starts_with("gpt") {
4415        let api_key =
4416            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4417        complete_with_open_ai(request, response, session, api_key).await?;
4418    } else if request.model.starts_with("gemini") {
4419        let api_key = google_ai_api_key
4420            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4421        complete_with_google_ai(request, response, session, api_key).await?;
4422    } else if request.model.starts_with("claude") {
4423        let api_key = anthropic_api_key
4424            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4425        complete_with_anthropic(request, response, session, api_key).await?;
4426    }
4427
4428    Ok(())
4429}
4430
4431async fn complete_with_open_ai(
4432    request: proto::CompleteWithLanguageModel,
4433    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4434    session: UserSession,
4435    api_key: Arc<str>,
4436) -> Result<()> {
4437    let mut completion_stream = open_ai::stream_completion(
4438        session.http_client.as_ref(),
4439        OPEN_AI_API_URL,
4440        &api_key,
4441        crate::ai::language_model_request_to_open_ai(request)?,
4442        None,
4443    )
4444    .await
4445    .context("open_ai::stream_completion request failed within collab")?;
4446
4447    while let Some(event) = completion_stream.next().await {
4448        let event = event?;
4449        response.send(proto::LanguageModelResponse {
4450            choices: event
4451                .choices
4452                .into_iter()
4453                .map(|choice| proto::LanguageModelChoiceDelta {
4454                    index: choice.index,
4455                    delta: Some(proto::LanguageModelResponseMessage {
4456                        role: choice.delta.role.map(|role| match role {
4457                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4458                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4459                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4460                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4461                        } as i32),
4462                        content: choice.delta.content,
4463                        tool_calls: choice
4464                            .delta
4465                            .tool_calls
4466                            .into_iter()
4467                            .map(|delta| proto::ToolCallDelta {
4468                                index: delta.index as u32,
4469                                id: delta.id,
4470                                variant: match delta.function {
4471                                    Some(function) => {
4472                                        let name = function.name;
4473                                        let arguments = function.arguments;
4474
4475                                        Some(proto::tool_call_delta::Variant::Function(
4476                                            proto::tool_call_delta::FunctionCallDelta {
4477                                                name,
4478                                                arguments,
4479                                            },
4480                                        ))
4481                                    }
4482                                    None => None,
4483                                },
4484                            })
4485                            .collect(),
4486                    }),
4487                    finish_reason: choice.finish_reason,
4488                })
4489                .collect(),
4490        })?;
4491    }
4492
4493    Ok(())
4494}
4495
4496async fn complete_with_google_ai(
4497    request: proto::CompleteWithLanguageModel,
4498    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4499    session: UserSession,
4500    api_key: Arc<str>,
4501) -> Result<()> {
4502    let mut stream = google_ai::stream_generate_content(
4503        session.http_client.clone(),
4504        google_ai::API_URL,
4505        api_key.as_ref(),
4506        crate::ai::language_model_request_to_google_ai(request)?,
4507    )
4508    .await
4509    .context("google_ai::stream_generate_content request failed")?;
4510
4511    while let Some(event) = stream.next().await {
4512        let event = event?;
4513        response.send(proto::LanguageModelResponse {
4514            choices: event
4515                .candidates
4516                .unwrap_or_default()
4517                .into_iter()
4518                .map(|candidate| proto::LanguageModelChoiceDelta {
4519                    index: candidate.index as u32,
4520                    delta: Some(proto::LanguageModelResponseMessage {
4521                        role: Some(match candidate.content.role {
4522                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4523                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4524                        } as i32),
4525                        content: Some(
4526                            candidate
4527                                .content
4528                                .parts
4529                                .into_iter()
4530                                .filter_map(|part| match part {
4531                                    google_ai::Part::TextPart(part) => Some(part.text),
4532                                    google_ai::Part::InlineDataPart(_) => None,
4533                                })
4534                                .collect(),
4535                        ),
4536                        // Tool calls are not supported for Google
4537                        tool_calls: Vec::new(),
4538                    }),
4539                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4540                })
4541                .collect(),
4542        })?;
4543    }
4544
4545    Ok(())
4546}
4547
4548async fn complete_with_anthropic(
4549    request: proto::CompleteWithLanguageModel,
4550    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4551    session: UserSession,
4552    api_key: Arc<str>,
4553) -> Result<()> {
4554    let model = anthropic::Model::from_id(&request.model)?;
4555
4556    let mut system_message = String::new();
4557    let messages = request
4558        .messages
4559        .into_iter()
4560        .filter_map(|message| {
4561            match message.role() {
4562                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4563                    role: anthropic::Role::User,
4564                    content: message.content,
4565                }),
4566                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4567                    role: anthropic::Role::Assistant,
4568                    content: message.content,
4569                }),
4570                // Anthropic's API breaks system instructions out as a separate field rather
4571                // than having a system message role.
4572                LanguageModelRole::LanguageModelSystem => {
4573                    if !system_message.is_empty() {
4574                        system_message.push_str("\n\n");
4575                    }
4576                    system_message.push_str(&message.content);
4577
4578                    None
4579                }
4580                // We don't yet support tool calls for Anthropic
4581                LanguageModelRole::LanguageModelTool => None,
4582            }
4583        })
4584        .collect();
4585
4586    let mut stream = anthropic::stream_completion(
4587        session.http_client.as_ref(),
4588        anthropic::ANTHROPIC_API_URL,
4589        &api_key,
4590        anthropic::Request {
4591            model,
4592            messages,
4593            stream: true,
4594            system: system_message,
4595            max_tokens: 4092,
4596        },
4597        None,
4598    )
4599    .await?;
4600
4601    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4602
4603    while let Some(event) = stream.next().await {
4604        let event = event?;
4605
4606        match event {
4607            anthropic::ResponseEvent::MessageStart { message } => {
4608                if let Some(role) = message.role {
4609                    if role == "assistant" {
4610                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4611                    } else if role == "user" {
4612                        current_role = proto::LanguageModelRole::LanguageModelUser;
4613                    }
4614                }
4615            }
4616            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4617                match content_block {
4618                    anthropic::ContentBlock::Text { text } => {
4619                        if !text.is_empty() {
4620                            response.send(proto::LanguageModelResponse {
4621                                choices: vec![proto::LanguageModelChoiceDelta {
4622                                    index: 0,
4623                                    delta: Some(proto::LanguageModelResponseMessage {
4624                                        role: Some(current_role as i32),
4625                                        content: Some(text),
4626                                        tool_calls: Vec::new(),
4627                                    }),
4628                                    finish_reason: None,
4629                                }],
4630                            })?;
4631                        }
4632                    }
4633                }
4634            }
4635            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4636                anthropic::TextDelta::TextDelta { text } => {
4637                    response.send(proto::LanguageModelResponse {
4638                        choices: vec![proto::LanguageModelChoiceDelta {
4639                            index: 0,
4640                            delta: Some(proto::LanguageModelResponseMessage {
4641                                role: Some(current_role as i32),
4642                                content: Some(text),
4643                                tool_calls: Vec::new(),
4644                            }),
4645                            finish_reason: None,
4646                        }],
4647                    })?;
4648                }
4649            },
4650            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4651                if let Some(stop_reason) = delta.stop_reason {
4652                    response.send(proto::LanguageModelResponse {
4653                        choices: vec![proto::LanguageModelChoiceDelta {
4654                            index: 0,
4655                            delta: None,
4656                            finish_reason: Some(stop_reason),
4657                        }],
4658                    })?;
4659                }
4660            }
4661            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4662            anthropic::ResponseEvent::MessageStop {} => {}
4663            anthropic::ResponseEvent::Ping {} => {}
4664        }
4665    }
4666
4667    Ok(())
4668}
4669
4670struct CountTokensWithLanguageModelRateLimit;
4671
4672impl RateLimit for CountTokensWithLanguageModelRateLimit {
4673    fn capacity() -> usize {
4674        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4675            .ok()
4676            .and_then(|v| v.parse().ok())
4677            .unwrap_or(600) // Picked arbitrarily
4678    }
4679
4680    fn refill_duration() -> chrono::Duration {
4681        chrono::Duration::hours(1)
4682    }
4683
4684    fn db_name() -> &'static str {
4685        "count-tokens-with-language-model"
4686    }
4687}
4688
4689async fn count_tokens_with_language_model(
4690    request: proto::CountTokensWithLanguageModel,
4691    response: Response<proto::CountTokensWithLanguageModel>,
4692    session: UserSession,
4693    google_ai_api_key: Option<Arc<str>>,
4694) -> Result<()> {
4695    authorize_access_to_language_models(&session).await?;
4696
4697    if !request.model.starts_with("gemini") {
4698        return Err(anyhow!(
4699            "counting tokens for model: {:?} is not supported",
4700            request.model
4701        ))?;
4702    }
4703
4704    session
4705        .rate_limiter
4706        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4707        .await?;
4708
4709    let api_key = google_ai_api_key
4710        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4711    let tokens_response = google_ai::count_tokens(
4712        session.http_client.as_ref(),
4713        google_ai::API_URL,
4714        &api_key,
4715        crate::ai::count_tokens_request_to_google_ai(request)?,
4716    )
4717    .await?;
4718    response.send(proto::CountTokensResponse {
4719        token_count: tokens_response.total_tokens as u32,
4720    })?;
4721    Ok(())
4722}
4723
4724struct ComputeEmbeddingsRateLimit;
4725
4726impl RateLimit for ComputeEmbeddingsRateLimit {
4727    fn capacity() -> usize {
4728        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4729            .ok()
4730            .and_then(|v| v.parse().ok())
4731            .unwrap_or(5000) // Picked arbitrarily
4732    }
4733
4734    fn refill_duration() -> chrono::Duration {
4735        chrono::Duration::hours(1)
4736    }
4737
4738    fn db_name() -> &'static str {
4739        "compute-embeddings"
4740    }
4741}
4742
4743async fn compute_embeddings(
4744    request: proto::ComputeEmbeddings,
4745    response: Response<proto::ComputeEmbeddings>,
4746    session: UserSession,
4747    api_key: Option<Arc<str>>,
4748) -> Result<()> {
4749    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4750    authorize_access_to_language_models(&session).await?;
4751
4752    session
4753        .rate_limiter
4754        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4755        .await?;
4756
4757    let embeddings = match request.model.as_str() {
4758        "openai/text-embedding-3-small" => {
4759            open_ai::embed(
4760                session.http_client.as_ref(),
4761                OPEN_AI_API_URL,
4762                &api_key,
4763                OpenAiEmbeddingModel::TextEmbedding3Small,
4764                request.texts.iter().map(|text| text.as_str()),
4765            )
4766            .await?
4767        }
4768        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4769    };
4770
4771    let embeddings = request
4772        .texts
4773        .iter()
4774        .map(|text| {
4775            let mut hasher = sha2::Sha256::new();
4776            hasher.update(text.as_bytes());
4777            let result = hasher.finalize();
4778            result.to_vec()
4779        })
4780        .zip(
4781            embeddings
4782                .data
4783                .into_iter()
4784                .map(|embedding| embedding.embedding),
4785        )
4786        .collect::<HashMap<_, _>>();
4787
4788    let db = session.db().await;
4789    db.save_embeddings(&request.model, &embeddings)
4790        .await
4791        .context("failed to save embeddings")
4792        .trace_err();
4793
4794    response.send(proto::ComputeEmbeddingsResponse {
4795        embeddings: embeddings
4796            .into_iter()
4797            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4798            .collect(),
4799    })?;
4800    Ok(())
4801}
4802
4803async fn get_cached_embeddings(
4804    request: proto::GetCachedEmbeddings,
4805    response: Response<proto::GetCachedEmbeddings>,
4806    session: UserSession,
4807) -> Result<()> {
4808    authorize_access_to_language_models(&session).await?;
4809
4810    let db = session.db().await;
4811    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4812
4813    response.send(proto::GetCachedEmbeddingsResponse {
4814        embeddings: embeddings
4815            .into_iter()
4816            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4817            .collect(),
4818    })?;
4819    Ok(())
4820}
4821
4822async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4823    let db = session.db().await;
4824    let flags = db.get_user_flags(session.user_id()).await?;
4825    if flags.iter().any(|flag| flag == "language-models") {
4826        Ok(())
4827    } else {
4828        Err(anyhow!("permission denied"))?
4829    }
4830}
4831
4832/// Get a Supermaven API key for the user
4833async fn get_supermaven_api_key(
4834    _request: proto::GetSupermavenApiKey,
4835    response: Response<proto::GetSupermavenApiKey>,
4836    session: UserSession,
4837) -> Result<()> {
4838    let user_id: String = session.user_id().to_string();
4839    if !session.is_staff() {
4840        return Err(anyhow!("supermaven not enabled for this account"))?;
4841    }
4842
4843    let email = session
4844        .email()
4845        .ok_or_else(|| anyhow!("user must have an email"))?;
4846
4847    let supermaven_admin_api = session
4848        .supermaven_client
4849        .as_ref()
4850        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4851
4852    let result = supermaven_admin_api
4853        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4854        .await?;
4855
4856    response.send(proto::GetSupermavenApiKeyResponse {
4857        api_key: result.api_key,
4858    })?;
4859
4860    Ok(())
4861}
4862
4863/// Start receiving chat updates for a channel
4864async fn join_channel_chat(
4865    request: proto::JoinChannelChat,
4866    response: Response<proto::JoinChannelChat>,
4867    session: UserSession,
4868) -> Result<()> {
4869    let channel_id = ChannelId::from_proto(request.channel_id);
4870
4871    let db = session.db().await;
4872    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4873        .await?;
4874    let messages = db
4875        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4876        .await?;
4877    response.send(proto::JoinChannelChatResponse {
4878        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4879        messages,
4880    })?;
4881    Ok(())
4882}
4883
4884/// Stop receiving chat updates for a channel
4885async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4886    let channel_id = ChannelId::from_proto(request.channel_id);
4887    session
4888        .db()
4889        .await
4890        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4891        .await?;
4892    Ok(())
4893}
4894
4895/// Retrieve the chat history for a channel
4896async fn get_channel_messages(
4897    request: proto::GetChannelMessages,
4898    response: Response<proto::GetChannelMessages>,
4899    session: UserSession,
4900) -> Result<()> {
4901    let channel_id = ChannelId::from_proto(request.channel_id);
4902    let messages = session
4903        .db()
4904        .await
4905        .get_channel_messages(
4906            channel_id,
4907            session.user_id(),
4908            MESSAGE_COUNT_PER_PAGE,
4909            Some(MessageId::from_proto(request.before_message_id)),
4910        )
4911        .await?;
4912    response.send(proto::GetChannelMessagesResponse {
4913        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4914        messages,
4915    })?;
4916    Ok(())
4917}
4918
4919/// Retrieve specific chat messages
4920async fn get_channel_messages_by_id(
4921    request: proto::GetChannelMessagesById,
4922    response: Response<proto::GetChannelMessagesById>,
4923    session: UserSession,
4924) -> Result<()> {
4925    let message_ids = request
4926        .message_ids
4927        .iter()
4928        .map(|id| MessageId::from_proto(*id))
4929        .collect::<Vec<_>>();
4930    let messages = session
4931        .db()
4932        .await
4933        .get_channel_messages_by_id(session.user_id(), &message_ids)
4934        .await?;
4935    response.send(proto::GetChannelMessagesResponse {
4936        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4937        messages,
4938    })?;
4939    Ok(())
4940}
4941
4942/// Retrieve the current users notifications
4943async fn get_notifications(
4944    request: proto::GetNotifications,
4945    response: Response<proto::GetNotifications>,
4946    session: UserSession,
4947) -> Result<()> {
4948    let notifications = session
4949        .db()
4950        .await
4951        .get_notifications(
4952            session.user_id(),
4953            NOTIFICATION_COUNT_PER_PAGE,
4954            request
4955                .before_id
4956                .map(|id| db::NotificationId::from_proto(id)),
4957        )
4958        .await?;
4959    response.send(proto::GetNotificationsResponse {
4960        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4961        notifications,
4962    })?;
4963    Ok(())
4964}
4965
4966/// Mark notifications as read
4967async fn mark_notification_as_read(
4968    request: proto::MarkNotificationRead,
4969    response: Response<proto::MarkNotificationRead>,
4970    session: UserSession,
4971) -> Result<()> {
4972    let database = &session.db().await;
4973    let notifications = database
4974        .mark_notification_as_read_by_id(
4975            session.user_id(),
4976            NotificationId::from_proto(request.notification_id),
4977        )
4978        .await?;
4979    send_notifications(
4980        &*session.connection_pool().await,
4981        &session.peer,
4982        notifications,
4983    );
4984    response.send(proto::Ack {})?;
4985    Ok(())
4986}
4987
4988/// Get the current users information
4989async fn get_private_user_info(
4990    _request: proto::GetPrivateUserInfo,
4991    response: Response<proto::GetPrivateUserInfo>,
4992    session: UserSession,
4993) -> Result<()> {
4994    let db = session.db().await;
4995
4996    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4997    let user = db
4998        .get_user_by_id(session.user_id())
4999        .await?
5000        .ok_or_else(|| anyhow!("user not found"))?;
5001    let flags = db.get_user_flags(session.user_id()).await?;
5002
5003    response.send(proto::GetPrivateUserInfoResponse {
5004        metrics_id,
5005        staff: user.admin,
5006        flags,
5007    })?;
5008    Ok(())
5009}
5010
5011fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
5012    match message {
5013        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5014        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5015        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5016        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5017        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5018            code: frame.code.into(),
5019            reason: frame.reason,
5020        })),
5021    }
5022}
5023
5024fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5025    match message {
5026        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5027        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5028        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5029        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5030        AxumMessage::Close(frame) => {
5031            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5032                code: frame.code.into(),
5033                reason: frame.reason,
5034            }))
5035        }
5036    }
5037}
5038
5039fn notify_membership_updated(
5040    connection_pool: &mut ConnectionPool,
5041    result: MembershipUpdated,
5042    user_id: UserId,
5043    peer: &Peer,
5044) {
5045    for membership in &result.new_channels.channel_memberships {
5046        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5047    }
5048    for channel_id in &result.removed_channels {
5049        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5050    }
5051
5052    let user_channels_update = proto::UpdateUserChannels {
5053        channel_memberships: result
5054            .new_channels
5055            .channel_memberships
5056            .iter()
5057            .map(|cm| proto::ChannelMembership {
5058                channel_id: cm.channel_id.to_proto(),
5059                role: cm.role.into(),
5060            })
5061            .collect(),
5062        ..Default::default()
5063    };
5064
5065    let mut update = build_channels_update(result.new_channels);
5066    update.delete_channels = result
5067        .removed_channels
5068        .into_iter()
5069        .map(|id| id.to_proto())
5070        .collect();
5071    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5072
5073    for connection_id in connection_pool.user_connection_ids(user_id) {
5074        peer.send(connection_id, user_channels_update.clone())
5075            .trace_err();
5076        peer.send(connection_id, update.clone()).trace_err();
5077    }
5078}
5079
5080fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5081    proto::UpdateUserChannels {
5082        channel_memberships: channels
5083            .channel_memberships
5084            .iter()
5085            .map(|m| proto::ChannelMembership {
5086                channel_id: m.channel_id.to_proto(),
5087                role: m.role.into(),
5088            })
5089            .collect(),
5090        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5091        observed_channel_message_id: channels.observed_channel_messages.clone(),
5092    }
5093}
5094
5095fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5096    let mut update = proto::UpdateChannels::default();
5097
5098    for channel in channels.channels {
5099        update.channels.push(channel.to_proto());
5100    }
5101
5102    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5103    update.latest_channel_message_ids = channels.latest_channel_messages;
5104
5105    for (channel_id, participants) in channels.channel_participants {
5106        update
5107            .channel_participants
5108            .push(proto::ChannelParticipants {
5109                channel_id: channel_id.to_proto(),
5110                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5111            });
5112    }
5113
5114    for channel in channels.invited_channels {
5115        update.channel_invitations.push(channel.to_proto());
5116    }
5117
5118    update.hosted_projects = channels.hosted_projects;
5119    update
5120}
5121
5122fn build_initial_contacts_update(
5123    contacts: Vec<db::Contact>,
5124    pool: &ConnectionPool,
5125) -> proto::UpdateContacts {
5126    let mut update = proto::UpdateContacts::default();
5127
5128    for contact in contacts {
5129        match contact {
5130            db::Contact::Accepted { user_id, busy } => {
5131                update.contacts.push(contact_for_user(user_id, busy, &pool));
5132            }
5133            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5134            db::Contact::Incoming { user_id } => {
5135                update
5136                    .incoming_requests
5137                    .push(proto::IncomingContactRequest {
5138                        requester_id: user_id.to_proto(),
5139                    })
5140            }
5141        }
5142    }
5143
5144    update
5145}
5146
5147fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5148    proto::Contact {
5149        user_id: user_id.to_proto(),
5150        online: pool.is_user_online(user_id),
5151        busy,
5152    }
5153}
5154
5155fn room_updated(room: &proto::Room, peer: &Peer) {
5156    broadcast(
5157        None,
5158        room.participants
5159            .iter()
5160            .filter_map(|participant| Some(participant.peer_id?.into())),
5161        |peer_id| {
5162            peer.send(
5163                peer_id,
5164                proto::RoomUpdated {
5165                    room: Some(room.clone()),
5166                },
5167            )
5168        },
5169    );
5170}
5171
5172fn channel_updated(
5173    channel: &db::channel::Model,
5174    room: &proto::Room,
5175    peer: &Peer,
5176    pool: &ConnectionPool,
5177) {
5178    let participants = room
5179        .participants
5180        .iter()
5181        .map(|p| p.user_id)
5182        .collect::<Vec<_>>();
5183
5184    broadcast(
5185        None,
5186        pool.channel_connection_ids(channel.root_id())
5187            .filter_map(|(channel_id, role)| {
5188                role.can_see_channel(channel.visibility).then(|| channel_id)
5189            }),
5190        |peer_id| {
5191            peer.send(
5192                peer_id,
5193                proto::UpdateChannels {
5194                    channel_participants: vec![proto::ChannelParticipants {
5195                        channel_id: channel.id.to_proto(),
5196                        participant_user_ids: participants.clone(),
5197                    }],
5198                    ..Default::default()
5199                },
5200            )
5201        },
5202    );
5203}
5204
5205async fn send_dev_server_projects_update(
5206    user_id: UserId,
5207    mut status: proto::DevServerProjectsUpdate,
5208    session: &Session,
5209) {
5210    let pool = session.connection_pool().await;
5211    for dev_server in &mut status.dev_servers {
5212        dev_server.status =
5213            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5214    }
5215    let connections = pool.user_connection_ids(user_id);
5216    for connection_id in connections {
5217        session.peer.send(connection_id, status.clone()).trace_err();
5218    }
5219}
5220
5221async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5222    let db = session.db().await;
5223
5224    let contacts = db.get_contacts(user_id).await?;
5225    let busy = db.is_user_busy(user_id).await?;
5226
5227    let pool = session.connection_pool().await;
5228    let updated_contact = contact_for_user(user_id, busy, &pool);
5229    for contact in contacts {
5230        if let db::Contact::Accepted {
5231            user_id: contact_user_id,
5232            ..
5233        } = contact
5234        {
5235            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5236                session
5237                    .peer
5238                    .send(
5239                        contact_conn_id,
5240                        proto::UpdateContacts {
5241                            contacts: vec![updated_contact.clone()],
5242                            remove_contacts: Default::default(),
5243                            incoming_requests: Default::default(),
5244                            remove_incoming_requests: Default::default(),
5245                            outgoing_requests: Default::default(),
5246                            remove_outgoing_requests: Default::default(),
5247                        },
5248                    )
5249                    .trace_err();
5250            }
5251        }
5252    }
5253    Ok(())
5254}
5255
5256async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5257    log::info!("lost dev server connection, unsharing projects");
5258    let project_ids = session
5259        .db()
5260        .await
5261        .get_stale_dev_server_projects(session.connection_id)
5262        .await?;
5263
5264    for project_id in project_ids {
5265        // not unshare re-checks the connection ids match, so we get away with no transaction
5266        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5267    }
5268
5269    let user_id = session.dev_server().user_id;
5270    let update = session
5271        .db()
5272        .await
5273        .dev_server_projects_update(user_id)
5274        .await?;
5275
5276    send_dev_server_projects_update(user_id, update, session).await;
5277
5278    Ok(())
5279}
5280
5281async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5282    let mut contacts_to_update = HashSet::default();
5283
5284    let room_id;
5285    let canceled_calls_to_user_ids;
5286    let live_kit_room;
5287    let delete_live_kit_room;
5288    let room;
5289    let channel;
5290
5291    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5292        contacts_to_update.insert(session.user_id());
5293
5294        for project in left_room.left_projects.values() {
5295            project_left(project, session);
5296        }
5297
5298        room_id = RoomId::from_proto(left_room.room.id);
5299        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5300        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5301        delete_live_kit_room = left_room.deleted;
5302        room = mem::take(&mut left_room.room);
5303        channel = mem::take(&mut left_room.channel);
5304
5305        room_updated(&room, &session.peer);
5306    } else {
5307        return Ok(());
5308    }
5309
5310    if let Some(channel) = channel {
5311        channel_updated(
5312            &channel,
5313            &room,
5314            &session.peer,
5315            &*session.connection_pool().await,
5316        );
5317    }
5318
5319    {
5320        let pool = session.connection_pool().await;
5321        for canceled_user_id in canceled_calls_to_user_ids {
5322            for connection_id in pool.user_connection_ids(canceled_user_id) {
5323                session
5324                    .peer
5325                    .send(
5326                        connection_id,
5327                        proto::CallCanceled {
5328                            room_id: room_id.to_proto(),
5329                        },
5330                    )
5331                    .trace_err();
5332            }
5333            contacts_to_update.insert(canceled_user_id);
5334        }
5335    }
5336
5337    for contact_user_id in contacts_to_update {
5338        update_user_contacts(contact_user_id, &session).await?;
5339    }
5340
5341    if let Some(live_kit) = session.live_kit_client.as_ref() {
5342        live_kit
5343            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5344            .await
5345            .trace_err();
5346
5347        if delete_live_kit_room {
5348            live_kit.delete_room(live_kit_room).await.trace_err();
5349        }
5350    }
5351
5352    Ok(())
5353}
5354
5355async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5356    let left_channel_buffers = session
5357        .db()
5358        .await
5359        .leave_channel_buffers(session.connection_id)
5360        .await?;
5361
5362    for left_buffer in left_channel_buffers {
5363        channel_buffer_updated(
5364            session.connection_id,
5365            left_buffer.connections,
5366            &proto::UpdateChannelBufferCollaborators {
5367                channel_id: left_buffer.channel_id.to_proto(),
5368                collaborators: left_buffer.collaborators,
5369            },
5370            &session.peer,
5371        );
5372    }
5373
5374    Ok(())
5375}
5376
5377fn project_left(project: &db::LeftProject, session: &UserSession) {
5378    for connection_id in &project.connection_ids {
5379        if project.should_unshare {
5380            session
5381                .peer
5382                .send(
5383                    *connection_id,
5384                    proto::UnshareProject {
5385                        project_id: project.id.to_proto(),
5386                    },
5387                )
5388                .trace_err();
5389        } else {
5390            session
5391                .peer
5392                .send(
5393                    *connection_id,
5394                    proto::RemoveProjectCollaborator {
5395                        project_id: project.id.to_proto(),
5396                        peer_id: Some(session.connection_id.into()),
5397                    },
5398                )
5399                .trace_err();
5400        }
5401    }
5402}
5403
5404pub trait ResultExt {
5405    type Ok;
5406
5407    fn trace_err(self) -> Option<Self::Ok>;
5408}
5409
5410impl<T, E> ResultExt for Result<T, E>
5411where
5412    E: std::fmt::Debug,
5413{
5414    type Ok = T;
5415
5416    #[track_caller]
5417    fn trace_err(self) -> Option<T> {
5418        match self {
5419            Ok(value) => Some(value),
5420            Err(error) => {
5421                tracing::error!("{:?}", error);
5422                None
5423            }
5424        }
5425    }
5426}