rpc.rs

   1mod connection_pool;
   2
   3use crate::api::CloudflareIpCountryHeader;
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
  10        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
  11        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  12        ServerId, UpdatedChannelMessage, User, UserId,
  13    },
  14    executor::Executor,
  15    AppState, Config, Error, RateLimit, Result,
  16};
  17use anyhow::{anyhow, bail, Context as _};
  18use async_tungstenite::tungstenite::{
  19    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  20};
  21use axum::{
  22    body::Body,
  23    extract::{
  24        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  25        ConnectInfo, WebSocketUpgrade,
  26    },
  27    headers::{Header, HeaderName},
  28    http::StatusCode,
  29    middleware,
  30    response::IntoResponse,
  31    routing::get,
  32    Extension, Router, TypedHeader,
  33};
  34use chrono::Utc;
  35use collections::{HashMap, HashSet};
  36pub use connection_pool::{ConnectionPool, ZedVersion};
  37use core::fmt::{self, Debug, Formatter};
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use sha2::Digest;
  40use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  41
  42use futures::{
  43    channel::oneshot,
  44    future::{self, BoxFuture},
  45    stream::FuturesUnordered,
  46    FutureExt, SinkExt, StreamExt, TryStreamExt,
  47};
  48use http_client::IsahcHttpClient;
  49use prometheus::{register_int_gauge, IntGauge};
  50use rpc::{
  51    proto::{
  52        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  53        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  54    },
  55    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  56};
  57use semantic_version::SemanticVersion;
  58use serde::{Serialize, Serializer};
  59use std::{
  60    any::TypeId,
  61    future::Future,
  62    marker::PhantomData,
  63    mem,
  64    net::SocketAddr,
  65    ops::{Deref, DerefMut},
  66    rc::Rc,
  67    sync::{
  68        atomic::{AtomicBool, Ordering::SeqCst},
  69        Arc, OnceLock,
  70    },
  71    time::{Duration, Instant},
  72};
  73use time::OffsetDateTime;
  74use tokio::sync::{watch, MutexGuard, Semaphore};
  75use tower::ServiceBuilder;
  76use tracing::{
  77    field::{self},
  78    info_span, instrument, Instrument,
  79};
  80
  81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  82
  83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  85
  86const MESSAGE_COUNT_PER_PAGE: usize = 100;
  87const MAX_MESSAGE_LEN: usize = 1024;
  88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  89
  90type MessageHandler =
  91    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  92
  93struct Response<R> {
  94    peer: Arc<Peer>,
  95    receipt: Receipt<R>,
  96    responded: Arc<AtomicBool>,
  97}
  98
  99impl<R: RequestMessage> Response<R> {
 100    fn send(self, payload: R::Response) -> Result<()> {
 101        self.responded.store(true, SeqCst);
 102        self.peer.respond(self.receipt, payload)?;
 103        Ok(())
 104    }
 105}
 106
 107#[derive(Clone, Debug)]
 108pub enum Principal {
 109    User(User),
 110    Impersonated { user: User, admin: User },
 111    DevServer(dev_server::Model),
 112}
 113
 114impl Principal {
 115    fn update_span(&self, span: &tracing::Span) {
 116        match &self {
 117            Principal::User(user) => {
 118                span.record("user_id", &user.id.0);
 119                span.record("login", &user.github_login);
 120            }
 121            Principal::Impersonated { user, admin } => {
 122                span.record("user_id", &user.id.0);
 123                span.record("login", &user.github_login);
 124                span.record("impersonator", &admin.github_login);
 125            }
 126            Principal::DevServer(dev_server) => {
 127                span.record("dev_server_id", &dev_server.id.0);
 128            }
 129        }
 130    }
 131}
 132
 133#[derive(Clone)]
 134struct Session {
 135    principal: Principal,
 136    connection_id: ConnectionId,
 137    db: Arc<tokio::sync::Mutex<DbHandle>>,
 138    peer: Arc<Peer>,
 139    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 140    app_state: Arc<AppState>,
 141    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 142    http_client: Arc<IsahcHttpClient>,
 143    /// The GeoIP country code for the user.
 144    #[allow(unused)]
 145    geoip_country_code: Option<String>,
 146    _executor: Executor,
 147}
 148
 149impl Session {
 150    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 151        #[cfg(test)]
 152        tokio::task::yield_now().await;
 153        let guard = self.db.lock().await;
 154        #[cfg(test)]
 155        tokio::task::yield_now().await;
 156        guard
 157    }
 158
 159    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 160        #[cfg(test)]
 161        tokio::task::yield_now().await;
 162        let guard = self.connection_pool.lock();
 163        ConnectionPoolGuard {
 164            guard,
 165            _not_send: PhantomData,
 166        }
 167    }
 168
 169    fn for_user(self) -> Option<UserSession> {
 170        UserSession::new(self)
 171    }
 172
 173    fn for_dev_server(self) -> Option<DevServerSession> {
 174        DevServerSession::new(self)
 175    }
 176
 177    fn user_id(&self) -> Option<UserId> {
 178        match &self.principal {
 179            Principal::User(user) => Some(user.id),
 180            Principal::Impersonated { user, .. } => Some(user.id),
 181            Principal::DevServer(_) => None,
 182        }
 183    }
 184
 185    fn is_staff(&self) -> bool {
 186        match &self.principal {
 187            Principal::User(user) => user.admin,
 188            Principal::Impersonated { .. } => true,
 189            Principal::DevServer(_) => false,
 190        }
 191    }
 192
 193    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
 194        if self.is_staff() {
 195            return Ok(proto::Plan::ZedPro);
 196        }
 197
 198        let Some(user_id) = self.user_id() else {
 199            return Ok(proto::Plan::Free);
 200        };
 201
 202        if db.has_active_billing_subscription(user_id).await? {
 203            Ok(proto::Plan::ZedPro)
 204        } else {
 205            Ok(proto::Plan::Free)
 206        }
 207    }
 208
 209    fn dev_server_id(&self) -> Option<DevServerId> {
 210        match &self.principal {
 211            Principal::User(_) | Principal::Impersonated { .. } => None,
 212            Principal::DevServer(dev_server) => Some(dev_server.id),
 213        }
 214    }
 215
 216    fn principal_id(&self) -> PrincipalId {
 217        match &self.principal {
 218            Principal::User(user) => PrincipalId::UserId(user.id),
 219            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 220            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 221        }
 222    }
 223}
 224
 225impl Debug for Session {
 226    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 227        let mut result = f.debug_struct("Session");
 228        match &self.principal {
 229            Principal::User(user) => {
 230                result.field("user", &user.github_login);
 231            }
 232            Principal::Impersonated { user, admin } => {
 233                result.field("user", &user.github_login);
 234                result.field("impersonator", &admin.github_login);
 235            }
 236            Principal::DevServer(dev_server) => {
 237                result.field("dev_server", &dev_server.id);
 238            }
 239        }
 240        result.field("connection_id", &self.connection_id).finish()
 241    }
 242}
 243
 244struct UserSession(Session);
 245
 246impl UserSession {
 247    pub fn new(s: Session) -> Option<Self> {
 248        s.user_id().map(|_| UserSession(s))
 249    }
 250    pub fn user_id(&self) -> UserId {
 251        self.0.user_id().unwrap()
 252    }
 253
 254    pub fn email(&self) -> Option<String> {
 255        match &self.0.principal {
 256            Principal::User(user) => user.email_address.clone(),
 257            Principal::Impersonated { user, .. } => user.email_address.clone(),
 258            Principal::DevServer(..) => None,
 259        }
 260    }
 261}
 262
 263impl Deref for UserSession {
 264    type Target = Session;
 265
 266    fn deref(&self) -> &Self::Target {
 267        &self.0
 268    }
 269}
 270impl DerefMut for UserSession {
 271    fn deref_mut(&mut self) -> &mut Self::Target {
 272        &mut self.0
 273    }
 274}
 275
 276struct DevServerSession(Session);
 277
 278impl DevServerSession {
 279    pub fn new(s: Session) -> Option<Self> {
 280        s.dev_server_id().map(|_| DevServerSession(s))
 281    }
 282    pub fn dev_server_id(&self) -> DevServerId {
 283        self.0.dev_server_id().unwrap()
 284    }
 285
 286    fn dev_server(&self) -> &dev_server::Model {
 287        match &self.0.principal {
 288            Principal::DevServer(dev_server) => dev_server,
 289            _ => unreachable!(),
 290        }
 291    }
 292}
 293
 294impl Deref for DevServerSession {
 295    type Target = Session;
 296
 297    fn deref(&self) -> &Self::Target {
 298        &self.0
 299    }
 300}
 301impl DerefMut for DevServerSession {
 302    fn deref_mut(&mut self) -> &mut Self::Target {
 303        &mut self.0
 304    }
 305}
 306
 307fn user_handler<M: RequestMessage, Fut>(
 308    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 309) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 310where
 311    Fut: Send + Future<Output = Result<()>>,
 312{
 313    let handler = Arc::new(handler);
 314    move |message, response, session| {
 315        let handler = handler.clone();
 316        Box::pin(async move {
 317            if let Some(user_session) = session.for_user() {
 318                Ok(handler(message, response, user_session).await?)
 319            } else {
 320                Err(Error::Internal(anyhow!(
 321                    "must be a user to call {}",
 322                    M::NAME
 323                )))
 324            }
 325        })
 326    }
 327}
 328
 329fn dev_server_handler<M: RequestMessage, Fut>(
 330    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 331) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 332where
 333    Fut: Send + Future<Output = Result<()>>,
 334{
 335    let handler = Arc::new(handler);
 336    move |message, response, session| {
 337        let handler = handler.clone();
 338        Box::pin(async move {
 339            if let Some(dev_server_session) = session.for_dev_server() {
 340                Ok(handler(message, response, dev_server_session).await?)
 341            } else {
 342                Err(Error::Internal(anyhow!(
 343                    "must be a dev server to call {}",
 344                    M::NAME
 345                )))
 346            }
 347        })
 348    }
 349}
 350
 351fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 352    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 353) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 354where
 355    InnertRetFut: Send + Future<Output = Result<()>>,
 356{
 357    let handler = Arc::new(handler);
 358    move |message, session| {
 359        let handler = handler.clone();
 360        Box::pin(async move {
 361            if let Some(user_session) = session.for_user() {
 362                Ok(handler(message, user_session).await?)
 363            } else {
 364                Err(Error::Internal(anyhow!(
 365                    "must be a user to call {}",
 366                    M::NAME
 367                )))
 368            }
 369        })
 370    }
 371}
 372
 373struct DbHandle(Arc<Database>);
 374
 375impl Deref for DbHandle {
 376    type Target = Database;
 377
 378    fn deref(&self) -> &Self::Target {
 379        self.0.as_ref()
 380    }
 381}
 382
 383pub struct Server {
 384    id: parking_lot::Mutex<ServerId>,
 385    peer: Arc<Peer>,
 386    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 387    app_state: Arc<AppState>,
 388    handlers: HashMap<TypeId, MessageHandler>,
 389    teardown: watch::Sender<bool>,
 390}
 391
 392pub(crate) struct ConnectionPoolGuard<'a> {
 393    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 394    _not_send: PhantomData<Rc<()>>,
 395}
 396
 397#[derive(Serialize)]
 398pub struct ServerSnapshot<'a> {
 399    peer: &'a Peer,
 400    #[serde(serialize_with = "serialize_deref")]
 401    connection_pool: ConnectionPoolGuard<'a>,
 402}
 403
 404pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 405where
 406    S: Serializer,
 407    T: Deref<Target = U>,
 408    U: Serialize,
 409{
 410    Serialize::serialize(value.deref(), serializer)
 411}
 412
 413impl Server {
 414    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 415        let mut server = Self {
 416            id: parking_lot::Mutex::new(id),
 417            peer: Peer::new(id.0 as u32),
 418            app_state: app_state.clone(),
 419            connection_pool: Default::default(),
 420            handlers: Default::default(),
 421            teardown: watch::channel(false).0,
 422        };
 423
 424        server
 425            .add_request_handler(ping)
 426            .add_request_handler(user_handler(create_room))
 427            .add_request_handler(user_handler(join_room))
 428            .add_request_handler(user_handler(rejoin_room))
 429            .add_request_handler(user_handler(leave_room))
 430            .add_request_handler(user_handler(set_room_participant_role))
 431            .add_request_handler(user_handler(call))
 432            .add_request_handler(user_handler(cancel_call))
 433            .add_message_handler(user_message_handler(decline_call))
 434            .add_request_handler(user_handler(update_participant_location))
 435            .add_request_handler(user_handler(share_project))
 436            .add_message_handler(unshare_project)
 437            .add_request_handler(user_handler(join_project))
 438            .add_request_handler(user_handler(join_hosted_project))
 439            .add_request_handler(user_handler(rejoin_dev_server_projects))
 440            .add_request_handler(user_handler(create_dev_server_project))
 441            .add_request_handler(user_handler(update_dev_server_project))
 442            .add_request_handler(user_handler(delete_dev_server_project))
 443            .add_request_handler(user_handler(create_dev_server))
 444            .add_request_handler(user_handler(regenerate_dev_server_token))
 445            .add_request_handler(user_handler(rename_dev_server))
 446            .add_request_handler(user_handler(delete_dev_server))
 447            .add_request_handler(user_handler(list_remote_directory))
 448            .add_request_handler(dev_server_handler(share_dev_server_project))
 449            .add_request_handler(dev_server_handler(shutdown_dev_server))
 450            .add_request_handler(dev_server_handler(reconnect_dev_server))
 451            .add_message_handler(user_message_handler(leave_project))
 452            .add_request_handler(update_project)
 453            .add_request_handler(update_worktree)
 454            .add_message_handler(start_language_server)
 455            .add_message_handler(update_language_server)
 456            .add_message_handler(update_diagnostic_summary)
 457            .add_message_handler(update_worktree_settings)
 458            .add_request_handler(user_handler(
 459                forward_project_request_for_owner::<proto::TaskContextForLocation>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_project_request_for_owner::<proto::TaskTemplates>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_read_only_project_request::<proto::GetHover>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_read_only_project_request::<proto::GetDefinition>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_read_only_project_request::<proto::GetTypeDefinition>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_read_only_project_request::<proto::GetReferences>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_read_only_project_request::<proto::SearchProject>,
 478            ))
 479            .add_request_handler(user_handler(forward_find_search_candidates_request))
 480            .add_request_handler(user_handler(
 481                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 482            ))
 483            .add_request_handler(user_handler(
 484                forward_read_only_project_request::<proto::GetProjectSymbols>,
 485            ))
 486            .add_request_handler(user_handler(
 487                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 488            ))
 489            .add_request_handler(user_handler(
 490                forward_read_only_project_request::<proto::OpenBufferById>,
 491            ))
 492            .add_request_handler(user_handler(
 493                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 494            ))
 495            .add_request_handler(user_handler(
 496                forward_read_only_project_request::<proto::InlayHints>,
 497            ))
 498            .add_request_handler(user_handler(
 499                forward_read_only_project_request::<proto::OpenBufferByPath>,
 500            ))
 501            .add_request_handler(user_handler(
 502                forward_mutating_project_request::<proto::GetCompletions>,
 503            ))
 504            .add_request_handler(user_handler(
 505                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 506            ))
 507            .add_request_handler(user_handler(
 508                forward_mutating_project_request::<proto::OpenNewBuffer>,
 509            ))
 510            .add_request_handler(user_handler(
 511                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 512            ))
 513            .add_request_handler(user_handler(
 514                forward_mutating_project_request::<proto::GetCodeActions>,
 515            ))
 516            .add_request_handler(user_handler(
 517                forward_mutating_project_request::<proto::ApplyCodeAction>,
 518            ))
 519            .add_request_handler(user_handler(
 520                forward_mutating_project_request::<proto::PrepareRename>,
 521            ))
 522            .add_request_handler(user_handler(
 523                forward_mutating_project_request::<proto::PerformRename>,
 524            ))
 525            .add_request_handler(user_handler(
 526                forward_mutating_project_request::<proto::ReloadBuffers>,
 527            ))
 528            .add_request_handler(user_handler(
 529                forward_mutating_project_request::<proto::FormatBuffers>,
 530            ))
 531            .add_request_handler(user_handler(
 532                forward_mutating_project_request::<proto::CreateProjectEntry>,
 533            ))
 534            .add_request_handler(user_handler(
 535                forward_mutating_project_request::<proto::RenameProjectEntry>,
 536            ))
 537            .add_request_handler(user_handler(
 538                forward_mutating_project_request::<proto::CopyProjectEntry>,
 539            ))
 540            .add_request_handler(user_handler(
 541                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 542            ))
 543            .add_request_handler(user_handler(
 544                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 545            ))
 546            .add_request_handler(user_handler(
 547                forward_mutating_project_request::<proto::OnTypeFormatting>,
 548            ))
 549            .add_request_handler(user_handler(
 550                forward_mutating_project_request::<proto::SaveBuffer>,
 551            ))
 552            .add_request_handler(user_handler(
 553                forward_mutating_project_request::<proto::BlameBuffer>,
 554            ))
 555            .add_request_handler(user_handler(
 556                forward_mutating_project_request::<proto::MultiLspQuery>,
 557            ))
 558            .add_request_handler(user_handler(
 559                forward_mutating_project_request::<proto::RestartLanguageServers>,
 560            ))
 561            .add_request_handler(user_handler(
 562                forward_mutating_project_request::<proto::LinkedEditingRange>,
 563            ))
 564            .add_message_handler(create_buffer_for_peer)
 565            .add_request_handler(update_buffer)
 566            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 567            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 568            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 569            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 570            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 571            .add_request_handler(get_users)
 572            .add_request_handler(user_handler(fuzzy_search_users))
 573            .add_request_handler(user_handler(request_contact))
 574            .add_request_handler(user_handler(remove_contact))
 575            .add_request_handler(user_handler(respond_to_contact_request))
 576            .add_message_handler(subscribe_to_channels)
 577            .add_request_handler(user_handler(create_channel))
 578            .add_request_handler(user_handler(delete_channel))
 579            .add_request_handler(user_handler(invite_channel_member))
 580            .add_request_handler(user_handler(remove_channel_member))
 581            .add_request_handler(user_handler(set_channel_member_role))
 582            .add_request_handler(user_handler(set_channel_visibility))
 583            .add_request_handler(user_handler(rename_channel))
 584            .add_request_handler(user_handler(join_channel_buffer))
 585            .add_request_handler(user_handler(leave_channel_buffer))
 586            .add_message_handler(user_message_handler(update_channel_buffer))
 587            .add_request_handler(user_handler(rejoin_channel_buffers))
 588            .add_request_handler(user_handler(get_channel_members))
 589            .add_request_handler(user_handler(respond_to_channel_invite))
 590            .add_request_handler(user_handler(join_channel))
 591            .add_request_handler(user_handler(join_channel_chat))
 592            .add_message_handler(user_message_handler(leave_channel_chat))
 593            .add_request_handler(user_handler(send_channel_message))
 594            .add_request_handler(user_handler(remove_channel_message))
 595            .add_request_handler(user_handler(update_channel_message))
 596            .add_request_handler(user_handler(get_channel_messages))
 597            .add_request_handler(user_handler(get_channel_messages_by_id))
 598            .add_request_handler(user_handler(get_notifications))
 599            .add_request_handler(user_handler(mark_notification_as_read))
 600            .add_request_handler(user_handler(move_channel))
 601            .add_request_handler(user_handler(follow))
 602            .add_message_handler(user_message_handler(unfollow))
 603            .add_message_handler(user_message_handler(update_followers))
 604            .add_request_handler(user_handler(get_private_user_info))
 605            .add_request_handler(user_handler(get_llm_api_token))
 606            .add_request_handler(user_handler(accept_terms_of_service))
 607            .add_message_handler(user_message_handler(acknowledge_channel_message))
 608            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 609            .add_request_handler(user_handler(get_supermaven_api_key))
 610            .add_request_handler(user_handler(
 611                forward_mutating_project_request::<proto::OpenContext>,
 612            ))
 613            .add_request_handler(user_handler(
 614                forward_mutating_project_request::<proto::CreateContext>,
 615            ))
 616            .add_request_handler(user_handler(
 617                forward_mutating_project_request::<proto::SynchronizeContexts>,
 618            ))
 619            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 620            .add_message_handler(update_context)
 621            .add_request_handler({
 622                let app_state = app_state.clone();
 623                move |request, response, session| {
 624                    let app_state = app_state.clone();
 625                    async move {
 626                        count_language_model_tokens(request, response, session, &app_state.config)
 627                            .await
 628                    }
 629                }
 630            })
 631            .add_request_handler({
 632                user_handler(move |request, response, session| {
 633                    get_cached_embeddings(request, response, session)
 634                })
 635            })
 636            .add_request_handler({
 637                let app_state = app_state.clone();
 638                user_handler(move |request, response, session| {
 639                    compute_embeddings(
 640                        request,
 641                        response,
 642                        session,
 643                        app_state.config.openai_api_key.clone(),
 644                    )
 645                })
 646            });
 647
 648        Arc::new(server)
 649    }
 650
 651    pub async fn start(&self) -> Result<()> {
 652        let server_id = *self.id.lock();
 653        let app_state = self.app_state.clone();
 654        let peer = self.peer.clone();
 655        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 656        let pool = self.connection_pool.clone();
 657        let live_kit_client = self.app_state.live_kit_client.clone();
 658
 659        let span = info_span!("start server");
 660        self.app_state.executor.spawn_detached(
 661            async move {
 662                tracing::info!("waiting for cleanup timeout");
 663                timeout.await;
 664                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 665                if let Some((room_ids, channel_ids)) = app_state
 666                    .db
 667                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 668                    .await
 669                    .trace_err()
 670                {
 671                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 672                    tracing::info!(
 673                        stale_channel_buffer_count = channel_ids.len(),
 674                        "retrieved stale channel buffers"
 675                    );
 676
 677                    for channel_id in channel_ids {
 678                        if let Some(refreshed_channel_buffer) = app_state
 679                            .db
 680                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 681                            .await
 682                            .trace_err()
 683                        {
 684                            for connection_id in refreshed_channel_buffer.connection_ids {
 685                                peer.send(
 686                                    connection_id,
 687                                    proto::UpdateChannelBufferCollaborators {
 688                                        channel_id: channel_id.to_proto(),
 689                                        collaborators: refreshed_channel_buffer
 690                                            .collaborators
 691                                            .clone(),
 692                                    },
 693                                )
 694                                .trace_err();
 695                            }
 696                        }
 697                    }
 698
 699                    for room_id in room_ids {
 700                        let mut contacts_to_update = HashSet::default();
 701                        let mut canceled_calls_to_user_ids = Vec::new();
 702                        let mut live_kit_room = String::new();
 703                        let mut delete_live_kit_room = false;
 704
 705                        if let Some(mut refreshed_room) = app_state
 706                            .db
 707                            .clear_stale_room_participants(room_id, server_id)
 708                            .await
 709                            .trace_err()
 710                        {
 711                            tracing::info!(
 712                                room_id = room_id.0,
 713                                new_participant_count = refreshed_room.room.participants.len(),
 714                                "refreshed room"
 715                            );
 716                            room_updated(&refreshed_room.room, &peer);
 717                            if let Some(channel) = refreshed_room.channel.as_ref() {
 718                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 719                            }
 720                            contacts_to_update
 721                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 722                            contacts_to_update
 723                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 724                            canceled_calls_to_user_ids =
 725                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 726                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 727                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 728                        }
 729
 730                        {
 731                            let pool = pool.lock();
 732                            for canceled_user_id in canceled_calls_to_user_ids {
 733                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 734                                    peer.send(
 735                                        connection_id,
 736                                        proto::CallCanceled {
 737                                            room_id: room_id.to_proto(),
 738                                        },
 739                                    )
 740                                    .trace_err();
 741                                }
 742                            }
 743                        }
 744
 745                        for user_id in contacts_to_update {
 746                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 747                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 748                            if let Some((busy, contacts)) = busy.zip(contacts) {
 749                                let pool = pool.lock();
 750                                let updated_contact = contact_for_user(user_id, busy, &pool);
 751                                for contact in contacts {
 752                                    if let db::Contact::Accepted {
 753                                        user_id: contact_user_id,
 754                                        ..
 755                                    } = contact
 756                                    {
 757                                        for contact_conn_id in
 758                                            pool.user_connection_ids(contact_user_id)
 759                                        {
 760                                            peer.send(
 761                                                contact_conn_id,
 762                                                proto::UpdateContacts {
 763                                                    contacts: vec![updated_contact.clone()],
 764                                                    remove_contacts: Default::default(),
 765                                                    incoming_requests: Default::default(),
 766                                                    remove_incoming_requests: Default::default(),
 767                                                    outgoing_requests: Default::default(),
 768                                                    remove_outgoing_requests: Default::default(),
 769                                                },
 770                                            )
 771                                            .trace_err();
 772                                        }
 773                                    }
 774                                }
 775                            }
 776                        }
 777
 778                        if let Some(live_kit) = live_kit_client.as_ref() {
 779                            if delete_live_kit_room {
 780                                live_kit.delete_room(live_kit_room).await.trace_err();
 781                            }
 782                        }
 783                    }
 784                }
 785
 786                app_state
 787                    .db
 788                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 789                    .await
 790                    .trace_err();
 791            }
 792            .instrument(span),
 793        );
 794        Ok(())
 795    }
 796
 797    pub fn teardown(&self) {
 798        self.peer.teardown();
 799        self.connection_pool.lock().reset();
 800        let _ = self.teardown.send(true);
 801    }
 802
 803    #[cfg(test)]
 804    pub fn reset(&self, id: ServerId) {
 805        self.teardown();
 806        *self.id.lock() = id;
 807        self.peer.reset(id.0 as u32);
 808        let _ = self.teardown.send(false);
 809    }
 810
 811    #[cfg(test)]
 812    pub fn id(&self) -> ServerId {
 813        *self.id.lock()
 814    }
 815
 816    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 817    where
 818        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 819        Fut: 'static + Send + Future<Output = Result<()>>,
 820        M: EnvelopedMessage,
 821    {
 822        let prev_handler = self.handlers.insert(
 823            TypeId::of::<M>(),
 824            Box::new(move |envelope, session| {
 825                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 826                let received_at = envelope.received_at;
 827                tracing::info!("message received");
 828                let start_time = Instant::now();
 829                let future = (handler)(*envelope, session);
 830                async move {
 831                    let result = future.await;
 832                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 833                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 834                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 835                    let payload_type = M::NAME;
 836
 837                    match result {
 838                        Err(error) => {
 839                            tracing::error!(
 840                                ?error,
 841                                total_duration_ms,
 842                                processing_duration_ms,
 843                                queue_duration_ms,
 844                                payload_type,
 845                                "error handling message"
 846                            )
 847                        }
 848                        Ok(()) => tracing::info!(
 849                            total_duration_ms,
 850                            processing_duration_ms,
 851                            queue_duration_ms,
 852                            "finished handling message"
 853                        ),
 854                    }
 855                }
 856                .boxed()
 857            }),
 858        );
 859        if prev_handler.is_some() {
 860            panic!("registered a handler for the same message twice");
 861        }
 862        self
 863    }
 864
 865    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 866    where
 867        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 868        Fut: 'static + Send + Future<Output = Result<()>>,
 869        M: EnvelopedMessage,
 870    {
 871        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 872        self
 873    }
 874
 875    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 876    where
 877        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 878        Fut: Send + Future<Output = Result<()>>,
 879        M: RequestMessage,
 880    {
 881        let handler = Arc::new(handler);
 882        self.add_handler(move |envelope, session| {
 883            let receipt = envelope.receipt();
 884            let handler = handler.clone();
 885            async move {
 886                let peer = session.peer.clone();
 887                let responded = Arc::new(AtomicBool::default());
 888                let response = Response {
 889                    peer: peer.clone(),
 890                    responded: responded.clone(),
 891                    receipt,
 892                };
 893                match (handler)(envelope.payload, response, session).await {
 894                    Ok(()) => {
 895                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 896                            Ok(())
 897                        } else {
 898                            Err(anyhow!("handler did not send a response"))?
 899                        }
 900                    }
 901                    Err(error) => {
 902                        let proto_err = match &error {
 903                            Error::Internal(err) => err.to_proto(),
 904                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 905                        };
 906                        peer.respond_with_error(receipt, proto_err)?;
 907                        Err(error)
 908                    }
 909                }
 910            }
 911        })
 912    }
 913
 914    #[allow(clippy::too_many_arguments)]
 915    pub fn handle_connection(
 916        self: &Arc<Self>,
 917        connection: Connection,
 918        address: String,
 919        principal: Principal,
 920        zed_version: ZedVersion,
 921        geoip_country_code: Option<String>,
 922        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 923        executor: Executor,
 924    ) -> impl Future<Output = ()> {
 925        let this = self.clone();
 926        let span = info_span!("handle connection", %address,
 927            connection_id=field::Empty,
 928            user_id=field::Empty,
 929            login=field::Empty,
 930            impersonator=field::Empty,
 931            dev_server_id=field::Empty,
 932            geoip_country_code=field::Empty
 933        );
 934        principal.update_span(&span);
 935        if let Some(country_code) = geoip_country_code.as_ref() {
 936            span.record("geoip_country_code", country_code);
 937        }
 938
 939        let mut teardown = self.teardown.subscribe();
 940        async move {
 941            if *teardown.borrow() {
 942                tracing::error!("server is tearing down");
 943                return
 944            }
 945            let (connection_id, handle_io, mut incoming_rx) = this
 946                .peer
 947                .add_connection(connection, {
 948                    let executor = executor.clone();
 949                    move |duration| executor.sleep(duration)
 950                });
 951            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 952
 953            tracing::info!("connection opened");
 954
 955            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 956            let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
 957                Ok(http_client) => Arc::new(http_client),
 958                Err(error) => {
 959                    tracing::error!(?error, "failed to create HTTP client");
 960                    return;
 961                }
 962            };
 963
 964            let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
 965                Some(Arc::new(SupermavenAdminApi::new(
 966                    supermaven_admin_api_key.to_string(),
 967                    http_client.clone(),
 968                )))
 969            } else {
 970                None
 971            };
 972
 973            let session = Session {
 974                principal: principal.clone(),
 975                connection_id,
 976                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 977                peer: this.peer.clone(),
 978                connection_pool: this.connection_pool.clone(),
 979                app_state: this.app_state.clone(),
 980                http_client,
 981                geoip_country_code,
 982                _executor: executor.clone(),
 983                supermaven_client,
 984            };
 985
 986            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 987                tracing::error!(?error, "failed to send initial client update");
 988                return;
 989            }
 990
 991            let handle_io = handle_io.fuse();
 992            futures::pin_mut!(handle_io);
 993
 994            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 995            // This prevents deadlocks when e.g., client A performs a request to client B and
 996            // client B performs a request to client A. If both clients stop processing further
 997            // messages until their respective request completes, they won't have a chance to
 998            // respond to the other client's request and cause a deadlock.
 999            //
