rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, InviteMemberResult, MembershipUpdated,
   8        MessageId, NotificationId, PrincipalId, Project, ProjectId, RejoinedProject,
   9        RemoteProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37
  38use futures::{
  39    channel::oneshot,
  40    future::{self, BoxFuture},
  41    stream::FuturesUnordered,
  42    FutureExt, SinkExt, StreamExt, TryStreamExt,
  43};
  44use prometheus::{register_int_gauge, IntGauge};
  45use rpc::{
  46    proto::{
  47        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  48        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  49    },
  50    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  51};
  52use semantic_version::SemanticVersion;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        atomic::{AtomicBool, Ordering::SeqCst},
  64        Arc, OnceLock,
  65    },
  66    time::{Duration, Instant},
  67};
  68use time::OffsetDateTime;
  69use tokio::sync::{watch, Semaphore};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    field::{self},
  73    info_span, instrument, Instrument,
  74};
  75use util::http::IsahcHttpClient;
  76
  77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  78
  79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  81
  82const MESSAGE_COUNT_PER_PAGE: usize = 100;
  83const MAX_MESSAGE_LEN: usize = 1024;
  84const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  85
  86type MessageHandler =
  87    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  88
  89struct Response<R> {
  90    peer: Arc<Peer>,
  91    receipt: Receipt<R>,
  92    responded: Arc<AtomicBool>,
  93}
  94
  95impl<R: RequestMessage> Response<R> {
  96    fn send(self, payload: R::Response) -> Result<()> {
  97        self.responded.store(true, SeqCst);
  98        self.peer.respond(self.receipt, payload)?;
  99        Ok(())
 100    }
 101}
 102
 103struct StreamingResponse<R: RequestMessage> {
 104    peer: Arc<Peer>,
 105    receipt: Receipt<R>,
 106}
 107
 108impl<R: RequestMessage> StreamingResponse<R> {
 109    fn send(&self, payload: R::Response) -> Result<()> {
 110        self.peer.respond(self.receipt, payload)?;
 111        Ok(())
 112    }
 113}
 114
 115#[derive(Clone, Debug)]
 116pub enum Principal {
 117    User(User),
 118    Impersonated { user: User, admin: User },
 119    DevServer(dev_server::Model),
 120}
 121
 122impl Principal {
 123    fn update_span(&self, span: &tracing::Span) {
 124        match &self {
 125            Principal::User(user) => {
 126                span.record("user_id", &user.id.0);
 127                span.record("login", &user.github_login);
 128            }
 129            Principal::Impersonated { user, admin } => {
 130                span.record("user_id", &user.id.0);
 131                span.record("login", &user.github_login);
 132                span.record("impersonator", &admin.github_login);
 133            }
 134            Principal::DevServer(dev_server) => {
 135                span.record("dev_server_id", &dev_server.id.0);
 136            }
 137        }
 138    }
 139}
 140
 141#[derive(Clone)]
 142struct Session {
 143    principal: Principal,
 144    connection_id: ConnectionId,
 145    db: Arc<tokio::sync::Mutex<DbHandle>>,
 146    peer: Arc<Peer>,
 147    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 149    http_client: IsahcHttpClient,
 150    rate_limiter: Arc<RateLimiter>,
 151    _executor: Executor,
 152}
 153
 154impl Session {
 155    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        let guard = self.db.lock().await;
 159        #[cfg(test)]
 160        tokio::task::yield_now().await;
 161        guard
 162    }
 163
 164    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 165        #[cfg(test)]
 166        tokio::task::yield_now().await;
 167        let guard = self.connection_pool.lock();
 168        ConnectionPoolGuard {
 169            guard,
 170            _not_send: PhantomData,
 171        }
 172    }
 173
 174    fn for_user(self) -> Option<UserSession> {
 175        UserSession::new(self)
 176    }
 177
 178    fn for_dev_server(self) -> Option<DevServerSession> {
 179        DevServerSession::new(self)
 180    }
 181
 182    fn user_id(&self) -> Option<UserId> {
 183        match &self.principal {
 184            Principal::User(user) => Some(user.id),
 185            Principal::Impersonated { user, .. } => Some(user.id),
 186            Principal::DevServer(_) => None,
 187        }
 188    }
 189
 190    fn dev_server_id(&self) -> Option<DevServerId> {
 191        match &self.principal {
 192            Principal::User(_) | Principal::Impersonated { .. } => None,
 193            Principal::DevServer(dev_server) => Some(dev_server.id),
 194        }
 195    }
 196
 197    fn principal_id(&self) -> PrincipalId {
 198        match &self.principal {
 199            Principal::User(user) => PrincipalId::UserId(user.id),
 200            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 201            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 202        }
 203    }
 204}
 205
 206impl Debug for Session {
 207    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 208        let mut result = f.debug_struct("Session");
 209        match &self.principal {
 210            Principal::User(user) => {
 211                result.field("user", &user.github_login);
 212            }
 213            Principal::Impersonated { user, admin } => {
 214                result.field("user", &user.github_login);
 215                result.field("impersonator", &admin.github_login);
 216            }
 217            Principal::DevServer(dev_server) => {
 218                result.field("dev_server", &dev_server.id);
 219            }
 220        }
 221        result.field("connection_id", &self.connection_id).finish()
 222    }
 223}
 224
 225struct UserSession(Session);
 226
 227impl UserSession {
 228    pub fn new(s: Session) -> Option<Self> {
 229        s.user_id().map(|_| UserSession(s))
 230    }
 231    pub fn user_id(&self) -> UserId {
 232        self.0.user_id().unwrap()
 233    }
 234}
 235
 236impl Deref for UserSession {
 237    type Target = Session;
 238
 239    fn deref(&self) -> &Self::Target {
 240        &self.0
 241    }
 242}
 243impl DerefMut for UserSession {
 244    fn deref_mut(&mut self) -> &mut Self::Target {
 245        &mut self.0
 246    }
 247}
 248
 249struct DevServerSession(Session);
 250
 251impl DevServerSession {
 252    pub fn new(s: Session) -> Option<Self> {
 253        s.dev_server_id().map(|_| DevServerSession(s))
 254    }
 255    pub fn dev_server_id(&self) -> DevServerId {
 256        self.0.dev_server_id().unwrap()
 257    }
 258
 259    fn dev_server(&self) -> &dev_server::Model {
 260        match &self.0.principal {
 261            Principal::DevServer(dev_server) => dev_server,
 262            _ => unreachable!(),
 263        }
 264    }
 265}
 266
 267impl Deref for DevServerSession {
 268    type Target = Session;
 269
 270    fn deref(&self) -> &Self::Target {
 271        &self.0
 272    }
 273}
 274impl DerefMut for DevServerSession {
 275    fn deref_mut(&mut self) -> &mut Self::Target {
 276        &mut self.0
 277    }
 278}
 279
 280fn user_handler<M: RequestMessage, Fut>(
 281    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 282) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 283where
 284    Fut: Send + Future<Output = Result<()>>,
 285{
 286    let handler = Arc::new(handler);
 287    move |message, response, session| {
 288        let handler = handler.clone();
 289        Box::pin(async move {
 290            if let Some(user_session) = session.for_user() {
 291                Ok(handler(message, response, user_session).await?)
 292            } else {
 293                Err(Error::Internal(anyhow!(
 294                    "must be a user to call {}",
 295                    M::NAME
 296                )))
 297            }
 298        })
 299    }
 300}
 301
 302fn dev_server_handler<M: RequestMessage, Fut>(
 303    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 304) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 305where
 306    Fut: Send + Future<Output = Result<()>>,
 307{
 308    let handler = Arc::new(handler);
 309    move |message, response, session| {
 310        let handler = handler.clone();
 311        Box::pin(async move {
 312            if let Some(dev_server_session) = session.for_dev_server() {
 313                Ok(handler(message, response, dev_server_session).await?)
 314            } else {
 315                Err(Error::Internal(anyhow!(
 316                    "must be a dev server to call {}",
 317                    M::NAME
 318                )))
 319            }
 320        })
 321    }
 322}
 323
 324fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 325    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 326) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 327where
 328    InnertRetFut: Send + Future<Output = Result<()>>,
 329{
 330    let handler = Arc::new(handler);
 331    move |message, session| {
 332        let handler = handler.clone();
 333        Box::pin(async move {
 334            if let Some(user_session) = session.for_user() {
 335                Ok(handler(message, user_session).await?)
 336            } else {
 337                Err(Error::Internal(anyhow!(
 338                    "must be a user to call {}",
 339                    M::NAME
 340                )))
 341            }
 342        })
 343    }
 344}
 345
 346struct DbHandle(Arc<Database>);
 347
 348impl Deref for DbHandle {
 349    type Target = Database;
 350
 351    fn deref(&self) -> &Self::Target {
 352        self.0.as_ref()
 353    }
 354}
 355
 356pub struct Server {
 357    id: parking_lot::Mutex<ServerId>,
 358    peer: Arc<Peer>,
 359    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 360    app_state: Arc<AppState>,
 361    handlers: HashMap<TypeId, MessageHandler>,
 362    teardown: watch::Sender<bool>,
 363}
 364
 365pub(crate) struct ConnectionPoolGuard<'a> {
 366    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 367    _not_send: PhantomData<Rc<()>>,
 368}
 369
 370#[derive(Serialize)]
 371pub struct ServerSnapshot<'a> {
 372    peer: &'a Peer,
 373    #[serde(serialize_with = "serialize_deref")]
 374    connection_pool: ConnectionPoolGuard<'a>,
 375}
 376
 377pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 378where
 379    S: Serializer,
 380    T: Deref<Target = U>,
 381    U: Serialize,
 382{
 383    Serialize::serialize(value.deref(), serializer)
 384}
 385
 386impl Server {
 387    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 388        let mut server = Self {
 389            id: parking_lot::Mutex::new(id),
 390            peer: Peer::new(id.0 as u32),
 391            app_state: app_state.clone(),
 392            connection_pool: Default::default(),
 393            handlers: Default::default(),
 394            teardown: watch::channel(false).0,
 395        };
 396
 397        server
 398            .add_request_handler(ping)
 399            .add_request_handler(user_handler(create_room))
 400            .add_request_handler(user_handler(join_room))
 401            .add_request_handler(user_handler(rejoin_room))
 402            .add_request_handler(user_handler(leave_room))
 403            .add_request_handler(user_handler(set_room_participant_role))
 404            .add_request_handler(user_handler(call))
 405            .add_request_handler(user_handler(cancel_call))
 406            .add_message_handler(user_message_handler(decline_call))
 407            .add_request_handler(user_handler(update_participant_location))
 408            .add_request_handler(user_handler(share_project))
 409            .add_message_handler(unshare_project)
 410            .add_request_handler(user_handler(join_project))
 411            .add_request_handler(user_handler(join_hosted_project))
 412            .add_request_handler(user_handler(rejoin_remote_projects))
 413            .add_request_handler(user_handler(create_remote_project))
 414            .add_request_handler(user_handler(create_dev_server))
 415            .add_request_handler(user_handler(delete_dev_server))
 416            .add_request_handler(dev_server_handler(share_remote_project))
 417            .add_request_handler(dev_server_handler(shutdown_dev_server))
 418            .add_request_handler(dev_server_handler(reconnect_dev_server))
 419            .add_message_handler(user_message_handler(leave_project))
 420            .add_request_handler(update_project)
 421            .add_request_handler(update_worktree)
 422            .add_message_handler(start_language_server)
 423            .add_message_handler(update_language_server)
 424            .add_message_handler(update_diagnostic_summary)
 425            .add_message_handler(update_worktree_settings)
 426            .add_request_handler(user_handler(
 427                forward_read_only_project_request::<proto::GetHover>,
 428            ))
 429            .add_request_handler(user_handler(
 430                forward_read_only_project_request::<proto::GetDefinition>,
 431            ))
 432            .add_request_handler(user_handler(
 433                forward_read_only_project_request::<proto::GetTypeDefinition>,
 434            ))
 435            .add_request_handler(user_handler(
 436                forward_read_only_project_request::<proto::GetReferences>,
 437            ))
 438            .add_request_handler(user_handler(
 439                forward_read_only_project_request::<proto::SearchProject>,
 440            ))
 441            .add_request_handler(user_handler(
 442                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 443            ))
 444            .add_request_handler(user_handler(
 445                forward_read_only_project_request::<proto::GetProjectSymbols>,
 446            ))
 447            .add_request_handler(user_handler(
 448                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 449            ))
 450            .add_request_handler(user_handler(
 451                forward_read_only_project_request::<proto::OpenBufferById>,
 452            ))
 453            .add_request_handler(user_handler(
 454                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 455            ))
 456            .add_request_handler(user_handler(
 457                forward_read_only_project_request::<proto::InlayHints>,
 458            ))
 459            .add_request_handler(user_handler(
 460                forward_read_only_project_request::<proto::OpenBufferByPath>,
 461            ))
 462            .add_request_handler(user_handler(
 463                forward_mutating_project_request::<proto::GetCompletions>,
 464            ))
 465            .add_request_handler(user_handler(
 466                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 467            ))
 468            .add_request_handler(user_handler(
 469                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 470            ))
 471            .add_request_handler(user_handler(
 472                forward_mutating_project_request::<proto::GetCodeActions>,
 473            ))
 474            .add_request_handler(user_handler(
 475                forward_mutating_project_request::<proto::ApplyCodeAction>,
 476            ))
 477            .add_request_handler(user_handler(
 478                forward_mutating_project_request::<proto::PrepareRename>,
 479            ))
 480            .add_request_handler(user_handler(
 481                forward_mutating_project_request::<proto::PerformRename>,
 482            ))
 483            .add_request_handler(user_handler(
 484                forward_mutating_project_request::<proto::ReloadBuffers>,
 485            ))
 486            .add_request_handler(user_handler(
 487                forward_mutating_project_request::<proto::FormatBuffers>,
 488            ))
 489            .add_request_handler(user_handler(
 490                forward_mutating_project_request::<proto::CreateProjectEntry>,
 491            ))
 492            .add_request_handler(user_handler(
 493                forward_mutating_project_request::<proto::RenameProjectEntry>,
 494            ))
 495            .add_request_handler(user_handler(
 496                forward_mutating_project_request::<proto::CopyProjectEntry>,
 497            ))
 498            .add_request_handler(user_handler(
 499                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 500            ))
 501            .add_request_handler(user_handler(
 502                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 503            ))
 504            .add_request_handler(user_handler(
 505                forward_mutating_project_request::<proto::OnTypeFormatting>,
 506            ))
 507            .add_request_handler(user_handler(
 508                forward_mutating_project_request::<proto::SaveBuffer>,
 509            ))
 510            .add_request_handler(user_handler(
 511                forward_mutating_project_request::<proto::BlameBuffer>,
 512            ))
 513            .add_request_handler(user_handler(
 514                forward_mutating_project_request::<proto::MultiLspQuery>,
 515            ))
 516            .add_message_handler(create_buffer_for_peer)
 517            .add_request_handler(update_buffer)
 518            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 519            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 520            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 521            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 522            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 523            .add_request_handler(get_users)
 524            .add_request_handler(user_handler(fuzzy_search_users))
 525            .add_request_handler(user_handler(request_contact))
 526            .add_request_handler(user_handler(remove_contact))
 527            .add_request_handler(user_handler(respond_to_contact_request))
 528            .add_request_handler(user_handler(create_channel))
 529            .add_request_handler(user_handler(delete_channel))
 530            .add_request_handler(user_handler(invite_channel_member))
 531            .add_request_handler(user_handler(remove_channel_member))
 532            .add_request_handler(user_handler(set_channel_member_role))
 533            .add_request_handler(user_handler(set_channel_visibility))
 534            .add_request_handler(user_handler(rename_channel))
 535            .add_request_handler(user_handler(join_channel_buffer))
 536            .add_request_handler(user_handler(leave_channel_buffer))
 537            .add_message_handler(user_message_handler(update_channel_buffer))
 538            .add_request_handler(user_handler(rejoin_channel_buffers))
 539            .add_request_handler(user_handler(get_channel_members))
 540            .add_request_handler(user_handler(respond_to_channel_invite))
 541            .add_request_handler(user_handler(join_channel))
 542            .add_request_handler(user_handler(join_channel_chat))
 543            .add_message_handler(user_message_handler(leave_channel_chat))
 544            .add_request_handler(user_handler(send_channel_message))
 545            .add_request_handler(user_handler(remove_channel_message))
 546            .add_request_handler(user_handler(update_channel_message))
 547            .add_request_handler(user_handler(get_channel_messages))
 548            .add_request_handler(user_handler(get_channel_messages_by_id))
 549            .add_request_handler(user_handler(get_notifications))
 550            .add_request_handler(user_handler(mark_notification_as_read))
 551            .add_request_handler(user_handler(move_channel))
 552            .add_request_handler(user_handler(follow))
 553            .add_message_handler(user_message_handler(unfollow))
 554            .add_message_handler(user_message_handler(update_followers))
 555            .add_request_handler(user_handler(get_private_user_info))
 556            .add_message_handler(user_message_handler(acknowledge_channel_message))
 557            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 558            .add_streaming_request_handler({
 559                let app_state = app_state.clone();
 560                move |request, response, session| {
 561                    complete_with_language_model(
 562                        request,
 563                        response,
 564                        session,
 565                        app_state.config.openai_api_key.clone(),
 566                        app_state.config.google_ai_api_key.clone(),
 567                        app_state.config.anthropic_api_key.clone(),
 568                    )
 569                }
 570            })
 571            .add_request_handler({
 572                let app_state = app_state.clone();
 573                user_handler(move |request, response, session| {
 574                    count_tokens_with_language_model(
 575                        request,
 576                        response,
 577                        session,
 578                        app_state.config.google_ai_api_key.clone(),
 579                    )
 580                })
 581            })
 582            .add_request_handler({
 583                user_handler(move |request, response, session| {
 584                    get_cached_embeddings(request, response, session)
 585                })
 586            })
 587            .add_request_handler({
 588                let app_state = app_state.clone();
 589                user_handler(move |request, response, session| {
 590                    compute_embeddings(
 591                        request,
 592                        response,
 593                        session,
 594                        app_state.config.openai_api_key.clone(),
 595                    )
 596                })
 597            });
 598
 599        Arc::new(server)
 600    }
 601
 602    pub async fn start(&self) -> Result<()> {
 603        let server_id = *self.id.lock();
 604        let app_state = self.app_state.clone();
 605        let peer = self.peer.clone();
 606        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 607        let pool = self.connection_pool.clone();
 608        let live_kit_client = self.app_state.live_kit_client.clone();
 609
 610        let span = info_span!("start server");
 611        self.app_state.executor.spawn_detached(
 612            async move {
 613                tracing::info!("waiting for cleanup timeout");
 614                timeout.await;
 615                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 616                if let Some((room_ids, channel_ids)) = app_state
 617                    .db
 618                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 619                    .await
 620                    .trace_err()
 621                {
 622                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 623                    tracing::info!(
 624                        stale_channel_buffer_count = channel_ids.len(),
 625                        "retrieved stale channel buffers"
 626                    );
 627
 628                    for channel_id in channel_ids {
 629                        if let Some(refreshed_channel_buffer) = app_state
 630                            .db
 631                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 632                            .await
 633                            .trace_err()
 634                        {
 635                            for connection_id in refreshed_channel_buffer.connection_ids {
 636                                peer.send(
 637                                    connection_id,
 638                                    proto::UpdateChannelBufferCollaborators {
 639                                        channel_id: channel_id.to_proto(),
 640                                        collaborators: refreshed_channel_buffer
 641                                            .collaborators
 642                                            .clone(),
 643                                    },
 644                                )
 645                                .trace_err();
 646                            }
 647                        }
 648                    }
 649
 650                    for room_id in room_ids {
 651                        let mut contacts_to_update = HashSet::default();
 652                        let mut canceled_calls_to_user_ids = Vec::new();
 653                        let mut live_kit_room = String::new();
 654                        let mut delete_live_kit_room = false;
 655
 656                        if let Some(mut refreshed_room) = app_state
 657                            .db
 658                            .clear_stale_room_participants(room_id, server_id)
 659                            .await
 660                            .trace_err()
 661                        {
 662                            tracing::info!(
 663                                room_id = room_id.0,
 664                                new_participant_count = refreshed_room.room.participants.len(),
 665                                "refreshed room"
 666                            );
 667                            room_updated(&refreshed_room.room, &peer);
 668                            if let Some(channel) = refreshed_room.channel.as_ref() {
 669                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 670                            }
 671                            contacts_to_update
 672                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 673                            contacts_to_update
 674                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 675                            canceled_calls_to_user_ids =
 676                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 677                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 678                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 679                        }
 680
 681                        {
 682                            let pool = pool.lock();
 683                            for canceled_user_id in canceled_calls_to_user_ids {
 684                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 685                                    peer.send(
 686                                        connection_id,
 687                                        proto::CallCanceled {
 688                                            room_id: room_id.to_proto(),
 689                                        },
 690                                    )
 691                                    .trace_err();
 692                                }
 693                            }
 694                        }
 695
 696                        for user_id in contacts_to_update {
 697                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 698                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 699                            if let Some((busy, contacts)) = busy.zip(contacts) {
 700                                let pool = pool.lock();
 701                                let updated_contact = contact_for_user(user_id, busy, &pool);
 702                                for contact in contacts {
 703                                    if let db::Contact::Accepted {
 704                                        user_id: contact_user_id,
 705                                        ..
 706                                    } = contact
 707                                    {
 708                                        for contact_conn_id in
 709                                            pool.user_connection_ids(contact_user_id)
 710                                        {
 711                                            peer.send(
 712                                                contact_conn_id,
 713                                                proto::UpdateContacts {
 714                                                    contacts: vec![updated_contact.clone()],
 715                                                    remove_contacts: Default::default(),
 716                                                    incoming_requests: Default::default(),
 717                                                    remove_incoming_requests: Default::default(),
 718                                                    outgoing_requests: Default::default(),
 719                                                    remove_outgoing_requests: Default::default(),
 720                                                },
 721                                            )
 722                                            .trace_err();
 723                                        }
 724                                    }
 725                                }
 726                            }
 727                        }
 728
 729                        if let Some(live_kit) = live_kit_client.as_ref() {
 730                            if delete_live_kit_room {
 731                                live_kit.delete_room(live_kit_room).await.trace_err();
 732                            }
 733                        }
 734                    }
 735                }
 736
 737                app_state
 738                    .db
 739                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 740                    .await
 741                    .trace_err();
 742            }
 743            .instrument(span),
 744        );
 745        Ok(())
 746    }
 747
 748    pub fn teardown(&self) {
 749        self.peer.teardown();
 750        self.connection_pool.lock().reset();
 751        let _ = self.teardown.send(true);
 752    }
 753
 754    #[cfg(test)]
 755    pub fn reset(&self, id: ServerId) {
 756        self.teardown();
 757        *self.id.lock() = id;
 758        self.peer.reset(id.0 as u32);
 759        let _ = self.teardown.send(false);
 760    }
 761
 762    #[cfg(test)]
 763    pub fn id(&self) -> ServerId {
 764        *self.id.lock()
 765    }
 766
 767    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 768    where
 769        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 770        Fut: 'static + Send + Future<Output = Result<()>>,
 771        M: EnvelopedMessage,
 772    {
 773        let prev_handler = self.handlers.insert(
 774            TypeId::of::<M>(),
 775            Box::new(move |envelope, session| {
 776                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 777                let received_at = envelope.received_at;
 778                tracing::info!("message received");
 779                let start_time = Instant::now();
 780                let future = (handler)(*envelope, session);
 781                async move {
 782                    let result = future.await;
 783                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 784                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 785                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 786                    let payload_type = M::NAME;
 787
 788                    match result {
 789                        Err(error) => {
 790                            tracing::error!(
 791                                ?error,
 792                                total_duration_ms,
 793                                processing_duration_ms,
 794                                queue_duration_ms,
 795                                payload_type,
 796                                "error handling message"
 797                            )
 798                        }
 799                        Ok(()) => tracing::info!(
 800                            total_duration_ms,
 801                            processing_duration_ms,
 802                            queue_duration_ms,
 803                            "finished handling message"
 804                        ),
 805                    }
 806                }
 807                .boxed()
 808            }),
 809        );
 810        if prev_handler.is_some() {
 811            panic!("registered a handler for the same message twice");
 812        }
 813        self
 814    }
 815
 816    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 817    where
 818        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 819        Fut: 'static + Send + Future<Output = Result<()>>,
 820        M: EnvelopedMessage,
 821    {
 822        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 823        self
 824    }
 825
 826    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 827    where
 828        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 829        Fut: Send + Future<Output = Result<()>>,
 830        M: RequestMessage,
 831    {
 832        let handler = Arc::new(handler);
 833        self.add_handler(move |envelope, session| {
 834            let receipt = envelope.receipt();
 835            let handler = handler.clone();
 836            async move {
 837                let peer = session.peer.clone();
 838                let responded = Arc::new(AtomicBool::default());
 839                let response = Response {
 840                    peer: peer.clone(),
 841                    responded: responded.clone(),
 842                    receipt,
 843                };
 844                match (handler)(envelope.payload, response, session).await {
 845                    Ok(()) => {
 846                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 847                            Ok(())
 848                        } else {
 849                            Err(anyhow!("handler did not send a response"))?
 850                        }
 851                    }
 852                    Err(error) => {
 853                        let proto_err = match &error {
 854                            Error::Internal(err) => err.to_proto(),
 855                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 856                        };
 857                        peer.respond_with_error(receipt, proto_err)?;
 858                        Err(error)
 859                    }
 860                }
 861            }
 862        })
 863    }
 864
 865    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 866    where
 867        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 868        Fut: Send + Future<Output = Result<()>>,
 869        M: RequestMessage,
 870    {
 871        let handler = Arc::new(handler);
 872        self.add_handler(move |envelope, session| {
 873            let receipt = envelope.receipt();
 874            let handler = handler.clone();
 875            async move {
 876                let peer = session.peer.clone();
 877                let response = StreamingResponse {
 878                    peer: peer.clone(),
 879                    receipt,
 880                };
 881                match (handler)(envelope.payload, response, session).await {
 882                    Ok(()) => {
 883                        peer.end_stream(receipt)?;
 884                        Ok(())
 885                    }
 886                    Err(error) => {
 887                        let proto_err = match &error {
 888                            Error::Internal(err) => err.to_proto(),
 889                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 890                        };
 891                        peer.respond_with_error(receipt, proto_err)?;
 892                        Err(error)
 893                    }
 894                }
 895            }
 896        })
 897    }
 898
 899    #[allow(clippy::too_many_arguments)]
 900    pub fn handle_connection(
 901        self: &Arc<Self>,
 902        connection: Connection,
 903        address: String,
 904        principal: Principal,
 905        zed_version: ZedVersion,
 906        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 907        executor: Executor,
 908    ) -> impl Future<Output = ()> {
 909        let this = self.clone();
 910        let span = info_span!("handle connection", %address,
 911            connection_id=field::Empty,
 912            user_id=field::Empty,
 913            login=field::Empty,
 914            impersonator=field::Empty,
 915            dev_server_id=field::Empty
 916        );
 917        principal.update_span(&span);
 918
 919        let mut teardown = self.teardown.subscribe();
 920        async move {
 921            if *teardown.borrow() {
 922                tracing::error!("server is tearing down");
 923                return
 924            }
 925            let (connection_id, handle_io, mut incoming_rx) = this
 926                .peer
 927                .add_connection(connection, {
 928                    let executor = executor.clone();
 929                    move |duration| executor.sleep(duration)
 930                });
 931            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 932            tracing::info!("connection opened");
 933
 934            let http_client = match IsahcHttpClient::new() {
 935                Ok(http_client) => http_client,
 936                Err(error) => {
 937                    tracing::error!(?error, "failed to create HTTP client");
 938                    return;
 939                }
 940            };
 941
 942            let session = Session {
 943                principal: principal.clone(),
 944                connection_id,
 945                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 946                peer: this.peer.clone(),
 947                connection_pool: this.connection_pool.clone(),
 948                live_kit_client: this.app_state.live_kit_client.clone(),
 949                http_client,
 950                rate_limiter: this.app_state.rate_limiter.clone(),
 951                _executor: executor.clone(),
 952            };
 953
 954            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 955                tracing::error!(?error, "failed to send initial client update");
 956                return;
 957            }
 958
 959            let handle_io = handle_io.fuse();
 960            futures::pin_mut!(handle_io);
 961
 962            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 963            // This prevents deadlocks when e.g., client A performs a request to client B and
 964            // client B performs a request to client A. If both clients stop processing further
 965            // messages until their respective request completes, they won't have a chance to
 966            // respond to the other client's request and cause a deadlock.
 967            //
 968            // This arrangement ensures we will attempt to process earlier messages first, but fall
 969            // back to processing messages arrived later in the spirit of making progress.
 970            let mut foreground_message_handlers = FuturesUnordered::new();
 971            let concurrent_handlers = Arc::new(Semaphore::new(256));
 972            loop {
 973                let next_message = async {
 974                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 975                    let message = incoming_rx.next().await;
 976                    (permit, message)
 977                }.fuse();
 978                futures::pin_mut!(next_message);
 979                futures::select_biased! {
 980                    _ = teardown.changed().fuse() => return,
 981                    result = handle_io => {
 982                        if let Err(error) = result {
 983                            tracing::error!(?error, "error handling I/O");
 984                        }
 985                        break;
 986                    }
 987                    _ = foreground_message_handlers.next() => {}
 988                    next_message = next_message => {
 989                        let (permit, message) = next_message;
 990                        if let Some(message) = message {
 991                            let type_name = message.payload_type_name();
 992                            // note: we copy all the fields from the parent span so we can query them in the logs.
 993                            // (https://github.com/tokio-rs/tracing/issues/2670).
 994                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 995                                user_id=field::Empty,
 996                                login=field::Empty,
 997                                impersonator=field::Empty,
 998                                dev_server_id=field::Empty
 999                            );
1000                            principal.update_span(&span);
1001                            let span_enter = span.enter();
1002                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1003                                let is_background = message.is_background();
1004                                let handle_message = (handler)(message, session.clone());
1005                                drop(span_enter);
1006
1007                                let handle_message = async move {
1008                                    handle_message.await;
1009                                    drop(permit);
1010                                }.instrument(span);
1011                                if is_background {
1012                                    executor.spawn_detached(handle_message);
1013                                } else {
1014                                    foreground_message_handlers.push(handle_message);
1015                                }
1016                            } else {
1017                                tracing::error!("no message handler");
1018                            }
1019                        } else {
1020                            tracing::info!("connection closed");
1021                            break;
1022                        }
1023                    }
1024                }
1025            }
1026
1027            drop(foreground_message_handlers);
1028            tracing::info!("signing out");
1029            if let Err(error) = connection_lost(session, teardown, executor).await {
1030                tracing::error!(?error, "error signing out");
1031            }
1032
1033        }.instrument(span)
1034    }
1035
1036    async fn send_initial_client_update(
1037        &self,
1038        connection_id: ConnectionId,
1039        principal: &Principal,
1040        zed_version: ZedVersion,
1041        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1042        session: &Session,
1043    ) -> Result<()> {
1044        self.peer.send(
1045            connection_id,
1046            proto::Hello {
1047                peer_id: Some(connection_id.into()),
1048            },
1049        )?;
1050        tracing::info!("sent hello message");
1051        if let Some(send_connection_id) = send_connection_id.take() {
1052            let _ = send_connection_id.send(connection_id);
1053        }
1054
1055        match principal {
1056            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1057                if !user.connected_once {
1058                    self.peer.send(connection_id, proto::ShowContacts {})?;
1059                    self.app_state
1060                        .db
1061                        .set_user_connected_once(user.id, true)
1062                        .await?;
1063                }
1064
1065                let (contacts, channels_for_user, channel_invites, remote_projects) =
1066                    future::try_join4(
1067                        self.app_state.db.get_contacts(user.id),
1068                        self.app_state.db.get_channels_for_user(user.id),
1069                        self.app_state.db.get_channel_invites_for_user(user.id),
1070                        self.app_state.db.remote_projects_update(user.id),
1071                    )
1072                    .await?;
1073
1074                {
1075                    let mut pool = self.connection_pool.lock();
1076                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1077                    for membership in &channels_for_user.channel_memberships {
1078                        pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1079                    }
1080                    self.peer.send(
1081                        connection_id,
1082                        build_initial_contacts_update(contacts, &pool),
1083                    )?;
1084                    self.peer.send(
1085                        connection_id,
1086                        build_update_user_channels(&channels_for_user),
1087                    )?;
1088                    self.peer.send(
1089                        connection_id,
1090                        build_channels_update(channels_for_user, channel_invites),
1091                    )?;
1092                }
1093                send_remote_projects_update(user.id, remote_projects, session).await;
1094
1095                if let Some(incoming_call) =
1096                    self.app_state.db.incoming_call_for_user(user.id).await?
1097                {
1098                    self.peer.send(connection_id, incoming_call)?;
1099                }
1100
1101                update_user_contacts(user.id, &session).await?;
1102            }
1103            Principal::DevServer(dev_server) => {
1104                {
1105                    let mut pool = self.connection_pool.lock();
1106                    if pool.dev_server_connection_id(dev_server.id).is_some() {
1107                        return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1108                    };
1109                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1110                }
1111
1112                let projects = self
1113                    .app_state
1114                    .db
1115                    .get_remote_projects_for_dev_server(dev_server.id)
1116                    .await?;
1117                self.peer
1118                    .send(connection_id, proto::DevServerInstructions { projects })?;
1119
1120                let status = self
1121                    .app_state
1122                    .db
1123                    .remote_projects_update(dev_server.user_id)
1124                    .await?;
1125                send_remote_projects_update(dev_server.user_id, status, &session).await;
1126            }
1127        }
1128
1129        Ok(())
1130    }
1131
1132    pub async fn invite_code_redeemed(
1133        self: &Arc<Self>,
1134        inviter_id: UserId,
1135        invitee_id: UserId,
1136    ) -> Result<()> {
1137        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1138            if let Some(code) = &user.invite_code {
1139                let pool = self.connection_pool.lock();
1140                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1141                for connection_id in pool.user_connection_ids(inviter_id) {
1142                    self.peer.send(
1143                        connection_id,
1144                        proto::UpdateContacts {
1145                            contacts: vec![invitee_contact.clone()],
1146                            ..Default::default()
1147                        },
1148                    )?;
1149                    self.peer.send(
1150                        connection_id,
1151                        proto::UpdateInviteInfo {
1152                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1153                            count: user.invite_count as u32,
1154                        },
1155                    )?;
1156                }
1157            }
1158        }
1159        Ok(())
1160    }
1161
1162    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1163        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1164            if let Some(invite_code) = &user.invite_code {
1165                let pool = self.connection_pool.lock();
1166                for connection_id in pool.user_connection_ids(user_id) {
1167                    self.peer.send(
1168                        connection_id,
1169                        proto::UpdateInviteInfo {
1170                            url: format!(
1171                                "{}{}",
1172                                self.app_state.config.invite_link_prefix, invite_code
1173                            ),
1174                            count: user.invite_count as u32,
1175                        },
1176                    )?;
1177                }
1178            }
1179        }
1180        Ok(())
1181    }
1182
1183    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1184        ServerSnapshot {
1185            connection_pool: ConnectionPoolGuard {
1186                guard: self.connection_pool.lock(),
1187                _not_send: PhantomData,
1188            },
1189            peer: &self.peer,
1190        }
1191    }
1192}
1193
1194impl<'a> Deref for ConnectionPoolGuard<'a> {
1195    type Target = ConnectionPool;
1196
1197    fn deref(&self) -> &Self::Target {
1198        &self.guard
1199    }
1200}
1201
1202impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1203    fn deref_mut(&mut self) -> &mut Self::Target {
1204        &mut self.guard
1205    }
1206}
1207
1208impl<'a> Drop for ConnectionPoolGuard<'a> {
1209    fn drop(&mut self) {
1210        #[cfg(test)]
1211        self.check_invariants();
1212    }
1213}
1214
1215fn broadcast<F>(
1216    sender_id: Option<ConnectionId>,
1217    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1218    mut f: F,
1219) where
1220    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1221{
1222    for receiver_id in receiver_ids {
1223        if Some(receiver_id) != sender_id {
1224            if let Err(error) = f(receiver_id) {
1225                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1226            }
1227        }
1228    }
1229}
1230
1231pub struct ProtocolVersion(u32);
1232
1233impl Header for ProtocolVersion {
1234    fn name() -> &'static HeaderName {
1235        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1236        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1237    }
1238
1239    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1240    where
1241        Self: Sized,
1242        I: Iterator<Item = &'i axum::http::HeaderValue>,
1243    {
1244        let version = values
1245            .next()
1246            .ok_or_else(axum::headers::Error::invalid)?
1247            .to_str()
1248            .map_err(|_| axum::headers::Error::invalid())?
1249            .parse()
1250            .map_err(|_| axum::headers::Error::invalid())?;
1251        Ok(Self(version))
1252    }
1253
1254    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1255        values.extend([self.0.to_string().parse().unwrap()]);
1256    }
1257}
1258
1259pub struct AppVersionHeader(SemanticVersion);
1260impl Header for AppVersionHeader {
1261    fn name() -> &'static HeaderName {
1262        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1263        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1264    }
1265
1266    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1267    where
1268        Self: Sized,
1269        I: Iterator<Item = &'i axum::http::HeaderValue>,
1270    {
1271        let version = values
1272            .next()
1273            .ok_or_else(axum::headers::Error::invalid)?
1274            .to_str()
1275            .map_err(|_| axum::headers::Error::invalid())?
1276            .parse()
1277            .map_err(|_| axum::headers::Error::invalid())?;
1278        Ok(Self(version))
1279    }
1280
1281    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1282        values.extend([self.0.to_string().parse().unwrap()]);
1283    }
1284}
1285
1286pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1287    Router::new()
1288        .route("/rpc", get(handle_websocket_request))
1289        .layer(
1290            ServiceBuilder::new()
1291                .layer(Extension(server.app_state.clone()))
1292                .layer(middleware::from_fn(auth::validate_header)),
1293        )
1294        .route("/metrics", get(handle_metrics))
1295        .layer(Extension(server))
1296}
1297
1298pub async fn handle_websocket_request(
1299    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1300    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1301    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1302    Extension(server): Extension<Arc<Server>>,
1303    Extension(principal): Extension<Principal>,
1304    ws: WebSocketUpgrade,
1305) -> axum::response::Response {
1306    if protocol_version != rpc::PROTOCOL_VERSION {
1307        return (
1308            StatusCode::UPGRADE_REQUIRED,
1309            "client must be upgraded".to_string(),
1310        )
1311            .into_response();
1312    }
1313
1314    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1315        return (
1316            StatusCode::UPGRADE_REQUIRED,
1317            "no version header found".to_string(),
1318        )
1319            .into_response();
1320    };
1321
1322    if !version.can_collaborate() {
1323        return (
1324            StatusCode::UPGRADE_REQUIRED,
1325            "client must be upgraded".to_string(),
1326        )
1327            .into_response();
1328    }
1329
1330    let socket_address = socket_address.to_string();
1331    ws.on_upgrade(move |socket| {
1332        let socket = socket
1333            .map_ok(to_tungstenite_message)
1334            .err_into()
1335            .with(|message| async move { Ok(to_axum_message(message)) });
1336        let connection = Connection::new(Box::pin(socket));
1337        async move {
1338            server
1339                .handle_connection(
1340                    connection,
1341                    socket_address,
1342                    principal,
1343                    version,
1344                    None,
1345                    Executor::Production,
1346                )
1347                .await;
1348        }
1349    })
1350}
1351
1352pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1353    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1354    let connections_metric = CONNECTIONS_METRIC
1355        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1356
1357    let connections = server
1358        .connection_pool
1359        .lock()
1360        .connections()
1361        .filter(|connection| !connection.admin)
1362        .count();
1363    connections_metric.set(connections as _);
1364
1365    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1366    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1367        register_int_gauge!(
1368            "shared_projects",
1369            "number of open projects with one or more guests"
1370        )
1371        .unwrap()
1372    });
1373
1374    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1375    shared_projects_metric.set(shared_projects as _);
1376
1377    let encoder = prometheus::TextEncoder::new();
1378    let metric_families = prometheus::gather();
1379    let encoded_metrics = encoder
1380        .encode_to_string(&metric_families)
1381        .map_err(|err| anyhow!("{}", err))?;
1382    Ok(encoded_metrics)
1383}
1384
1385#[instrument(err, skip(executor))]
1386async fn connection_lost(
1387    session: Session,
1388    mut teardown: watch::Receiver<bool>,
1389    executor: Executor,
1390) -> Result<()> {
1391    session.peer.disconnect(session.connection_id);
1392    session
1393        .connection_pool()
1394        .await
1395        .remove_connection(session.connection_id)?;
1396
1397    session
1398        .db()
1399        .await
1400        .connection_lost(session.connection_id)
1401        .await
1402        .trace_err();
1403
1404    futures::select_biased! {
1405        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1406            match &session.principal {
1407                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1408                    let session = session.for_user().unwrap();
1409
1410                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1411                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1412                    leave_channel_buffers_for_session(&session)
1413                        .await
1414                        .trace_err();
1415
1416                    if !session
1417                        .connection_pool()
1418                        .await
1419                        .is_user_online(session.user_id())
1420                    {
1421                        let db = session.db().await;
1422                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1423                            room_updated(&room, &session.peer);
1424                        }
1425                    }
1426
1427                    update_user_contacts(session.user_id(), &session).await?;
1428                },
1429            Principal::DevServer(_) => {
1430                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1431            },
1432        }
1433        },
1434        _ = teardown.changed().fuse() => {}
1435    }
1436
1437    Ok(())
1438}
1439
1440/// Acknowledges a ping from a client, used to keep the connection alive.
1441async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1442    response.send(proto::Ack {})?;
1443    Ok(())
1444}
1445
1446/// Creates a new room for calling (outside of channels)
1447async fn create_room(
1448    _request: proto::CreateRoom,
1449    response: Response<proto::CreateRoom>,
1450    session: UserSession,
1451) -> Result<()> {
1452    let live_kit_room = nanoid::nanoid!(30);
1453
1454    let live_kit_connection_info = util::maybe!(async {
1455        let live_kit = session.live_kit_client.as_ref();
1456        let live_kit = live_kit?;
1457        let user_id = session.user_id().to_string();
1458
1459        let token = live_kit
1460            .room_token(&live_kit_room, &user_id.to_string())
1461            .trace_err()?;
1462
1463        Some(proto::LiveKitConnectionInfo {
1464            server_url: live_kit.url().into(),
1465            token,
1466            can_publish: true,
1467        })
1468    })
1469    .await;
1470
1471    let room = session
1472        .db()
1473        .await
1474        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1475        .await?;
1476
1477    response.send(proto::CreateRoomResponse {
1478        room: Some(room.clone()),
1479        live_kit_connection_info,
1480    })?;
1481
1482    update_user_contacts(session.user_id(), &session).await?;
1483    Ok(())
1484}
1485
1486/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1487async fn join_room(
1488    request: proto::JoinRoom,
1489    response: Response<proto::JoinRoom>,
1490    session: UserSession,
1491) -> Result<()> {
1492    let room_id = RoomId::from_proto(request.id);
1493
1494    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1495
1496    if let Some(channel_id) = channel_id {
1497        return join_channel_internal(channel_id, Box::new(response), session).await;
1498    }
1499
1500    let joined_room = {
1501        let room = session
1502            .db()
1503            .await
1504            .join_room(room_id, session.user_id(), session.connection_id)
1505            .await?;
1506        room_updated(&room.room, &session.peer);
1507        room.into_inner()
1508    };
1509
1510    for connection_id in session
1511        .connection_pool()
1512        .await
1513        .user_connection_ids(session.user_id())
1514    {
1515        session
1516            .peer
1517            .send(
1518                connection_id,
1519                proto::CallCanceled {
1520                    room_id: room_id.to_proto(),
1521                },
1522            )
1523            .trace_err();
1524    }
1525
1526    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1527        if let Some(token) = live_kit
1528            .room_token(
1529                &joined_room.room.live_kit_room,
1530                &session.user_id().to_string(),
1531            )
1532            .trace_err()
1533        {
1534            Some(proto::LiveKitConnectionInfo {
1535                server_url: live_kit.url().into(),
1536                token,
1537                can_publish: true,
1538            })
1539        } else {
1540            None
1541        }
1542    } else {
1543        None
1544    };
1545
1546    response.send(proto::JoinRoomResponse {
1547        room: Some(joined_room.room),
1548        channel_id: None,
1549        live_kit_connection_info,
1550    })?;
1551
1552    update_user_contacts(session.user_id(), &session).await?;
1553    Ok(())
1554}
1555
1556/// Rejoin room is used to reconnect to a room after connection errors.
1557async fn rejoin_room(
1558    request: proto::RejoinRoom,
1559    response: Response<proto::RejoinRoom>,
1560    session: UserSession,
1561) -> Result<()> {
1562    let room;
1563    let channel;
1564    {
1565        let mut rejoined_room = session
1566            .db()
1567            .await
1568            .rejoin_room(request, session.user_id(), session.connection_id)
1569            .await?;
1570
1571        response.send(proto::RejoinRoomResponse {
1572            room: Some(rejoined_room.room.clone()),
1573            reshared_projects: rejoined_room
1574                .reshared_projects
1575                .iter()
1576                .map(|project| proto::ResharedProject {
1577                    id: project.id.to_proto(),
1578                    collaborators: project
1579                        .collaborators
1580                        .iter()
1581                        .map(|collaborator| collaborator.to_proto())
1582                        .collect(),
1583                })
1584                .collect(),
1585            rejoined_projects: rejoined_room
1586                .rejoined_projects
1587                .iter()
1588                .map(|rejoined_project| rejoined_project.to_proto())
1589                .collect(),
1590        })?;
1591        room_updated(&rejoined_room.room, &session.peer);
1592
1593        for project in &rejoined_room.reshared_projects {
1594            for collaborator in &project.collaborators {
1595                session
1596                    .peer
1597                    .send(
1598                        collaborator.connection_id,
1599                        proto::UpdateProjectCollaborator {
1600                            project_id: project.id.to_proto(),
1601                            old_peer_id: Some(project.old_connection_id.into()),
1602                            new_peer_id: Some(session.connection_id.into()),
1603                        },
1604                    )
1605                    .trace_err();
1606            }
1607
1608            broadcast(
1609                Some(session.connection_id),
1610                project
1611                    .collaborators
1612                    .iter()
1613                    .map(|collaborator| collaborator.connection_id),
1614                |connection_id| {
1615                    session.peer.forward_send(
1616                        session.connection_id,
1617                        connection_id,
1618                        proto::UpdateProject {
1619                            project_id: project.id.to_proto(),
1620                            worktrees: project.worktrees.clone(),
1621                        },
1622                    )
1623                },
1624            );
1625        }
1626
1627        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1628
1629        let rejoined_room = rejoined_room.into_inner();
1630
1631        room = rejoined_room.room;
1632        channel = rejoined_room.channel;
1633    }
1634
1635    if let Some(channel) = channel {
1636        channel_updated(
1637            &channel,
1638            &room,
1639            &session.peer,
1640            &*session.connection_pool().await,
1641        );
1642    }
1643
1644    update_user_contacts(session.user_id(), &session).await?;
1645    Ok(())
1646}
1647
1648fn notify_rejoined_projects(
1649    rejoined_projects: &mut Vec<RejoinedProject>,
1650    session: &UserSession,
1651) -> Result<()> {
1652    for project in rejoined_projects.iter() {
1653        for collaborator in &project.collaborators {
1654            session
1655                .peer
1656                .send(
1657                    collaborator.connection_id,
1658                    proto::UpdateProjectCollaborator {
1659                        project_id: project.id.to_proto(),
1660                        old_peer_id: Some(project.old_connection_id.into()),
1661                        new_peer_id: Some(session.connection_id.into()),
1662                    },
1663                )
1664                .trace_err();
1665        }
1666    }
1667
1668    for project in rejoined_projects {
1669        for worktree in mem::take(&mut project.worktrees) {
1670            #[cfg(any(test, feature = "test-support"))]
1671            const MAX_CHUNK_SIZE: usize = 2;
1672            #[cfg(not(any(test, feature = "test-support")))]
1673            const MAX_CHUNK_SIZE: usize = 256;
1674
1675            // Stream this worktree's entries.
1676            let message = proto::UpdateWorktree {
1677                project_id: project.id.to_proto(),
1678                worktree_id: worktree.id,
1679                abs_path: worktree.abs_path.clone(),
1680                root_name: worktree.root_name,
1681                updated_entries: worktree.updated_entries,
1682                removed_entries: worktree.removed_entries,
1683                scan_id: worktree.scan_id,
1684                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1685                updated_repositories: worktree.updated_repositories,
1686                removed_repositories: worktree.removed_repositories,
1687            };
1688            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1689                session.peer.send(session.connection_id, update.clone())?;
1690            }
1691
1692            // Stream this worktree's diagnostics.
1693            for summary in worktree.diagnostic_summaries {
1694                session.peer.send(
1695                    session.connection_id,
1696                    proto::UpdateDiagnosticSummary {
1697                        project_id: project.id.to_proto(),
1698                        worktree_id: worktree.id,
1699                        summary: Some(summary),
1700                    },
1701                )?;
1702            }
1703
1704            for settings_file in worktree.settings_files {
1705                session.peer.send(
1706                    session.connection_id,
1707                    proto::UpdateWorktreeSettings {
1708                        project_id: project.id.to_proto(),
1709                        worktree_id: worktree.id,
1710                        path: settings_file.path,
1711                        content: Some(settings_file.content),
1712                    },
1713                )?;
1714            }
1715        }
1716
1717        for language_server in &project.language_servers {
1718            session.peer.send(
1719                session.connection_id,
1720                proto::UpdateLanguageServer {
1721                    project_id: project.id.to_proto(),
1722                    language_server_id: language_server.id,
1723                    variant: Some(
1724                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1725                            proto::LspDiskBasedDiagnosticsUpdated {},
1726                        ),
1727                    ),
1728                },
1729            )?;
1730        }
1731    }
1732    Ok(())
1733}
1734
1735/// leave room disconnects from the room.
1736async fn leave_room(
1737    _: proto::LeaveRoom,
1738    response: Response<proto::LeaveRoom>,
1739    session: UserSession,
1740) -> Result<()> {
1741    leave_room_for_session(&session, session.connection_id).await?;
1742    response.send(proto::Ack {})?;
1743    Ok(())
1744}
1745
1746/// Updates the permissions of someone else in the room.
1747async fn set_room_participant_role(
1748    request: proto::SetRoomParticipantRole,
1749    response: Response<proto::SetRoomParticipantRole>,
1750    session: UserSession,
1751) -> Result<()> {
1752    let user_id = UserId::from_proto(request.user_id);
1753    let role = ChannelRole::from(request.role());
1754
1755    let (live_kit_room, can_publish) = {
1756        let room = session
1757            .db()
1758            .await
1759            .set_room_participant_role(
1760                session.user_id(),
1761                RoomId::from_proto(request.room_id),
1762                user_id,
1763                role,
1764            )
1765            .await?;
1766
1767        let live_kit_room = room.live_kit_room.clone();
1768        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1769        room_updated(&room, &session.peer);
1770        (live_kit_room, can_publish)
1771    };
1772
1773    if let Some(live_kit) = session.live_kit_client.as_ref() {
1774        live_kit
1775            .update_participant(
1776                live_kit_room.clone(),
1777                request.user_id.to_string(),
1778                live_kit_server::proto::ParticipantPermission {
1779                    can_subscribe: true,
1780                    can_publish,
1781                    can_publish_data: can_publish,
1782                    hidden: false,
1783                    recorder: false,
1784                },
1785            )
1786            .await
1787            .trace_err();
1788    }
1789
1790    response.send(proto::Ack {})?;
1791    Ok(())
1792}
1793
1794/// Call someone else into the current room
1795async fn call(
1796    request: proto::Call,
1797    response: Response<proto::Call>,
1798    session: UserSession,
1799) -> Result<()> {
1800    let room_id = RoomId::from_proto(request.room_id);
1801    let calling_user_id = session.user_id();
1802    let calling_connection_id = session.connection_id;
1803    let called_user_id = UserId::from_proto(request.called_user_id);
1804    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1805    if !session
1806        .db()
1807        .await
1808        .has_contact(calling_user_id, called_user_id)
1809        .await?
1810    {
1811        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1812    }
1813
1814    let incoming_call = {
1815        let (room, incoming_call) = &mut *session
1816            .db()
1817            .await
1818            .call(
1819                room_id,
1820                calling_user_id,
1821                calling_connection_id,
1822                called_user_id,
1823                initial_project_id,
1824            )
1825            .await?;
1826        room_updated(&room, &session.peer);
1827        mem::take(incoming_call)
1828    };
1829    update_user_contacts(called_user_id, &session).await?;
1830
1831    let mut calls = session
1832        .connection_pool()
1833        .await
1834        .user_connection_ids(called_user_id)
1835        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1836        .collect::<FuturesUnordered<_>>();
1837
1838    while let Some(call_response) = calls.next().await {
1839        match call_response.as_ref() {
1840            Ok(_) => {
1841                response.send(proto::Ack {})?;
1842                return Ok(());
1843            }
1844            Err(_) => {
1845                call_response.trace_err();
1846            }
1847        }
1848    }
1849
1850    {
1851        let room = session
1852            .db()
1853            .await
1854            .call_failed(room_id, called_user_id)
1855            .await?;
1856        room_updated(&room, &session.peer);
1857    }
1858    update_user_contacts(called_user_id, &session).await?;
1859
1860    Err(anyhow!("failed to ring user"))?
1861}
1862
1863/// Cancel an outgoing call.
1864async fn cancel_call(
1865    request: proto::CancelCall,
1866    response: Response<proto::CancelCall>,
1867    session: UserSession,
1868) -> Result<()> {
1869    let called_user_id = UserId::from_proto(request.called_user_id);
1870    let room_id = RoomId::from_proto(request.room_id);
1871    {
1872        let room = session
1873            .db()
1874            .await
1875            .cancel_call(room_id, session.connection_id, called_user_id)
1876            .await?;
1877        room_updated(&room, &session.peer);
1878    }
1879
1880    for connection_id in session
1881        .connection_pool()
1882        .await
1883        .user_connection_ids(called_user_id)
1884    {
1885        session
1886            .peer
1887            .send(
1888                connection_id,
1889                proto::CallCanceled {
1890                    room_id: room_id.to_proto(),
1891                },
1892            )
1893            .trace_err();
1894    }
1895    response.send(proto::Ack {})?;
1896
1897    update_user_contacts(called_user_id, &session).await?;
1898    Ok(())
1899}
1900
1901/// Decline an incoming call.
1902async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1903    let room_id = RoomId::from_proto(message.room_id);
1904    {
1905        let room = session
1906            .db()
1907            .await
1908            .decline_call(Some(room_id), session.user_id())
1909            .await?
1910            .ok_or_else(|| anyhow!("failed to decline call"))?;
1911        room_updated(&room, &session.peer);
1912    }
1913
1914    for connection_id in session
1915        .connection_pool()
1916        .await
1917        .user_connection_ids(session.user_id())
1918    {
1919        session
1920            .peer
1921            .send(
1922                connection_id,
1923                proto::CallCanceled {
1924                    room_id: room_id.to_proto(),
1925                },
1926            )
1927            .trace_err();
1928    }
1929    update_user_contacts(session.user_id(), &session).await?;
1930    Ok(())
1931}
1932
1933/// Updates other participants in the room with your current location.
1934async fn update_participant_location(
1935    request: proto::UpdateParticipantLocation,
1936    response: Response<proto::UpdateParticipantLocation>,
1937    session: UserSession,
1938) -> Result<()> {
1939    let room_id = RoomId::from_proto(request.room_id);
1940    let location = request
1941        .location
1942        .ok_or_else(|| anyhow!("invalid location"))?;
1943
1944    let db = session.db().await;
1945    let room = db
1946        .update_room_participant_location(room_id, session.connection_id, location)
1947        .await?;
1948
1949    room_updated(&room, &session.peer);
1950    response.send(proto::Ack {})?;
1951    Ok(())
1952}
1953
1954/// Share a project into the room.
1955async fn share_project(
1956    request: proto::ShareProject,
1957    response: Response<proto::ShareProject>,
1958    session: UserSession,
1959) -> Result<()> {
1960    let (project_id, room) = &*session
1961        .db()
1962        .await
1963        .share_project(
1964            RoomId::from_proto(request.room_id),
1965            session.connection_id,
1966            &request.worktrees,
1967            request
1968                .remote_project_id
1969                .map(|id| RemoteProjectId::from_proto(id)),
1970        )
1971        .await?;
1972    response.send(proto::ShareProjectResponse {
1973        project_id: project_id.to_proto(),
1974    })?;
1975    room_updated(&room, &session.peer);
1976
1977    Ok(())
1978}
1979
1980/// Unshare a project from the room.
1981async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1982    let project_id = ProjectId::from_proto(message.project_id);
1983    unshare_project_internal(
1984        project_id,
1985        session.connection_id,
1986        session.user_id(),
1987        &session,
1988    )
1989    .await
1990}
1991
1992async fn unshare_project_internal(
1993    project_id: ProjectId,
1994    connection_id: ConnectionId,
1995    user_id: Option<UserId>,
1996    session: &Session,
1997) -> Result<()> {
1998    let (room, guest_connection_ids) = &*session
1999        .db()
2000        .await
2001        .unshare_project(project_id, connection_id, user_id)
2002        .await?;
2003
2004    let message = proto::UnshareProject {
2005        project_id: project_id.to_proto(),
2006    };
2007
2008    broadcast(
2009        Some(connection_id),
2010        guest_connection_ids.iter().copied(),
2011        |conn_id| session.peer.send(conn_id, message.clone()),
2012    );
2013    if let Some(room) = room {
2014        room_updated(room, &session.peer);
2015    }
2016
2017    Ok(())
2018}
2019
2020/// DevServer makes a project available online
2021async fn share_remote_project(
2022    request: proto::ShareRemoteProject,
2023    response: Response<proto::ShareRemoteProject>,
2024    session: DevServerSession,
2025) -> Result<()> {
2026    let (remote_project, user_id, status) = session
2027        .db()
2028        .await
2029        .share_remote_project(
2030            RemoteProjectId::from_proto(request.remote_project_id),
2031            session.dev_server_id(),
2032            session.connection_id,
2033            &request.worktrees,
2034        )
2035        .await?;
2036    let Some(project_id) = remote_project.project_id else {
2037        return Err(anyhow!("failed to share remote project"))?;
2038    };
2039
2040    send_remote_projects_update(user_id, status, &session).await;
2041
2042    response.send(proto::ShareProjectResponse { project_id })?;
2043
2044    Ok(())
2045}
2046
2047/// Join someone elses shared project.
2048async fn join_project(
2049    request: proto::JoinProject,
2050    response: Response<proto::JoinProject>,
2051    session: UserSession,
2052) -> Result<()> {
2053    let project_id = ProjectId::from_proto(request.project_id);
2054
2055    tracing::info!(%project_id, "join project");
2056
2057    let db = session.db().await;
2058    let (project, replica_id) = &mut *db
2059        .join_project(project_id, session.connection_id, session.user_id())
2060        .await?;
2061    drop(db);
2062    tracing::info!(%project_id, "join remote project");
2063    join_project_internal(response, session, project, replica_id)
2064}
2065
2066trait JoinProjectInternalResponse {
2067    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2068}
2069impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2070    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2071        Response::<proto::JoinProject>::send(self, result)
2072    }
2073}
2074impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2075    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2076        Response::<proto::JoinHostedProject>::send(self, result)
2077    }
2078}
2079
2080fn join_project_internal(
2081    response: impl JoinProjectInternalResponse,
2082    session: UserSession,
2083    project: &mut Project,
2084    replica_id: &ReplicaId,
2085) -> Result<()> {
2086    let collaborators = project
2087        .collaborators
2088        .iter()
2089        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2090        .map(|collaborator| collaborator.to_proto())
2091        .collect::<Vec<_>>();
2092    let project_id = project.id;
2093    let guest_user_id = session.user_id();
2094
2095    let worktrees = project
2096        .worktrees
2097        .iter()
2098        .map(|(id, worktree)| proto::WorktreeMetadata {
2099            id: *id,
2100            root_name: worktree.root_name.clone(),
2101            visible: worktree.visible,
2102            abs_path: worktree.abs_path.clone(),
2103        })
2104        .collect::<Vec<_>>();
2105
2106    let add_project_collaborator = proto::AddProjectCollaborator {
2107        project_id: project_id.to_proto(),
2108        collaborator: Some(proto::Collaborator {
2109            peer_id: Some(session.connection_id.into()),
2110            replica_id: replica_id.0 as u32,
2111            user_id: guest_user_id.to_proto(),
2112        }),
2113    };
2114
2115    for collaborator in &collaborators {
2116        session
2117            .peer
2118            .send(
2119                collaborator.peer_id.unwrap().into(),
2120                add_project_collaborator.clone(),
2121            )
2122            .trace_err();
2123    }
2124
2125    // First, we send the metadata associated with each worktree.
2126    response.send(proto::JoinProjectResponse {
2127        project_id: project.id.0 as u64,
2128        worktrees: worktrees.clone(),
2129        replica_id: replica_id.0 as u32,
2130        collaborators: collaborators.clone(),
2131        language_servers: project.language_servers.clone(),
2132        role: project.role.into(),
2133        remote_project_id: project
2134            .remote_project_id
2135            .map(|remote_project_id| remote_project_id.0 as u64),
2136    })?;
2137
2138    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2139        #[cfg(any(test, feature = "test-support"))]
2140        const MAX_CHUNK_SIZE: usize = 2;
2141        #[cfg(not(any(test, feature = "test-support")))]
2142        const MAX_CHUNK_SIZE: usize = 256;
2143
2144        // Stream this worktree's entries.
2145        let message = proto::UpdateWorktree {
2146            project_id: project_id.to_proto(),
2147            worktree_id,
2148            abs_path: worktree.abs_path.clone(),
2149            root_name: worktree.root_name,
2150            updated_entries: worktree.entries,
2151            removed_entries: Default::default(),
2152            scan_id: worktree.scan_id,
2153            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2154            updated_repositories: worktree.repository_entries.into_values().collect(),
2155            removed_repositories: Default::default(),
2156        };
2157        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2158            session.peer.send(session.connection_id, update.clone())?;
2159        }
2160
2161        // Stream this worktree's diagnostics.
2162        for summary in worktree.diagnostic_summaries {
2163            session.peer.send(
2164                session.connection_id,
2165                proto::UpdateDiagnosticSummary {
2166                    project_id: project_id.to_proto(),
2167                    worktree_id: worktree.id,
2168                    summary: Some(summary),
2169                },
2170            )?;
2171        }
2172
2173        for settings_file in worktree.settings_files {
2174            session.peer.send(
2175                session.connection_id,
2176                proto::UpdateWorktreeSettings {
2177                    project_id: project_id.to_proto(),
2178                    worktree_id: worktree.id,
2179                    path: settings_file.path,
2180                    content: Some(settings_file.content),
2181                },
2182            )?;
2183        }
2184    }
2185
2186    for language_server in &project.language_servers {
2187        session.peer.send(
2188            session.connection_id,
2189            proto::UpdateLanguageServer {
2190                project_id: project_id.to_proto(),
2191                language_server_id: language_server.id,
2192                variant: Some(
2193                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2194                        proto::LspDiskBasedDiagnosticsUpdated {},
2195                    ),
2196                ),
2197            },
2198        )?;
2199    }
2200
2201    Ok(())
2202}
2203
2204/// Leave someone elses shared project.
2205async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2206    let sender_id = session.connection_id;
2207    let project_id = ProjectId::from_proto(request.project_id);
2208    let db = session.db().await;
2209    if db.is_hosted_project(project_id).await? {
2210        let project = db.leave_hosted_project(project_id, sender_id).await?;
2211        project_left(&project, &session);
2212        return Ok(());
2213    }
2214
2215    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2216    tracing::info!(
2217        %project_id,
2218        "leave project"
2219    );
2220
2221    project_left(&project, &session);
2222    if let Some(room) = room {
2223        room_updated(&room, &session.peer);
2224    }
2225
2226    Ok(())
2227}
2228
2229async fn join_hosted_project(
2230    request: proto::JoinHostedProject,
2231    response: Response<proto::JoinHostedProject>,
2232    session: UserSession,
2233) -> Result<()> {
2234    let (mut project, replica_id) = session
2235        .db()
2236        .await
2237        .join_hosted_project(
2238            ProjectId(request.project_id as i32),
2239            session.user_id(),
2240            session.connection_id,
2241        )
2242        .await?;
2243
2244    join_project_internal(response, session, &mut project, &replica_id)
2245}
2246
2247async fn create_remote_project(
2248    request: proto::CreateRemoteProject,
2249    response: Response<proto::CreateRemoteProject>,
2250    session: UserSession,
2251) -> Result<()> {
2252    let dev_server_id = DevServerId(request.dev_server_id as i32);
2253    let dev_server_connection_id = session
2254        .connection_pool()
2255        .await
2256        .dev_server_connection_id(dev_server_id);
2257    let Some(dev_server_connection_id) = dev_server_connection_id else {
2258        Err(ErrorCode::DevServerOffline
2259            .message("Cannot create a remote project when the dev server is offline".to_string())
2260            .anyhow())?
2261    };
2262
2263    let path = request.path.clone();
2264    //Check that the path exists on the dev server
2265    session
2266        .peer
2267        .forward_request(
2268            session.connection_id,
2269            dev_server_connection_id,
2270            proto::ValidateRemoteProjectRequest { path: path.clone() },
2271        )
2272        .await?;
2273
2274    let (remote_project, update) = session
2275        .db()
2276        .await
2277        .create_remote_project(
2278            DevServerId(request.dev_server_id as i32),
2279            &request.path,
2280            session.user_id(),
2281        )
2282        .await?;
2283
2284    let projects = session
2285        .db()
2286        .await
2287        .get_remote_projects_for_dev_server(remote_project.dev_server_id)
2288        .await?;
2289
2290    session.peer.send(
2291        dev_server_connection_id,
2292        proto::DevServerInstructions { projects },
2293    )?;
2294
2295    send_remote_projects_update(session.user_id(), update, &session).await;
2296
2297    response.send(proto::CreateRemoteProjectResponse {
2298        remote_project: Some(remote_project.to_proto(None)),
2299    })?;
2300    Ok(())
2301}
2302
2303async fn create_dev_server(
2304    request: proto::CreateDevServer,
2305    response: Response<proto::CreateDevServer>,
2306    session: UserSession,
2307) -> Result<()> {
2308    let access_token = auth::random_token();
2309    let hashed_access_token = auth::hash_access_token(&access_token);
2310
2311    let (dev_server, status) = session
2312        .db()
2313        .await
2314        .create_dev_server(&request.name, &hashed_access_token, session.user_id())
2315        .await?;
2316
2317    send_remote_projects_update(session.user_id(), status, &session).await;
2318
2319    response.send(proto::CreateDevServerResponse {
2320        dev_server_id: dev_server.id.0 as u64,
2321        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2322        name: request.name.clone(),
2323    })?;
2324    Ok(())
2325}
2326
2327async fn delete_dev_server(
2328    request: proto::DeleteDevServer,
2329    response: Response<proto::DeleteDevServer>,
2330    session: UserSession,
2331) -> Result<()> {
2332    let dev_server_id = DevServerId(request.dev_server_id as i32);
2333    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2334    if dev_server.user_id != session.user_id() {
2335        return Err(anyhow!(ErrorCode::Forbidden))?;
2336    }
2337
2338    let connection_id = session
2339        .connection_pool()
2340        .await
2341        .dev_server_connection_id(dev_server_id);
2342    if let Some(connection_id) = connection_id {
2343        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2344        session
2345            .peer
2346            .send(connection_id, proto::ShutdownDevServer {})?;
2347    }
2348
2349    let status = session
2350        .db()
2351        .await
2352        .delete_dev_server(dev_server_id, session.user_id())
2353        .await?;
2354
2355    send_remote_projects_update(session.user_id(), status, &session).await;
2356
2357    response.send(proto::Ack {})?;
2358    Ok(())
2359}
2360
2361async fn rejoin_remote_projects(
2362    request: proto::RejoinRemoteProjects,
2363    response: Response<proto::RejoinRemoteProjects>,
2364    session: UserSession,
2365) -> Result<()> {
2366    let mut rejoined_projects = {
2367        let db = session.db().await;
2368        db.rejoin_remote_projects(
2369            &request.rejoined_projects,
2370            session.user_id(),
2371            session.0.connection_id,
2372        )
2373        .await?
2374    };
2375    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2376
2377    response.send(proto::RejoinRemoteProjectsResponse {
2378        rejoined_projects: rejoined_projects
2379            .into_iter()
2380            .map(|project| project.to_proto())
2381            .collect(),
2382    })
2383}
2384
2385async fn reconnect_dev_server(
2386    request: proto::ReconnectDevServer,
2387    response: Response<proto::ReconnectDevServer>,
2388    session: DevServerSession,
2389) -> Result<()> {
2390    let reshared_projects = {
2391        let db = session.db().await;
2392        db.reshare_remote_projects(
2393            &request.reshared_projects,
2394            session.dev_server_id(),
2395            session.0.connection_id,
2396        )
2397        .await?
2398    };
2399
2400    for project in &reshared_projects {
2401        for collaborator in &project.collaborators {
2402            session
2403                .peer
2404                .send(
2405                    collaborator.connection_id,
2406                    proto::UpdateProjectCollaborator {
2407                        project_id: project.id.to_proto(),
2408                        old_peer_id: Some(project.old_connection_id.into()),
2409                        new_peer_id: Some(session.connection_id.into()),
2410                    },
2411                )
2412                .trace_err();
2413        }
2414
2415        broadcast(
2416            Some(session.connection_id),
2417            project
2418                .collaborators
2419                .iter()
2420                .map(|collaborator| collaborator.connection_id),
2421            |connection_id| {
2422                session.peer.forward_send(
2423                    session.connection_id,
2424                    connection_id,
2425                    proto::UpdateProject {
2426                        project_id: project.id.to_proto(),
2427                        worktrees: project.worktrees.clone(),
2428                    },
2429                )
2430            },
2431        );
2432    }
2433
2434    response.send(proto::ReconnectDevServerResponse {
2435        reshared_projects: reshared_projects
2436            .iter()
2437            .map(|project| proto::ResharedProject {
2438                id: project.id.to_proto(),
2439                collaborators: project
2440                    .collaborators
2441                    .iter()
2442                    .map(|collaborator| collaborator.to_proto())
2443                    .collect(),
2444            })
2445            .collect(),
2446    })?;
2447
2448    Ok(())
2449}
2450
2451async fn shutdown_dev_server(
2452    _: proto::ShutdownDevServer,
2453    response: Response<proto::ShutdownDevServer>,
2454    session: DevServerSession,
2455) -> Result<()> {
2456    response.send(proto::Ack {})?;
2457    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await
2458}
2459
2460async fn shutdown_dev_server_internal(
2461    dev_server_id: DevServerId,
2462    connection_id: ConnectionId,
2463    session: &Session,
2464) -> Result<()> {
2465    let (remote_projects, dev_server) = {
2466        let db = session.db().await;
2467        let remote_projects = db.get_remote_projects_for_dev_server(dev_server_id).await?;
2468        let dev_server = db.get_dev_server(dev_server_id).await?;
2469        (remote_projects, dev_server)
2470    };
2471
2472    for project_id in remote_projects.iter().filter_map(|p| p.project_id) {
2473        unshare_project_internal(
2474            ProjectId::from_proto(project_id),
2475            connection_id,
2476            None,
2477            session,
2478        )
2479        .await?;
2480    }
2481
2482    session
2483        .connection_pool()
2484        .await
2485        .set_dev_server_offline(dev_server_id);
2486
2487    let status = session
2488        .db()
2489        .await
2490        .remote_projects_update(dev_server.user_id)
2491        .await?;
2492    send_remote_projects_update(dev_server.user_id, status, &session).await;
2493
2494    Ok(())
2495}
2496
2497/// Updates other participants with changes to the project
2498async fn update_project(
2499    request: proto::UpdateProject,
2500    response: Response<proto::UpdateProject>,
2501    session: Session,
2502) -> Result<()> {
2503    let project_id = ProjectId::from_proto(request.project_id);
2504    let (room, guest_connection_ids) = &*session
2505        .db()
2506        .await
2507        .update_project(project_id, session.connection_id, &request.worktrees)
2508        .await?;
2509    broadcast(
2510        Some(session.connection_id),
2511        guest_connection_ids.iter().copied(),
2512        |connection_id| {
2513            session
2514                .peer
2515                .forward_send(session.connection_id, connection_id, request.clone())
2516        },
2517    );
2518    if let Some(room) = room {
2519        room_updated(&room, &session.peer);
2520    }
2521    response.send(proto::Ack {})?;
2522
2523    Ok(())
2524}
2525
2526/// Updates other participants with changes to the worktree
2527async fn update_worktree(
2528    request: proto::UpdateWorktree,
2529    response: Response<proto::UpdateWorktree>,
2530    session: Session,
2531) -> Result<()> {
2532    let guest_connection_ids = session
2533        .db()
2534        .await
2535        .update_worktree(&request, session.connection_id)
2536        .await?;
2537
2538    broadcast(
2539        Some(session.connection_id),
2540        guest_connection_ids.iter().copied(),
2541        |connection_id| {
2542            session
2543                .peer
2544                .forward_send(session.connection_id, connection_id, request.clone())
2545        },
2546    );
2547    response.send(proto::Ack {})?;
2548    Ok(())
2549}
2550
2551/// Updates other participants with changes to the diagnostics
2552async fn update_diagnostic_summary(
2553    message: proto::UpdateDiagnosticSummary,
2554    session: Session,
2555) -> Result<()> {
2556    let guest_connection_ids = session
2557        .db()
2558        .await
2559        .update_diagnostic_summary(&message, session.connection_id)
2560        .await?;
2561
2562    broadcast(
2563        Some(session.connection_id),
2564        guest_connection_ids.iter().copied(),
2565        |connection_id| {
2566            session
2567                .peer
2568                .forward_send(session.connection_id, connection_id, message.clone())
2569        },
2570    );
2571
2572    Ok(())
2573}
2574
2575/// Updates other participants with changes to the worktree settings
2576async fn update_worktree_settings(
2577    message: proto::UpdateWorktreeSettings,
2578    session: Session,
2579) -> Result<()> {
2580    let guest_connection_ids = session
2581        .db()
2582        .await
2583        .update_worktree_settings(&message, session.connection_id)
2584        .await?;
2585
2586    broadcast(
2587        Some(session.connection_id),
2588        guest_connection_ids.iter().copied(),
2589        |connection_id| {
2590            session
2591                .peer
2592                .forward_send(session.connection_id, connection_id, message.clone())
2593        },
2594    );
2595
2596    Ok(())
2597}
2598
2599/// Notify other participants that a  language server has started.
2600async fn start_language_server(
2601    request: proto::StartLanguageServer,
2602    session: Session,
2603) -> Result<()> {
2604    let guest_connection_ids = session
2605        .db()
2606        .await
2607        .start_language_server(&request, session.connection_id)
2608        .await?;
2609
2610    broadcast(
2611        Some(session.connection_id),
2612        guest_connection_ids.iter().copied(),
2613        |connection_id| {
2614            session
2615                .peer
2616                .forward_send(session.connection_id, connection_id, request.clone())
2617        },
2618    );
2619    Ok(())
2620}
2621
2622/// Notify other participants that a language server has changed.
2623async fn update_language_server(
2624    request: proto::UpdateLanguageServer,
2625    session: Session,
2626) -> Result<()> {
2627    let project_id = ProjectId::from_proto(request.project_id);
2628    let project_connection_ids = session
2629        .db()
2630        .await
2631        .project_connection_ids(project_id, session.connection_id, true)
2632        .await?;
2633    broadcast(
2634        Some(session.connection_id),
2635        project_connection_ids.iter().copied(),
2636        |connection_id| {
2637            session
2638                .peer
2639                .forward_send(session.connection_id, connection_id, request.clone())
2640        },
2641    );
2642    Ok(())
2643}
2644
2645/// forward a project request to the host. These requests should be read only
2646/// as guests are allowed to send them.
2647async fn forward_read_only_project_request<T>(
2648    request: T,
2649    response: Response<T>,
2650    session: UserSession,
2651) -> Result<()>
2652where
2653    T: EntityMessage + RequestMessage,
2654{
2655    let project_id = ProjectId::from_proto(request.remote_entity_id());
2656    let host_connection_id = session
2657        .db()
2658        .await
2659        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2660        .await?;
2661    let payload = session
2662        .peer
2663        .forward_request(session.connection_id, host_connection_id, request)
2664        .await?;
2665    response.send(payload)?;
2666    Ok(())
2667}
2668
2669/// forward a project request to the host. These requests are disallowed
2670/// for guests.
2671async fn forward_mutating_project_request<T>(
2672    request: T,
2673    response: Response<T>,
2674    session: UserSession,
2675) -> Result<()>
2676where
2677    T: EntityMessage + RequestMessage,
2678{
2679    let project_id = ProjectId::from_proto(request.remote_entity_id());
2680    let host_connection_id = session
2681        .db()
2682        .await
2683        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2684        .await?;
2685    let payload = session
2686        .peer
2687        .forward_request(session.connection_id, host_connection_id, request)
2688        .await?;
2689    response.send(payload)?;
2690    Ok(())
2691}
2692
2693/// Notify other participants that a new buffer has been created
2694async fn create_buffer_for_peer(
2695    request: proto::CreateBufferForPeer,
2696    session: Session,
2697) -> Result<()> {
2698    session
2699        .db()
2700        .await
2701        .check_user_is_project_host(
2702            ProjectId::from_proto(request.project_id),
2703            session.connection_id,
2704        )
2705        .await?;
2706    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2707    session
2708        .peer
2709        .forward_send(session.connection_id, peer_id.into(), request)?;
2710    Ok(())
2711}
2712
2713/// Notify other participants that a buffer has been updated. This is
2714/// allowed for guests as long as the update is limited to selections.
2715async fn update_buffer(
2716    request: proto::UpdateBuffer,
2717    response: Response<proto::UpdateBuffer>,
2718    session: Session,
2719) -> Result<()> {
2720    let project_id = ProjectId::from_proto(request.project_id);
2721    let mut capability = Capability::ReadOnly;
2722
2723    for op in request.operations.iter() {
2724        match op.variant {
2725            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2726            Some(_) => capability = Capability::ReadWrite,
2727        }
2728    }
2729
2730    let host = {
2731        let guard = session
2732            .db()
2733            .await
2734            .connections_for_buffer_update(
2735                project_id,
2736                session.principal_id(),
2737                session.connection_id,
2738                capability,
2739            )
2740            .await?;
2741
2742        let (host, guests) = &*guard;
2743
2744        broadcast(
2745            Some(session.connection_id),
2746            guests.clone(),
2747            |connection_id| {
2748                session
2749                    .peer
2750                    .forward_send(session.connection_id, connection_id, request.clone())
2751            },
2752        );
2753
2754        *host
2755    };
2756
2757    if host != session.connection_id {
2758        session
2759            .peer
2760            .forward_request(session.connection_id, host, request.clone())
2761            .await?;
2762    }
2763
2764    response.send(proto::Ack {})?;
2765    Ok(())
2766}
2767
2768/// Notify other participants that a project has been updated.
2769async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2770    request: T,
2771    session: Session,
2772) -> Result<()> {
2773    let project_id = ProjectId::from_proto(request.remote_entity_id());
2774    let project_connection_ids = session
2775        .db()
2776        .await
2777        .project_connection_ids(project_id, session.connection_id, false)
2778        .await?;
2779
2780    broadcast(
2781        Some(session.connection_id),
2782        project_connection_ids.iter().copied(),
2783        |connection_id| {
2784            session
2785                .peer
2786                .forward_send(session.connection_id, connection_id, request.clone())
2787        },
2788    );
2789    Ok(())
2790}
2791
2792/// Start following another user in a call.
2793async fn follow(
2794    request: proto::Follow,
2795    response: Response<proto::Follow>,
2796    session: UserSession,
2797) -> Result<()> {
2798    let room_id = RoomId::from_proto(request.room_id);
2799    let project_id = request.project_id.map(ProjectId::from_proto);
2800    let leader_id = request
2801        .leader_id
2802        .ok_or_else(|| anyhow!("invalid leader id"))?
2803        .into();
2804    let follower_id = session.connection_id;
2805
2806    session
2807        .db()
2808        .await
2809        .check_room_participants(room_id, leader_id, session.connection_id)
2810        .await?;
2811
2812    let response_payload = session
2813        .peer
2814        .forward_request(session.connection_id, leader_id, request)
2815        .await?;
2816    response.send(response_payload)?;
2817
2818    if let Some(project_id) = project_id {
2819        let room = session
2820            .db()
2821            .await
2822            .follow(room_id, project_id, leader_id, follower_id)
2823            .await?;
2824        room_updated(&room, &session.peer);
2825    }
2826
2827    Ok(())
2828}
2829
2830/// Stop following another user in a call.
2831async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2832    let room_id = RoomId::from_proto(request.room_id);
2833    let project_id = request.project_id.map(ProjectId::from_proto);
2834    let leader_id = request
2835        .leader_id
2836        .ok_or_else(|| anyhow!("invalid leader id"))?
2837        .into();
2838    let follower_id = session.connection_id;
2839
2840    session
2841        .db()
2842        .await
2843        .check_room_participants(room_id, leader_id, session.connection_id)
2844        .await?;
2845
2846    session
2847        .peer
2848        .forward_send(session.connection_id, leader_id, request)?;
2849
2850    if let Some(project_id) = project_id {
2851        let room = session
2852            .db()
2853            .await
2854            .unfollow(room_id, project_id, leader_id, follower_id)
2855            .await?;
2856        room_updated(&room, &session.peer);
2857    }
2858
2859    Ok(())
2860}
2861
2862/// Notify everyone following you of your current location.
2863async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2864    let room_id = RoomId::from_proto(request.room_id);
2865    let database = session.db.lock().await;
2866
2867    let connection_ids = if let Some(project_id) = request.project_id {
2868        let project_id = ProjectId::from_proto(project_id);
2869        database
2870            .project_connection_ids(project_id, session.connection_id, true)
2871            .await?
2872    } else {
2873        database
2874            .room_connection_ids(room_id, session.connection_id)
2875            .await?
2876    };
2877
2878    // For now, don't send view update messages back to that view's current leader.
2879    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2880        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2881        _ => None,
2882    });
2883
2884    for connection_id in connection_ids.iter().cloned() {
2885        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2886            session
2887                .peer
2888                .forward_send(session.connection_id, connection_id, request.clone())?;
2889        }
2890    }
2891    Ok(())
2892}
2893
2894/// Get public data about users.
2895async fn get_users(
2896    request: proto::GetUsers,
2897    response: Response<proto::GetUsers>,
2898    session: Session,
2899) -> Result<()> {
2900    let user_ids = request
2901        .user_ids
2902        .into_iter()
2903        .map(UserId::from_proto)
2904        .collect();
2905    let users = session
2906        .db()
2907        .await
2908        .get_users_by_ids(user_ids)
2909        .await?
2910        .into_iter()
2911        .map(|user| proto::User {
2912            id: user.id.to_proto(),
2913            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2914            github_login: user.github_login,
2915        })
2916        .collect();
2917    response.send(proto::UsersResponse { users })?;
2918    Ok(())
2919}
2920
2921/// Search for users (to invite) buy Github login
2922async fn fuzzy_search_users(
2923    request: proto::FuzzySearchUsers,
2924    response: Response<proto::FuzzySearchUsers>,
2925    session: UserSession,
2926) -> Result<()> {
2927    let query = request.query;
2928    let users = match query.len() {
2929        0 => vec![],
2930        1 | 2 => session
2931            .db()
2932            .await
2933            .get_user_by_github_login(&query)
2934            .await?
2935            .into_iter()
2936            .collect(),
2937        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2938    };
2939    let users = users
2940        .into_iter()
2941        .filter(|user| user.id != session.user_id())
2942        .map(|user| proto::User {
2943            id: user.id.to_proto(),
2944            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2945            github_login: user.github_login,
2946        })
2947        .collect();
2948    response.send(proto::UsersResponse { users })?;
2949    Ok(())
2950}
2951
2952/// Send a contact request to another user.
2953async fn request_contact(
2954    request: proto::RequestContact,
2955    response: Response<proto::RequestContact>,
2956    session: UserSession,
2957) -> Result<()> {
2958    let requester_id = session.user_id();
2959    let responder_id = UserId::from_proto(request.responder_id);
2960    if requester_id == responder_id {
2961        return Err(anyhow!("cannot add yourself as a contact"))?;
2962    }
2963
2964    let notifications = session
2965        .db()
2966        .await
2967        .send_contact_request(requester_id, responder_id)
2968        .await?;
2969
2970    // Update outgoing contact requests of requester
2971    let mut update = proto::UpdateContacts::default();
2972    update.outgoing_requests.push(responder_id.to_proto());
2973    for connection_id in session
2974        .connection_pool()
2975        .await
2976        .user_connection_ids(requester_id)
2977    {
2978        session.peer.send(connection_id, update.clone())?;
2979    }
2980
2981    // Update incoming contact requests of responder
2982    let mut update = proto::UpdateContacts::default();
2983    update
2984        .incoming_requests
2985        .push(proto::IncomingContactRequest {
2986            requester_id: requester_id.to_proto(),
2987        });
2988    let connection_pool = session.connection_pool().await;
2989    for connection_id in connection_pool.user_connection_ids(responder_id) {
2990        session.peer.send(connection_id, update.clone())?;
2991    }
2992
2993    send_notifications(&connection_pool, &session.peer, notifications);
2994
2995    response.send(proto::Ack {})?;
2996    Ok(())
2997}
2998
2999/// Accept or decline a contact request
3000async fn respond_to_contact_request(
3001    request: proto::RespondToContactRequest,
3002    response: Response<proto::RespondToContactRequest>,
3003    session: UserSession,
3004) -> Result<()> {
3005    let responder_id = session.user_id();
3006    let requester_id = UserId::from_proto(request.requester_id);
3007    let db = session.db().await;
3008    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3009        db.dismiss_contact_notification(responder_id, requester_id)
3010            .await?;
3011    } else {
3012        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3013
3014        let notifications = db
3015            .respond_to_contact_request(responder_id, requester_id, accept)
3016            .await?;
3017        let requester_busy = db.is_user_busy(requester_id).await?;
3018        let responder_busy = db.is_user_busy(responder_id).await?;
3019
3020        let pool = session.connection_pool().await;
3021        // Update responder with new contact
3022        let mut update = proto::UpdateContacts::default();
3023        if accept {
3024            update
3025                .contacts
3026                .push(contact_for_user(requester_id, requester_busy, &pool));
3027        }
3028        update
3029            .remove_incoming_requests
3030            .push(requester_id.to_proto());
3031        for connection_id in pool.user_connection_ids(responder_id) {
3032            session.peer.send(connection_id, update.clone())?;
3033        }
3034
3035        // Update requester with new contact
3036        let mut update = proto::UpdateContacts::default();
3037        if accept {
3038            update
3039                .contacts
3040                .push(contact_for_user(responder_id, responder_busy, &pool));
3041        }
3042        update
3043            .remove_outgoing_requests
3044            .push(responder_id.to_proto());
3045
3046        for connection_id in pool.user_connection_ids(requester_id) {
3047            session.peer.send(connection_id, update.clone())?;
3048        }
3049
3050        send_notifications(&pool, &session.peer, notifications);
3051    }
3052
3053    response.send(proto::Ack {})?;
3054    Ok(())
3055}
3056
3057/// Remove a contact.
3058async fn remove_contact(
3059    request: proto::RemoveContact,
3060    response: Response<proto::RemoveContact>,
3061    session: UserSession,
3062) -> Result<()> {
3063    let requester_id = session.user_id();
3064    let responder_id = UserId::from_proto(request.user_id);
3065    let db = session.db().await;
3066    let (contact_accepted, deleted_notification_id) =
3067        db.remove_contact(requester_id, responder_id).await?;
3068
3069    let pool = session.connection_pool().await;
3070    // Update outgoing contact requests of requester
3071    let mut update = proto::UpdateContacts::default();
3072    if contact_accepted {
3073        update.remove_contacts.push(responder_id.to_proto());
3074    } else {
3075        update
3076            .remove_outgoing_requests
3077            .push(responder_id.to_proto());
3078    }
3079    for connection_id in pool.user_connection_ids(requester_id) {
3080        session.peer.send(connection_id, update.clone())?;
3081    }
3082
3083    // Update incoming contact requests of responder
3084    let mut update = proto::UpdateContacts::default();
3085    if contact_accepted {
3086        update.remove_contacts.push(requester_id.to_proto());
3087    } else {
3088        update
3089            .remove_incoming_requests
3090            .push(requester_id.to_proto());
3091    }
3092    for connection_id in pool.user_connection_ids(responder_id) {
3093        session.peer.send(connection_id, update.clone())?;
3094        if let Some(notification_id) = deleted_notification_id {
3095            session.peer.send(
3096                connection_id,
3097                proto::DeleteNotification {
3098                    notification_id: notification_id.to_proto(),
3099                },
3100            )?;
3101        }
3102    }
3103
3104    response.send(proto::Ack {})?;
3105    Ok(())
3106}
3107
3108/// Creates a new channel.
3109async fn create_channel(
3110    request: proto::CreateChannel,
3111    response: Response<proto::CreateChannel>,
3112    session: UserSession,
3113) -> Result<()> {
3114    let db = session.db().await;
3115
3116    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3117    let (channel, membership) = db
3118        .create_channel(&request.name, parent_id, session.user_id())
3119        .await?;
3120
3121    let root_id = channel.root_id();
3122    let channel = Channel::from_model(channel);
3123
3124    response.send(proto::CreateChannelResponse {
3125        channel: Some(channel.to_proto()),
3126        parent_id: request.parent_id,
3127    })?;
3128
3129    let mut connection_pool = session.connection_pool().await;
3130    if let Some(membership) = membership {
3131        connection_pool.subscribe_to_channel(
3132            membership.user_id,
3133            membership.channel_id,
3134            membership.role,
3135        );
3136        let update = proto::UpdateUserChannels {
3137            channel_memberships: vec![proto::ChannelMembership {
3138                channel_id: membership.channel_id.to_proto(),
3139                role: membership.role.into(),
3140            }],
3141            ..Default::default()
3142        };
3143        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3144            session.peer.send(connection_id, update.clone())?;
3145        }
3146    }
3147
3148    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3149        if !role.can_see_channel(channel.visibility) {
3150            continue;
3151        }
3152
3153        let update = proto::UpdateChannels {
3154            channels: vec![channel.to_proto()],
3155            ..Default::default()
3156        };
3157        session.peer.send(connection_id, update.clone())?;
3158    }
3159
3160    Ok(())
3161}
3162
3163/// Delete a channel
3164async fn delete_channel(
3165    request: proto::DeleteChannel,
3166    response: Response<proto::DeleteChannel>,
3167    session: UserSession,
3168) -> Result<()> {
3169    let db = session.db().await;
3170
3171    let channel_id = request.channel_id;
3172    let (root_channel, removed_channels) = db
3173        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3174        .await?;
3175    response.send(proto::Ack {})?;
3176
3177    // Notify members of removed channels
3178    let mut update = proto::UpdateChannels::default();
3179    update
3180        .delete_channels
3181        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3182
3183    let connection_pool = session.connection_pool().await;
3184    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3185        session.peer.send(connection_id, update.clone())?;
3186    }
3187
3188    Ok(())
3189}
3190
3191/// Invite someone to join a channel.
3192async fn invite_channel_member(
3193    request: proto::InviteChannelMember,
3194    response: Response<proto::InviteChannelMember>,
3195    session: UserSession,
3196) -> Result<()> {
3197    let db = session.db().await;
3198    let channel_id = ChannelId::from_proto(request.channel_id);
3199    let invitee_id = UserId::from_proto(request.user_id);
3200    let InviteMemberResult {
3201        channel,
3202        notifications,
3203    } = db
3204        .invite_channel_member(
3205            channel_id,
3206            invitee_id,
3207            session.user_id(),
3208            request.role().into(),
3209        )
3210        .await?;
3211
3212    let update = proto::UpdateChannels {
3213        channel_invitations: vec![channel.to_proto()],
3214        ..Default::default()
3215    };
3216
3217    let connection_pool = session.connection_pool().await;
3218    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3219        session.peer.send(connection_id, update.clone())?;
3220    }
3221
3222    send_notifications(&connection_pool, &session.peer, notifications);
3223
3224    response.send(proto::Ack {})?;
3225    Ok(())
3226}
3227
3228/// remove someone from a channel
3229async fn remove_channel_member(
3230    request: proto::RemoveChannelMember,
3231    response: Response<proto::RemoveChannelMember>,
3232    session: UserSession,
3233) -> Result<()> {
3234    let db = session.db().await;
3235    let channel_id = ChannelId::from_proto(request.channel_id);
3236    let member_id = UserId::from_proto(request.user_id);
3237
3238    let RemoveChannelMemberResult {
3239        membership_update,
3240        notification_id,
3241    } = db
3242        .remove_channel_member(channel_id, member_id, session.user_id())
3243        .await?;
3244
3245    let mut connection_pool = session.connection_pool().await;
3246    notify_membership_updated(
3247        &mut connection_pool,
3248        membership_update,
3249        member_id,
3250        &session.peer,
3251    );
3252    for connection_id in connection_pool.user_connection_ids(member_id) {
3253        if let Some(notification_id) = notification_id {
3254            session
3255                .peer
3256                .send(
3257                    connection_id,
3258                    proto::DeleteNotification {
3259                        notification_id: notification_id.to_proto(),
3260                    },
3261                )
3262                .trace_err();
3263        }
3264    }
3265
3266    response.send(proto::Ack {})?;
3267    Ok(())
3268}
3269
3270/// Toggle the channel between public and private.
3271/// Care is taken to maintain the invariant that public channels only descend from public channels,
3272/// (though members-only channels can appear at any point in the hierarchy).
3273async fn set_channel_visibility(
3274    request: proto::SetChannelVisibility,
3275    response: Response<proto::SetChannelVisibility>,
3276    session: UserSession,
3277) -> Result<()> {
3278    let db = session.db().await;
3279    let channel_id = ChannelId::from_proto(request.channel_id);
3280    let visibility = request.visibility().into();
3281
3282    let channel_model = db
3283        .set_channel_visibility(channel_id, visibility, session.user_id())
3284        .await?;
3285    let root_id = channel_model.root_id();
3286    let channel = Channel::from_model(channel_model);
3287
3288    let mut connection_pool = session.connection_pool().await;
3289    for (user_id, role) in connection_pool
3290        .channel_user_ids(root_id)
3291        .collect::<Vec<_>>()
3292        .into_iter()
3293    {
3294        let update = if role.can_see_channel(channel.visibility) {
3295            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3296            proto::UpdateChannels {
3297                channels: vec![channel.to_proto()],
3298                ..Default::default()
3299            }
3300        } else {
3301            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3302            proto::UpdateChannels {
3303                delete_channels: vec![channel.id.to_proto()],
3304                ..Default::default()
3305            }
3306        };
3307
3308        for connection_id in connection_pool.user_connection_ids(user_id) {
3309            session.peer.send(connection_id, update.clone())?;
3310        }
3311    }
3312
3313    response.send(proto::Ack {})?;
3314    Ok(())
3315}
3316
3317/// Alter the role for a user in the channel.
3318async fn set_channel_member_role(
3319    request: proto::SetChannelMemberRole,
3320    response: Response<proto::SetChannelMemberRole>,
3321    session: UserSession,
3322) -> Result<()> {
3323    let db = session.db().await;
3324    let channel_id = ChannelId::from_proto(request.channel_id);
3325    let member_id = UserId::from_proto(request.user_id);
3326    let result = db
3327        .set_channel_member_role(
3328            channel_id,
3329            session.user_id(),
3330            member_id,
3331            request.role().into(),
3332        )
3333        .await?;
3334
3335    match result {
3336        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3337            let mut connection_pool = session.connection_pool().await;
3338            notify_membership_updated(
3339                &mut connection_pool,
3340                membership_update,
3341                member_id,
3342                &session.peer,
3343            )
3344        }
3345        db::SetMemberRoleResult::InviteUpdated(channel) => {
3346            let update = proto::UpdateChannels {
3347                channel_invitations: vec![channel.to_proto()],
3348                ..Default::default()
3349            };
3350
3351            for connection_id in session
3352                .connection_pool()
3353                .await
3354                .user_connection_ids(member_id)
3355            {
3356                session.peer.send(connection_id, update.clone())?;
3357            }
3358        }
3359    }
3360
3361    response.send(proto::Ack {})?;
3362    Ok(())
3363}
3364
3365/// Change the name of a channel
3366async fn rename_channel(
3367    request: proto::RenameChannel,
3368    response: Response<proto::RenameChannel>,
3369    session: UserSession,
3370) -> Result<()> {
3371    let db = session.db().await;
3372    let channel_id = ChannelId::from_proto(request.channel_id);
3373    let channel_model = db
3374        .rename_channel(channel_id, session.user_id(), &request.name)
3375        .await?;
3376    let root_id = channel_model.root_id();
3377    let channel = Channel::from_model(channel_model);
3378
3379    response.send(proto::RenameChannelResponse {
3380        channel: Some(channel.to_proto()),
3381    })?;
3382
3383    let connection_pool = session.connection_pool().await;
3384    let update = proto::UpdateChannels {
3385        channels: vec![channel.to_proto()],
3386        ..Default::default()
3387    };
3388    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3389        if role.can_see_channel(channel.visibility) {
3390            session.peer.send(connection_id, update.clone())?;
3391        }
3392    }
3393
3394    Ok(())
3395}
3396
3397/// Move a channel to a new parent.
3398async fn move_channel(
3399    request: proto::MoveChannel,
3400    response: Response<proto::MoveChannel>,
3401    session: UserSession,
3402) -> Result<()> {
3403    let channel_id = ChannelId::from_proto(request.channel_id);
3404    let to = ChannelId::from_proto(request.to);
3405
3406    let (root_id, channels) = session
3407        .db()
3408        .await
3409        .move_channel(channel_id, to, session.user_id())
3410        .await?;
3411
3412    let connection_pool = session.connection_pool().await;
3413    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3414        let channels = channels
3415            .iter()
3416            .filter_map(|channel| {
3417                if role.can_see_channel(channel.visibility) {
3418                    Some(channel.to_proto())
3419                } else {
3420                    None
3421                }
3422            })
3423            .collect::<Vec<_>>();
3424        if channels.is_empty() {
3425            continue;
3426        }
3427
3428        let update = proto::UpdateChannels {
3429            channels,
3430            ..Default::default()
3431        };
3432
3433        session.peer.send(connection_id, update.clone())?;
3434    }
3435
3436    response.send(Ack {})?;
3437    Ok(())
3438}
3439
3440/// Get the list of channel members
3441async fn get_channel_members(
3442    request: proto::GetChannelMembers,
3443    response: Response<proto::GetChannelMembers>,
3444    session: UserSession,
3445) -> Result<()> {
3446    let db = session.db().await;
3447    let channel_id = ChannelId::from_proto(request.channel_id);
3448    let members = db
3449        .get_channel_participant_details(channel_id, session.user_id())
3450        .await?;
3451    response.send(proto::GetChannelMembersResponse { members })?;
3452    Ok(())
3453}
3454
3455/// Accept or decline a channel invitation.
3456async fn respond_to_channel_invite(
3457    request: proto::RespondToChannelInvite,
3458    response: Response<proto::RespondToChannelInvite>,
3459    session: UserSession,
3460) -> Result<()> {
3461    let db = session.db().await;
3462    let channel_id = ChannelId::from_proto(request.channel_id);
3463    let RespondToChannelInvite {
3464        membership_update,
3465        notifications,
3466    } = db
3467        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3468        .await?;
3469
3470    let mut connection_pool = session.connection_pool().await;
3471    if let Some(membership_update) = membership_update {
3472        notify_membership_updated(
3473            &mut connection_pool,
3474            membership_update,
3475            session.user_id(),
3476            &session.peer,
3477        );
3478    } else {
3479        let update = proto::UpdateChannels {
3480            remove_channel_invitations: vec![channel_id.to_proto()],
3481            ..Default::default()
3482        };
3483
3484        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3485            session.peer.send(connection_id, update.clone())?;
3486        }
3487    };
3488
3489    send_notifications(&connection_pool, &session.peer, notifications);
3490
3491    response.send(proto::Ack {})?;
3492
3493    Ok(())
3494}
3495
3496/// Join the channels' room
3497async fn join_channel(
3498    request: proto::JoinChannel,
3499    response: Response<proto::JoinChannel>,
3500    session: UserSession,
3501) -> Result<()> {
3502    let channel_id = ChannelId::from_proto(request.channel_id);
3503    join_channel_internal(channel_id, Box::new(response), session).await
3504}
3505
3506trait JoinChannelInternalResponse {
3507    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3508}
3509impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3510    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3511        Response::<proto::JoinChannel>::send(self, result)
3512    }
3513}
3514impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3515    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3516        Response::<proto::JoinRoom>::send(self, result)
3517    }
3518}
3519
3520async fn join_channel_internal(
3521    channel_id: ChannelId,
3522    response: Box<impl JoinChannelInternalResponse>,
3523    session: UserSession,
3524) -> Result<()> {
3525    let joined_room = {
3526        let mut db = session.db().await;
3527        // If zed quits without leaving the room, and the user re-opens zed before the
3528        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3529        // room they were in.
3530        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3531            tracing::info!(
3532                stale_connection_id = %connection,
3533                "cleaning up stale connection",
3534            );
3535            drop(db);
3536            leave_room_for_session(&session, connection).await?;
3537            db = session.db().await;
3538        }
3539
3540        let (joined_room, membership_updated, role) = db
3541            .join_channel(channel_id, session.user_id(), session.connection_id)
3542            .await?;
3543
3544        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3545            let (can_publish, token) = if role == ChannelRole::Guest {
3546                (
3547                    false,
3548                    live_kit
3549                        .guest_token(
3550                            &joined_room.room.live_kit_room,
3551                            &session.user_id().to_string(),
3552                        )
3553                        .trace_err()?,
3554                )
3555            } else {
3556                (
3557                    true,
3558                    live_kit
3559                        .room_token(
3560                            &joined_room.room.live_kit_room,
3561                            &session.user_id().to_string(),
3562                        )
3563                        .trace_err()?,
3564                )
3565            };
3566
3567            Some(LiveKitConnectionInfo {
3568                server_url: live_kit.url().into(),
3569                token,
3570                can_publish,
3571            })
3572        });
3573
3574        response.send(proto::JoinRoomResponse {
3575            room: Some(joined_room.room.clone()),
3576            channel_id: joined_room
3577                .channel
3578                .as_ref()
3579                .map(|channel| channel.id.to_proto()),
3580            live_kit_connection_info,
3581        })?;
3582
3583        let mut connection_pool = session.connection_pool().await;
3584        if let Some(membership_updated) = membership_updated {
3585            notify_membership_updated(
3586                &mut connection_pool,
3587                membership_updated,
3588                session.user_id(),
3589                &session.peer,
3590            );
3591        }
3592
3593        room_updated(&joined_room.room, &session.peer);
3594
3595        joined_room
3596    };
3597
3598    channel_updated(
3599        &joined_room
3600            .channel
3601            .ok_or_else(|| anyhow!("channel not returned"))?,
3602        &joined_room.room,
3603        &session.peer,
3604        &*session.connection_pool().await,
3605    );
3606
3607    update_user_contacts(session.user_id(), &session).await?;
3608    Ok(())
3609}
3610
3611/// Start editing the channel notes
3612async fn join_channel_buffer(
3613    request: proto::JoinChannelBuffer,
3614    response: Response<proto::JoinChannelBuffer>,
3615    session: UserSession,
3616) -> Result<()> {
3617    let db = session.db().await;
3618    let channel_id = ChannelId::from_proto(request.channel_id);
3619
3620    let open_response = db
3621        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3622        .await?;
3623
3624    let collaborators = open_response.collaborators.clone();
3625    response.send(open_response)?;
3626
3627    let update = UpdateChannelBufferCollaborators {
3628        channel_id: channel_id.to_proto(),
3629        collaborators: collaborators.clone(),
3630    };
3631    channel_buffer_updated(
3632        session.connection_id,
3633        collaborators
3634            .iter()
3635            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3636        &update,
3637        &session.peer,
3638    );
3639
3640    Ok(())
3641}
3642
3643/// Edit the channel notes
3644async fn update_channel_buffer(
3645    request: proto::UpdateChannelBuffer,
3646    session: UserSession,
3647) -> Result<()> {
3648    let db = session.db().await;
3649    let channel_id = ChannelId::from_proto(request.channel_id);
3650
3651    let (collaborators, non_collaborators, epoch, version) = db
3652        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3653        .await?;
3654
3655    channel_buffer_updated(
3656        session.connection_id,
3657        collaborators,
3658        &proto::UpdateChannelBuffer {
3659            channel_id: channel_id.to_proto(),
3660            operations: request.operations,
3661        },
3662        &session.peer,
3663    );
3664
3665    let pool = &*session.connection_pool().await;
3666
3667    broadcast(
3668        None,
3669        non_collaborators
3670            .iter()
3671            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3672        |peer_id| {
3673            session.peer.send(
3674                peer_id,
3675                proto::UpdateChannels {
3676                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3677                        channel_id: channel_id.to_proto(),
3678                        epoch: epoch as u64,
3679                        version: version.clone(),
3680                    }],
3681                    ..Default::default()
3682                },
3683            )
3684        },
3685    );
3686
3687    Ok(())
3688}
3689
3690/// Rejoin the channel notes after a connection blip
3691async fn rejoin_channel_buffers(
3692    request: proto::RejoinChannelBuffers,
3693    response: Response<proto::RejoinChannelBuffers>,
3694    session: UserSession,
3695) -> Result<()> {
3696    let db = session.db().await;
3697    let buffers = db
3698        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3699        .await?;
3700
3701    for rejoined_buffer in &buffers {
3702        let collaborators_to_notify = rejoined_buffer
3703            .buffer
3704            .collaborators
3705            .iter()
3706            .filter_map(|c| Some(c.peer_id?.into()));
3707        channel_buffer_updated(
3708            session.connection_id,
3709            collaborators_to_notify,
3710            &proto::UpdateChannelBufferCollaborators {
3711                channel_id: rejoined_buffer.buffer.channel_id,
3712                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3713            },
3714            &session.peer,
3715        );
3716    }
3717
3718    response.send(proto::RejoinChannelBuffersResponse {
3719        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3720    })?;
3721
3722    Ok(())
3723}
3724
3725/// Stop editing the channel notes
3726async fn leave_channel_buffer(
3727    request: proto::LeaveChannelBuffer,
3728    response: Response<proto::LeaveChannelBuffer>,
3729    session: UserSession,
3730) -> Result<()> {
3731    let db = session.db().await;
3732    let channel_id = ChannelId::from_proto(request.channel_id);
3733
3734    let left_buffer = db
3735        .leave_channel_buffer(channel_id, session.connection_id)
3736        .await?;
3737
3738    response.send(Ack {})?;
3739
3740    channel_buffer_updated(
3741        session.connection_id,
3742        left_buffer.connections,
3743        &proto::UpdateChannelBufferCollaborators {
3744            channel_id: channel_id.to_proto(),
3745            collaborators: left_buffer.collaborators,
3746        },
3747        &session.peer,
3748    );
3749
3750    Ok(())
3751}
3752
3753fn channel_buffer_updated<T: EnvelopedMessage>(
3754    sender_id: ConnectionId,
3755    collaborators: impl IntoIterator<Item = ConnectionId>,
3756    message: &T,
3757    peer: &Peer,
3758) {
3759    broadcast(Some(sender_id), collaborators, |peer_id| {
3760        peer.send(peer_id, message.clone())
3761    });
3762}
3763
3764fn send_notifications(
3765    connection_pool: &ConnectionPool,
3766    peer: &Peer,
3767    notifications: db::NotificationBatch,
3768) {
3769    for (user_id, notification) in notifications {
3770        for connection_id in connection_pool.user_connection_ids(user_id) {
3771            if let Err(error) = peer.send(
3772                connection_id,
3773                proto::AddNotification {
3774                    notification: Some(notification.clone()),
3775                },
3776            ) {
3777                tracing::error!(
3778                    "failed to send notification to {:?} {}",
3779                    connection_id,
3780                    error
3781                );
3782            }
3783        }
3784    }
3785}
3786
3787/// Send a message to the channel
3788async fn send_channel_message(
3789    request: proto::SendChannelMessage,
3790    response: Response<proto::SendChannelMessage>,
3791    session: UserSession,
3792) -> Result<()> {
3793    // Validate the message body.
3794    let body = request.body.trim().to_string();
3795    if body.len() > MAX_MESSAGE_LEN {
3796        return Err(anyhow!("message is too long"))?;
3797    }
3798    if body.is_empty() {
3799        return Err(anyhow!("message can't be blank"))?;
3800    }
3801
3802    // TODO: adjust mentions if body is trimmed
3803
3804    let timestamp = OffsetDateTime::now_utc();
3805    let nonce = request
3806        .nonce
3807        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3808
3809    let channel_id = ChannelId::from_proto(request.channel_id);
3810    let CreatedChannelMessage {
3811        message_id,
3812        participant_connection_ids,
3813        channel_members,
3814        notifications,
3815    } = session
3816        .db()
3817        .await
3818        .create_channel_message(
3819            channel_id,
3820            session.user_id(),
3821            &body,
3822            &request.mentions,
3823            timestamp,
3824            nonce.clone().into(),
3825            match request.reply_to_message_id {
3826                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3827                None => None,
3828            },
3829        )
3830        .await?;
3831
3832    let message = proto::ChannelMessage {
3833        sender_id: session.user_id().to_proto(),
3834        id: message_id.to_proto(),
3835        body,
3836        mentions: request.mentions,
3837        timestamp: timestamp.unix_timestamp() as u64,
3838        nonce: Some(nonce),
3839        reply_to_message_id: request.reply_to_message_id,
3840        edited_at: None,
3841    };
3842    broadcast(
3843        Some(session.connection_id),
3844        participant_connection_ids,
3845        |connection| {
3846            session.peer.send(
3847                connection,
3848                proto::ChannelMessageSent {
3849                    channel_id: channel_id.to_proto(),
3850                    message: Some(message.clone()),
3851                },
3852            )
3853        },
3854    );
3855    response.send(proto::SendChannelMessageResponse {
3856        message: Some(message),
3857    })?;
3858
3859    let pool = &*session.connection_pool().await;
3860    broadcast(
3861        None,
3862        channel_members
3863            .iter()
3864            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3865        |peer_id| {
3866            session.peer.send(
3867                peer_id,
3868                proto::UpdateChannels {
3869                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3870                        channel_id: channel_id.to_proto(),
3871                        message_id: message_id.to_proto(),
3872                    }],
3873                    ..Default::default()
3874                },
3875            )
3876        },
3877    );
3878    send_notifications(pool, &session.peer, notifications);
3879
3880    Ok(())
3881}
3882
3883/// Delete a channel message
3884async fn remove_channel_message(
3885    request: proto::RemoveChannelMessage,
3886    response: Response<proto::RemoveChannelMessage>,
3887    session: UserSession,
3888) -> Result<()> {
3889    let channel_id = ChannelId::from_proto(request.channel_id);
3890    let message_id = MessageId::from_proto(request.message_id);
3891    let (connection_ids, existing_notification_ids) = session
3892        .db()
3893        .await
3894        .remove_channel_message(channel_id, message_id, session.user_id())
3895        .await?;
3896
3897    broadcast(
3898        Some(session.connection_id),
3899        connection_ids,
3900        move |connection| {
3901            session.peer.send(connection, request.clone())?;
3902
3903            for notification_id in &existing_notification_ids {
3904                session.peer.send(
3905                    connection,
3906                    proto::DeleteNotification {
3907                        notification_id: (*notification_id).to_proto(),
3908                    },
3909                )?;
3910            }
3911
3912            Ok(())
3913        },
3914    );
3915    response.send(proto::Ack {})?;
3916    Ok(())
3917}
3918
3919async fn update_channel_message(
3920    request: proto::UpdateChannelMessage,
3921    response: Response<proto::UpdateChannelMessage>,
3922    session: UserSession,
3923) -> Result<()> {
3924    let channel_id = ChannelId::from_proto(request.channel_id);
3925    let message_id = MessageId::from_proto(request.message_id);
3926    let updated_at = OffsetDateTime::now_utc();
3927    let UpdatedChannelMessage {
3928        message_id,
3929        participant_connection_ids,
3930        notifications,
3931        reply_to_message_id,
3932        timestamp,
3933        deleted_mention_notification_ids,
3934        updated_mention_notifications,
3935    } = session
3936        .db()
3937        .await
3938        .update_channel_message(
3939            channel_id,
3940            message_id,
3941            session.user_id(),
3942            request.body.as_str(),
3943            &request.mentions,
3944            updated_at,
3945        )
3946        .await?;
3947
3948    let nonce = request
3949        .nonce
3950        .clone()
3951        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3952
3953    let message = proto::ChannelMessage {
3954        sender_id: session.user_id().to_proto(),
3955        id: message_id.to_proto(),
3956        body: request.body.clone(),
3957        mentions: request.mentions.clone(),
3958        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3959        nonce: Some(nonce),
3960        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3961        edited_at: Some(updated_at.unix_timestamp() as u64),
3962    };
3963
3964    response.send(proto::Ack {})?;
3965
3966    let pool = &*session.connection_pool().await;
3967    broadcast(
3968        Some(session.connection_id),
3969        participant_connection_ids,
3970        |connection| {
3971            session.peer.send(
3972                connection,
3973                proto::ChannelMessageUpdate {
3974                    channel_id: channel_id.to_proto(),
3975                    message: Some(message.clone()),
3976                },
3977            )?;
3978
3979            for notification_id in &deleted_mention_notification_ids {
3980                session.peer.send(
3981                    connection,
3982                    proto::DeleteNotification {
3983                        notification_id: (*notification_id).to_proto(),
3984                    },
3985                )?;
3986            }
3987
3988            for notification in &updated_mention_notifications {
3989                session.peer.send(
3990                    connection,
3991                    proto::UpdateNotification {
3992                        notification: Some(notification.clone()),
3993                    },
3994                )?;
3995            }
3996
3997            Ok(())
3998        },
3999    );
4000
4001    send_notifications(pool, &session.peer, notifications);
4002
4003    Ok(())
4004}
4005
4006/// Mark a channel message as read
4007async fn acknowledge_channel_message(
4008    request: proto::AckChannelMessage,
4009    session: UserSession,
4010) -> Result<()> {
4011    let channel_id = ChannelId::from_proto(request.channel_id);
4012    let message_id = MessageId::from_proto(request.message_id);
4013    let notifications = session
4014        .db()
4015        .await
4016        .observe_channel_message(channel_id, session.user_id(), message_id)
4017        .await?;
4018    send_notifications(
4019        &*session.connection_pool().await,
4020        &session.peer,
4021        notifications,
4022    );
4023    Ok(())
4024}
4025
4026/// Mark a buffer version as synced
4027async fn acknowledge_buffer_version(
4028    request: proto::AckBufferOperation,
4029    session: UserSession,
4030) -> Result<()> {
4031    let buffer_id = BufferId::from_proto(request.buffer_id);
4032    session
4033        .db()
4034        .await
4035        .observe_buffer_version(
4036            buffer_id,
4037            session.user_id(),
4038            request.epoch as i32,
4039            &request.version,
4040        )
4041        .await?;
4042    Ok(())
4043}
4044
4045struct CompleteWithLanguageModelRateLimit;
4046
4047impl RateLimit for CompleteWithLanguageModelRateLimit {
4048    fn capacity() -> usize {
4049        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4050            .ok()
4051            .and_then(|v| v.parse().ok())
4052            .unwrap_or(120) // Picked arbitrarily
4053    }
4054
4055    fn refill_duration() -> chrono::Duration {
4056        chrono::Duration::hours(1)
4057    }
4058
4059    fn db_name() -> &'static str {
4060        "complete-with-language-model"
4061    }
4062}
4063
4064async fn complete_with_language_model(
4065    request: proto::CompleteWithLanguageModel,
4066    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4067    session: Session,
4068    open_ai_api_key: Option<Arc<str>>,
4069    google_ai_api_key: Option<Arc<str>>,
4070    anthropic_api_key: Option<Arc<str>>,
4071) -> Result<()> {
4072    let Some(session) = session.for_user() else {
4073        return Err(anyhow!("user not found"))?;
4074    };
4075    authorize_access_to_language_models(&session).await?;
4076    session
4077        .rate_limiter
4078        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4079        .await?;
4080
4081    if request.model.starts_with("gpt") {
4082        let api_key =
4083            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4084        complete_with_open_ai(request, response, session, api_key).await?;
4085    } else if request.model.starts_with("gemini") {
4086        let api_key = google_ai_api_key
4087            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4088        complete_with_google_ai(request, response, session, api_key).await?;
4089    } else if request.model.starts_with("claude") {
4090        let api_key = anthropic_api_key
4091            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4092        complete_with_anthropic(request, response, session, api_key).await?;
4093    }
4094
4095    Ok(())
4096}
4097
4098async fn complete_with_open_ai(
4099    request: proto::CompleteWithLanguageModel,
4100    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4101    session: UserSession,
4102    api_key: Arc<str>,
4103) -> Result<()> {
4104    let mut completion_stream = open_ai::stream_completion(
4105        &session.http_client,
4106        OPEN_AI_API_URL,
4107        &api_key,
4108        crate::ai::language_model_request_to_open_ai(request)?,
4109    )
4110    .await
4111    .context("open_ai::stream_completion request failed within collab")?;
4112
4113    while let Some(event) = completion_stream.next().await {
4114        let event = event?;
4115        response.send(proto::LanguageModelResponse {
4116            choices: event
4117                .choices
4118                .into_iter()
4119                .map(|choice| proto::LanguageModelChoiceDelta {
4120                    index: choice.index,
4121                    delta: Some(proto::LanguageModelResponseMessage {
4122                        role: choice.delta.role.map(|role| match role {
4123                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4124                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4125                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4126                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4127                        } as i32),
4128                        content: choice.delta.content,
4129                        tool_calls: choice
4130                            .delta
4131                            .tool_calls
4132                            .into_iter()
4133                            .map(|delta| proto::ToolCallDelta {
4134                                index: delta.index as u32,
4135                                id: delta.id,
4136                                variant: match delta.function {
4137                                    Some(function) => {
4138                                        let name = function.name;
4139                                        let arguments = function.arguments;
4140
4141                                        Some(proto::tool_call_delta::Variant::Function(
4142                                            proto::tool_call_delta::FunctionCallDelta {
4143                                                name,
4144                                                arguments,
4145                                            },
4146                                        ))
4147                                    }
4148                                    None => None,
4149                                },
4150                            })
4151                            .collect(),
4152                    }),
4153                    finish_reason: choice.finish_reason,
4154                })
4155                .collect(),
4156        })?;
4157    }
4158
4159    Ok(())
4160}
4161
4162async fn complete_with_google_ai(
4163    request: proto::CompleteWithLanguageModel,
4164    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4165    session: UserSession,
4166    api_key: Arc<str>,
4167) -> Result<()> {
4168    let mut stream = google_ai::stream_generate_content(
4169        &session.http_client,
4170        google_ai::API_URL,
4171        api_key.as_ref(),
4172        crate::ai::language_model_request_to_google_ai(request)?,
4173    )
4174    .await
4175    .context("google_ai::stream_generate_content request failed")?;
4176
4177    while let Some(event) = stream.next().await {
4178        let event = event?;
4179        response.send(proto::LanguageModelResponse {
4180            choices: event
4181                .candidates
4182                .unwrap_or_default()
4183                .into_iter()
4184                .map(|candidate| proto::LanguageModelChoiceDelta {
4185                    index: candidate.index as u32,
4186                    delta: Some(proto::LanguageModelResponseMessage {
4187                        role: Some(match candidate.content.role {
4188                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4189                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4190                        } as i32),
4191                        content: Some(
4192                            candidate
4193                                .content
4194                                .parts
4195                                .into_iter()
4196                                .filter_map(|part| match part {
4197                                    google_ai::Part::TextPart(part) => Some(part.text),
4198                                    google_ai::Part::InlineDataPart(_) => None,
4199                                })
4200                                .collect(),
4201                        ),
4202                        // Tool calls are not supported for Google
4203                        tool_calls: Vec::new(),
4204                    }),
4205                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4206                })
4207                .collect(),
4208        })?;
4209    }
4210
4211    Ok(())
4212}
4213
4214async fn complete_with_anthropic(
4215    request: proto::CompleteWithLanguageModel,
4216    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4217    session: UserSession,
4218    api_key: Arc<str>,
4219) -> Result<()> {
4220    let model = anthropic::Model::from_id(&request.model)?;
4221
4222    let mut system_message = String::new();
4223    let messages = request
4224        .messages
4225        .into_iter()
4226        .filter_map(|message| {
4227            match message.role() {
4228                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4229                    role: anthropic::Role::User,
4230                    content: message.content,
4231                }),
4232                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4233                    role: anthropic::Role::Assistant,
4234                    content: message.content,
4235                }),
4236                // Anthropic's API breaks system instructions out as a separate field rather
4237                // than having a system message role.
4238                LanguageModelRole::LanguageModelSystem => {
4239                    if !system_message.is_empty() {
4240                        system_message.push_str("\n\n");
4241                    }
4242                    system_message.push_str(&message.content);
4243
4244                    None
4245                }
4246                // We don't yet support tool calls for Anthropic
4247                LanguageModelRole::LanguageModelTool => None,
4248            }
4249        })
4250        .collect();
4251
4252    let mut stream = anthropic::stream_completion(
4253        &session.http_client,
4254        "https://api.anthropic.com",
4255        &api_key,
4256        anthropic::Request {
4257            model,
4258            messages,
4259            stream: true,
4260            system: system_message,
4261            max_tokens: 4092,
4262        },
4263    )
4264    .await?;
4265
4266    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4267
4268    while let Some(event) = stream.next().await {
4269        let event = event?;
4270
4271        match event {
4272            anthropic::ResponseEvent::MessageStart { message } => {
4273                if let Some(role) = message.role {
4274                    if role == "assistant" {
4275                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4276                    } else if role == "user" {
4277                        current_role = proto::LanguageModelRole::LanguageModelUser;
4278                    }
4279                }
4280            }
4281            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4282                match content_block {
4283                    anthropic::ContentBlock::Text { text } => {
4284                        if !text.is_empty() {
4285                            response.send(proto::LanguageModelResponse {
4286                                choices: vec![proto::LanguageModelChoiceDelta {
4287                                    index: 0,
4288                                    delta: Some(proto::LanguageModelResponseMessage {
4289                                        role: Some(current_role as i32),
4290                                        content: Some(text),
4291                                        tool_calls: Vec::new(),
4292                                    }),
4293                                    finish_reason: None,
4294                                }],
4295                            })?;
4296                        }
4297                    }
4298                }
4299            }
4300            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4301                anthropic::TextDelta::TextDelta { text } => {
4302                    response.send(proto::LanguageModelResponse {
4303                        choices: vec![proto::LanguageModelChoiceDelta {
4304                            index: 0,
4305                            delta: Some(proto::LanguageModelResponseMessage {
4306                                role: Some(current_role as i32),
4307                                content: Some(text),
4308                                tool_calls: Vec::new(),
4309                            }),
4310                            finish_reason: None,
4311                        }],
4312                    })?;
4313                }
4314            },
4315            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4316                if let Some(stop_reason) = delta.stop_reason {
4317                    response.send(proto::LanguageModelResponse {
4318                        choices: vec![proto::LanguageModelChoiceDelta {
4319                            index: 0,
4320                            delta: None,
4321                            finish_reason: Some(stop_reason),
4322                        }],
4323                    })?;
4324                }
4325            }
4326            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4327            anthropic::ResponseEvent::MessageStop {} => {}
4328            anthropic::ResponseEvent::Ping {} => {}
4329        }
4330    }
4331
4332    Ok(())
4333}
4334
4335struct CountTokensWithLanguageModelRateLimit;
4336
4337impl RateLimit for CountTokensWithLanguageModelRateLimit {
4338    fn capacity() -> usize {
4339        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4340            .ok()
4341            .and_then(|v| v.parse().ok())
4342            .unwrap_or(600) // Picked arbitrarily
4343    }
4344
4345    fn refill_duration() -> chrono::Duration {
4346        chrono::Duration::hours(1)
4347    }
4348
4349    fn db_name() -> &'static str {
4350        "count-tokens-with-language-model"
4351    }
4352}
4353
4354async fn count_tokens_with_language_model(
4355    request: proto::CountTokensWithLanguageModel,
4356    response: Response<proto::CountTokensWithLanguageModel>,
4357    session: UserSession,
4358    google_ai_api_key: Option<Arc<str>>,
4359) -> Result<()> {
4360    authorize_access_to_language_models(&session).await?;
4361
4362    if !request.model.starts_with("gemini") {
4363        return Err(anyhow!(
4364            "counting tokens for model: {:?} is not supported",
4365            request.model
4366        ))?;
4367    }
4368
4369    session
4370        .rate_limiter
4371        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4372        .await?;
4373
4374    let api_key = google_ai_api_key
4375        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4376    let tokens_response = google_ai::count_tokens(
4377        &session.http_client,
4378        google_ai::API_URL,
4379        &api_key,
4380        crate::ai::count_tokens_request_to_google_ai(request)?,
4381    )
4382    .await?;
4383    response.send(proto::CountTokensResponse {
4384        token_count: tokens_response.total_tokens as u32,
4385    })?;
4386    Ok(())
4387}
4388
4389struct ComputeEmbeddingsRateLimit;
4390
4391impl RateLimit for ComputeEmbeddingsRateLimit {
4392    fn capacity() -> usize {
4393        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4394            .ok()
4395            .and_then(|v| v.parse().ok())
4396            .unwrap_or(120) // Picked arbitrarily
4397    }
4398
4399    fn refill_duration() -> chrono::Duration {
4400        chrono::Duration::hours(1)
4401    }
4402
4403    fn db_name() -> &'static str {
4404        "compute-embeddings"
4405    }
4406}
4407
4408async fn compute_embeddings(
4409    request: proto::ComputeEmbeddings,
4410    response: Response<proto::ComputeEmbeddings>,
4411    session: UserSession,
4412    api_key: Option<Arc<str>>,
4413) -> Result<()> {
4414    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4415    authorize_access_to_language_models(&session).await?;
4416
4417    session
4418        .rate_limiter
4419        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4420        .await?;
4421
4422    let embeddings = match request.model.as_str() {
4423        "openai/text-embedding-3-small" => {
4424            open_ai::embed(
4425                &session.http_client,
4426                OPEN_AI_API_URL,
4427                &api_key,
4428                OpenAiEmbeddingModel::TextEmbedding3Small,
4429                request.texts.iter().map(|text| text.as_str()),
4430            )
4431            .await?
4432        }
4433        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4434    };
4435
4436    let embeddings = request
4437        .texts
4438        .iter()
4439        .map(|text| {
4440            let mut hasher = sha2::Sha256::new();
4441            hasher.update(text.as_bytes());
4442            let result = hasher.finalize();
4443            result.to_vec()
4444        })
4445        .zip(
4446            embeddings
4447                .data
4448                .into_iter()
4449                .map(|embedding| embedding.embedding),
4450        )
4451        .collect::<HashMap<_, _>>();
4452
4453    let db = session.db().await;
4454    db.save_embeddings(&request.model, &embeddings)
4455        .await
4456        .context("failed to save embeddings")
4457        .trace_err();
4458
4459    response.send(proto::ComputeEmbeddingsResponse {
4460        embeddings: embeddings
4461            .into_iter()
4462            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4463            .collect(),
4464    })?;
4465    Ok(())
4466}
4467
4468struct GetCachedEmbeddingsRateLimit;
4469
4470impl RateLimit for GetCachedEmbeddingsRateLimit {
4471    fn capacity() -> usize {
4472        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4473            .ok()
4474            .and_then(|v| v.parse().ok())
4475            .unwrap_or(120) // Picked arbitrarily
4476    }
4477
4478    fn refill_duration() -> chrono::Duration {
4479        chrono::Duration::hours(1)
4480    }
4481
4482    fn db_name() -> &'static str {
4483        "get-cached-embeddings"
4484    }
4485}
4486
4487async fn get_cached_embeddings(
4488    request: proto::GetCachedEmbeddings,
4489    response: Response<proto::GetCachedEmbeddings>,
4490    session: UserSession,
4491) -> Result<()> {
4492    authorize_access_to_language_models(&session).await?;
4493
4494    session
4495        .rate_limiter
4496        .check::<GetCachedEmbeddingsRateLimit>(session.user_id())
4497        .await?;
4498
4499    let db = session.db().await;
4500    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4501
4502    response.send(proto::GetCachedEmbeddingsResponse {
4503        embeddings: embeddings
4504            .into_iter()
4505            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4506            .collect(),
4507    })?;
4508    Ok(())
4509}
4510
4511async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4512    let db = session.db().await;
4513    let flags = db.get_user_flags(session.user_id()).await?;
4514    if flags.iter().any(|flag| flag == "language-models") {
4515        Ok(())
4516    } else {
4517        Err(anyhow!("permission denied"))?
4518    }
4519}
4520
4521/// Start receiving chat updates for a channel
4522async fn join_channel_chat(
4523    request: proto::JoinChannelChat,
4524    response: Response<proto::JoinChannelChat>,
4525    session: UserSession,
4526) -> Result<()> {
4527    let channel_id = ChannelId::from_proto(request.channel_id);
4528
4529    let db = session.db().await;
4530    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4531        .await?;
4532    let messages = db
4533        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4534        .await?;
4535    response.send(proto::JoinChannelChatResponse {
4536        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4537        messages,
4538    })?;
4539    Ok(())
4540}
4541
4542/// Stop receiving chat updates for a channel
4543async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4544    let channel_id = ChannelId::from_proto(request.channel_id);
4545    session
4546        .db()
4547        .await
4548        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4549        .await?;
4550    Ok(())
4551}
4552
4553/// Retrieve the chat history for a channel
4554async fn get_channel_messages(
4555    request: proto::GetChannelMessages,
4556    response: Response<proto::GetChannelMessages>,
4557    session: UserSession,
4558) -> Result<()> {
4559    let channel_id = ChannelId::from_proto(request.channel_id);
4560    let messages = session
4561        .db()
4562        .await
4563        .get_channel_messages(
4564            channel_id,
4565            session.user_id(),
4566            MESSAGE_COUNT_PER_PAGE,
4567            Some(MessageId::from_proto(request.before_message_id)),
4568        )
4569        .await?;
4570    response.send(proto::GetChannelMessagesResponse {
4571        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4572        messages,
4573    })?;
4574    Ok(())
4575}
4576
4577/// Retrieve specific chat messages
4578async fn get_channel_messages_by_id(
4579    request: proto::GetChannelMessagesById,
4580    response: Response<proto::GetChannelMessagesById>,
4581    session: UserSession,
4582) -> Result<()> {
4583    let message_ids = request
4584        .message_ids
4585        .iter()
4586        .map(|id| MessageId::from_proto(*id))
4587        .collect::<Vec<_>>();
4588    let messages = session
4589        .db()
4590        .await
4591        .get_channel_messages_by_id(session.user_id(), &message_ids)
4592        .await?;
4593    response.send(proto::GetChannelMessagesResponse {
4594        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4595        messages,
4596    })?;
4597    Ok(())
4598}
4599
4600/// Retrieve the current users notifications
4601async fn get_notifications(
4602    request: proto::GetNotifications,
4603    response: Response<proto::GetNotifications>,
4604    session: UserSession,
4605) -> Result<()> {
4606    let notifications = session
4607        .db()
4608        .await
4609        .get_notifications(
4610            session.user_id(),
4611            NOTIFICATION_COUNT_PER_PAGE,
4612            request
4613                .before_id
4614                .map(|id| db::NotificationId::from_proto(id)),
4615        )
4616        .await?;
4617    response.send(proto::GetNotificationsResponse {
4618        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4619        notifications,
4620    })?;
4621    Ok(())
4622}
4623
4624/// Mark notifications as read
4625async fn mark_notification_as_read(
4626    request: proto::MarkNotificationRead,
4627    response: Response<proto::MarkNotificationRead>,
4628    session: UserSession,
4629) -> Result<()> {
4630    let database = &session.db().await;
4631    let notifications = database
4632        .mark_notification_as_read_by_id(
4633            session.user_id(),
4634            NotificationId::from_proto(request.notification_id),
4635        )
4636        .await?;
4637    send_notifications(
4638        &*session.connection_pool().await,
4639        &session.peer,
4640        notifications,
4641    );
4642    response.send(proto::Ack {})?;
4643    Ok(())
4644}
4645
4646/// Get the current users information
4647async fn get_private_user_info(
4648    _request: proto::GetPrivateUserInfo,
4649    response: Response<proto::GetPrivateUserInfo>,
4650    session: UserSession,
4651) -> Result<()> {
4652    let db = session.db().await;
4653
4654    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4655    let user = db
4656        .get_user_by_id(session.user_id())
4657        .await?
4658        .ok_or_else(|| anyhow!("user not found"))?;
4659    let flags = db.get_user_flags(session.user_id()).await?;
4660
4661    response.send(proto::GetPrivateUserInfoResponse {
4662        metrics_id,
4663        staff: user.admin,
4664        flags,
4665    })?;
4666    Ok(())
4667}
4668
4669fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4670    match message {
4671        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4672        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4673        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4674        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4675        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4676            code: frame.code.into(),
4677            reason: frame.reason,
4678        })),
4679    }
4680}
4681
4682fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4683    match message {
4684        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4685        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4686        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4687        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4688        AxumMessage::Close(frame) => {
4689            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4690                code: frame.code.into(),
4691                reason: frame.reason,
4692            }))
4693        }
4694    }
4695}
4696
4697fn notify_membership_updated(
4698    connection_pool: &mut ConnectionPool,
4699    result: MembershipUpdated,
4700    user_id: UserId,
4701    peer: &Peer,
4702) {
4703    for membership in &result.new_channels.channel_memberships {
4704        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4705    }
4706    for channel_id in &result.removed_channels {
4707        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4708    }
4709
4710    let user_channels_update = proto::UpdateUserChannels {
4711        channel_memberships: result
4712            .new_channels
4713            .channel_memberships
4714            .iter()
4715            .map(|cm| proto::ChannelMembership {
4716                channel_id: cm.channel_id.to_proto(),
4717                role: cm.role.into(),
4718            })
4719            .collect(),
4720        ..Default::default()
4721    };
4722
4723    let mut update = build_channels_update(result.new_channels, vec![]);
4724    update.delete_channels = result
4725        .removed_channels
4726        .into_iter()
4727        .map(|id| id.to_proto())
4728        .collect();
4729    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4730
4731    for connection_id in connection_pool.user_connection_ids(user_id) {
4732        peer.send(connection_id, user_channels_update.clone())
4733            .trace_err();
4734        peer.send(connection_id, update.clone()).trace_err();
4735    }
4736}
4737
4738fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4739    proto::UpdateUserChannels {
4740        channel_memberships: channels
4741            .channel_memberships
4742            .iter()
4743            .map(|m| proto::ChannelMembership {
4744                channel_id: m.channel_id.to_proto(),
4745                role: m.role.into(),
4746            })
4747            .collect(),
4748        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4749        observed_channel_message_id: channels.observed_channel_messages.clone(),
4750    }
4751}
4752
4753fn build_channels_update(
4754    channels: ChannelsForUser,
4755    channel_invites: Vec<db::Channel>,
4756) -> proto::UpdateChannels {
4757    let mut update = proto::UpdateChannels::default();
4758
4759    for channel in channels.channels {
4760        update.channels.push(channel.to_proto());
4761    }
4762
4763    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4764    update.latest_channel_message_ids = channels.latest_channel_messages;
4765
4766    for (channel_id, participants) in channels.channel_participants {
4767        update
4768            .channel_participants
4769            .push(proto::ChannelParticipants {
4770                channel_id: channel_id.to_proto(),
4771                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4772            });
4773    }
4774
4775    for channel in channel_invites {
4776        update.channel_invitations.push(channel.to_proto());
4777    }
4778
4779    update.hosted_projects = channels.hosted_projects;
4780    update
4781}
4782
4783fn build_initial_contacts_update(
4784    contacts: Vec<db::Contact>,
4785    pool: &ConnectionPool,
4786) -> proto::UpdateContacts {
4787    let mut update = proto::UpdateContacts::default();
4788
4789    for contact in contacts {
4790        match contact {
4791            db::Contact::Accepted { user_id, busy } => {
4792                update.contacts.push(contact_for_user(user_id, busy, &pool));
4793            }
4794            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4795            db::Contact::Incoming { user_id } => {
4796                update
4797                    .incoming_requests
4798                    .push(proto::IncomingContactRequest {
4799                        requester_id: user_id.to_proto(),
4800                    })
4801            }
4802        }
4803    }
4804
4805    update
4806}
4807
4808fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4809    proto::Contact {
4810        user_id: user_id.to_proto(),
4811        online: pool.is_user_online(user_id),
4812        busy,
4813    }
4814}
4815
4816fn room_updated(room: &proto::Room, peer: &Peer) {
4817    broadcast(
4818        None,
4819        room.participants
4820            .iter()
4821            .filter_map(|participant| Some(participant.peer_id?.into())),
4822        |peer_id| {
4823            peer.send(
4824                peer_id,
4825                proto::RoomUpdated {
4826                    room: Some(room.clone()),
4827                },
4828            )
4829        },
4830    );
4831}
4832
4833fn channel_updated(
4834    channel: &db::channel::Model,
4835    room: &proto::Room,
4836    peer: &Peer,
4837    pool: &ConnectionPool,
4838) {
4839    let participants = room
4840        .participants
4841        .iter()
4842        .map(|p| p.user_id)
4843        .collect::<Vec<_>>();
4844
4845    broadcast(
4846        None,
4847        pool.channel_connection_ids(channel.root_id())
4848            .filter_map(|(channel_id, role)| {
4849                role.can_see_channel(channel.visibility).then(|| channel_id)
4850            }),
4851        |peer_id| {
4852            peer.send(
4853                peer_id,
4854                proto::UpdateChannels {
4855                    channel_participants: vec![proto::ChannelParticipants {
4856                        channel_id: channel.id.to_proto(),
4857                        participant_user_ids: participants.clone(),
4858                    }],
4859                    ..Default::default()
4860                },
4861            )
4862        },
4863    );
4864}
4865
4866async fn send_remote_projects_update(
4867    user_id: UserId,
4868    mut status: proto::RemoteProjectsUpdate,
4869    session: &Session,
4870) {
4871    let pool = session.connection_pool().await;
4872    for dev_server in &mut status.dev_servers {
4873        dev_server.status =
4874            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
4875    }
4876    let connections = pool.user_connection_ids(user_id);
4877    for connection_id in connections {
4878        session.peer.send(connection_id, status.clone()).trace_err();
4879    }
4880}
4881
4882async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4883    let db = session.db().await;
4884
4885    let contacts = db.get_contacts(user_id).await?;
4886    let busy = db.is_user_busy(user_id).await?;
4887
4888    let pool = session.connection_pool().await;
4889    let updated_contact = contact_for_user(user_id, busy, &pool);
4890    for contact in contacts {
4891        if let db::Contact::Accepted {
4892            user_id: contact_user_id,
4893            ..
4894        } = contact
4895        {
4896            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4897                session
4898                    .peer
4899                    .send(
4900                        contact_conn_id,
4901                        proto::UpdateContacts {
4902                            contacts: vec![updated_contact.clone()],
4903                            remove_contacts: Default::default(),
4904                            incoming_requests: Default::default(),
4905                            remove_incoming_requests: Default::default(),
4906                            outgoing_requests: Default::default(),
4907                            remove_outgoing_requests: Default::default(),
4908                        },
4909                    )
4910                    .trace_err();
4911            }
4912        }
4913    }
4914    Ok(())
4915}
4916
4917async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
4918    log::info!("lost dev server connection, unsharing projects");
4919    let project_ids = session
4920        .db()
4921        .await
4922        .get_stale_dev_server_projects(session.connection_id)
4923        .await?;
4924
4925    for project_id in project_ids {
4926        // not unshare re-checks the connection ids match, so we get away with no transaction
4927        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
4928    }
4929
4930    let user_id = session.dev_server().user_id;
4931    let update = session.db().await.remote_projects_update(user_id).await?;
4932
4933    send_remote_projects_update(user_id, update, session).await;
4934
4935    Ok(())
4936}
4937
4938async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4939    let mut contacts_to_update = HashSet::default();
4940
4941    let room_id;
4942    let canceled_calls_to_user_ids;
4943    let live_kit_room;
4944    let delete_live_kit_room;
4945    let room;
4946    let channel;
4947
4948    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4949        contacts_to_update.insert(session.user_id());
4950
4951        for project in left_room.left_projects.values() {
4952            project_left(project, session);
4953        }
4954
4955        room_id = RoomId::from_proto(left_room.room.id);
4956        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4957        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4958        delete_live_kit_room = left_room.deleted;
4959        room = mem::take(&mut left_room.room);
4960        channel = mem::take(&mut left_room.channel);
4961
4962        room_updated(&room, &session.peer);
4963    } else {
4964        return Ok(());
4965    }
4966
4967    if let Some(channel) = channel {
4968        channel_updated(
4969            &channel,
4970            &room,
4971            &session.peer,
4972            &*session.connection_pool().await,
4973        );
4974    }
4975
4976    {
4977        let pool = session.connection_pool().await;
4978        for canceled_user_id in canceled_calls_to_user_ids {
4979            for connection_id in pool.user_connection_ids(canceled_user_id) {
4980                session
4981                    .peer
4982                    .send(
4983                        connection_id,
4984                        proto::CallCanceled {
4985                            room_id: room_id.to_proto(),
4986                        },
4987                    )
4988                    .trace_err();
4989            }
4990            contacts_to_update.insert(canceled_user_id);
4991        }
4992    }
4993
4994    for contact_user_id in contacts_to_update {
4995        update_user_contacts(contact_user_id, &session).await?;
4996    }
4997
4998    if let Some(live_kit) = session.live_kit_client.as_ref() {
4999        live_kit
5000            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5001            .await
5002            .trace_err();
5003
5004        if delete_live_kit_room {
5005            live_kit.delete_room(live_kit_room).await.trace_err();
5006        }
5007    }
5008
5009    Ok(())
5010}
5011
5012async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5013    let left_channel_buffers = session
5014        .db()
5015        .await
5016        .leave_channel_buffers(session.connection_id)
5017        .await?;
5018
5019    for left_buffer in left_channel_buffers {
5020        channel_buffer_updated(
5021            session.connection_id,
5022            left_buffer.connections,
5023            &proto::UpdateChannelBufferCollaborators {
5024                channel_id: left_buffer.channel_id.to_proto(),
5025                collaborators: left_buffer.collaborators,
5026            },
5027            &session.peer,
5028        );
5029    }
5030
5031    Ok(())
5032}
5033
5034fn project_left(project: &db::LeftProject, session: &UserSession) {
5035    for connection_id in &project.connection_ids {
5036        if project.should_unshare {
5037            session
5038                .peer
5039                .send(
5040                    *connection_id,
5041                    proto::UnshareProject {
5042                        project_id: project.id.to_proto(),
5043                    },
5044                )
5045                .trace_err();
5046        } else {
5047            session
5048                .peer
5049                .send(
5050                    *connection_id,
5051                    proto::RemoveProjectCollaborator {
5052                        project_id: project.id.to_proto(),
5053                        peer_id: Some(session.connection_id.into()),
5054                    },
5055                )
5056                .trace_err();
5057        }
5058    }
5059}
5060
5061pub trait ResultExt {
5062    type Ok;
5063
5064    fn trace_err(self) -> Option<Self::Ok>;
5065}
5066
5067impl<T, E> ResultExt for Result<T, E>
5068where
5069    E: std::fmt::Debug,
5070{
5071    type Ok = T;
5072
5073    #[track_caller]
5074    fn trace_err(self) -> Option<T> {
5075        match self {
5076            Ok(value) => Some(value),
5077            Err(error) => {
5078                tracing::error!("{:?}", error);
5079                None
5080            }
5081        }
5082    }
5083}