1000            // This arrangement ensures we will attempt to process earlier messages first, but fall
1001            // back to processing messages arrived later in the spirit of making progress.
1002            let mut foreground_message_handlers = FuturesUnordered::new();
1003            let concurrent_handlers = Arc::new(Semaphore::new(256));
1004            loop {
1005                let next_message = async {
1006                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1007                    let message = incoming_rx.next().await;
1008                    (permit, message)
1009                }.fuse();
1010                futures::pin_mut!(next_message);
1011                futures::select_biased! {
1012                    _ = teardown.changed().fuse() => return,
1013                    result = handle_io => {
1014                        if let Err(error) = result {
1015                            tracing::error!(?error, "error handling I/O");
1016                        }
1017                        break;
1018                    }
1019                    _ = foreground_message_handlers.next() => {}
1020                    next_message = next_message => {
1021                        let (permit, message) = next_message;
1022                        if let Some(message) = message {
1023                            let type_name = message.payload_type_name();
1024                            // note: we copy all the fields from the parent span so we can query them in the logs.
1025                            // (https://github.com/tokio-rs/tracing/issues/2670).
1026                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1027                                user_id=field::Empty,
1028                                login=field::Empty,
1029                                impersonator=field::Empty,
1030                                dev_server_id=field::Empty
1031                            );
1032                            principal.update_span(&span);
1033                            let span_enter = span.enter();
1034                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1035                                let is_background = message.is_background();
1036                                let handle_message = (handler)(message, session.clone());
1037                                drop(span_enter);
1038
1039                                let handle_message = async move {
1040                                    handle_message.await;
1041                                    drop(permit);
1042                                }.instrument(span);
1043                                if is_background {
1044                                    executor.spawn_detached(handle_message);
1045                                } else {
1046                                    foreground_message_handlers.push(handle_message);
1047                                }
1048                            } else {
1049                                tracing::error!("no message handler");
1050                            }
1051                        } else {
1052                            tracing::info!("connection closed");
1053                            break;
1054                        }
1055                    }
1056                }
1057            }
1058
1059            drop(foreground_message_handlers);
1060            tracing::info!("signing out");
1061            if let Err(error) = connection_lost(session, teardown, executor).await {
1062                tracing::error!(?error, "error signing out");
1063            }
1064
1065        }.instrument(span)
1066    }
1067
1068    async fn send_initial_client_update(
1069        &self,
1070        connection_id: ConnectionId,
1071        principal: &Principal,
1072        zed_version: ZedVersion,
1073        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1074        session: &Session,
1075    ) -> Result<()> {
1076        self.peer.send(
1077            connection_id,
1078            proto::Hello {
1079                peer_id: Some(connection_id.into()),
1080            },
1081        )?;
1082        tracing::info!("sent hello message");
1083        if let Some(send_connection_id) = send_connection_id.take() {
1084            let _ = send_connection_id.send(connection_id);
1085        }
1086
1087        match principal {
1088            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1089                if !user.connected_once {
1090                    self.peer.send(connection_id, proto::ShowContacts {})?;
1091                    self.app_state
1092                        .db
1093                        .set_user_connected_once(user.id, true)
1094                        .await?;
1095                }
1096
1097                update_user_plan(user.id, session).await?;
1098
1099                let (contacts, dev_server_projects) = future::try_join(
1100                    self.app_state.db.get_contacts(user.id),
1101                    self.app_state.db.dev_server_projects_update(user.id),
1102                )
1103                .await?;
1104
1105                {
1106                    let mut pool = self.connection_pool.lock();
1107                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1108                    self.peer.send(
1109                        connection_id,
1110                        build_initial_contacts_update(contacts, &pool),
1111                    )?;
1112                }
1113
1114                if should_auto_subscribe_to_channels(zed_version) {
1115                    subscribe_user_to_channels(user.id, session).await?;
1116                }
1117
1118                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1119
1120                if let Some(incoming_call) =
1121                    self.app_state.db.incoming_call_for_user(user.id).await?
1122                {
1123                    self.peer.send(connection_id, incoming_call)?;
1124                }
1125
1126                update_user_contacts(user.id, &session).await?;
1127            }
1128            Principal::DevServer(dev_server) => {
1129                {
1130                    let mut pool = self.connection_pool.lock();
1131                    if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1132                    {
1133                        self.peer.send(
1134                            stale_connection_id,
1135                            proto::ShutdownDevServer {
1136                                reason: Some(
1137                                    "another dev server connected with the same token".to_string(),
1138                                ),
1139                            },
1140                        )?;
1141                        pool.remove_connection(stale_connection_id)?;
1142                    };
1143                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1144                }
1145
1146                let projects = self
1147                    .app_state
1148                    .db
1149                    .get_projects_for_dev_server(dev_server.id)
1150                    .await?;
1151                self.peer
1152                    .send(connection_id, proto::DevServerInstructions { projects })?;
1153
1154                let status = self
1155                    .app_state
1156                    .db
1157                    .dev_server_projects_update(dev_server.user_id)
1158                    .await?;
1159                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1160            }
1161        }
1162
1163        Ok(())
1164    }
1165
1166    pub async fn invite_code_redeemed(
1167        self: &Arc<Self>,
1168        inviter_id: UserId,
1169        invitee_id: UserId,
1170    ) -> Result<()> {
1171        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1172            if let Some(code) = &user.invite_code {
1173                let pool = self.connection_pool.lock();
1174                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1175                for connection_id in pool.user_connection_ids(inviter_id) {
1176                    self.peer.send(
1177                        connection_id,
1178                        proto::UpdateContacts {
1179                            contacts: vec![invitee_contact.clone()],
1180                            ..Default::default()
1181                        },
1182                    )?;
1183                    self.peer.send(
1184                        connection_id,
1185                        proto::UpdateInviteInfo {
1186                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1187                            count: user.invite_count as u32,
1188                        },
1189                    )?;
1190                }
1191            }
1192        }
1193        Ok(())
1194    }
1195
1196    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1197        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1198            if let Some(invite_code) = &user.invite_code {
1199                let pool = self.connection_pool.lock();
1200                for connection_id in pool.user_connection_ids(user_id) {
1201                    self.peer.send(
1202                        connection_id,
1203                        proto::UpdateInviteInfo {
1204                            url: format!(
1205                                "{}{}",
1206                                self.app_state.config.invite_link_prefix, invite_code
1207                            ),
1208                            count: user.invite_count as u32,
1209                        },
1210                    )?;
1211                }
1212            }
1213        }
1214        Ok(())
1215    }
1216
1217    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1218        ServerSnapshot {
1219            connection_pool: ConnectionPoolGuard {
1220                guard: self.connection_pool.lock(),
1221                _not_send: PhantomData,
1222            },
1223            peer: &self.peer,
1224        }
1225    }
1226}
1227
1228impl<'a> Deref for ConnectionPoolGuard<'a> {
1229    type Target = ConnectionPool;
1230
1231    fn deref(&self) -> &Self::Target {
1232        &self.guard
1233    }
1234}
1235
1236impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1237    fn deref_mut(&mut self) -> &mut Self::Target {
1238        &mut self.guard
1239    }
1240}
1241
1242impl<'a> Drop for ConnectionPoolGuard<'a> {
1243    fn drop(&mut self) {
1244        #[cfg(test)]
1245        self.check_invariants();
1246    }
1247}
1248
1249fn broadcast<F>(
1250    sender_id: Option<ConnectionId>,
1251    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1252    mut f: F,
1253) where
1254    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1255{
1256    for receiver_id in receiver_ids {
1257        if Some(receiver_id) != sender_id {
1258            if let Err(error) = f(receiver_id) {
1259                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1260            }
1261        }
1262    }
1263}
1264
1265pub struct ProtocolVersion(u32);
1266
1267impl Header for ProtocolVersion {
1268    fn name() -> &'static HeaderName {
1269        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1270        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1271    }
1272
1273    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1274    where
1275        Self: Sized,
1276        I: Iterator<Item = &'i axum::http::HeaderValue>,
1277    {
1278        let version = values
1279            .next()
1280            .ok_or_else(axum::headers::Error::invalid)?
1281            .to_str()
1282            .map_err(|_| axum::headers::Error::invalid())?
1283            .parse()
1284            .map_err(|_| axum::headers::Error::invalid())?;
1285        Ok(Self(version))
1286    }
1287
1288    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1289        values.extend([self.0.to_string().parse().unwrap()]);
1290    }
1291}
1292
1293pub struct AppVersionHeader(SemanticVersion);
1294impl Header for AppVersionHeader {
1295    fn name() -> &'static HeaderName {
1296        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1297        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1298    }
1299
1300    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1301    where
1302        Self: Sized,
1303        I: Iterator<Item = &'i axum::http::HeaderValue>,
1304    {
1305        let version = values
1306            .next()
1307            .ok_or_else(axum::headers::Error::invalid)?
1308            .to_str()
1309            .map_err(|_| axum::headers::Error::invalid())?
1310            .parse()
1311            .map_err(|_| axum::headers::Error::invalid())?;
1312        Ok(Self(version))
1313    }
1314
1315    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1316        values.extend([self.0.to_string().parse().unwrap()]);
1317    }
1318}
1319
1320pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1321    Router::new()
1322        .route("/rpc", get(handle_websocket_request))
1323        .layer(
1324            ServiceBuilder::new()
1325                .layer(Extension(server.app_state.clone()))
1326                .layer(middleware::from_fn(auth::validate_header)),
1327        )
1328        .route("/metrics", get(handle_metrics))
1329        .layer(Extension(server))
1330}
1331
1332pub async fn handle_websocket_request(
1333    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1334    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1335    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1336    Extension(server): Extension<Arc<Server>>,
1337    Extension(principal): Extension<Principal>,
1338    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1339    ws: WebSocketUpgrade,
1340) -> axum::response::Response {
1341    if protocol_version != rpc::PROTOCOL_VERSION {
1342        return (
1343            StatusCode::UPGRADE_REQUIRED,
1344            "client must be upgraded".to_string(),
1345        )
1346            .into_response();
1347    }
1348
1349    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1350        return (
1351            StatusCode::UPGRADE_REQUIRED,
1352            "no version header found".to_string(),
1353        )
1354            .into_response();
1355    };
1356
1357    if !version.can_collaborate() {
1358        return (
1359            StatusCode::UPGRADE_REQUIRED,
1360            "client must be upgraded".to_string(),
1361        )
1362            .into_response();
1363    }
1364
1365    let socket_address = socket_address.to_string();
1366    ws.on_upgrade(move |socket| {
1367        let socket = socket
1368            .map_ok(to_tungstenite_message)
1369            .err_into()
1370            .with(|message| async move { to_axum_message(message) });
1371        let connection = Connection::new(Box::pin(socket));
1372        async move {
1373            server
1374                .handle_connection(
1375                    connection,
1376                    socket_address,
1377                    principal,
1378                    version,
1379                    country_code_header.map(|header| header.to_string()),
1380                    None,
1381                    Executor::Production,
1382                )
1383                .await;
1384        }
1385    })
1386}
1387
1388pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1389    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1390    let connections_metric = CONNECTIONS_METRIC
1391        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1392
1393    let connections = server
1394        .connection_pool
1395        .lock()
1396        .connections()
1397        .filter(|connection| !connection.admin)
1398        .count();
1399    connections_metric.set(connections as _);
1400
1401    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1402    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1403        register_int_gauge!(
1404            "shared_projects",
1405            "number of open projects with one or more guests"
1406        )
1407        .unwrap()
1408    });
1409
1410    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1411    shared_projects_metric.set(shared_projects as _);
1412
1413    let encoder = prometheus::TextEncoder::new();
1414    let metric_families = prometheus::gather();
1415    let encoded_metrics = encoder
1416        .encode_to_string(&metric_families)
1417        .map_err(|err| anyhow!("{}", err))?;
1418    Ok(encoded_metrics)
1419}
1420
1421#[instrument(err, skip(executor))]
1422async fn connection_lost(
1423    session: Session,
1424    mut teardown: watch::Receiver<bool>,
1425    executor: Executor,
1426) -> Result<()> {
1427    session.peer.disconnect(session.connection_id);
1428    session
1429        .connection_pool()
1430        .await
1431        .remove_connection(session.connection_id)?;
1432
1433    session
1434        .db()
1435        .await
1436        .connection_lost(session.connection_id)
1437        .await
1438        .trace_err();
1439
1440    futures::select_biased! {
1441        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1442            match &session.principal {
1443                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1444                    let session = session.for_user().unwrap();
1445
1446                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1447                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1448                    leave_channel_buffers_for_session(&session)
1449                        .await
1450                        .trace_err();
1451
1452                    if !session
1453                        .connection_pool()
1454                        .await
1455                        .is_user_online(session.user_id())
1456                    {
1457                        let db = session.db().await;
1458                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1459                            room_updated(&room, &session.peer);
1460                        }
1461                    }
1462
1463                    update_user_contacts(session.user_id(), &session).await?;
1464                },
1465            Principal::DevServer(_) => {
1466                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1467            },
1468        }
1469        },
1470        _ = teardown.changed().fuse() => {}
1471    }
1472
1473    Ok(())
1474}
1475
1476/// Acknowledges a ping from a client, used to keep the connection alive.
1477async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1478    response.send(proto::Ack {})?;
1479    Ok(())
1480}
1481
1482/// Creates a new room for calling (outside of channels)
1483async fn create_room(
1484    _request: proto::CreateRoom,
1485    response: Response<proto::CreateRoom>,
1486    session: UserSession,
1487) -> Result<()> {
1488    let live_kit_room = nanoid::nanoid!(30);
1489
1490    let live_kit_connection_info = util::maybe!(async {
1491        let live_kit = session.app_state.live_kit_client.as_ref();
1492        let live_kit = live_kit?;
1493        let user_id = session.user_id().to_string();
1494
1495        let token = live_kit
1496            .room_token(&live_kit_room, &user_id.to_string())
1497            .trace_err()?;
1498
1499        Some(proto::LiveKitConnectionInfo {
1500            server_url: live_kit.url().into(),
1501            token,
1502            can_publish: true,
1503        })
1504    })
1505    .await;
1506
1507    let room = session
1508        .db()
1509        .await
1510        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1511        .await?;
1512
1513    response.send(proto::CreateRoomResponse {
1514        room: Some(room.clone()),
1515        live_kit_connection_info,
1516    })?;
1517
1518    update_user_contacts(session.user_id(), &session).await?;
1519    Ok(())
1520}
1521
1522/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1523async fn join_room(
1524    request: proto::JoinRoom,
1525    response: Response<proto::JoinRoom>,
1526    session: UserSession,
1527) -> Result<()> {
1528    let room_id = RoomId::from_proto(request.id);
1529
1530    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1531
1532    if let Some(channel_id) = channel_id {
1533        return join_channel_internal(channel_id, Box::new(response), session).await;
1534    }
1535
1536    let joined_room = {
1537        let room = session
1538            .db()
1539            .await
1540            .join_room(room_id, session.user_id(), session.connection_id)
1541            .await?;
1542        room_updated(&room.room, &session.peer);
1543        room.into_inner()
1544    };
1545
1546    for connection_id in session
1547        .connection_pool()
1548        .await
1549        .user_connection_ids(session.user_id())
1550    {
1551        session
1552            .peer
1553            .send(
1554                connection_id,
1555                proto::CallCanceled {
1556                    room_id: room_id.to_proto(),
1557                },
1558            )
1559            .trace_err();
1560    }
1561
1562    let live_kit_connection_info =
1563        if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1564            if let Some(token) = live_kit
1565                .room_token(
1566                    &joined_room.room.live_kit_room,
1567                    &session.user_id().to_string(),
1568                )
1569                .trace_err()
1570            {
1571                Some(proto::LiveKitConnectionInfo {
1572                    server_url: live_kit.url().into(),
1573                    token,
1574                    can_publish: true,
1575                })
1576            } else {
1577                None
1578            }
1579        } else {
1580            None
1581        };
1582
1583    response.send(proto::JoinRoomResponse {
1584        room: Some(joined_room.room),
1585        channel_id: None,
1586        live_kit_connection_info,
1587    })?;
1588
1589    update_user_contacts(session.user_id(), &session).await?;
1590    Ok(())
1591}
1592
1593/// Rejoin room is used to reconnect to a room after connection errors.
1594async fn rejoin_room(
1595    request: proto::RejoinRoom,
1596    response: Response<proto::RejoinRoom>,
1597    session: UserSession,
1598) -> Result<()> {
1599    let room;
1600    let channel;
1601    {
1602        let mut rejoined_room = session
1603            .db()
1604            .await
1605            .rejoin_room(request, session.user_id(), session.connection_id)
1606            .await?;
1607
1608        response.send(proto::RejoinRoomResponse {
1609            room: Some(rejoined_room.room.clone()),
1610            reshared_projects: rejoined_room
1611                .reshared_projects
1612                .iter()
1613                .map(|project| proto::ResharedProject {
1614                    id: project.id.to_proto(),
1615                    collaborators: project
1616                        .collaborators
1617                        .iter()
1618                        .map(|collaborator| collaborator.to_proto())
1619                        .collect(),
1620                })
1621                .collect(),
1622            rejoined_projects: rejoined_room
1623                .rejoined_projects
1624                .iter()
1625                .map(|rejoined_project| rejoined_project.to_proto())
1626                .collect(),
1627        })?;
1628        room_updated(&rejoined_room.room, &session.peer);
1629
1630        for project in &rejoined_room.reshared_projects {
1631            for collaborator in &project.collaborators {
1632                session
1633                    .peer
1634                    .send(
1635                        collaborator.connection_id,
1636                        proto::UpdateProjectCollaborator {
1637                            project_id: project.id.to_proto(),
1638                            old_peer_id: Some(project.old_connection_id.into()),
1639                            new_peer_id: Some(session.connection_id.into()),
1640                        },
1641                    )
1642                    .trace_err();
1643            }
1644
1645            broadcast(
1646                Some(session.connection_id),
1647                project
1648                    .collaborators
1649                    .iter()
1650                    .map(|collaborator| collaborator.connection_id),
1651                |connection_id| {
1652                    session.peer.forward_send(
1653                        session.connection_id,
1654                        connection_id,
1655                        proto::UpdateProject {
1656                            project_id: project.id.to_proto(),
1657                            worktrees: project.worktrees.clone(),
1658                        },
1659                    )
1660                },
1661            );
1662        }
1663
1664        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1665
1666        let rejoined_room = rejoined_room.into_inner();
1667
1668        room = rejoined_room.room;
1669        channel = rejoined_room.channel;
1670    }
1671
1672    if let Some(channel) = channel {
1673        channel_updated(
1674            &channel,
1675            &room,
1676            &session.peer,
1677            &*session.connection_pool().await,
1678        );
1679    }
1680
1681    update_user_contacts(session.user_id(), &session).await?;
1682    Ok(())
1683}
1684
1685fn notify_rejoined_projects(
1686    rejoined_projects: &mut Vec<RejoinedProject>,
1687    session: &UserSession,
1688) -> Result<()> {
1689    for project in rejoined_projects.iter() {
1690        for collaborator in &project.collaborators {
1691            session
1692                .peer
1693                .send(
1694                    collaborator.connection_id,
1695                    proto::UpdateProjectCollaborator {
1696                        project_id: project.id.to_proto(),
1697                        old_peer_id: Some(project.old_connection_id.into()),
1698                        new_peer_id: Some(session.connection_id.into()),
1699                    },
1700                )
1701                .trace_err();
1702        }
1703    }
1704
1705    for project in rejoined_projects {
1706        for worktree in mem::take(&mut project.worktrees) {
1707            #[cfg(any(test, feature = "test-support"))]
1708            const MAX_CHUNK_SIZE: usize = 2;
1709            #[cfg(not(any(test, feature = "test-support")))]
1710            const MAX_CHUNK_SIZE: usize = 256;
1711
1712            // Stream this worktree's entries.
1713            let message = proto::UpdateWorktree {
1714                project_id: project.id.to_proto(),
1715                worktree_id: worktree.id,
1716                abs_path: worktree.abs_path.clone(),
1717                root_name: worktree.root_name,
1718                updated_entries: worktree.updated_entries,
1719                removed_entries: worktree.removed_entries,
1720                scan_id: worktree.scan_id,
1721                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1722                updated_repositories: worktree.updated_repositories,
1723                removed_repositories: worktree.removed_repositories,
1724            };
1725            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1726                session.peer.send(session.connection_id, update.clone())?;
1727            }
1728
1729            // Stream this worktree's diagnostics.
1730            for summary in worktree.diagnostic_summaries {
1731                session.peer.send(
1732                    session.connection_id,
1733                    proto::UpdateDiagnosticSummary {
1734                        project_id: project.id.to_proto(),
1735                        worktree_id: worktree.id,
1736                        summary: Some(summary),
1737                    },
1738                )?;
1739            }
1740
1741            for settings_file in worktree.settings_files {
1742                session.peer.send(
1743                    session.connection_id,
1744                    proto::UpdateWorktreeSettings {
1745                        project_id: project.id.to_proto(),
1746                        worktree_id: worktree.id,
1747                        path: settings_file.path,
1748                        content: Some(settings_file.content),
1749                    },
1750                )?;
1751            }
1752        }
1753
1754        for language_server in &project.language_servers {
1755            session.peer.send(
1756                session.connection_id,
1757                proto::UpdateLanguageServer {
1758                    project_id: project.id.to_proto(),
1759                    language_server_id: language_server.id,
1760                    variant: Some(
1761                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1762                            proto::LspDiskBasedDiagnosticsUpdated {},
1763                        ),
1764                    ),
1765                },
1766            )?;
1767        }
1768    }
1769    Ok(())
1770}
1771
1772/// leave room disconnects from the room.
1773async fn leave_room(
1774    _: proto::LeaveRoom,
1775    response: Response<proto::LeaveRoom>,
1776    session: UserSession,
1777) -> Result<()> {
1778    leave_room_for_session(&session, session.connection_id).await?;
1779    response.send(proto::Ack {})?;
1780    Ok(())
1781}
1782
1783/// Updates the permissions of someone else in the room.
1784async fn set_room_participant_role(
1785    request: proto::SetRoomParticipantRole,
1786    response: Response<proto::SetRoomParticipantRole>,
1787    session: UserSession,
1788) -> Result<()> {
1789    let user_id = UserId::from_proto(request.user_id);
1790    let role = ChannelRole::from(request.role());
1791
1792    let (live_kit_room, can_publish) = {
1793        let room = session
1794            .db()
1795            .await
1796            .set_room_participant_role(
1797                session.user_id(),
1798                RoomId::from_proto(request.room_id),
1799                user_id,
1800                role,
1801            )
1802            .await?;
1803
1804        let live_kit_room = room.live_kit_room.clone();
1805        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1806        room_updated(&room, &session.peer);
1807        (live_kit_room, can_publish)
1808    };
1809
1810    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1811        live_kit
1812            .update_participant(
1813                live_kit_room.clone(),
1814                request.user_id.to_string(),
1815                live_kit_server::proto::ParticipantPermission {
1816                    can_subscribe: true,
1817                    can_publish,
1818                    can_publish_data: can_publish,
1819                    hidden: false,
1820                    recorder: false,
1821                },
1822            )
1823            .await
1824            .trace_err();
1825    }
1826
1827    response.send(proto::Ack {})?;
1828    Ok(())
1829}
1830
1831/// Call someone else into the current room
1832async fn call(
1833    request: proto::Call,
1834    response: Response<proto::Call>,
1835    session: UserSession,
1836) -> Result<()> {
1837    let room_id = RoomId::from_proto(request.room_id);
1838    let calling_user_id = session.user_id();
1839    let calling_connection_id = session.connection_id;
1840    let called_user_id = UserId::from_proto(request.called_user_id);
1841    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1842    if !session
1843        .db()
1844        .await
1845        .has_contact(calling_user_id, called_user_id)
1846        .await?
1847    {
1848        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1849    }
1850
1851    let incoming_call = {
1852        let (room, incoming_call) = &mut *session
1853            .db()
1854            .await
1855            .call(
1856                room_id,
1857                calling_user_id,
1858                calling_connection_id,
1859                called_user_id,
1860                initial_project_id,
1861            )
1862            .await?;
1863        room_updated(&room, &session.peer);
1864        mem::take(incoming_call)
1865    };
1866    update_user_contacts(called_user_id, &session).await?;
1867
1868    let mut calls = session
1869        .connection_pool()
1870        .await
1871        .user_connection_ids(called_user_id)
1872        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1873        .collect::<FuturesUnordered<_>>();
1874
1875    while let Some(call_response) = calls.next().await {
1876        match call_response.as_ref() {
1877            Ok(_) => {
1878                response.send(proto::Ack {})?;
1879                return Ok(());
1880            }
1881            Err(_) => {
1882                call_response.trace_err();
1883            }
1884        }
1885    }
1886
1887    {
1888        let room = session
1889            .db()
1890            .await
1891            .call_failed(room_id, called_user_id)
1892            .await?;
1893        room_updated(&room, &session.peer);
1894    }
1895    update_user_contacts(called_user_id, &session).await?;
1896
1897    Err(anyhow!("failed to ring user"))?
1898}
1899
1900/// Cancel an outgoing call.
1901async fn cancel_call(
1902    request: proto::CancelCall,
1903    response: Response<proto::CancelCall>,
1904    session: UserSession,
1905) -> Result<()> {
1906    let called_user_id = UserId::from_proto(request.called_user_id);
1907    let room_id = RoomId::from_proto(request.room_id);
1908    {
1909        let room = session
1910            .db()
1911            .await
1912            .cancel_call(room_id, session.connection_id, called_user_id)
1913            .await?;
1914        room_updated(&room, &session.peer);
1915    }
1916
1917    for connection_id in session
1918        .connection_pool()
1919        .await
1920        .user_connection_ids(called_user_id)
1921    {
1922        session
1923            .peer
1924            .send(
1925                connection_id,
1926                proto::CallCanceled {
1927                    room_id: room_id.to_proto(),
1928                },
1929            )
1930            .trace_err();
1931    }
1932    response.send(proto::Ack {})?;
1933
1934    update_user_contacts(called_user_id, &session).await?;
1935    Ok(())
1936}
1937
1938/// Decline an incoming call.
1939async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1940    let room_id = RoomId::from_proto(message.room_id);
1941    {
1942        let room = session
1943            .db()
1944            .await
1945            .decline_call(Some(room_id), session.user_id())
1946            .await?
1947            .ok_or_else(|| anyhow!("failed to decline call"))?;
1948        room_updated(&room, &session.peer);
1949    }
1950
1951    for connection_id in session
1952        .connection_pool()
1953        .await
1954        .user_connection_ids(session.user_id())
1955    {
1956        session
1957            .peer
1958            .send(
1959                connection_id,
1960                proto::CallCanceled {
1961                    room_id: room_id.to_proto(),
1962                },
1963            )
1964            .trace_err();
1965    }
1966    update_user_contacts(session.user_id(), &session).await?;
1967    Ok(())
1968}
1969
1970/// Updates other participants in the room with your current location.
1971async fn update_participant_location(
1972    request: proto::UpdateParticipantLocation,
1973    response: Response<proto::UpdateParticipantLocation>,
1974    session: UserSession,
1975) -> Result<()> {
1976    let room_id = RoomId::from_proto(request.room_id);
1977    let location = request
1978        .location
1979        .ok_or_else(|| anyhow!("invalid location"))?;
1980
1981    let db = session.db().await;
1982    let room = db
1983        .update_room_participant_location(room_id, session.connection_id, location)
1984        .await?;
1985
1986    room_updated(&room, &session.peer);
1987    response.send(proto::Ack {})?;
1988    Ok(())
1989}
1990
1991/// Share a project into the room.
1992async fn share_project(
1993    request: proto::ShareProject,
1994    response: Response<proto::ShareProject>,
1995    session: UserSession,
1996) -> Result<()> {
1997    let (project_id, room) = &*session
1998        .db()
1999        .await
2000        .share_project(
2001            RoomId::from_proto(request.room_id),
2002            session.connection_id,
2003            &request.worktrees,
2004            request
2005                .dev_server_project_id
2006                .map(|id| DevServerProjectId::from_proto(id)),
2007        )
2008        .await?;
2009    response.send(proto::ShareProjectResponse {
2010        project_id: project_id.to_proto(),
2011    })?;
2012    room_updated(&room, &session.peer);
2013
2014    Ok(())
2015}
2016
2017/// Unshare a project from the room.
2018async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2019    let project_id = ProjectId::from_proto(message.project_id);
2020    unshare_project_internal(
2021        project_id,
2022        session.connection_id,
2023        session.user_id(),
2024        &session,
2025    )
2026    .await
2027}
2028
2029async fn unshare_project_internal(
2030    project_id: ProjectId,
2031    connection_id: ConnectionId,
2032    user_id: Option<UserId>,
2033    session: &Session,
2034) -> Result<()> {
2035    let delete = {
2036        let room_guard = session
2037            .db()
2038            .await
2039            .unshare_project(project_id, connection_id, user_id)
2040            .await?;
2041
2042        let (delete, room, guest_connection_ids) = &*room_guard;
2043
2044        let message = proto::UnshareProject {
2045            project_id: project_id.to_proto(),
2046        };
2047
2048        broadcast(
2049            Some(connection_id),
2050            guest_connection_ids.iter().copied(),
2051            |conn_id| session.peer.send(conn_id, message.clone()),
2052        );
2053        if let Some(room) = room {
2054            room_updated(room, &session.peer);
2055        }
2056
2057        *delete
2058    };
2059
2060    if delete {
2061        let db = session.db().await;
2062        db.delete_project(project_id).await?;
2063    }
2064
2065    Ok(())
2066}
2067
2068/// DevServer makes a project available online
2069async fn share_dev_server_project(
2070    request: proto::ShareDevServerProject,
2071    response: Response<proto::ShareDevServerProject>,
2072    session: DevServerSession,
2073) -> Result<()> {
2074    let (dev_server_project, user_id, status) = session
2075        .db()
2076        .await
2077        .share_dev_server_project(
2078            DevServerProjectId::from_proto(request.dev_server_project_id),
2079            session.dev_server_id(),
2080            session.connection_id,
2081            &request.worktrees,
2082        )
2083        .await?;
2084    let Some(project_id) = dev_server_project.project_id else {
2085        return Err(anyhow!("failed to share remote project"))?;
2086    };
2087
2088    send_dev_server_projects_update(user_id, status, &session).await;
2089
2090    response.send(proto::ShareProjectResponse { project_id })?;
2091
2092    Ok(())
2093}
2094
2095/// Join someone elses shared project.
2096async fn join_project(
2097    request: proto::JoinProject,
2098    response: Response<proto::JoinProject>,
2099    session: UserSession,
2100) -> Result<()> {
2101    let project_id = ProjectId::from_proto(request.project_id);
2102
2103    tracing::info!(%project_id, "join project");
2104
2105    let db = session.db().await;
2106    let (project, replica_id) = &mut *db
2107        .join_project(project_id, session.connection_id, session.user_id())
2108        .await?;
2109    drop(db);
2110    tracing::info!(%project_id, "join remote project");
2111    join_project_internal(response, session, project, replica_id)
2112}
2113
2114trait JoinProjectInternalResponse {
2115    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2116}
2117impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2118    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2119        Response::<proto::JoinProject>::send(self, result)
2120    }
2121}
2122impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2123    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2124        Response::<proto::JoinHostedProject>::send(self, result)
2125    }
2126}
2127
2128fn join_project_internal(
2129    response: impl JoinProjectInternalResponse,
2130    session: UserSession,
2131    project: &mut Project,
2132    replica_id: &ReplicaId,
2133) -> Result<()> {
2134    let collaborators = project
2135        .collaborators
2136        .iter()
2137        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2138        .map(|collaborator| collaborator.to_proto())
2139        .collect::<Vec<_>>();
2140    let project_id = project.id;
2141    let guest_user_id = session.user_id();
2142
2143    let worktrees = project
2144        .worktrees
2145        .iter()
2146        .map(|(id, worktree)| proto::WorktreeMetadata {
2147            id: *id,
2148            root_name: worktree.root_name.clone(),
2149            visible: worktree.visible,
2150            abs_path: worktree.abs_path.clone(),
2151        })
2152        .collect::<Vec<_>>();
2153
2154    let add_project_collaborator = proto::AddProjectCollaborator {
2155        project_id: project_id.to_proto(),
2156        collaborator: Some(proto::Collaborator {
2157            peer_id: Some(session.connection_id.into()),
2158            replica_id: replica_id.0 as u32,
2159            user_id: guest_user_id.to_proto(),
2160        }),
2161    };
2162
2163    for collaborator in &collaborators {
2164        session
2165            .peer
2166            .send(
2167                collaborator.peer_id.unwrap().into(),
2168                add_project_collaborator.clone(),
2169            )
2170            .trace_err();
2171    }
2172
2173    // First, we send the metadata associated with each worktree.
2174    response.send(proto::JoinProjectResponse {
2175        project_id: project.id.0 as u64,
2176        worktrees: worktrees.clone(),
2177        replica_id: replica_id.0 as u32,
2178        collaborators: collaborators.clone(),
2179        language_servers: project.language_servers.clone(),
2180        role: project.role.into(),
2181        dev_server_project_id: project
2182            .dev_server_project_id
2183            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2184    })?;
2185
2186    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2187        #[cfg(any(test, feature = "test-support"))]
2188        const MAX_CHUNK_SIZE: usize = 2;
2189        #[cfg(not(any(test, feature = "test-support")))]
2190        const MAX_CHUNK_SIZE: usize = 256;
2191
2192        // Stream this worktree's entries.
2193        let message = proto::UpdateWorktree {
2194            project_id: project_id.to_proto(),
2195            worktree_id,
2196            abs_path: worktree.abs_path.clone(),
2197            root_name: worktree.root_name,
2198            updated_entries: worktree.entries,
2199            removed_entries: Default::default(),
2200            scan_id: worktree.scan_id,
2201            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2202            updated_repositories: worktree.repository_entries.into_values().collect(),
2203            removed_repositories: Default::default(),
2204        };
2205        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2206            session.peer.send(session.connection_id, update.clone())?;
2207        }
2208
2209        // Stream this worktree's diagnostics.
2210        for summary in worktree.diagnostic_summaries {
2211            session.peer.send(
2212                session.connection_id,
2213                proto::UpdateDiagnosticSummary {
2214                    project_id: project_id.to_proto(),
2215                    worktree_id: worktree.id,
2216                    summary: Some(summary),
2217                },
2218            )?;
2219        }
2220
2221        for settings_file in worktree.settings_files {
2222            session.peer.send(
2223                session.connection_id,
2224                proto::UpdateWorktreeSettings {
2225                    project_id: project_id.to_proto(),
2226                    worktree_id: worktree.id,
2227                    path: settings_file.path,
2228                    content: Some(settings_file.content),
2229                },
2230            )?;
2231        }
2232    }
2233
2234    for language_server in &project.language_servers {
2235        session.peer.send(
2236            session.connection_id,
2237            proto::UpdateLanguageServer {
2238                project_id: project_id.to_proto(),
2239                language_server_id: language_server.id,
2240                variant: Some(
2241                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2242                        proto::LspDiskBasedDiagnosticsUpdated {},
2243                    ),
2244                ),
2245            },
2246        )?;
2247    }
2248
2249    Ok(())
2250}
2251
2252/// Leave someone elses shared project.
2253async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2254    let sender_id = session.connection_id;
2255    let project_id = ProjectId::from_proto(request.project_id);
2256    let db = session.db().await;
2257    if db.is_hosted_project(project_id).await? {
2258        let project = db.leave_hosted_project(project_id, sender_id).await?;
2259        project_left(&project, &session);
2260        return Ok(());
2261    }
2262
2263    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2264    tracing::info!(
2265        %project_id,
2266        "leave project"
2267    );
2268
2269    project_left(&project, &session);
2270    if let Some(room) = room {
2271        room_updated(&room, &session.peer);
2272    }
2273
2274    Ok(())
2275}
2276
2277async fn join_hosted_project(
2278    request: proto::JoinHostedProject,
2279    response: Response<proto::JoinHostedProject>,
2280    session: UserSession,
2281) -> Result<()> {
2282    let (mut project, replica_id) = session
2283        .db()
2284        .await
2285        .join_hosted_project(
2286            ProjectId(request.project_id as i32),
2287            session.user_id(),
2288            session.connection_id,
2289        )
2290        .await?;
2291
2292    join_project_internal(response, session, &mut project, &replica_id)
2293}
2294
2295async fn list_remote_directory(
2296    request: proto::ListRemoteDirectory,
2297    response: Response<proto::ListRemoteDirectory>,
2298    session: UserSession,
2299) -> Result<()> {
2300    let dev_server_id = DevServerId(request.dev_server_id as i32);
2301    let dev_server_connection_id = session
2302        .connection_pool()
2303        .await
2304        .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2305
2306    session
2307        .db()
2308        .await
2309        .get_dev_server_for_user(dev_server_id, session.user_id())
2310        .await?;
2311
2312    response.send(
2313        session
2314            .peer
2315            .forward_request(session.connection_id, dev_server_connection_id, request)
2316            .await?,
2317    )?;
2318    Ok(())
2319}
2320
2321async fn update_dev_server_project(
2322    request: proto::UpdateDevServerProject,
2323    response: Response<proto::UpdateDevServerProject>,
2324    session: UserSession,
2325) -> Result<()> {
2326    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2327
2328    let (dev_server_project, update) = session
2329        .db()
2330        .await
2331        .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2332        .await?;
2333
2334    let projects = session
2335        .db()
2336        .await
2337        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2338        .await?;
2339
2340    let dev_server_connection_id = session
2341        .connection_pool()
2342        .await
2343        .dev_server_connection_id_supporting(
2344            dev_server_project.dev_server_id,
2345            ZedVersion::with_list_directory(),
2346        )?;
2347
2348    session.peer.send(
2349        dev_server_connection_id,
2350        proto::DevServerInstructions { projects },
2351    )?;
2352
2353    send_dev_server_projects_update(session.user_id(), update, &session).await;
2354
2355    response.send(proto::Ack {})
2356}
2357
2358async fn create_dev_server_project(
2359    request: proto::CreateDevServerProject,
2360    response: Response<proto::CreateDevServerProject>,
2361    session: UserSession,
2362) -> Result<()> {
2363    let dev_server_id = DevServerId(request.dev_server_id as i32);
2364    let dev_server_connection_id = session
2365        .connection_pool()
2366        .await
2367        .dev_server_connection_id(dev_server_id);
2368    let Some(dev_server_connection_id) = dev_server_connection_id else {
2369        Err(ErrorCode::DevServerOffline
2370            .message("Cannot create a remote project when the dev server is offline".to_string())
2371            .anyhow())?
2372    };
2373
2374    let path = request.path.clone();
2375    //Check that the path exists on the dev server
2376    session
2377        .peer
2378        .forward_request(
2379            session.connection_id,
2380            dev_server_connection_id,
2381            proto::ValidateDevServerProjectRequest { path: path.clone() },
2382        )
2383        .await?;
2384
2385    let (dev_server_project, update) = session
2386        .db()
2387        .await
2388        .create_dev_server_project(
2389            DevServerId(request.dev_server_id as i32),
2390            &request.path,
2391            session.user_id(),
2392        )
2393        .await?;
2394
2395    let projects = session
2396        .db()
2397        .await
2398        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2399        .await?;
2400
2401    session.peer.send(
2402        dev_server_connection_id,
2403        proto::DevServerInstructions { projects },
2404    )?;
2405
2406    send_dev_server_projects_update(session.user_id(), update, &session).await;
2407
2408    response.send(proto::CreateDevServerProjectResponse {
2409        dev_server_project: Some(dev_server_project.to_proto(None)),
2410    })?;
2411    Ok(())
2412}
2413
2414async fn create_dev_server(
2415    request: proto::CreateDevServer,
2416    response: Response<proto::CreateDevServer>,
2417    session: UserSession,
2418) -> Result<()> {
2419    let access_token = auth::random_token();
2420    let hashed_access_token = auth::hash_access_token(&access_token);
2421
2422    if request.name.is_empty() {
2423        return Err(proto::ErrorCode::Forbidden
2424            .message("Dev server name cannot be empty".to_string())
2425            .anyhow())?;
2426    }
2427
2428    let (dev_server, status) = session
2429        .db()
2430        .await
2431        .create_dev_server(
2432            &request.name,
2433            request.ssh_connection_string.as_deref(),
2434            &hashed_access_token,
2435            session.user_id(),
2436        )
2437        .await?;
2438
2439    send_dev_server_projects_update(session.user_id(), status, &session).await;
2440
2441    response.send(proto::CreateDevServerResponse {
2442        dev_server_id: dev_server.id.0 as u64,
2443        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2444        name: request.name,
2445    })?;
2446    Ok(())
2447}
2448
2449async fn regenerate_dev_server_token(
2450    request: proto::RegenerateDevServerToken,
2451    response: Response<proto::RegenerateDevServerToken>,
2452    session: UserSession,
2453) -> Result<()> {
2454    let dev_server_id = DevServerId(request.dev_server_id as i32);
2455    let access_token = auth::random_token();
2456    let hashed_access_token = auth::hash_access_token(&access_token);
2457
2458    let connection_id = session
2459        .connection_pool()
2460        .await
2461        .dev_server_connection_id(dev_server_id);
2462    if let Some(connection_id) = connection_id {
2463        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2464        session.peer.send(
2465            connection_id,
2466            proto::ShutdownDevServer {
2467                reason: Some("dev server token was regenerated".to_string()),
2468            },
2469        )?;
2470        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2471    }
2472
2473    let status = session
2474        .db()
2475        .await
2476        .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2477        .await?;
2478
2479    send_dev_server_projects_update(session.user_id(), status, &session).await;
2480
2481    response.send(proto::RegenerateDevServerTokenResponse {
2482        dev_server_id: dev_server_id.to_proto(),
2483        access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2484    })?;
2485    Ok(())
2486}
2487
2488async fn rename_dev_server(
2489    request: proto::RenameDevServer,
2490    response: Response<proto::RenameDevServer>,
2491    session: UserSession,
2492) -> Result<()> {
2493    if request.name.trim().is_empty() {
2494        return Err(proto::ErrorCode::Forbidden
2495            .message("Dev server name cannot be empty".to_string())
2496            .anyhow())?;
2497    }
2498
2499    let dev_server_id = DevServerId(request.dev_server_id as i32);
2500    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2501    if dev_server.user_id != session.user_id() {
2502        return Err(anyhow!(ErrorCode::Forbidden))?;
2503    }
2504
2505    let status = session
2506        .db()
2507        .await
2508        .rename_dev_server(
2509            dev_server_id,
2510            &request.name,
2511            request.ssh_connection_string.as_deref(),
2512            session.user_id(),
2513        )
2514        .await?;
2515
2516    send_dev_server_projects_update(session.user_id(), status, &session).await;
2517
2518    response.send(proto::Ack {})?;
2519    Ok(())
2520}
2521
2522async fn delete_dev_server(
2523    request: proto::DeleteDevServer,
2524    response: Response<proto::DeleteDevServer>,
2525    session: UserSession,
2526) -> Result<()> {
2527    let dev_server_id = DevServerId(request.dev_server_id as i32);
2528    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2529    if dev_server.user_id != session.user_id() {
2530        return Err(anyhow!(ErrorCode::Forbidden))?;
2531    }
2532
2533    let connection_id = session
2534        .connection_pool()
2535        .await
2536        .dev_server_connection_id(dev_server_id);
2537    if let Some(connection_id) = connection_id {
2538        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2539        session.peer.send(
2540            connection_id,
2541            proto::ShutdownDevServer {
2542                reason: Some("dev server was deleted".to_string()),
2543            },
2544        )?;
2545        let _ = remove_dev_server_connection(dev_server_id, &session).await;
2546    }
2547
2548    let status = session
2549        .db()
2550        .await
2551        .delete_dev_server(dev_server_id, session.user_id())
2552        .await?;
2553
2554    send_dev_server_projects_update(session.user_id(), status, &session).await;
2555
2556    response.send(proto::Ack {})?;
2557    Ok(())
2558}
2559
2560async fn delete_dev_server_project(
2561    request: proto::DeleteDevServerProject,
2562    response: Response<proto::DeleteDevServerProject>,
2563    session: UserSession,
2564) -> Result<()> {
2565    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2566    let dev_server_project = session
2567        .db()
2568        .await
2569        .get_dev_server_project(dev_server_project_id)
2570        .await?;
2571
2572    let dev_server = session
2573        .db()
2574        .await
2575        .get_dev_server(dev_server_project.dev_server_id)
2576        .await?;
2577    if dev_server.user_id != session.user_id() {
2578        return Err(anyhow!(ErrorCode::Forbidden))?;
2579    }
2580
2581    let dev_server_connection_id = session
2582        .connection_pool()
2583        .await
2584        .dev_server_connection_id(dev_server.id);
2585
2586    if let Some(dev_server_connection_id) = dev_server_connection_id {
2587        let project = session
2588            .db()
2589            .await
2590            .find_dev_server_project(dev_server_project_id)
2591            .await;
2592        if let Ok(project) = project {
2593            unshare_project_internal(
2594                project.id,
2595                dev_server_connection_id,
2596                Some(session.user_id()),
2597                &session,
2598            )
2599            .await?;
2600        }
2601    }
2602
2603    let (projects, status) = session
2604        .db()
2605        .await
2606        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2607        .await?;
2608
2609    if let Some(dev_server_connection_id) = dev_server_connection_id {
2610        session.peer.send(
2611            dev_server_connection_id,
2612            proto::DevServerInstructions { projects },
2613        )?;
2614    }
2615
2616    send_dev_server_projects_update(session.user_id(), status, &session).await;
2617
2618    response.send(proto::Ack {})?;
2619    Ok(())
2620}
2621
2622async fn rejoin_dev_server_projects(
2623    request: proto::RejoinRemoteProjects,
2624    response: Response<proto::RejoinRemoteProjects>,
2625    session: UserSession,
2626) -> Result<()> {
2627    let mut rejoined_projects = {
2628        let db = session.db().await;
2629        db.rejoin_dev_server_projects(
2630            &request.rejoined_projects,
2631            session.user_id(),
2632            session.0.connection_id,
2633        )
2634        .await?
2635    };
2636    response.send(proto::RejoinRemoteProjectsResponse {
2637        rejoined_projects: rejoined_projects
2638            .iter()
2639            .map(|project| project.to_proto())
2640            .collect(),
2641    })?;
2642    notify_rejoined_projects(&mut rejoined_projects, &session)
2643}
2644
2645async fn reconnect_dev_server(
2646    request: proto::ReconnectDevServer,
2647    response: Response<proto::ReconnectDevServer>,
2648    session: DevServerSession,
2649) -> Result<()> {
2650    let reshared_projects = {
2651        let db = session.db().await;
2652        db.reshare_dev_server_projects(
2653            &request.reshared_projects,
2654            session.dev_server_id(),
2655            session.0.connection_id,
2656        )
2657        .await?
2658    };
2659
2660    for project in &reshared_projects {
2661        for collaborator in &project.collaborators {
2662            session
2663                .peer
2664                .send(
2665                    collaborator.connection_id,
2666                    proto::UpdateProjectCollaborator {
2667                        project_id: project.id.to_proto(),
2668                        old_peer_id: Some(project.old_connection_id.into()),
2669                        new_peer_id: Some(session.connection_id.into()),
2670                    },
2671                )
2672                .trace_err();
2673        }
2674
2675        broadcast(
2676            Some(session.connection_id),
2677            project
2678                .collaborators
2679                .iter()
2680                .map(|collaborator| collaborator.connection_id),
2681            |connection_id| {
2682                session.peer.forward_send(
2683                    session.connection_id,
2684                    connection_id,
2685                    proto::UpdateProject {
2686                        project_id: project.id.to_proto(),
2687                        worktrees: project.worktrees.clone(),
2688                    },
2689                )
2690            },
2691        );
2692    }
2693
2694    response.send(proto::ReconnectDevServerResponse {
2695        reshared_projects: reshared_projects
2696            .iter()
2697            .map(|project| proto::ResharedProject {
2698                id: project.id.to_proto(),
2699                collaborators: project
2700                    .collaborators
2701                    .iter()
2702                    .map(|collaborator| collaborator.to_proto())
2703                    .collect(),
2704            })
2705            .collect(),
2706    })?;
2707
2708    Ok(())
2709}
2710
2711async fn shutdown_dev_server(
2712    _: proto::ShutdownDevServer,
2713    response: Response<proto::ShutdownDevServer>,
2714    session: DevServerSession,
2715) -> Result<()> {
2716    response.send(proto::Ack {})?;
2717    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2718    remove_dev_server_connection(session.dev_server_id(), &session).await
2719}
2720
2721async fn shutdown_dev_server_internal(
2722    dev_server_id: DevServerId,
2723    connection_id: ConnectionId,
2724    session: &Session,
2725) -> Result<()> {
2726    let (dev_server_projects, dev_server) = {
2727        let db = session.db().await;
2728        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2729        let dev_server = db.get_dev_server(dev_server_id).await?;
2730        (dev_server_projects, dev_server)
2731    };
2732
2733    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2734        unshare_project_internal(
2735            ProjectId::from_proto(project_id),
2736            connection_id,
2737            None,
2738            session,
2739        )
2740        .await?;
2741    }
2742
2743    session
2744        .connection_pool()
2745        .await
2746        .set_dev_server_offline(dev_server_id);
2747
2748    let status = session
2749        .db()
2750        .await
2751        .dev_server_projects_update(dev_server.user_id)
2752        .await?;
2753    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2754
2755    Ok(())
2756}
2757
2758async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2759    let dev_server_connection = session
2760        .connection_pool()
2761        .await
2762        .dev_server_connection_id(dev_server_id);
2763
2764    if let Some(dev_server_connection) = dev_server_connection {
2765        session
2766            .connection_pool()
2767            .await
2768            .remove_connection(dev_server_connection)?;
2769    }
2770    Ok(())
2771}
2772
2773/// Updates other participants with changes to the project
2774async fn update_project(
2775    request: proto::UpdateProject,
2776    response: Response<proto::UpdateProject>,
2777    session: Session,
2778) -> Result<()> {
2779    let project_id = ProjectId::from_proto(request.project_id);
2780    let (room, guest_connection_ids) = &*session
2781        .db()
2782        .await
2783        .update_project(project_id, session.connection_id, &request.worktrees)
2784        .await?;
2785    broadcast(
2786        Some(session.connection_id),
2787        guest_connection_ids.iter().copied(),
2788        |connection_id| {
2789            session
2790                .peer
2791                .forward_send(session.connection_id, connection_id, request.clone())
2792        },
2793    );
2794    if let Some(room) = room {
2795        room_updated(&room, &session.peer);
2796    }
2797    response.send(proto::Ack {})?;
2798
2799    Ok(())
2800}
2801
2802/// Updates other participants with changes to the worktree
2803async fn update_worktree(
2804    request: proto::UpdateWorktree,
2805    response: Response<proto::UpdateWorktree>,
2806    session: Session,
2807) -> Result<()> {
2808    let guest_connection_ids = session
2809        .db()
2810        .await
2811        .update_worktree(&request, session.connection_id)
2812        .await?;
2813
2814    broadcast(
2815        Some(session.connection_id),
2816        guest_connection_ids.iter().copied(),
2817        |connection_id| {
2818            session
2819                .peer
2820                .forward_send(session.connection_id, connection_id, request.clone())
2821        },
2822    );
2823    response.send(proto::Ack {})?;
2824    Ok(())
2825}
2826
2827/// Updates other participants with changes to the diagnostics
2828async fn update_diagnostic_summary(
2829    message: proto::UpdateDiagnosticSummary,
2830    session: Session,
2831) -> Result<()> {
2832    let guest_connection_ids = session
2833        .db()
2834        .await
2835        .update_diagnostic_summary(&message, session.connection_id)
2836        .await?;
2837
2838    broadcast(
2839        Some(session.connection_id),
2840        guest_connection_ids.iter().copied(),
2841        |connection_id| {
2842            session
2843                .peer
2844                .forward_send(session.connection_id, connection_id, message.clone())
2845        },
2846    );
2847
2848    Ok(())
2849}
2850
2851/// Updates other participants with changes to the worktree settings
2852async fn update_worktree_settings(
2853    message: proto::UpdateWorktreeSettings,
2854    session: Session,
2855) -> Result<()> {
2856    let guest_connection_ids = session
2857        .db()
2858        .await
2859        .update_worktree_settings(&message, session.connection_id)
2860        .await?;
2861
2862    broadcast(
2863        Some(session.connection_id),
2864        guest_connection_ids.iter().copied(),
2865        |connection_id| {
2866            session
2867                .peer
2868                .forward_send(session.connection_id, connection_id, message.clone())
2869        },
2870    );
2871
2872    Ok(())
2873}
2874
2875/// Notify other participants that a  language server has started.
2876async fn start_language_server(
2877    request: proto::StartLanguageServer,
2878    session: Session,
2879) -> Result<()> {
2880    let guest_connection_ids = session
2881        .db()
2882        .await
2883        .start_language_server(&request, session.connection_id)
2884        .await?;
2885
2886    broadcast(
2887        Some(session.connection_id),
2888        guest_connection_ids.iter().copied(),
2889        |connection_id| {
2890            session
2891                .peer
2892                .forward_send(session.connection_id, connection_id, request.clone())
2893        },
2894    );
2895    Ok(())
2896}
2897
2898/// Notify other participants that a language server has changed.
2899async fn update_language_server(
2900    request: proto::UpdateLanguageServer,
2901    session: Session,
2902) -> Result<()> {
2903    let project_id = ProjectId::from_proto(request.project_id);
2904    let project_connection_ids = session
2905        .db()
2906        .await
2907        .project_connection_ids(project_id, session.connection_id, true)
2908        .await?;
2909    broadcast(
2910        Some(session.connection_id),
2911        project_connection_ids.iter().copied(),
2912        |connection_id| {
2913            session
2914                .peer
2915                .forward_send(session.connection_id, connection_id, request.clone())
2916        },
2917    );
2918    Ok(())
2919}
2920
2921/// forward a project request to the host. These requests should be read only
2922/// as guests are allowed to send them.
2923async fn forward_read_only_project_request<T>(
2924    request: T,
2925    response: Response<T>,
2926    session: UserSession,
2927) -> Result<()>
2928where
2929    T: EntityMessage + RequestMessage,
2930{
2931    let project_id = ProjectId::from_proto(request.remote_entity_id());
2932    let host_connection_id = session
2933        .db()
2934        .await
2935        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2936        .await?;
2937    let payload = session
2938        .peer
2939        .forward_request(session.connection_id, host_connection_id, request)
2940        .await?;
2941    response.send(payload)?;
2942    Ok(())
2943}
2944
2945async fn forward_find_search_candidates_request(
2946    request: proto::FindSearchCandidates,
2947    response: Response<proto::FindSearchCandidates>,
2948    session: UserSession,
2949) -> Result<()> {
2950    let project_id = ProjectId::from_proto(request.remote_entity_id());
2951    let host_connection_id = session
2952        .db()
2953        .await
2954        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2955        .await?;
2956
2957    let host_version = session
2958        .connection_pool()
2959        .await
2960        .connection(host_connection_id)
2961        .map(|c| c.zed_version);
2962
2963    if host_version.is_some_and(|host_version| host_version < ZedVersion::with_search_candidates())
2964    {
2965        let query = request.query.ok_or_else(|| anyhow!("missing query"))?;
2966        let search = proto::SearchProject {
2967            project_id: project_id.to_proto(),
2968            query: query.query,
2969            regex: query.regex,
2970            whole_word: query.whole_word,
2971            case_sensitive: query.case_sensitive,
2972            files_to_include: query.files_to_include,
2973            files_to_exclude: query.files_to_exclude,
2974            include_ignored: query.include_ignored,
2975        };
2976
2977        let payload = session
2978            .peer
2979            .forward_request(session.connection_id, host_connection_id, search)
2980            .await?;
2981        return response.send(proto::FindSearchCandidatesResponse {
2982            buffer_ids: payload
2983                .locations
2984                .into_iter()
2985                .map(|loc| loc.buffer_id)
2986                .collect(),
2987        });
2988    }
2989
2990    let payload = session
2991        .peer
2992        .forward_request(session.connection_id, host_connection_id, request)
2993        .await?;
2994    response.send(payload)?;
2995    Ok(())
2996}
2997
2998/// forward a project request to the dev server. Only allowed
2999/// if it's your dev server.
3000async fn forward_project_request_for_owner<T>(
3001    request: T,
3002    response: Response<T>,
3003    session: UserSession,
3004) -> Result<()>
3005where
3006    T: EntityMessage + RequestMessage,
3007{
3008    let project_id = ProjectId::from_proto(request.remote_entity_id());
3009
3010    let host_connection_id = session
3011        .db()
3012        .await
3013        .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3014        .await?;
3015    let payload = session
3016        .peer
3017        .forward_request(session.connection_id, host_connection_id, request)
3018        .await?;
3019    response.send(payload)?;
3020    Ok(())
3021}
3022
3023/// forward a project request to the host. These requests are disallowed
3024/// for guests.
3025async fn forward_mutating_project_request<T>(
3026    request: T,
3027    response: Response<T>,
3028    session: UserSession,
3029) -> Result<()>
3030where
3031    T: EntityMessage + RequestMessage,
3032{
3033    let project_id = ProjectId::from_proto(request.remote_entity_id());
3034
3035    let host_connection_id = session
3036        .db()
3037        .await
3038        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3039        .await?;
3040    let payload = session
3041        .peer
3042        .forward_request(session.connection_id, host_connection_id, request)
3043        .await?;
3044    response.send(payload)?;
3045    Ok(())
3046}
3047
3048/// Notify other participants that a new buffer has been created
3049async fn create_buffer_for_peer(
3050    request: proto::CreateBufferForPeer,
3051    session: Session,
3052) -> Result<()> {
3053    session
3054        .db()
3055        .await
3056        .check_user_is_project_host(
3057            ProjectId::from_proto(request.project_id),
3058            session.connection_id,
3059        )
3060        .await?;
3061    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3062    session
3063        .peer
3064        .forward_send(session.connection_id, peer_id.into(), request)?;
3065    Ok(())
3066}
3067
3068/// Notify other participants that a buffer has been updated. This is
3069/// allowed for guests as long as the update is limited to selections.
3070async fn update_buffer(
3071    request: proto::UpdateBuffer,
3072    response: Response<proto::UpdateBuffer>,
3073    session: Session,
3074) -> Result<()> {
3075    let project_id = ProjectId::from_proto(request.project_id);
3076    let mut capability = Capability::ReadOnly;
3077
3078    for op in request.operations.iter() {
3079        match op.variant {
3080            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3081            Some(_) => capability = Capability::ReadWrite,
3082        }
3083    }
3084
3085    let host = {
3086        let guard = session
3087            .db()
3088            .await
3089            .connections_for_buffer_update(
3090                project_id,
3091                session.principal_id(),
3092                session.connection_id,
3093                capability,
3094            )
3095            .await?;
3096
3097        let (host, guests) = &*guard;
3098
3099        broadcast(
3100            Some(session.connection_id),
3101            guests.clone(),
3102            |connection_id| {
3103                session
3104                    .peer
3105                    .forward_send(session.connection_id, connection_id, request.clone())
3106            },
3107        );
3108
3109        *host
3110    };
3111
3112    if host != session.connection_id {
3113        session
3114            .peer
3115            .forward_request(session.connection_id, host, request.clone())
3116            .await?;
3117    }
3118
3119    response.send(proto::Ack {})?;
3120    Ok(())
3121}
3122
3123async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3124    let project_id = ProjectId::from_proto(message.project_id);
3125
3126    let operation = message.operation.as_ref().context("invalid operation")?;
3127    let capability = match operation.variant.as_ref() {
3128        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3129            if let Some(buffer_op) = buffer_op.operation.as_ref() {
3130                match buffer_op.variant {
3131                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3132                        Capability::ReadOnly
3133                    }
3134                    _ => Capability::ReadWrite,
3135                }
3136            } else {
3137                Capability::ReadWrite
3138            }
3139        }
3140        Some(_) => Capability::ReadWrite,
3141        None => Capability::ReadOnly,
3142    };
3143
3144    let guard = session
3145        .db()
3146        .await
3147        .connections_for_buffer_update(
3148            project_id,
3149            session.principal_id(),
3150            session.connection_id,
3151            capability,
3152        )
3153        .await?;
3154
3155    let (host, guests) = &*guard;
3156
3157    broadcast(
3158        Some(session.connection_id),
3159        guests.iter().chain([host]).copied(),
3160        |connection_id| {
3161            session
3162                .peer
3163                .forward_send(session.connection_id, connection_id, message.clone())
3164        },
3165    );
3166
3167    Ok(())
3168}
3169
3170/// Notify other participants that a project has been updated.
3171async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3172    request: T,
3173    session: Session,
3174) -> Result<()> {
3175    let project_id = ProjectId::from_proto(request.remote_entity_id());
3176    let project_connection_ids = session
3177        .db()
3178        .await
3179        .project_connection_ids(project_id, session.connection_id, false)
3180        .await?;
3181
3182    broadcast(
3183        Some(session.connection_id),
3184        project_connection_ids.iter().copied(),
3185        |connection_id| {
3186            session
3187                .peer
3188                .forward_send(session.connection_id, connection_id, request.clone())
3189        },
3190    );
3191    Ok(())
3192}
3193
3194/// Start following another user in a call.
3195async fn follow(
3196    request: proto::Follow,
3197    response: Response<proto::Follow>,
3198    session: UserSession,
3199) -> Result<()> {
3200    let room_id = RoomId::from_proto(request.room_id);
3201    let project_id = request.project_id.map(ProjectId::from_proto);
3202    let leader_id = request
3203        .leader_id
3204        .ok_or_else(|| anyhow!("invalid leader id"))?
3205        .into();
3206    let follower_id = session.connection_id;
3207
3208    session
3209        .db()
3210        .await
3211        .check_room_participants(room_id, leader_id, session.connection_id)
3212        .await?;
3213
3214    let response_payload = session
3215        .peer
3216        .forward_request(session.connection_id, leader_id, request)
3217        .await?;
3218    response.send(response_payload)?;
3219
3220    if let Some(project_id) = project_id {
3221        let room = session
3222            .db()
3223            .await
3224            .follow(room_id, project_id, leader_id, follower_id)
3225            .await?;
3226        room_updated(&room, &session.peer);
3227    }
3228
3229    Ok(())
3230}
3231
3232/// Stop following another user in a call.
3233async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3234    let room_id = RoomId::from_proto(request.room_id);
3235    let project_id = request.project_id.map(ProjectId::from_proto);
3236    let leader_id = request
3237        .leader_id
3238        .ok_or_else(|| anyhow!("invalid leader id"))?
3239        .into();
3240    let follower_id = session.connection_id;
3241
3242    session
3243        .db()
3244        .await
3245        .check_room_participants(room_id, leader_id, session.connection_id)
3246        .await?;
3247
3248    session
3249        .peer
3250        .forward_send(session.connection_id, leader_id, request)?;
3251
3252    if let Some(project_id) = project_id {
3253        let room = session
3254            .db()
3255            .await
3256            .unfollow(room_id, project_id, leader_id, follower_id)
3257            .await?;
3258        room_updated(&room, &session.peer);
3259    }
3260
3261    Ok(())
3262}
3263
3264/// Notify everyone following you of your current location.
3265async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3266    let room_id = RoomId::from_proto(request.room_id);
3267    let database = session.db.lock().await;
3268
3269    let connection_ids = if let Some(project_id) = request.project_id {
3270        let project_id = ProjectId::from_proto(project_id);
3271        database
3272            .project_connection_ids(project_id, session.connection_id, true)
3273            .await?
3274    } else {
3275        database
3276            .room_connection_ids(room_id, session.connection_id)
3277            .await?
3278    };
3279
3280    // For now, don't send view update messages back to that view's current leader.
3281    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3282        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3283        _ => None,
3284    });
3285
3286    for connection_id in connection_ids.iter().cloned() {
3287        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3288            session
3289                .peer
3290                .forward_send(session.connection_id, connection_id, request.clone())?;
3291        }
3292    }
3293    Ok(())
3294}
3295
3296/// Get public data about users.
3297async fn get_users(
3298    request: proto::GetUsers,
3299    response: Response<proto::GetUsers>,
3300    session: Session,
3301) -> Result<()> {
3302    let user_ids = request
3303        .user_ids
3304        .into_iter()
3305        .map(UserId::from_proto)
3306        .collect();
3307    let users = session
3308        .db()
3309        .await
3310        .get_users_by_ids(user_ids)
3311        .await?
3312        .into_iter()
3313        .map(|user| proto::User {
3314            id: user.id.to_proto(),
3315            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3316            github_login: user.github_login,
3317        })
3318        .collect();
3319    response.send(proto::UsersResponse { users })?;
3320    Ok(())
3321}
3322
3323/// Search for users (to invite) buy Github login
3324async fn fuzzy_search_users(
3325    request: proto::FuzzySearchUsers,
3326    response: Response<proto::FuzzySearchUsers>,
3327    session: UserSession,
3328) -> Result<()> {
3329    let query = request.query;
3330    let users = match query.len() {
3331        0 => vec![],
3332        1 | 2 => session
3333            .db()
3334            .await
3335            .get_user_by_github_login(&query)
3336            .await?
3337            .into_iter()
3338            .collect(),
3339        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3340    };
3341    let users = users
3342        .into_iter()
3343        .filter(|user| user.id != session.user_id())
3344        .map(|user| proto::User {
3345            id: user.id.to_proto(),
3346            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3347            github_login: user.github_login,
3348        })
3349        .collect();
3350    response.send(proto::UsersResponse { users })?;
3351    Ok(())
3352}
3353
3354/// Send a contact request to another user.
3355async fn request_contact(
3356    request: proto::RequestContact,
3357    response: Response<proto::RequestContact>,
3358    session: UserSession,
3359) -> Result<()> {
3360    let requester_id = session.user_id();
3361    let responder_id = UserId::from_proto(request.responder_id);
3362    if requester_id == responder_id {
3363        return Err(anyhow!("cannot add yourself as a contact"))?;
3364    }
3365
3366    let notifications = session
3367        .db()
3368        .await
3369        .send_contact_request(requester_id, responder_id)
3370        .await?;
3371
3372    // Update outgoing contact requests of requester
3373    let mut update = proto::UpdateContacts::default();
3374    update.outgoing_requests.push(responder_id.to_proto());
3375    for connection_id in session
3376        .connection_pool()
3377        .await
3378        .user_connection_ids(requester_id)
3379    {
3380        session.peer.send(connection_id, update.clone())?;
3381    }
3382
3383    // Update incoming contact requests of responder
3384    let mut update = proto::UpdateContacts::default();
3385    update
3386        .incoming_requests
3387        .push(proto::IncomingContactRequest {
3388            requester_id: requester_id.to_proto(),
3389        });
3390    let connection_pool = session.connection_pool().await;
3391    for connection_id in connection_pool.user_connection_ids(responder_id) {
3392        session.peer.send(connection_id, update.clone())?;
3393    }
3394
3395    send_notifications(&connection_pool, &session.peer, notifications);
3396
3397    response.send(proto::Ack {})?;
3398    Ok(())
3399}
3400
3401/// Accept or decline a contact request
3402async fn respond_to_contact_request(
3403    request: proto::RespondToContactRequest,
3404    response: Response<proto::RespondToContactRequest>,
3405    session: UserSession,
3406) -> Result<()> {
3407    let responder_id = session.user_id();
3408    let requester_id = UserId::from_proto(request.requester_id);
3409    let db = session.db().await;
3410    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3411        db.dismiss_contact_notification(responder_id, requester_id)
3412            .await?;
3413    } else {
3414        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3415
3416        let notifications = db
3417            .respond_to_contact_request(responder_id, requester_id, accept)
3418            .await?;
3419        let requester_busy = db.is_user_busy(requester_id).await?;
3420        let responder_busy = db.is_user_busy(responder_id).await?;
3421
3422        let pool = session.connection_pool().await;
3423        // Update responder with new contact
3424        let mut update = proto::UpdateContacts::default();
3425        if accept {
3426            update
3427                .contacts
3428                .push(contact_for_user(requester_id, requester_busy, &pool));
3429        }
3430        update
3431            .remove_incoming_requests
3432            .push(requester_id.to_proto());
3433        for connection_id in pool.user_connection_ids(responder_id) {
3434            session.peer.send(connection_id, update.clone())?;
3435        }
3436
3437        // Update requester with new contact
3438        let mut update = proto::UpdateContacts::default();
3439        if accept {
3440            update
3441                .contacts
3442                .push(contact_for_user(responder_id, responder_busy, &pool));
3443        }
3444        update
3445            .remove_outgoing_requests
3446            .push(responder_id.to_proto());
3447
3448        for connection_id in pool.user_connection_ids(requester_id) {
3449            session.peer.send(connection_id, update.clone())?;
3450        }
3451
3452        send_notifications(&pool, &session.peer, notifications);
3453    }
3454
3455    response.send(proto::Ack {})?;
3456    Ok(())
3457}
3458
3459/// Remove a contact.
3460async fn remove_contact(
3461    request: proto::RemoveContact,
3462    response: Response<proto::RemoveContact>,
3463    session: UserSession,
3464) -> Result<()> {
3465    let requester_id = session.user_id();
3466    let responder_id = UserId::from_proto(request.user_id);
3467    let db = session.db().await;
3468    let (contact_accepted, deleted_notification_id) =
3469        db.remove_contact(requester_id, responder_id).await?;
3470
3471    let pool = session.connection_pool().await;
3472    // Update outgoing contact requests of requester
3473    let mut update = proto::UpdateContacts::default();
3474    if contact_accepted {
3475        update.remove_contacts.push(responder_id.to_proto());
3476    } else {
3477        update
3478            .remove_outgoing_requests
3479            .push(responder_id.to_proto());
3480    }
3481    for connection_id in pool.user_connection_ids(requester_id) {
3482        session.peer.send(connection_id, update.clone())?;
3483    }
3484
3485    // Update incoming contact requests of responder
3486    let mut update = proto::UpdateContacts::default();
3487    if contact_accepted {
3488        update.remove_contacts.push(requester_id.to_proto());
3489    } else {
3490        update
3491            .remove_incoming_requests
3492            .push(requester_id.to_proto());
3493    }
3494    for connection_id in pool.user_connection_ids(responder_id) {
3495        session.peer.send(connection_id, update.clone())?;
3496        if let Some(notification_id) = deleted_notification_id {
3497            session.peer.send(
3498                connection_id,
3499                proto::DeleteNotification {
3500                    notification_id: notification_id.to_proto(),
3501                },
3502            )?;
3503        }
3504    }
3505
3506    response.send(proto::Ack {})?;
3507    Ok(())
3508}
3509
3510fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3511    version.0.minor() < 139
3512}
3513
3514async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3515    let plan = session.current_plan(session.db().await).await?;
3516
3517    session
3518        .peer
3519        .send(
3520            session.connection_id,
3521            proto::UpdateUserPlan { plan: plan.into() },
3522        )
3523        .trace_err();
3524
3525    Ok(())
3526}
3527
3528async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3529    subscribe_user_to_channels(
3530        session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3531        &session,
3532    )
3533    .await?;
3534    Ok(())
3535}
3536
3537async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3538    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3539    let mut pool = session.connection_pool().await;
3540    for membership in &channels_for_user.channel_memberships {
3541        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3542    }
3543    session.peer.send(
3544        session.connection_id,
3545        build_update_user_channels(&channels_for_user),
3546    )?;
3547    session.peer.send(
3548        session.connection_id,
3549        build_channels_update(channels_for_user),
3550    )?;
3551    Ok(())
3552}
3553
3554/// Creates a new channel.
3555async fn create_channel(
3556    request: proto::CreateChannel,
3557    response: Response<proto::CreateChannel>,
3558    session: UserSession,
3559) -> Result<()> {
3560    let db = session.db().await;
3561
3562    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3563    let (channel, membership) = db
3564        .create_channel(&request.name, parent_id, session.user_id())
3565        .await?;
3566
3567    let root_id = channel.root_id();
3568    let channel = Channel::from_model(channel);
3569
3570    response.send(proto::CreateChannelResponse {
3571        channel: Some(channel.to_proto()),
3572        parent_id: request.parent_id,
3573    })?;
3574
3575    let mut connection_pool = session.connection_pool().await;
3576    if let Some(membership) = membership {
3577        connection_pool.subscribe_to_channel(
3578            membership.user_id,
3579            membership.channel_id,
3580            membership.role,
3581        );
3582        let update = proto::UpdateUserChannels {
3583            channel_memberships: vec![proto::ChannelMembership {
3584                channel_id: membership.channel_id.to_proto(),
3585                role: membership.role.into(),
3586            }],
3587            ..Default::default()
3588        };
3589        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3590            session.peer.send(connection_id, update.clone())?;
3591        }
3592    }
3593
3594    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3595        if !role.can_see_channel(channel.visibility) {
3596            continue;
3597        }
3598
3599        let update = proto::UpdateChannels {
3600            channels: vec![channel.to_proto()],
3601            ..Default::default()
3602        };
3603        session.peer.send(connection_id, update.clone())?;
3604    }
3605
3606    Ok(())
3607}
3608
3609/// Delete a channel
3610async fn delete_channel(
3611    request: proto::DeleteChannel,
3612    response: Response<proto::DeleteChannel>,
3613    session: UserSession,
3614) -> Result<()> {
3615    let db = session.db().await;
3616
3617    let channel_id = request.channel_id;
3618    let (root_channel, removed_channels) = db
3619        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3620        .await?;
3621    response.send(proto::Ack {})?;
3622
3623    // Notify members of removed channels
3624    let mut update = proto::UpdateChannels::default();
3625    update
3626        .delete_channels
3627        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3628
3629    let connection_pool = session.connection_pool().await;
3630    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3631        session.peer.send(connection_id, update.clone())?;
3632    }
3633
3634    Ok(())
3635}
3636
3637/// Invite someone to join a channel.
3638async fn invite_channel_member(
3639    request: proto::InviteChannelMember,
3640    response: Response<proto::InviteChannelMember>,
3641    session: UserSession,
3642) -> Result<()> {
3643    let db = session.db().await;
3644    let channel_id = ChannelId::from_proto(request.channel_id);
3645    let invitee_id = UserId::from_proto(request.user_id);
3646    let InviteMemberResult {
3647        channel,
3648        notifications,
3649    } = db
3650        .invite_channel_member(
3651            channel_id,
3652            invitee_id,
3653            session.user_id(),
3654            request.role().into(),
3655        )
3656        .await?;
3657
3658    let update = proto::UpdateChannels {
3659        channel_invitations: vec![channel.to_proto()],
3660        ..Default::default()
3661    };
3662
3663    let connection_pool = session.connection_pool().await;
3664    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3665        session.peer.send(connection_id, update.clone())?;
3666    }
3667
3668    send_notifications(&connection_pool, &session.peer, notifications);
3669
3670    response.send(proto::Ack {})?;
3671    Ok(())
3672}
3673
3674/// remove someone from a channel
3675async fn remove_channel_member(
3676    request: proto::RemoveChannelMember,
3677    response: Response<proto::RemoveChannelMember>,
3678    session: UserSession,
3679) -> Result<()> {
3680    let db = session.db().await;
3681    let channel_id = ChannelId::from_proto(request.channel_id);
3682    let member_id = UserId::from_proto(request.user_id);
3683
3684    let RemoveChannelMemberResult {
3685        membership_update,
3686        notification_id,
3687    } = db
3688        .remove_channel_member(channel_id, member_id, session.user_id())
3689        .await?;
3690
3691    let mut connection_pool = session.connection_pool().await;
3692    notify_membership_updated(
3693        &mut connection_pool,
3694        membership_update,
3695        member_id,
3696        &session.peer,
3697    );
3698    for connection_id in connection_pool.user_connection_ids(member_id) {
3699        if let Some(notification_id) = notification_id {
3700            session
3701                .peer
3702                .send(
3703                    connection_id,
3704                    proto::DeleteNotification {
3705                        notification_id: notification_id.to_proto(),
3706                    },
3707                )
3708                .trace_err();
3709        }
3710    }
3711
3712    response.send(proto::Ack {})?;
3713    Ok(())
3714}
3715
3716/// Toggle the channel between public and private.
3717/// Care is taken to maintain the invariant that public channels only descend from public channels,
3718/// (though members-only channels can appear at any point in the hierarchy).
3719async fn set_channel_visibility(
3720    request: proto::SetChannelVisibility,
3721    response: Response<proto::SetChannelVisibility>,
3722    session: UserSession,
3723) -> Result<()> {
3724    let db = session.db().await;
3725    let channel_id = ChannelId::from_proto(request.channel_id);
3726    let visibility = request.visibility().into();
3727
3728    let channel_model = db
3729        .set_channel_visibility(channel_id, visibility, session.user_id())
3730        .await?;
3731    let root_id = channel_model.root_id();
3732    let channel = Channel::from_model(channel_model);
3733
3734    let mut connection_pool = session.connection_pool().await;
3735    for (user_id, role) in connection_pool
3736        .channel_user_ids(root_id)
3737        .collect::<Vec<_>>()
3738        .into_iter()
3739    {
3740        let update = if role.can_see_channel(channel.visibility) {
3741            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3742            proto::UpdateChannels {
3743                channels: vec![channel.to_proto()],
3744                ..Default::default()
3745            }
3746        } else {
3747            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3748            proto::UpdateChannels {
3749                delete_channels: vec![channel.id.to_proto()],
3750                ..Default::default()
3751            }
3752        };
3753
3754        for connection_id in connection_pool.user_connection_ids(user_id) {
3755            session.peer.send(connection_id, update.clone())?;
3756        }
3757    }
3758
3759    response.send(proto::Ack {})?;
3760    Ok(())
3761}
3762
3763/// Alter the role for a user in the channel.
3764async fn set_channel_member_role(
3765    request: proto::SetChannelMemberRole,
3766    response: Response<proto::SetChannelMemberRole>,
3767    session: UserSession,
3768) -> Result<()> {
3769    let db = session.db().await;
3770    let channel_id = ChannelId::from_proto(request.channel_id);
3771    let member_id = UserId::from_proto(request.user_id);
3772    let result = db
3773        .set_channel_member_role(
3774            channel_id,
3775            session.user_id(),
3776            member_id,
3777            request.role().into(),
3778        )
3779        .await?;
3780
3781    match result {
3782        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3783            let mut connection_pool = session.connection_pool().await;
3784            notify_membership_updated(
3785                &mut connection_pool,
3786                membership_update,
3787                member_id,
3788                &session.peer,
3789            )
3790        }
3791        db::SetMemberRoleResult::InviteUpdated(channel) => {
3792            let update = proto::UpdateChannels {
3793                channel_invitations: vec![channel.to_proto()],
3794                ..Default::default()
3795            };
3796
3797            for connection_id in session
3798                .connection_pool()
3799                .await
3800                .user_connection_ids(member_id)
3801            {
3802                session.peer.send(connection_id, update.clone())?;
3803            }
3804        }
3805    }
3806
3807    response.send(proto::Ack {})?;
3808    Ok(())
3809}
3810
3811/// Change the name of a channel
3812async fn rename_channel(
3813    request: proto::RenameChannel,
3814    response: Response<proto::RenameChannel>,
3815    session: UserSession,
3816) -> Result<()> {
3817    let db = session.db().await;
3818    let channel_id = ChannelId::from_proto(request.channel_id);
3819    let channel_model = db
3820        .rename_channel(channel_id, session.user_id(), &request.name)
3821        .await?;
3822    let root_id = channel_model.root_id();
3823    let channel = Channel::from_model(channel_model);
3824
3825    response.send(proto::RenameChannelResponse {
3826        channel: Some(channel.to_proto()),
3827    })?;
3828
3829    let connection_pool = session.connection_pool().await;
3830    let update = proto::UpdateChannels {
3831        channels: vec![channel.to_proto()],
3832        ..Default::default()
3833    };
3834    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3835        if role.can_see_channel(channel.visibility) {
3836            session.peer.send(connection_id, update.clone())?;
3837        }
3838    }
3839
3840    Ok(())
3841}
3842
3843/// Move a channel to a new parent.
3844async fn move_channel(
3845    request: proto::MoveChannel,
3846    response: Response<proto::MoveChannel>,
3847    session: UserSession,
3848) -> Result<()> {
3849    let channel_id = ChannelId::from_proto(request.channel_id);
3850    let to = ChannelId::from_proto(request.to);
3851
3852    let (root_id, channels) = session
3853        .db()
3854        .await
3855        .move_channel(channel_id, to, session.user_id())
3856        .await?;
3857
3858    let connection_pool = session.connection_pool().await;
3859    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3860        let channels = channels
3861            .iter()
3862            .filter_map(|channel| {
3863                if role.can_see_channel(channel.visibility) {
3864                    Some(channel.to_proto())
3865                } else {
3866                    None
3867                }
3868            })
3869            .collect::<Vec<_>>();
3870        if channels.is_empty() {
3871            continue;
3872        }
3873
3874        let update = proto::UpdateChannels {
3875            channels,
3876            ..Default::default()
3877        };
3878
3879        session.peer.send(connection_id, update.clone())?;
3880    }
3881
3882    response.send(Ack {})?;
3883    Ok(())
3884}
3885
3886/// Get the list of channel members
3887async fn get_channel_members(
3888    request: proto::GetChannelMembers,
3889    response: Response<proto::GetChannelMembers>,
3890    session: UserSession,
3891) -> Result<()> {
3892    let db = session.db().await;
3893    let channel_id = ChannelId::from_proto(request.channel_id);
3894    let limit = if request.limit == 0 {
3895        u16::MAX as u64
3896    } else {
3897        request.limit
3898    };
3899    let (members, users) = db
3900        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3901        .await?;
3902    response.send(proto::GetChannelMembersResponse { members, users })?;
3903    Ok(())
3904}
3905
3906/// Accept or decline a channel invitation.
3907async fn respond_to_channel_invite(
3908    request: proto::RespondToChannelInvite,
3909    response: Response<proto::RespondToChannelInvite>,
3910    session: UserSession,
3911) -> Result<()> {
3912    let db = session.db().await;
3913    let channel_id = ChannelId::from_proto(request.channel_id);
3914    let RespondToChannelInvite {
3915        membership_update,
3916        notifications,
3917    } = db
3918        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3919        .await?;
3920
3921    let mut connection_pool = session.connection_pool().await;
3922    if let Some(membership_update) = membership_update {
3923        notify_membership_updated(
3924            &mut connection_pool,
3925            membership_update,
3926            session.user_id(),
3927            &session.peer,
3928        );
3929    } else {
3930        let update = proto::UpdateChannels {
3931            remove_channel_invitations: vec![channel_id.to_proto()],
3932            ..Default::default()
3933        };
3934
3935        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3936            session.peer.send(connection_id, update.clone())?;
3937        }
3938    };
3939
3940    send_notifications(&connection_pool, &session.peer, notifications);
3941
3942    response.send(proto::Ack {})?;
3943
3944    Ok(())
3945}
3946
3947/// Join the channels' room
3948async fn join_channel(
3949    request: proto::JoinChannel,
3950    response: Response<proto::JoinChannel>,
3951    session: UserSession,
3952) -> Result<()> {
3953    let channel_id = ChannelId::from_proto(request.channel_id);
3954    join_channel_internal(channel_id, Box::new(response), session).await
3955}
3956
3957trait JoinChannelInternalResponse {
3958    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3959}
3960impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3961    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3962        Response::<proto::JoinChannel>::send(self, result)
3963    }
3964}
3965impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3966    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3967        Response::<proto::JoinRoom>::send(self, result)
3968    }
3969}
3970
3971async fn join_channel_internal(
3972    channel_id: ChannelId,
3973    response: Box<impl JoinChannelInternalResponse>,
3974    session: UserSession,
3975) -> Result<()> {
3976    let joined_room = {
3977        let mut db = session.db().await;
3978        // If zed quits without leaving the room, and the user re-opens zed before the
3979        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3980        // room they were in.
3981        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3982            tracing::info!(
3983                stale_connection_id = %connection,
3984                "cleaning up stale connection",
3985            );
3986            drop(db);
3987            leave_room_for_session(&session, connection).await?;
3988            db = session.db().await;
3989        }
3990
3991        let (joined_room, membership_updated, role) = db
3992            .join_channel(channel_id, session.user_id(), session.connection_id)
3993            .await?;
3994
3995        let live_kit_connection_info =
3996            session
3997                .app_state
3998                .live_kit_client
3999                .as_ref()
4000                .and_then(|live_kit| {
4001                    let (can_publish, token) = if role == ChannelRole::Guest {
4002                        (
4003                            false,
4004                            live_kit
4005                                .guest_token(
4006                                    &joined_room.room.live_kit_room,
4007                                    &session.user_id().to_string(),
4008                                )
4009                                .trace_err()?,
4010                        )
4011                    } else {
4012                        (
4013                            true,
4014                            live_kit
4015                                .room_token(
4016                                    &joined_room.room.live_kit_room,
4017                                    &session.user_id().to_string(),
4018                                )
4019                                .trace_err()?,
4020                        )
4021                    };
4022
4023                    Some(LiveKitConnectionInfo {
4024                        server_url: live_kit.url().into(),
4025                        token,
4026                        can_publish,
4027                    })
4028                });
4029
4030        response.send(proto::JoinRoomResponse {
4031            room: Some(joined_room.room.clone()),
4032            channel_id: joined_room
4033                .channel
4034                .as_ref()
4035                .map(|channel| channel.id.to_proto()),
4036            live_kit_connection_info,
4037        })?;
4038
4039        let mut connection_pool = session.connection_pool().await;
4040        if let Some(membership_updated) = membership_updated {
4041            notify_membership_updated(
4042                &mut connection_pool,
4043                membership_updated,
4044                session.user_id(),
4045                &session.peer,
4046            );
4047        }
4048
4049        room_updated(&joined_room.room, &session.peer);
4050
4051        joined_room
4052    };
4053
4054    channel_updated(
4055        &joined_room
4056            .channel
4057            .ok_or_else(|| anyhow!("channel not returned"))?,
4058        &joined_room.room,
4059        &session.peer,
4060        &*session.connection_pool().await,
4061    );
4062
4063    update_user_contacts(session.user_id(), &session).await?;
4064    Ok(())
4065}
4066
4067/// Start editing the channel notes
4068async fn join_channel_buffer(
4069    request: proto::JoinChannelBuffer,
4070    response: Response<proto::JoinChannelBuffer>,
4071    session: UserSession,
4072) -> Result<()> {
4073    let db = session.db().await;
4074    let channel_id = ChannelId::from_proto(request.channel_id);
4075
4076    let open_response = db
4077        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4078        .await?;
4079
4080    let collaborators = open_response.collaborators.clone();
4081    response.send(open_response)?;
4082
4083    let update = UpdateChannelBufferCollaborators {
4084        channel_id: channel_id.to_proto(),
4085        collaborators: collaborators.clone(),
4086    };
4087    channel_buffer_updated(
4088        session.connection_id,
4089        collaborators
4090            .iter()
4091            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4092        &update,
4093        &session.peer,
4094    );
4095
4096    Ok(())
4097}
4098
4099/// Edit the channel notes
4100async fn update_channel_buffer(
4101    request: proto::UpdateChannelBuffer,
4102    session: UserSession,
4103) -> Result<()> {
4104    let db = session.db().await;
4105    let channel_id = ChannelId::from_proto(request.channel_id);
4106
4107    let (collaborators, epoch, version) = db
4108        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4109        .await?;
4110
4111    channel_buffer_updated(
4112        session.connection_id,
4113        collaborators.clone(),
4114        &proto::UpdateChannelBuffer {
4115            channel_id: channel_id.to_proto(),
4116            operations: request.operations,
4117        },
4118        &session.peer,
4119    );
4120
4121    let pool = &*session.connection_pool().await;
4122
4123    let non_collaborators =
4124        pool.channel_connection_ids(channel_id)
4125            .filter_map(|(connection_id, _)| {
4126                if collaborators.contains(&connection_id) {
4127                    None
4128                } else {
4129                    Some(connection_id)
4130                }
4131            });
4132
4133    broadcast(None, non_collaborators, |peer_id| {
4134        session.peer.send(
4135            peer_id,
4136            proto::UpdateChannels {
4137                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4138                    channel_id: channel_id.to_proto(),
4139                    epoch: epoch as u64,
4140                    version: version.clone(),
4141                }],
4142                ..Default::default()
4143            },
4144        )
4145    });
4146
4147    Ok(())
4148}
4149
4150/// Rejoin the channel notes after a connection blip
4151async fn rejoin_channel_buffers(
4152    request: proto::RejoinChannelBuffers,
4153    response: Response<proto::RejoinChannelBuffers>,
4154    session: UserSession,
4155) -> Result<()> {
4156    let db = session.db().await;
4157    let buffers = db
4158        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4159        .await?;
4160
4161    for rejoined_buffer in &buffers {
4162        let collaborators_to_notify = rejoined_buffer
4163            .buffer
4164            .collaborators
4165            .iter()
4166            .filter_map(|c| Some(c.peer_id?.into()));
4167        channel_buffer_updated(
4168            session.connection_id,
4169            collaborators_to_notify,
4170            &proto::UpdateChannelBufferCollaborators {
4171                channel_id: rejoined_buffer.buffer.channel_id,
4172                collaborators: rejoined_buffer.buffer.collaborators.clone(),
4173            },
4174            &session.peer,
4175        );
4176    }
4177
4178    response.send(proto::RejoinChannelBuffersResponse {
4179        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4180    })?;
4181
4182    Ok(())
4183}
4184
4185/// Stop editing the channel notes
4186async fn leave_channel_buffer(
4187    request: proto::LeaveChannelBuffer,
4188    response: Response<proto::LeaveChannelBuffer>,
4189    session: UserSession,
4190) -> Result<()> {
4191    let db = session.db().await;
4192    let channel_id = ChannelId::from_proto(request.channel_id);
4193
4194    let left_buffer = db
4195        .leave_channel_buffer(channel_id, session.connection_id)
4196        .await?;
4197
4198    response.send(Ack {})?;
4199
4200    channel_buffer_updated(
4201        session.connection_id,
4202        left_buffer.connections,
4203        &proto::UpdateChannelBufferCollaborators {
4204            channel_id: channel_id.to_proto(),
4205            collaborators: left_buffer.collaborators,
4206        },
4207        &session.peer,
4208    );
4209
4210    Ok(())
4211}
4212
4213fn channel_buffer_updated<T: EnvelopedMessage>(
4214    sender_id: ConnectionId,
4215    collaborators: impl IntoIterator<Item = ConnectionId>,
4216    message: &T,
4217    peer: &Peer,
4218) {
4219    broadcast(Some(sender_id), collaborators, |peer_id| {
4220        peer.send(peer_id, message.clone())
4221    });
4222}
4223
4224fn send_notifications(
4225    connection_pool: &ConnectionPool,
4226    peer: &Peer,
4227    notifications: db::NotificationBatch,
4228) {
4229    for (user_id, notification) in notifications {
4230        for connection_id in connection_pool.user_connection_ids(user_id) {
4231            if let Err(error) = peer.send(
4232                connection_id,
4233                proto::AddNotification {
4234                    notification: Some(notification.clone()),
4235                },
4236            ) {
4237                tracing::error!(
4238                    "failed to send notification to {:?} {}",
4239                    connection_id,
4240                    error
4241                );
4242            }
4243        }
4244    }
4245}
4246
4247/// Send a message to the channel
4248async fn send_channel_message(
4249    request: proto::SendChannelMessage,
4250    response: Response<proto::SendChannelMessage>,
4251    session: UserSession,
4252) -> Result<()> {
4253    // Validate the message body.
4254    let body = request.body.trim().to_string();
4255    if body.len() > MAX_MESSAGE_LEN {
4256        return Err(anyhow!("message is too long"))?;
4257    }
4258    if body.is_empty() {
4259        return Err(anyhow!("message can't be blank"))?;
4260    }
4261
4262    // TODO: adjust mentions if body is trimmed
4263
4264    let timestamp = OffsetDateTime::now_utc();
4265    let nonce = request
4266        .nonce
4267        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4268
4269    let channel_id = ChannelId::from_proto(request.channel_id);
4270    let CreatedChannelMessage {
4271        message_id,
4272        participant_connection_ids,
4273        notifications,
4274    } = session
4275        .db()
4276        .await
4277        .create_channel_message(
4278            channel_id,
4279            session.user_id(),
4280            &body,
4281            &request.mentions,
4282            timestamp,
4283            nonce.clone().into(),
4284            match request.reply_to_message_id {
4285                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4286                None => None,
4287            },
4288        )
4289        .await?;
4290
4291    let message = proto::ChannelMessage {
4292        sender_id: session.user_id().to_proto(),
4293        id: message_id.to_proto(),
4294        body,
4295        mentions: request.mentions,
4296        timestamp: timestamp.unix_timestamp() as u64,
4297        nonce: Some(nonce),
4298        reply_to_message_id: request.reply_to_message_id,
4299        edited_at: None,
4300    };
4301    broadcast(
4302        Some(session.connection_id),
4303        participant_connection_ids.clone(),
4304        |connection| {
4305            session.peer.send(
4306                connection,
4307                proto::ChannelMessageSent {
4308                    channel_id: channel_id.to_proto(),
4309                    message: Some(message.clone()),
4310                },
4311            )
4312        },
4313    );
4314    response.send(proto::SendChannelMessageResponse {
4315        message: Some(message),
4316    })?;
4317
4318    let pool = &*session.connection_pool().await;
4319    let non_participants =
4320        pool.channel_connection_ids(channel_id)
4321            .filter_map(|(connection_id, _)| {
4322                if participant_connection_ids.contains(&connection_id) {
4323                    None
4324                } else {
4325                    Some(connection_id)
4326                }
4327            });
4328    broadcast(None, non_participants, |peer_id| {
4329        session.peer.send(
4330            peer_id,
4331            proto::UpdateChannels {
4332                latest_channel_message_ids: vec![proto::ChannelMessageId {
4333                    channel_id: channel_id.to_proto(),
4334                    message_id: message_id.to_proto(),
4335                }],
4336                ..Default::default()
4337            },
4338        )
4339    });
4340    send_notifications(pool, &session.peer, notifications);
4341
4342    Ok(())
4343}
4344
4345/// Delete a channel message
4346async fn remove_channel_message(
4347    request: proto::RemoveChannelMessage,
4348    response: Response<proto::RemoveChannelMessage>,
4349    session: UserSession,
4350) -> Result<()> {
4351    let channel_id = ChannelId::from_proto(request.channel_id);
4352    let message_id = MessageId::from_proto(request.message_id);
4353    let (connection_ids, existing_notification_ids) = session
4354        .db()
4355        .await
4356        .remove_channel_message(channel_id, message_id, session.user_id())
4357        .await?;
4358
4359    broadcast(
4360        Some(session.connection_id),
4361        connection_ids,
4362        move |connection| {
4363            session.peer.send(connection, request.clone())?;
4364
4365            for notification_id in &existing_notification_ids {
4366                session.peer.send(
4367                    connection,
4368                    proto::DeleteNotification {
4369                        notification_id: (*notification_id).to_proto(),
4370                    },
4371                )?;
4372            }
4373
4374            Ok(())
4375        },
4376    );
4377    response.send(proto::Ack {})?;
4378    Ok(())
4379}
4380
4381async fn update_channel_message(
4382    request: proto::UpdateChannelMessage,
4383    response: Response<proto::UpdateChannelMessage>,
4384    session: UserSession,
4385) -> Result<()> {
4386    let channel_id = ChannelId::from_proto(request.channel_id);
4387    let message_id = MessageId::from_proto(request.message_id);
4388    let updated_at = OffsetDateTime::now_utc();
4389    let UpdatedChannelMessage {
4390        message_id,
4391        participant_connection_ids,
4392        notifications,
4393        reply_to_message_id,
4394        timestamp,
4395        deleted_mention_notification_ids,
4396        updated_mention_notifications,
4397    } = session
4398        .db()
4399        .await
4400        .update_channel_message(
4401            channel_id,
4402            message_id,
4403            session.user_id(),
4404            request.body.as_str(),
4405            &request.mentions,
4406            updated_at,
4407        )
4408        .await?;
4409
4410    let nonce = request
4411        .nonce
4412        .clone()
4413        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4414
4415    let message = proto::ChannelMessage {
4416        sender_id: session.user_id().to_proto(),
4417        id: message_id.to_proto(),
4418        body: request.body.clone(),
4419        mentions: request.mentions.clone(),
4420        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4421        nonce: Some(nonce),
4422        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4423        edited_at: Some(updated_at.unix_timestamp() as u64),
4424    };
4425
4426    response.send(proto::Ack {})?;
4427
4428    let pool = &*session.connection_pool().await;
4429    broadcast(
4430        Some(session.connection_id),
4431        participant_connection_ids,
4432        |connection| {
4433            session.peer.send(
4434                connection,
4435                proto::ChannelMessageUpdate {
4436                    channel_id: channel_id.to_proto(),
4437                    message: Some(message.clone()),
4438                },
4439            )?;
4440
4441            for notification_id in &deleted_mention_notification_ids {
4442                session.peer.send(
4443                    connection,
4444                    proto::DeleteNotification {
4445                        notification_id: (*notification_id).to_proto(),
4446                    },
4447                )?;
4448            }
4449
4450            for notification in &updated_mention_notifications {
4451                session.peer.send(
4452                    connection,
4453                    proto::UpdateNotification {
4454                        notification: Some(notification.clone()),
4455                    },
4456                )?;
4457            }
4458
4459            Ok(())
4460        },
4461    );
4462
4463    send_notifications(pool, &session.peer, notifications);
4464
4465    Ok(())
4466}
4467
4468/// Mark a channel message as read
4469async fn acknowledge_channel_message(
4470    request: proto::AckChannelMessage,
4471    session: UserSession,
4472) -> Result<()> {
4473    let channel_id = ChannelId::from_proto(request.channel_id);
4474    let message_id = MessageId::from_proto(request.message_id);
4475    let notifications = session
4476        .db()
4477        .await
4478        .observe_channel_message(channel_id, session.user_id(), message_id)
4479        .await?;
4480    send_notifications(
4481        &*session.connection_pool().await,
4482        &session.peer,
4483        notifications,
4484    );
4485    Ok(())
4486}
4487
4488/// Mark a buffer version as synced
4489async fn acknowledge_buffer_version(
4490    request: proto::AckBufferOperation,
4491    session: UserSession,
4492) -> Result<()> {
4493    let buffer_id = BufferId::from_proto(request.buffer_id);
4494    session
4495        .db()
4496        .await
4497        .observe_buffer_version(
4498            buffer_id,
4499            session.user_id(),
4500            request.epoch as i32,
4501            &request.version,
4502        )
4503        .await?;
4504    Ok(())
4505}
4506
4507async fn count_language_model_tokens(
4508    request: proto::CountLanguageModelTokens,
4509    response: Response<proto::CountLanguageModelTokens>,
4510    session: Session,
4511    config: &Config,
4512) -> Result<()> {
4513    let Some(session) = session.for_user() else {
4514        return Err(anyhow!("user not found"))?;
4515    };
4516    authorize_access_to_legacy_llm_endpoints(&session).await?;
4517
4518    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4519        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4520        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4521    };
4522
4523    session
4524        .app_state
4525        .rate_limiter
4526        .check(&*rate_limit, session.user_id())
4527        .await?;
4528
4529    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4530        Some(proto::LanguageModelProvider::Google) => {
4531            let api_key = config
4532                .google_ai_api_key
4533                .as_ref()
4534                .context("no Google AI API key configured on the server")?;
4535            google_ai::count_tokens(
4536                session.http_client.as_ref(),
4537                google_ai::API_URL,
4538                api_key,
4539                serde_json::from_str(&request.request)?,
4540            )
4541            .await?
4542        }
4543        _ => return Err(anyhow!("unsupported provider"))?,
4544    };
4545
4546    response.send(proto::CountLanguageModelTokensResponse {
4547        token_count: result.total_tokens as u32,
4548    })?;
4549
4550    Ok(())
4551}
4552
4553struct ZedProCountLanguageModelTokensRateLimit;
4554
4555impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4556    fn capacity(&self) -> usize {
4557        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4558            .ok()
4559            .and_then(|v| v.parse().ok())
4560            .unwrap_or(600) // Picked arbitrarily
4561    }
4562
4563    fn refill_duration(&self) -> chrono::Duration {
4564        chrono::Duration::hours(1)
4565    }
4566
4567    fn db_name(&self) -> &'static str {
4568        "zed-pro:count-language-model-tokens"
4569    }
4570}
4571
4572struct FreeCountLanguageModelTokensRateLimit;
4573
4574impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4575    fn capacity(&self) -> usize {
4576        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4577            .ok()
4578            .and_then(|v| v.parse().ok())
4579            .unwrap_or(600 / 10) // Picked arbitrarily
4580    }
4581
4582    fn refill_duration(&self) -> chrono::Duration {
4583        chrono::Duration::hours(1)
4584    }
4585
4586    fn db_name(&self) -> &'static str {
4587        "free:count-language-model-tokens"
4588    }
4589}
4590
4591struct ZedProComputeEmbeddingsRateLimit;
4592
4593impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4594    fn capacity(&self) -> usize {
4595        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4596            .ok()
4597            .and_then(|v| v.parse().ok())
4598            .unwrap_or(5000) // Picked arbitrarily
4599    }
4600
4601    fn refill_duration(&self) -> chrono::Duration {
4602        chrono::Duration::hours(1)
4603    }
4604
4605    fn db_name(&self) -> &'static str {
4606        "zed-pro:compute-embeddings"
4607    }
4608}
4609
4610struct FreeComputeEmbeddingsRateLimit;
4611
4612impl RateLimit for FreeComputeEmbeddingsRateLimit {
4613    fn capacity(&self) -> usize {
4614        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4615            .ok()
4616            .and_then(|v| v.parse().ok())
4617            .unwrap_or(5000 / 10) // Picked arbitrarily
4618    }
4619
4620    fn refill_duration(&self) -> chrono::Duration {
4621        chrono::Duration::hours(1)
4622    }
4623
4624    fn db_name(&self) -> &'static str {
4625        "free:compute-embeddings"
4626    }
4627}
4628
4629async fn compute_embeddings(
4630    request: proto::ComputeEmbeddings,
4631    response: Response<proto::ComputeEmbeddings>,
4632    session: UserSession,
4633    api_key: Option<Arc<str>>,
4634) -> Result<()> {
4635    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4636    authorize_access_to_legacy_llm_endpoints(&session).await?;
4637
4638    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4639        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4640        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4641    };
4642
4643    session
4644        .app_state
4645        .rate_limiter
4646        .check(&*rate_limit, session.user_id())
4647        .await?;
4648
4649    let embeddings = match request.model.as_str() {
4650        "openai/text-embedding-3-small" => {
4651            open_ai::embed(
4652                session.http_client.as_ref(),
4653                OPEN_AI_API_URL,
4654                &api_key,
4655                OpenAiEmbeddingModel::TextEmbedding3Small,
4656                request.texts.iter().map(|text| text.as_str()),
4657            )
4658            .await?
4659        }
4660        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4661    };
4662
4663    let embeddings = request
4664        .texts
4665        .iter()
4666        .map(|text| {
4667            let mut hasher = sha2::Sha256::new();
4668            hasher.update(text.as_bytes());
4669            let result = hasher.finalize();
4670            result.to_vec()
4671        })
4672        .zip(
4673            embeddings
4674                .data
4675                .into_iter()
4676                .map(|embedding| embedding.embedding),
4677        )
4678        .collect::<HashMap<_, _>>();
4679
4680    let db = session.db().await;
4681    db.save_embeddings(&request.model, &embeddings)
4682        .await
4683        .context("failed to save embeddings")
4684        .trace_err();
4685
4686    response.send(proto::ComputeEmbeddingsResponse {
4687        embeddings: embeddings
4688            .into_iter()
4689            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4690            .collect(),
4691    })?;
4692    Ok(())
4693}
4694
4695async fn get_cached_embeddings(
4696    request: proto::GetCachedEmbeddings,
4697    response: Response<proto::GetCachedEmbeddings>,
4698    session: UserSession,
4699) -> Result<()> {
4700    authorize_access_to_legacy_llm_endpoints(&session).await?;
4701
4702    let db = session.db().await;
4703    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4704
4705    response.send(proto::GetCachedEmbeddingsResponse {
4706        embeddings: embeddings
4707            .into_iter()
4708            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4709            .collect(),
4710    })?;
4711    Ok(())
4712}
4713
4714/// This is leftover from before the LLM service.
4715///
4716/// The endpoints protected by this check will be moved there eventually.
4717async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4718    if session.is_staff() {
4719        Ok(())
4720    } else {
4721        Err(anyhow!("permission denied"))?
4722    }
4723}
4724
4725/// Get a Supermaven API key for the user
4726async fn get_supermaven_api_key(
4727    _request: proto::GetSupermavenApiKey,
4728    response: Response<proto::GetSupermavenApiKey>,
4729    session: UserSession,
4730) -> Result<()> {
4731    let user_id: String = session.user_id().to_string();
4732    if !session.is_staff() {
4733        return Err(anyhow!("supermaven not enabled for this account"))?;
4734    }
4735
4736    let email = session
4737        .email()
4738        .ok_or_else(|| anyhow!("user must have an email"))?;
4739
4740    let supermaven_admin_api = session
4741        .supermaven_client
4742        .as_ref()
4743        .ok_or_else(|| anyhow!("supermaven not configured"))?;
4744
4745    let result = supermaven_admin_api
4746        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4747        .await?;
4748
4749    response.send(proto::GetSupermavenApiKeyResponse {
4750        api_key: result.api_key,
4751    })?;
4752
4753    Ok(())
4754}
4755
4756/// Start receiving chat updates for a channel
4757async fn join_channel_chat(
4758    request: proto::JoinChannelChat,
4759    response: Response<proto::JoinChannelChat>,
4760    session: UserSession,
4761) -> Result<()> {
4762    let channel_id = ChannelId::from_proto(request.channel_id);
4763
4764    let db = session.db().await;
4765    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4766        .await?;
4767    let messages = db
4768        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4769        .await?;
4770    response.send(proto::JoinChannelChatResponse {
4771        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4772        messages,
4773    })?;
4774    Ok(())
4775}
4776
4777/// Stop receiving chat updates for a channel
4778async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4779    let channel_id = ChannelId::from_proto(request.channel_id);
4780    session
4781        .db()
4782        .await
4783        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4784        .await?;
4785    Ok(())
4786}
4787
4788/// Retrieve the chat history for a channel
4789async fn get_channel_messages(
4790    request: proto::GetChannelMessages,
4791    response: Response<proto::GetChannelMessages>,
4792    session: UserSession,
4793) -> Result<()> {
4794    let channel_id = ChannelId::from_proto(request.channel_id);
4795    let messages = session
4796        .db()
4797        .await
4798        .get_channel_messages(
4799            channel_id,
4800            session.user_id(),
4801            MESSAGE_COUNT_PER_PAGE,
4802            Some(MessageId::from_proto(request.before_message_id)),
4803        )
4804        .await?;
4805    response.send(proto::GetChannelMessagesResponse {
4806        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4807        messages,
4808    })?;
4809    Ok(())
4810}
4811
4812/// Retrieve specific chat messages
4813async fn get_channel_messages_by_id(
4814    request: proto::GetChannelMessagesById,
4815    response: Response<proto::GetChannelMessagesById>,
4816    session: UserSession,
4817) -> Result<()> {
4818    let message_ids = request
4819        .message_ids
4820        .iter()
4821        .map(|id| MessageId::from_proto(*id))
4822        .collect::<Vec<_>>();
4823    let messages = session
4824        .db()
4825        .await
4826        .get_channel_messages_by_id(session.user_id(), &message_ids)
4827        .await?;
4828    response.send(proto::GetChannelMessagesResponse {
4829        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4830        messages,
4831    })?;
4832    Ok(())
4833}
4834
4835/// Retrieve the current users notifications
4836async fn get_notifications(
4837    request: proto::GetNotifications,
4838    response: Response<proto::GetNotifications>,
4839    session: UserSession,
4840) -> Result<()> {
4841    let notifications = session
4842        .db()
4843        .await
4844        .get_notifications(
4845            session.user_id(),
4846            NOTIFICATION_COUNT_PER_PAGE,
4847            request
4848                .before_id
4849                .map(|id| db::NotificationId::from_proto(id)),
4850        )
4851        .await?;
4852    response.send(proto::GetNotificationsResponse {
4853        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4854        notifications,
4855    })?;
4856    Ok(())
4857}
4858
4859/// Mark notifications as read
4860async fn mark_notification_as_read(
4861    request: proto::MarkNotificationRead,
4862    response: Response<proto::MarkNotificationRead>,
4863    session: UserSession,
4864) -> Result<()> {
4865    let database = &session.db().await;
4866    let notifications = database
4867        .mark_notification_as_read_by_id(
4868            session.user_id(),
4869            NotificationId::from_proto(request.notification_id),
4870        )
4871        .await?;
4872    send_notifications(
4873        &*session.connection_pool().await,
4874        &session.peer,
4875        notifications,
4876    );
4877    response.send(proto::Ack {})?;
4878    Ok(())
4879}
4880
4881/// Get the current users information
4882async fn get_private_user_info(
4883    _request: proto::GetPrivateUserInfo,
4884    response: Response<proto::GetPrivateUserInfo>,
4885    session: UserSession,
4886) -> Result<()> {
4887    let db = session.db().await;
4888
4889    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4890    let user = db
4891        .get_user_by_id(session.user_id())
4892        .await?
4893        .ok_or_else(|| anyhow!("user not found"))?;
4894    let flags = db.get_user_flags(session.user_id()).await?;
4895
4896    response.send(proto::GetPrivateUserInfoResponse {
4897        metrics_id,
4898        staff: user.admin,
4899        flags,
4900        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4901    })?;
4902    Ok(())
4903}
4904
4905/// Accept the terms of service (tos) on behalf of the current user
4906async fn accept_terms_of_service(
4907    _request: proto::AcceptTermsOfService,
4908    response: Response<proto::AcceptTermsOfService>,
4909    session: UserSession,
4910) -> Result<()> {
4911    let db = session.db().await;
4912
4913    let accepted_tos_at = Utc::now();
4914    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4915        .await?;
4916
4917    response.send(proto::AcceptTermsOfServiceResponse {
4918        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4919    })?;
4920    Ok(())
4921}
4922
4923/// The minimum account age an account must have in order to use the LLM service.
4924const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4925
4926async fn get_llm_api_token(
4927    _request: proto::GetLlmToken,
4928    response: Response<proto::GetLlmToken>,
4929    session: UserSession,
4930) -> Result<()> {
4931    let db = session.db().await;
4932
4933    let flags = db.get_user_flags(session.user_id()).await?;
4934    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4935    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4936
4937    if !session.is_staff() && !has_language_models_feature_flag {
4938        Err(anyhow!("permission denied"))?
4939    }
4940
4941    let user_id = session.user_id();
4942    let user = db
4943        .get_user_by_id(user_id)
4944        .await?
4945        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4946
4947    if user.accepted_tos_at.is_none() {
4948        Err(anyhow!("terms of service not accepted"))?
4949    }
4950
4951    let mut account_created_at = user.created_at;
4952    if let Some(github_created_at) = user.github_user_created_at {
4953        account_created_at = account_created_at.min(github_created_at);
4954    }
4955    if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4956        Err(anyhow!("account too young"))?
4957    }
4958    let token = LlmTokenClaims::create(
4959        user.id,
4960        user.github_login.clone(),
4961        session.is_staff(),
4962        has_llm_closed_beta_feature_flag,
4963        session.current_plan(db).await?,
4964        &session.app_state.config,
4965    )?;
4966    response.send(proto::GetLlmTokenResponse { token })?;
4967    Ok(())
4968}
4969
4970fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4971    let message = match message {
4972        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4973        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4974        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4975        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4976        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4977            code: frame.code.into(),
4978            reason: frame.reason,
4979        })),
4980        // We should never receive a frame while reading the message, according
4981        // to the `tungstenite` maintainers:
4982        //
4983        // > It cannot occur when you read messages from the WebSocket, but it
4984        // > can be used when you want to send the raw frames (e.g. you want to
4985        // > send the frames to the WebSocket without composing the full message first).
4986        // >
4987        // > — https://github.com/snapview/tungstenite-rs/issues/268
4988        TungsteniteMessage::Frame(_) => {
4989            bail!("received an unexpected frame while reading the message")
4990        }
4991    };
4992
4993    Ok(message)
4994}
4995
4996fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4997    match message {
4998        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4999        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5000        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5001        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5002        AxumMessage::Close(frame) => {
5003            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5004                code: frame.code.into(),
5005                reason: frame.reason,
5006            }))
5007        }
5008    }
5009}
5010
5011fn notify_membership_updated(
5012    connection_pool: &mut ConnectionPool,
5013    result: MembershipUpdated,
5014    user_id: UserId,
5015    peer: &Peer,
5016) {
5017    for membership in &result.new_channels.channel_memberships {
5018        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5019    }
5020    for channel_id in &result.removed_channels {
5021        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5022    }
5023
5024    let user_channels_update = proto::UpdateUserChannels {
5025        channel_memberships: result
5026            .new_channels
5027            .channel_memberships
5028            .iter()
5029            .map(|cm| proto::ChannelMembership {
5030                channel_id: cm.channel_id.to_proto(),
5031                role: cm.role.into(),
5032            })
5033            .collect(),
5034        ..Default::default()
5035    };
5036
5037    let mut update = build_channels_update(result.new_channels);
5038    update.delete_channels = result
5039        .removed_channels
5040        .into_iter()
5041        .map(|id| id.to_proto())
5042        .collect();
5043    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5044
5045    for connection_id in connection_pool.user_connection_ids(user_id) {
5046        peer.send(connection_id, user_channels_update.clone())
5047            .trace_err();
5048        peer.send(connection_id, update.clone()).trace_err();
5049    }
5050}
5051
5052fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5053    proto::UpdateUserChannels {
5054        channel_memberships: channels
5055            .channel_memberships
5056            .iter()
5057            .map(|m| proto::ChannelMembership {
5058                channel_id: m.channel_id.to_proto(),
5059                role: m.role.into(),
5060            })
5061            .collect(),
5062        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5063        observed_channel_message_id: channels.observed_channel_messages.clone(),
5064    }
5065}
5066
5067fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5068    let mut update = proto::UpdateChannels::default();
5069
5070    for channel in channels.channels {
5071        update.channels.push(channel.to_proto());
5072    }
5073
5074    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5075    update.latest_channel_message_ids = channels.latest_channel_messages;
5076
5077    for (channel_id, participants) in channels.channel_participants {
5078        update
5079            .channel_participants
5080            .push(proto::ChannelParticipants {
5081                channel_id: channel_id.to_proto(),
5082                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5083            });
5084    }
5085
5086    for channel in channels.invited_channels {
5087        update.channel_invitations.push(channel.to_proto());
5088    }
5089
5090    update.hosted_projects = channels.hosted_projects;
5091    update
5092}
5093
5094fn build_initial_contacts_update(
5095    contacts: Vec<db::Contact>,
5096    pool: &ConnectionPool,
5097) -> proto::UpdateContacts {
5098    let mut update = proto::UpdateContacts::default();
5099
5100    for contact in contacts {
5101        match contact {
5102            db::Contact::Accepted { user_id, busy } => {
5103                update.contacts.push(contact_for_user(user_id, busy, &pool));
5104            }
5105            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5106            db::Contact::Incoming { user_id } => {
5107                update
5108                    .incoming_requests
5109                    .push(proto::IncomingContactRequest {
5110                        requester_id: user_id.to_proto(),
5111                    })
5112            }
5113        }
5114    }
5115
5116    update
5117}
5118
5119fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5120    proto::Contact {
5121        user_id: user_id.to_proto(),
5122        online: pool.is_user_online(user_id),
5123        busy,
5124    }
5125}
5126
5127fn room_updated(room: &proto::Room, peer: &Peer) {
5128    broadcast(
5129        None,
5130        room.participants
5131            .iter()
5132            .filter_map(|participant| Some(participant.peer_id?.into())),
5133        |peer_id| {
5134            peer.send(
5135                peer_id,
5136                proto::RoomUpdated {
5137                    room: Some(room.clone()),
5138                },
5139            )
5140        },
5141    );
5142}
5143
5144fn channel_updated(
5145    channel: &db::channel::Model,
5146    room: &proto::Room,
5147    peer: &Peer,
5148    pool: &ConnectionPool,
5149) {
5150    let participants = room
5151        .participants
5152        .iter()
5153        .map(|p| p.user_id)
5154        .collect::<Vec<_>>();
5155
5156    broadcast(
5157        None,
5158        pool.channel_connection_ids(channel.root_id())
5159            .filter_map(|(channel_id, role)| {
5160                role.can_see_channel(channel.visibility).then(|| channel_id)
5161            }),
5162        |peer_id| {
5163            peer.send(
5164                peer_id,
5165                proto::UpdateChannels {
5166                    channel_participants: vec![proto::ChannelParticipants {
5167                        channel_id: channel.id.to_proto(),
5168                        participant_user_ids: participants.clone(),
5169                    }],
5170                    ..Default::default()
5171                },
5172            )
5173        },
5174    );
5175}
5176
5177async fn send_dev_server_projects_update(
5178    user_id: UserId,
5179    mut status: proto::DevServerProjectsUpdate,
5180    session: &Session,
5181) {
5182    let pool = session.connection_pool().await;
5183    for dev_server in &mut status.dev_servers {
5184        dev_server.status =
5185            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5186    }
5187    let connections = pool.user_connection_ids(user_id);
5188    for connection_id in connections {
5189        session.peer.send(connection_id, status.clone()).trace_err();
5190    }
5191}
5192
5193async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5194    let db = session.db().await;
5195
5196    let contacts = db.get_contacts(user_id).await?;
5197    let busy = db.is_user_busy(user_id).await?;
5198
5199    let pool = session.connection_pool().await;
5200    let updated_contact = contact_for_user(user_id, busy, &pool);
5201    for contact in contacts {
5202        if let db::Contact::Accepted {
5203            user_id: contact_user_id,
5204            ..
5205        } = contact
5206        {
5207            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5208                session
5209                    .peer
5210                    .send(
5211                        contact_conn_id,
5212                        proto::UpdateContacts {
5213                            contacts: vec![updated_contact.clone()],
5214                            remove_contacts: Default::default(),
5215                            incoming_requests: Default::default(),
5216                            remove_incoming_requests: Default::default(),
5217                            outgoing_requests: Default::default(),
5218                            remove_outgoing_requests: Default::default(),
5219                        },
5220                    )
5221                    .trace_err();
5222            }
5223        }
5224    }
5225    Ok(())
5226}
5227
5228async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5229    log::info!("lost dev server connection, unsharing projects");
5230    let project_ids = session
5231        .db()
5232        .await
5233        .get_stale_dev_server_projects(session.connection_id)
5234        .await?;
5235
5236    for project_id in project_ids {
5237        // not unshare re-checks the connection ids match, so we get away with no transaction
5238        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5239    }
5240
5241    let user_id = session.dev_server().user_id;
5242    let update = session
5243        .db()
5244        .await
5245        .dev_server_projects_update(user_id)
5246        .await?;
5247
5248    send_dev_server_projects_update(user_id, update, session).await;
5249
5250    Ok(())
5251}
5252
5253async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5254    let mut contacts_to_update = HashSet::default();
5255
5256    let room_id;
5257    let canceled_calls_to_user_ids;
5258    let live_kit_room;
5259    let delete_live_kit_room;
5260    let room;
5261    let channel;
5262
5263    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5264        contacts_to_update.insert(session.user_id());
5265
5266        for project in left_room.left_projects.values() {
5267            project_left(project, session);
5268        }
5269
5270        room_id = RoomId::from_proto(left_room.room.id);
5271        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5272        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5273        delete_live_kit_room = left_room.deleted;
5274        room = mem::take(&mut left_room.room);
5275        channel = mem::take(&mut left_room.channel);
5276
5277        room_updated(&room, &session.peer);
5278    } else {
5279        return Ok(());
5280    }
5281
5282    if let Some(channel) = channel {
5283        channel_updated(
5284            &channel,
5285            &room,
5286            &session.peer,
5287            &*session.connection_pool().await,
5288        );
5289    }
5290
5291    {
5292        let pool = session.connection_pool().await;
5293        for canceled_user_id in canceled_calls_to_user_ids {
5294            for connection_id in pool.user_connection_ids(canceled_user_id) {
5295                session
5296                    .peer
5297                    .send(
5298                        connection_id,
5299                        proto::CallCanceled {
5300                            room_id: room_id.to_proto(),
5301                        },
5302                    )
5303                    .trace_err();
5304            }
5305            contacts_to_update.insert(canceled_user_id);
5306        }
5307    }
5308
5309    for contact_user_id in contacts_to_update {
5310        update_user_contacts(contact_user_id, &session).await?;
5311    }
5312
5313    if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5314        live_kit
5315            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5316            .await
5317            .trace_err();
5318
5319        if delete_live_kit_room {
5320            live_kit.delete_room(live_kit_room).await.trace_err();
5321        }
5322    }
5323
5324    Ok(())
5325}
5326
5327async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5328    let left_channel_buffers = session
5329        .db()
5330        .await
5331        .leave_channel_buffers(session.connection_id)
5332        .await?;
5333
5334    for left_buffer in left_channel_buffers {
5335        channel_buffer_updated(
5336            session.connection_id,
5337            left_buffer.connections,
5338            &proto::UpdateChannelBufferCollaborators {
5339                channel_id: left_buffer.channel_id.to_proto(),
5340                collaborators: left_buffer.collaborators,
5341            },
5342            &session.peer,
5343        );
5344    }
5345
5346    Ok(())
5347}
5348
5349fn project_left(project: &db::LeftProject, session: &UserSession) {
5350    for connection_id in &project.connection_ids {
5351        if project.should_unshare {
5352            session
5353                .peer
5354                .send(
5355                    *connection_id,
5356                    proto::UnshareProject {
5357                        project_id: project.id.to_proto(),
5358                    },
5359                )
5360                .trace_err();
5361        } else {
5362            session
5363                .peer
5364                .send(
5365                    *connection_id,
5366                    proto::RemoveProjectCollaborator {
5367                        project_id: project.id.to_proto(),
5368                        peer_id: Some(session.connection_id.into()),
5369                    },
5370                )
5371                .trace_err();
5372        }
5373    }
5374}
5375
5376pub trait ResultExt {
5377    type Ok;
5378
5379    fn trace_err(self) -> Option<Self::Ok>;
5380}
5381
5382impl<T, E> ResultExt for Result<T, E>
5383where
5384    E: std::fmt::Debug,
5385{
5386    type Ok = T;
5387
5388    #[track_caller]
5389    fn trace_err(self) -> Option<T> {
5390        match self {
5391            Ok(value) => Some(value),
5392            Err(error) => {
5393                tracing::error!("{:?}", error);
5394                None
5395            }
5396        }
5397    }
5398}