rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, InviteMemberResult, MembershipUpdated,
   8        MessageId, NotificationId, PrincipalId, Project, ProjectId, RejoinedProject,
   9        RemoteProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37
  38use futures::{
  39    channel::oneshot,
  40    future::{self, BoxFuture},
  41    stream::FuturesUnordered,
  42    FutureExt, SinkExt, StreamExt, TryStreamExt,
  43};
  44use prometheus::{register_int_gauge, IntGauge};
  45use rpc::{
  46    proto::{
  47        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  48        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  49    },
  50    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  51};
  52use semantic_version::SemanticVersion;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        atomic::{AtomicBool, Ordering::SeqCst},
  64        Arc, OnceLock,
  65    },
  66    time::{Duration, Instant},
  67};
  68use time::OffsetDateTime;
  69use tokio::sync::{watch, Semaphore};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    field::{self},
  73    info_span, instrument, Instrument,
  74};
  75use util::http::IsahcHttpClient;
  76
  77use self::connection_pool::VersionedMessage;
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105struct StreamingResponse<R: RequestMessage> {
 106    peer: Arc<Peer>,
 107    receipt: Receipt<R>,
 108}
 109
 110impl<R: RequestMessage> StreamingResponse<R> {
 111    fn send(&self, payload: R::Response) -> Result<()> {
 112        self.peer.respond(self.receipt, payload)?;
 113        Ok(())
 114    }
 115}
 116
 117#[derive(Clone, Debug)]
 118pub enum Principal {
 119    User(User),
 120    Impersonated { user: User, admin: User },
 121    DevServer(dev_server::Model),
 122}
 123
 124impl Principal {
 125    fn update_span(&self, span: &tracing::Span) {
 126        match &self {
 127            Principal::User(user) => {
 128                span.record("user_id", &user.id.0);
 129                span.record("login", &user.github_login);
 130            }
 131            Principal::Impersonated { user, admin } => {
 132                span.record("user_id", &user.id.0);
 133                span.record("login", &user.github_login);
 134                span.record("impersonator", &admin.github_login);
 135            }
 136            Principal::DevServer(dev_server) => {
 137                span.record("dev_server_id", &dev_server.id.0);
 138            }
 139        }
 140    }
 141}
 142
 143#[derive(Clone)]
 144struct Session {
 145    principal: Principal,
 146    connection_id: ConnectionId,
 147    db: Arc<tokio::sync::Mutex<DbHandle>>,
 148    peer: Arc<Peer>,
 149    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 150    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 151    http_client: IsahcHttpClient,
 152    rate_limiter: Arc<RateLimiter>,
 153    _executor: Executor,
 154}
 155
 156impl Session {
 157    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 158        #[cfg(test)]
 159        tokio::task::yield_now().await;
 160        let guard = self.db.lock().await;
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        guard
 164    }
 165
 166    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 167        #[cfg(test)]
 168        tokio::task::yield_now().await;
 169        let guard = self.connection_pool.lock();
 170        ConnectionPoolGuard {
 171            guard,
 172            _not_send: PhantomData,
 173        }
 174    }
 175
 176    fn for_user(self) -> Option<UserSession> {
 177        UserSession::new(self)
 178    }
 179
 180    fn for_dev_server(self) -> Option<DevServerSession> {
 181        DevServerSession::new(self)
 182    }
 183
 184    fn user_id(&self) -> Option<UserId> {
 185        match &self.principal {
 186            Principal::User(user) => Some(user.id),
 187            Principal::Impersonated { user, .. } => Some(user.id),
 188            Principal::DevServer(_) => None,
 189        }
 190    }
 191
 192    fn dev_server_id(&self) -> Option<DevServerId> {
 193        match &self.principal {
 194            Principal::User(_) | Principal::Impersonated { .. } => None,
 195            Principal::DevServer(dev_server) => Some(dev_server.id),
 196        }
 197    }
 198
 199    fn principal_id(&self) -> PrincipalId {
 200        match &self.principal {
 201            Principal::User(user) => PrincipalId::UserId(user.id),
 202            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 203            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 204        }
 205    }
 206}
 207
 208impl Debug for Session {
 209    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 210        let mut result = f.debug_struct("Session");
 211        match &self.principal {
 212            Principal::User(user) => {
 213                result.field("user", &user.github_login);
 214            }
 215            Principal::Impersonated { user, admin } => {
 216                result.field("user", &user.github_login);
 217                result.field("impersonator", &admin.github_login);
 218            }
 219            Principal::DevServer(dev_server) => {
 220                result.field("dev_server", &dev_server.id);
 221            }
 222        }
 223        result.field("connection_id", &self.connection_id).finish()
 224    }
 225}
 226
 227struct UserSession(Session);
 228
 229impl UserSession {
 230    pub fn new(s: Session) -> Option<Self> {
 231        s.user_id().map(|_| UserSession(s))
 232    }
 233    pub fn user_id(&self) -> UserId {
 234        self.0.user_id().unwrap()
 235    }
 236}
 237
 238impl Deref for UserSession {
 239    type Target = Session;
 240
 241    fn deref(&self) -> &Self::Target {
 242        &self.0
 243    }
 244}
 245impl DerefMut for UserSession {
 246    fn deref_mut(&mut self) -> &mut Self::Target {
 247        &mut self.0
 248    }
 249}
 250
 251struct DevServerSession(Session);
 252
 253impl DevServerSession {
 254    pub fn new(s: Session) -> Option<Self> {
 255        s.dev_server_id().map(|_| DevServerSession(s))
 256    }
 257    pub fn dev_server_id(&self) -> DevServerId {
 258        self.0.dev_server_id().unwrap()
 259    }
 260
 261    fn dev_server(&self) -> &dev_server::Model {
 262        match &self.0.principal {
 263            Principal::DevServer(dev_server) => dev_server,
 264            _ => unreachable!(),
 265        }
 266    }
 267}
 268
 269impl Deref for DevServerSession {
 270    type Target = Session;
 271
 272    fn deref(&self) -> &Self::Target {
 273        &self.0
 274    }
 275}
 276impl DerefMut for DevServerSession {
 277    fn deref_mut(&mut self) -> &mut Self::Target {
 278        &mut self.0
 279    }
 280}
 281
 282fn user_handler<M: RequestMessage, Fut>(
 283    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 284) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 285where
 286    Fut: Send + Future<Output = Result<()>>,
 287{
 288    let handler = Arc::new(handler);
 289    move |message, response, session| {
 290        let handler = handler.clone();
 291        Box::pin(async move {
 292            if let Some(user_session) = session.for_user() {
 293                Ok(handler(message, response, user_session).await?)
 294            } else {
 295                Err(Error::Internal(anyhow!(
 296                    "must be a user to call {}",
 297                    M::NAME
 298                )))
 299            }
 300        })
 301    }
 302}
 303
 304fn dev_server_handler<M: RequestMessage, Fut>(
 305    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 306) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 307where
 308    Fut: Send + Future<Output = Result<()>>,
 309{
 310    let handler = Arc::new(handler);
 311    move |message, response, session| {
 312        let handler = handler.clone();
 313        Box::pin(async move {
 314            if let Some(dev_server_session) = session.for_dev_server() {
 315                Ok(handler(message, response, dev_server_session).await?)
 316            } else {
 317                Err(Error::Internal(anyhow!(
 318                    "must be a dev server to call {}",
 319                    M::NAME
 320                )))
 321            }
 322        })
 323    }
 324}
 325
 326fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 327    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 328) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 329where
 330    InnertRetFut: Send + Future<Output = Result<()>>,
 331{
 332    let handler = Arc::new(handler);
 333    move |message, session| {
 334        let handler = handler.clone();
 335        Box::pin(async move {
 336            if let Some(user_session) = session.for_user() {
 337                Ok(handler(message, user_session).await?)
 338            } else {
 339                Err(Error::Internal(anyhow!(
 340                    "must be a user to call {}",
 341                    M::NAME
 342                )))
 343            }
 344        })
 345    }
 346}
 347
 348struct DbHandle(Arc<Database>);
 349
 350impl Deref for DbHandle {
 351    type Target = Database;
 352
 353    fn deref(&self) -> &Self::Target {
 354        self.0.as_ref()
 355    }
 356}
 357
 358pub struct Server {
 359    id: parking_lot::Mutex<ServerId>,
 360    peer: Arc<Peer>,
 361    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 362    app_state: Arc<AppState>,
 363    handlers: HashMap<TypeId, MessageHandler>,
 364    teardown: watch::Sender<bool>,
 365}
 366
 367pub(crate) struct ConnectionPoolGuard<'a> {
 368    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 369    _not_send: PhantomData<Rc<()>>,
 370}
 371
 372#[derive(Serialize)]
 373pub struct ServerSnapshot<'a> {
 374    peer: &'a Peer,
 375    #[serde(serialize_with = "serialize_deref")]
 376    connection_pool: ConnectionPoolGuard<'a>,
 377}
 378
 379pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 380where
 381    S: Serializer,
 382    T: Deref<Target = U>,
 383    U: Serialize,
 384{
 385    Serialize::serialize(value.deref(), serializer)
 386}
 387
 388impl Server {
 389    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 390        let mut server = Self {
 391            id: parking_lot::Mutex::new(id),
 392            peer: Peer::new(id.0 as u32),
 393            app_state: app_state.clone(),
 394            connection_pool: Default::default(),
 395            handlers: Default::default(),
 396            teardown: watch::channel(false).0,
 397        };
 398
 399        server
 400            .add_request_handler(ping)
 401            .add_request_handler(user_handler(create_room))
 402            .add_request_handler(user_handler(join_room))
 403            .add_request_handler(user_handler(rejoin_room))
 404            .add_request_handler(user_handler(leave_room))
 405            .add_request_handler(user_handler(set_room_participant_role))
 406            .add_request_handler(user_handler(call))
 407            .add_request_handler(user_handler(cancel_call))
 408            .add_message_handler(user_message_handler(decline_call))
 409            .add_request_handler(user_handler(update_participant_location))
 410            .add_request_handler(user_handler(share_project))
 411            .add_message_handler(unshare_project)
 412            .add_request_handler(user_handler(join_project))
 413            .add_request_handler(user_handler(join_hosted_project))
 414            .add_request_handler(user_handler(rejoin_remote_projects))
 415            .add_request_handler(user_handler(create_remote_project))
 416            .add_request_handler(user_handler(create_dev_server))
 417            .add_request_handler(user_handler(delete_dev_server))
 418            .add_request_handler(dev_server_handler(share_remote_project))
 419            .add_request_handler(dev_server_handler(shutdown_dev_server))
 420            .add_request_handler(dev_server_handler(reconnect_dev_server))
 421            .add_message_handler(user_message_handler(leave_project))
 422            .add_request_handler(update_project)
 423            .add_request_handler(update_worktree)
 424            .add_message_handler(start_language_server)
 425            .add_message_handler(update_language_server)
 426            .add_message_handler(update_diagnostic_summary)
 427            .add_message_handler(update_worktree_settings)
 428            .add_request_handler(user_handler(
 429                forward_read_only_project_request::<proto::GetHover>,
 430            ))
 431            .add_request_handler(user_handler(
 432                forward_read_only_project_request::<proto::GetDefinition>,
 433            ))
 434            .add_request_handler(user_handler(
 435                forward_read_only_project_request::<proto::GetTypeDefinition>,
 436            ))
 437            .add_request_handler(user_handler(
 438                forward_read_only_project_request::<proto::GetReferences>,
 439            ))
 440            .add_request_handler(user_handler(
 441                forward_read_only_project_request::<proto::SearchProject>,
 442            ))
 443            .add_request_handler(user_handler(
 444                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 445            ))
 446            .add_request_handler(user_handler(
 447                forward_read_only_project_request::<proto::GetProjectSymbols>,
 448            ))
 449            .add_request_handler(user_handler(
 450                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 451            ))
 452            .add_request_handler(user_handler(
 453                forward_read_only_project_request::<proto::OpenBufferById>,
 454            ))
 455            .add_request_handler(user_handler(
 456                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 457            ))
 458            .add_request_handler(user_handler(
 459                forward_read_only_project_request::<proto::InlayHints>,
 460            ))
 461            .add_request_handler(user_handler(
 462                forward_read_only_project_request::<proto::OpenBufferByPath>,
 463            ))
 464            .add_request_handler(user_handler(
 465                forward_mutating_project_request::<proto::GetCompletions>,
 466            ))
 467            .add_request_handler(user_handler(
 468                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 469            ))
 470            .add_request_handler(user_handler(
 471                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 472            ))
 473            .add_request_handler(user_handler(
 474                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 475            ))
 476            .add_request_handler(user_handler(
 477                forward_mutating_project_request::<proto::GetCodeActions>,
 478            ))
 479            .add_request_handler(user_handler(
 480                forward_mutating_project_request::<proto::ApplyCodeAction>,
 481            ))
 482            .add_request_handler(user_handler(
 483                forward_mutating_project_request::<proto::PrepareRename>,
 484            ))
 485            .add_request_handler(user_handler(
 486                forward_mutating_project_request::<proto::PerformRename>,
 487            ))
 488            .add_request_handler(user_handler(
 489                forward_mutating_project_request::<proto::ReloadBuffers>,
 490            ))
 491            .add_request_handler(user_handler(
 492                forward_mutating_project_request::<proto::FormatBuffers>,
 493            ))
 494            .add_request_handler(user_handler(
 495                forward_mutating_project_request::<proto::CreateProjectEntry>,
 496            ))
 497            .add_request_handler(user_handler(
 498                forward_mutating_project_request::<proto::RenameProjectEntry>,
 499            ))
 500            .add_request_handler(user_handler(
 501                forward_mutating_project_request::<proto::CopyProjectEntry>,
 502            ))
 503            .add_request_handler(user_handler(
 504                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 505            ))
 506            .add_request_handler(user_handler(
 507                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 508            ))
 509            .add_request_handler(user_handler(
 510                forward_mutating_project_request::<proto::OnTypeFormatting>,
 511            ))
 512            .add_request_handler(user_handler(
 513                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 514            ))
 515            .add_request_handler(user_handler(
 516                forward_mutating_project_request::<proto::BlameBuffer>,
 517            ))
 518            .add_request_handler(user_handler(
 519                forward_mutating_project_request::<proto::MultiLspQuery>,
 520            ))
 521            .add_message_handler(create_buffer_for_peer)
 522            .add_request_handler(update_buffer)
 523            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 524            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 525            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 526            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 527            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 528            .add_request_handler(get_users)
 529            .add_request_handler(user_handler(fuzzy_search_users))
 530            .add_request_handler(user_handler(request_contact))
 531            .add_request_handler(user_handler(remove_contact))
 532            .add_request_handler(user_handler(respond_to_contact_request))
 533            .add_request_handler(user_handler(create_channel))
 534            .add_request_handler(user_handler(delete_channel))
 535            .add_request_handler(user_handler(invite_channel_member))
 536            .add_request_handler(user_handler(remove_channel_member))
 537            .add_request_handler(user_handler(set_channel_member_role))
 538            .add_request_handler(user_handler(set_channel_visibility))
 539            .add_request_handler(user_handler(rename_channel))
 540            .add_request_handler(user_handler(join_channel_buffer))
 541            .add_request_handler(user_handler(leave_channel_buffer))
 542            .add_message_handler(user_message_handler(update_channel_buffer))
 543            .add_request_handler(user_handler(rejoin_channel_buffers))
 544            .add_request_handler(user_handler(get_channel_members))
 545            .add_request_handler(user_handler(respond_to_channel_invite))
 546            .add_request_handler(user_handler(join_channel))
 547            .add_request_handler(user_handler(join_channel_chat))
 548            .add_message_handler(user_message_handler(leave_channel_chat))
 549            .add_request_handler(user_handler(send_channel_message))
 550            .add_request_handler(user_handler(remove_channel_message))
 551            .add_request_handler(user_handler(update_channel_message))
 552            .add_request_handler(user_handler(get_channel_messages))
 553            .add_request_handler(user_handler(get_channel_messages_by_id))
 554            .add_request_handler(user_handler(get_notifications))
 555            .add_request_handler(user_handler(mark_notification_as_read))
 556            .add_request_handler(user_handler(move_channel))
 557            .add_request_handler(user_handler(follow))
 558            .add_message_handler(user_message_handler(unfollow))
 559            .add_message_handler(user_message_handler(update_followers))
 560            .add_request_handler(user_handler(get_private_user_info))
 561            .add_message_handler(user_message_handler(acknowledge_channel_message))
 562            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 563            .add_streaming_request_handler({
 564                let app_state = app_state.clone();
 565                move |request, response, session| {
 566                    complete_with_language_model(
 567                        request,
 568                        response,
 569                        session,
 570                        app_state.config.openai_api_key.clone(),
 571                        app_state.config.google_ai_api_key.clone(),
 572                        app_state.config.anthropic_api_key.clone(),
 573                    )
 574                }
 575            })
 576            .add_request_handler({
 577                let app_state = app_state.clone();
 578                user_handler(move |request, response, session| {
 579                    count_tokens_with_language_model(
 580                        request,
 581                        response,
 582                        session,
 583                        app_state.config.google_ai_api_key.clone(),
 584                    )
 585                })
 586            })
 587            .add_request_handler({
 588                user_handler(move |request, response, session| {
 589                    get_cached_embeddings(request, response, session)
 590                })
 591            })
 592            .add_request_handler({
 593                let app_state = app_state.clone();
 594                user_handler(move |request, response, session| {
 595                    compute_embeddings(
 596                        request,
 597                        response,
 598                        session,
 599                        app_state.config.openai_api_key.clone(),
 600                    )
 601                })
 602            });
 603
 604        Arc::new(server)
 605    }
 606
 607    pub async fn start(&self) -> Result<()> {
 608        let server_id = *self.id.lock();
 609        let app_state = self.app_state.clone();
 610        let peer = self.peer.clone();
 611        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 612        let pool = self.connection_pool.clone();
 613        let live_kit_client = self.app_state.live_kit_client.clone();
 614
 615        let span = info_span!("start server");
 616        self.app_state.executor.spawn_detached(
 617            async move {
 618                tracing::info!("waiting for cleanup timeout");
 619                timeout.await;
 620                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 621                if let Some((room_ids, channel_ids)) = app_state
 622                    .db
 623                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 624                    .await
 625                    .trace_err()
 626                {
 627                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 628                    tracing::info!(
 629                        stale_channel_buffer_count = channel_ids.len(),
 630                        "retrieved stale channel buffers"
 631                    );
 632
 633                    for channel_id in channel_ids {
 634                        if let Some(refreshed_channel_buffer) = app_state
 635                            .db
 636                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 637                            .await
 638                            .trace_err()
 639                        {
 640                            for connection_id in refreshed_channel_buffer.connection_ids {
 641                                peer.send(
 642                                    connection_id,
 643                                    proto::UpdateChannelBufferCollaborators {
 644                                        channel_id: channel_id.to_proto(),
 645                                        collaborators: refreshed_channel_buffer
 646                                            .collaborators
 647                                            .clone(),
 648                                    },
 649                                )
 650                                .trace_err();
 651                            }
 652                        }
 653                    }
 654
 655                    for room_id in room_ids {
 656                        let mut contacts_to_update = HashSet::default();
 657                        let mut canceled_calls_to_user_ids = Vec::new();
 658                        let mut live_kit_room = String::new();
 659                        let mut delete_live_kit_room = false;
 660
 661                        if let Some(mut refreshed_room) = app_state
 662                            .db
 663                            .clear_stale_room_participants(room_id, server_id)
 664                            .await
 665                            .trace_err()
 666                        {
 667                            tracing::info!(
 668                                room_id = room_id.0,
 669                                new_participant_count = refreshed_room.room.participants.len(),
 670                                "refreshed room"
 671                            );
 672                            room_updated(&refreshed_room.room, &peer);
 673                            if let Some(channel) = refreshed_room.channel.as_ref() {
 674                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 675                            }
 676                            contacts_to_update
 677                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 678                            contacts_to_update
 679                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 680                            canceled_calls_to_user_ids =
 681                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 682                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 683                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 684                        }
 685
 686                        {
 687                            let pool = pool.lock();
 688                            for canceled_user_id in canceled_calls_to_user_ids {
 689                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 690                                    peer.send(
 691                                        connection_id,
 692                                        proto::CallCanceled {
 693                                            room_id: room_id.to_proto(),
 694                                        },
 695                                    )
 696                                    .trace_err();
 697                                }
 698                            }
 699                        }
 700
 701                        for user_id in contacts_to_update {
 702                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 703                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 704                            if let Some((busy, contacts)) = busy.zip(contacts) {
 705                                let pool = pool.lock();
 706                                let updated_contact = contact_for_user(user_id, busy, &pool);
 707                                for contact in contacts {
 708                                    if let db::Contact::Accepted {
 709                                        user_id: contact_user_id,
 710                                        ..
 711                                    } = contact
 712                                    {
 713                                        for contact_conn_id in
 714                                            pool.user_connection_ids(contact_user_id)
 715                                        {
 716                                            peer.send(
 717                                                contact_conn_id,
 718                                                proto::UpdateContacts {
 719                                                    contacts: vec![updated_contact.clone()],
 720                                                    remove_contacts: Default::default(),
 721                                                    incoming_requests: Default::default(),
 722                                                    remove_incoming_requests: Default::default(),
 723                                                    outgoing_requests: Default::default(),
 724                                                    remove_outgoing_requests: Default::default(),
 725                                                },
 726                                            )
 727                                            .trace_err();
 728                                        }
 729                                    }
 730                                }
 731                            }
 732                        }
 733
 734                        if let Some(live_kit) = live_kit_client.as_ref() {
 735                            if delete_live_kit_room {
 736                                live_kit.delete_room(live_kit_room).await.trace_err();
 737                            }
 738                        }
 739                    }
 740                }
 741
 742                app_state
 743                    .db
 744                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 745                    .await
 746                    .trace_err();
 747            }
 748            .instrument(span),
 749        );
 750        Ok(())
 751    }
 752
 753    pub fn teardown(&self) {
 754        self.peer.teardown();
 755        self.connection_pool.lock().reset();
 756        let _ = self.teardown.send(true);
 757    }
 758
 759    #[cfg(test)]
 760    pub fn reset(&self, id: ServerId) {
 761        self.teardown();
 762        *self.id.lock() = id;
 763        self.peer.reset(id.0 as u32);
 764        let _ = self.teardown.send(false);
 765    }
 766
 767    #[cfg(test)]
 768    pub fn id(&self) -> ServerId {
 769        *self.id.lock()
 770    }
 771
 772    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 773    where
 774        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 775        Fut: 'static + Send + Future<Output = Result<()>>,
 776        M: EnvelopedMessage,
 777    {
 778        let prev_handler = self.handlers.insert(
 779            TypeId::of::<M>(),
 780            Box::new(move |envelope, session| {
 781                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 782                let received_at = envelope.received_at;
 783                tracing::info!("message received");
 784                let start_time = Instant::now();
 785                let future = (handler)(*envelope, session);
 786                async move {
 787                    let result = future.await;
 788                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 789                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 790                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 791                    let payload_type = M::NAME;
 792
 793                    match result {
 794                        Err(error) => {
 795                            tracing::error!(
 796                                ?error,
 797                                total_duration_ms,
 798                                processing_duration_ms,
 799                                queue_duration_ms,
 800                                payload_type,
 801                                "error handling message"
 802                            )
 803                        }
 804                        Ok(()) => tracing::info!(
 805                            total_duration_ms,
 806                            processing_duration_ms,
 807                            queue_duration_ms,
 808                            "finished handling message"
 809                        ),
 810                    }
 811                }
 812                .boxed()
 813            }),
 814        );
 815        if prev_handler.is_some() {
 816            panic!("registered a handler for the same message twice");
 817        }
 818        self
 819    }
 820
 821    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 822    where
 823        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 824        Fut: 'static + Send + Future<Output = Result<()>>,
 825        M: EnvelopedMessage,
 826    {
 827        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 828        self
 829    }
 830
 831    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 832    where
 833        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 834        Fut: Send + Future<Output = Result<()>>,
 835        M: RequestMessage,
 836    {
 837        let handler = Arc::new(handler);
 838        self.add_handler(move |envelope, session| {
 839            let receipt = envelope.receipt();
 840            let handler = handler.clone();
 841            async move {
 842                let peer = session.peer.clone();
 843                let responded = Arc::new(AtomicBool::default());
 844                let response = Response {
 845                    peer: peer.clone(),
 846                    responded: responded.clone(),
 847                    receipt,
 848                };
 849                match (handler)(envelope.payload, response, session).await {
 850                    Ok(()) => {
 851                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 852                            Ok(())
 853                        } else {
 854                            Err(anyhow!("handler did not send a response"))?
 855                        }
 856                    }
 857                    Err(error) => {
 858                        let proto_err = match &error {
 859                            Error::Internal(err) => err.to_proto(),
 860                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 861                        };
 862                        peer.respond_with_error(receipt, proto_err)?;
 863                        Err(error)
 864                    }
 865                }
 866            }
 867        })
 868    }
 869
 870    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 871    where
 872        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 873        Fut: Send + Future<Output = Result<()>>,
 874        M: RequestMessage,
 875    {
 876        let handler = Arc::new(handler);
 877        self.add_handler(move |envelope, session| {
 878            let receipt = envelope.receipt();
 879            let handler = handler.clone();
 880            async move {
 881                let peer = session.peer.clone();
 882                let response = StreamingResponse {
 883                    peer: peer.clone(),
 884                    receipt,
 885                };
 886                match (handler)(envelope.payload, response, session).await {
 887                    Ok(()) => {
 888                        peer.end_stream(receipt)?;
 889                        Ok(())
 890                    }
 891                    Err(error) => {
 892                        let proto_err = match &error {
 893                            Error::Internal(err) => err.to_proto(),
 894                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 895                        };
 896                        peer.respond_with_error(receipt, proto_err)?;
 897                        Err(error)
 898                    }
 899                }
 900            }
 901        })
 902    }
 903
 904    #[allow(clippy::too_many_arguments)]
 905    pub fn handle_connection(
 906        self: &Arc<Self>,
 907        connection: Connection,
 908        address: String,
 909        principal: Principal,
 910        zed_version: ZedVersion,
 911        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 912        executor: Executor,
 913    ) -> impl Future<Output = ()> {
 914        let this = self.clone();
 915        let span = info_span!("handle connection", %address,
 916            connection_id=field::Empty,
 917            user_id=field::Empty,
 918            login=field::Empty,
 919            impersonator=field::Empty,
 920            dev_server_id=field::Empty
 921        );
 922        principal.update_span(&span);
 923
 924        let mut teardown = self.teardown.subscribe();
 925        async move {
 926            if *teardown.borrow() {
 927                tracing::error!("server is tearing down");
 928                return
 929            }
 930            let (connection_id, handle_io, mut incoming_rx) = this
 931                .peer
 932                .add_connection(connection, {
 933                    let executor = executor.clone();
 934                    move |duration| executor.sleep(duration)
 935                });
 936            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 937            tracing::info!("connection opened");
 938
 939            let http_client = match IsahcHttpClient::new() {
 940                Ok(http_client) => http_client,
 941                Err(error) => {
 942                    tracing::error!(?error, "failed to create HTTP client");
 943                    return;
 944                }
 945            };
 946
 947            let session = Session {
 948                principal: principal.clone(),
 949                connection_id,
 950                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 951                peer: this.peer.clone(),
 952                connection_pool: this.connection_pool.clone(),
 953                live_kit_client: this.app_state.live_kit_client.clone(),
 954                http_client,
 955                rate_limiter: this.app_state.rate_limiter.clone(),
 956                _executor: executor.clone(),
 957            };
 958
 959            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 960                tracing::error!(?error, "failed to send initial client update");
 961                return;
 962            }
 963
 964            let handle_io = handle_io.fuse();
 965            futures::pin_mut!(handle_io);
 966
 967            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 968            // This prevents deadlocks when e.g., client A performs a request to client B and
 969            // client B performs a request to client A. If both clients stop processing further
 970            // messages until their respective request completes, they won't have a chance to
 971            // respond to the other client's request and cause a deadlock.
 972            //
 973            // This arrangement ensures we will attempt to process earlier messages first, but fall
 974            // back to processing messages arrived later in the spirit of making progress.
 975            let mut foreground_message_handlers = FuturesUnordered::new();
 976            let concurrent_handlers = Arc::new(Semaphore::new(256));
 977            loop {
 978                let next_message = async {
 979                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 980                    let message = incoming_rx.next().await;
 981                    (permit, message)
 982                }.fuse();
 983                futures::pin_mut!(next_message);
 984                futures::select_biased! {
 985                    _ = teardown.changed().fuse() => return,
 986                    result = handle_io => {
 987                        if let Err(error) = result {
 988                            tracing::error!(?error, "error handling I/O");
 989                        }
 990                        break;
 991                    }
 992                    _ = foreground_message_handlers.next() => {}
 993                    next_message = next_message => {
 994                        let (permit, message) = next_message;
 995                        if let Some(message) = message {
 996                            let type_name = message.payload_type_name();
 997                            // note: we copy all the fields from the parent span so we can query them in the logs.
 998                            // (https://github.com/tokio-rs/tracing/issues/2670).
 999                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1000                                user_id=field::Empty,
1001                                login=field::Empty,
1002                                impersonator=field::Empty,
1003                                dev_server_id=field::Empty
1004                            );
1005                            principal.update_span(&span);
1006                            let span_enter = span.enter();
1007                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1008                                let is_background = message.is_background();
1009                                let handle_message = (handler)(message, session.clone());
1010                                drop(span_enter);
1011
1012                                let handle_message = async move {
1013                                    handle_message.await;
1014                                    drop(permit);
1015                                }.instrument(span);
1016                                if is_background {
1017                                    executor.spawn_detached(handle_message);
1018                                } else {
1019                                    foreground_message_handlers.push(handle_message);
1020                                }
1021                            } else {
1022                                tracing::error!("no message handler");
1023                            }
1024                        } else {
1025                            tracing::info!("connection closed");
1026                            break;
1027                        }
1028                    }
1029                }
1030            }
1031
1032            drop(foreground_message_handlers);
1033            tracing::info!("signing out");
1034            if let Err(error) = connection_lost(session, teardown, executor).await {
1035                tracing::error!(?error, "error signing out");
1036            }
1037
1038        }.instrument(span)
1039    }
1040
1041    async fn send_initial_client_update(
1042        &self,
1043        connection_id: ConnectionId,
1044        principal: &Principal,
1045        zed_version: ZedVersion,
1046        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1047        session: &Session,
1048    ) -> Result<()> {
1049        self.peer.send(
1050            connection_id,
1051            proto::Hello {
1052                peer_id: Some(connection_id.into()),
1053            },
1054        )?;
1055        tracing::info!("sent hello message");
1056        if let Some(send_connection_id) = send_connection_id.take() {
1057            let _ = send_connection_id.send(connection_id);
1058        }
1059
1060        match principal {
1061            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1062                if !user.connected_once {
1063                    self.peer.send(connection_id, proto::ShowContacts {})?;
1064                    self.app_state
1065                        .db
1066                        .set_user_connected_once(user.id, true)
1067                        .await?;
1068                }
1069
1070                let (contacts, channels_for_user, channel_invites, remote_projects) =
1071                    future::try_join4(
1072                        self.app_state.db.get_contacts(user.id),
1073                        self.app_state.db.get_channels_for_user(user.id),
1074                        self.app_state.db.get_channel_invites_for_user(user.id),
1075                        self.app_state.db.remote_projects_update(user.id),
1076                    )
1077                    .await?;
1078
1079                {
1080                    let mut pool = self.connection_pool.lock();
1081                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1082                    for membership in &channels_for_user.channel_memberships {
1083                        pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1084                    }
1085                    self.peer.send(
1086                        connection_id,
1087                        build_initial_contacts_update(contacts, &pool),
1088                    )?;
1089                    self.peer.send(
1090                        connection_id,
1091                        build_update_user_channels(&channels_for_user),
1092                    )?;
1093                    self.peer.send(
1094                        connection_id,
1095                        build_channels_update(channels_for_user, channel_invites),
1096                    )?;
1097                }
1098                send_remote_projects_update(user.id, remote_projects, session).await;
1099
1100                if let Some(incoming_call) =
1101                    self.app_state.db.incoming_call_for_user(user.id).await?
1102                {
1103                    self.peer.send(connection_id, incoming_call)?;
1104                }
1105
1106                update_user_contacts(user.id, &session).await?;
1107            }
1108            Principal::DevServer(dev_server) => {
1109                {
1110                    let mut pool = self.connection_pool.lock();
1111                    if pool.dev_server_connection_id(dev_server.id).is_some() {
1112                        return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1113                    };
1114                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1115                }
1116
1117                let projects = self
1118                    .app_state
1119                    .db
1120                    .get_remote_projects_for_dev_server(dev_server.id)
1121                    .await?;
1122                self.peer
1123                    .send(connection_id, proto::DevServerInstructions { projects })?;
1124
1125                let status = self
1126                    .app_state
1127                    .db
1128                    .remote_projects_update(dev_server.user_id)
1129                    .await?;
1130                send_remote_projects_update(dev_server.user_id, status, &session).await;
1131            }
1132        }
1133
1134        Ok(())
1135    }
1136
1137    pub async fn invite_code_redeemed(
1138        self: &Arc<Self>,
1139        inviter_id: UserId,
1140        invitee_id: UserId,
1141    ) -> Result<()> {
1142        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1143            if let Some(code) = &user.invite_code {
1144                let pool = self.connection_pool.lock();
1145                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1146                for connection_id in pool.user_connection_ids(inviter_id) {
1147                    self.peer.send(
1148                        connection_id,
1149                        proto::UpdateContacts {
1150                            contacts: vec![invitee_contact.clone()],
1151                            ..Default::default()
1152                        },
1153                    )?;
1154                    self.peer.send(
1155                        connection_id,
1156                        proto::UpdateInviteInfo {
1157                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1158                            count: user.invite_count as u32,
1159                        },
1160                    )?;
1161                }
1162            }
1163        }
1164        Ok(())
1165    }
1166
1167    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1168        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1169            if let Some(invite_code) = &user.invite_code {
1170                let pool = self.connection_pool.lock();
1171                for connection_id in pool.user_connection_ids(user_id) {
1172                    self.peer.send(
1173                        connection_id,
1174                        proto::UpdateInviteInfo {
1175                            url: format!(
1176                                "{}{}",
1177                                self.app_state.config.invite_link_prefix, invite_code
1178                            ),
1179                            count: user.invite_count as u32,
1180                        },
1181                    )?;
1182                }
1183            }
1184        }
1185        Ok(())
1186    }
1187
1188    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1189        ServerSnapshot {
1190            connection_pool: ConnectionPoolGuard {
1191                guard: self.connection_pool.lock(),
1192                _not_send: PhantomData,
1193            },
1194            peer: &self.peer,
1195        }
1196    }
1197}
1198
1199impl<'a> Deref for ConnectionPoolGuard<'a> {
1200    type Target = ConnectionPool;
1201
1202    fn deref(&self) -> &Self::Target {
1203        &self.guard
1204    }
1205}
1206
1207impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1208    fn deref_mut(&mut self) -> &mut Self::Target {
1209        &mut self.guard
1210    }
1211}
1212
1213impl<'a> Drop for ConnectionPoolGuard<'a> {
1214    fn drop(&mut self) {
1215        #[cfg(test)]
1216        self.check_invariants();
1217    }
1218}
1219
1220fn broadcast<F>(
1221    sender_id: Option<ConnectionId>,
1222    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1223    mut f: F,
1224) where
1225    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1226{
1227    for receiver_id in receiver_ids {
1228        if Some(receiver_id) != sender_id {
1229            if let Err(error) = f(receiver_id) {
1230                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1231            }
1232        }
1233    }
1234}
1235
1236pub struct ProtocolVersion(u32);
1237
1238impl Header for ProtocolVersion {
1239    fn name() -> &'static HeaderName {
1240        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1241        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1242    }
1243
1244    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1245    where
1246        Self: Sized,
1247        I: Iterator<Item = &'i axum::http::HeaderValue>,
1248    {
1249        let version = values
1250            .next()
1251            .ok_or_else(axum::headers::Error::invalid)?
1252            .to_str()
1253            .map_err(|_| axum::headers::Error::invalid())?
1254            .parse()
1255            .map_err(|_| axum::headers::Error::invalid())?;
1256        Ok(Self(version))
1257    }
1258
1259    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1260        values.extend([self.0.to_string().parse().unwrap()]);
1261    }
1262}
1263
1264pub struct AppVersionHeader(SemanticVersion);
1265impl Header for AppVersionHeader {
1266    fn name() -> &'static HeaderName {
1267        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1268        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1269    }
1270
1271    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1272    where
1273        Self: Sized,
1274        I: Iterator<Item = &'i axum::http::HeaderValue>,
1275    {
1276        let version = values
1277            .next()
1278            .ok_or_else(axum::headers::Error::invalid)?
1279            .to_str()
1280            .map_err(|_| axum::headers::Error::invalid())?
1281            .parse()
1282            .map_err(|_| axum::headers::Error::invalid())?;
1283        Ok(Self(version))
1284    }
1285
1286    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1287        values.extend([self.0.to_string().parse().unwrap()]);
1288    }
1289}
1290
1291pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1292    Router::new()
1293        .route("/rpc", get(handle_websocket_request))
1294        .layer(
1295            ServiceBuilder::new()
1296                .layer(Extension(server.app_state.clone()))
1297                .layer(middleware::from_fn(auth::validate_header)),
1298        )
1299        .route("/metrics", get(handle_metrics))
1300        .layer(Extension(server))
1301}
1302
1303pub async fn handle_websocket_request(
1304    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1305    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1306    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1307    Extension(server): Extension<Arc<Server>>,
1308    Extension(principal): Extension<Principal>,
1309    ws: WebSocketUpgrade,
1310) -> axum::response::Response {
1311    if protocol_version != rpc::PROTOCOL_VERSION {
1312        return (
1313            StatusCode::UPGRADE_REQUIRED,
1314            "client must be upgraded".to_string(),
1315        )
1316            .into_response();
1317    }
1318
1319    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1320        return (
1321            StatusCode::UPGRADE_REQUIRED,
1322            "no version header found".to_string(),
1323        )
1324            .into_response();
1325    };
1326
1327    if !version.can_collaborate() {
1328        return (
1329            StatusCode::UPGRADE_REQUIRED,
1330            "client must be upgraded".to_string(),
1331        )
1332            .into_response();
1333    }
1334
1335    let socket_address = socket_address.to_string();
1336    ws.on_upgrade(move |socket| {
1337        let socket = socket
1338            .map_ok(to_tungstenite_message)
1339            .err_into()
1340            .with(|message| async move { Ok(to_axum_message(message)) });
1341        let connection = Connection::new(Box::pin(socket));
1342        async move {
1343            server
1344                .handle_connection(
1345                    connection,
1346                    socket_address,
1347                    principal,
1348                    version,
1349                    None,
1350                    Executor::Production,
1351                )
1352                .await;
1353        }
1354    })
1355}
1356
1357pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1358    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1359    let connections_metric = CONNECTIONS_METRIC
1360        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1361
1362    let connections = server
1363        .connection_pool
1364        .lock()
1365        .connections()
1366        .filter(|connection| !connection.admin)
1367        .count();
1368    connections_metric.set(connections as _);
1369
1370    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1371    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1372        register_int_gauge!(
1373            "shared_projects",
1374            "number of open projects with one or more guests"
1375        )
1376        .unwrap()
1377    });
1378
1379    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1380    shared_projects_metric.set(shared_projects as _);
1381
1382    let encoder = prometheus::TextEncoder::new();
1383    let metric_families = prometheus::gather();
1384    let encoded_metrics = encoder
1385        .encode_to_string(&metric_families)
1386        .map_err(|err| anyhow!("{}", err))?;
1387    Ok(encoded_metrics)
1388}
1389
1390#[instrument(err, skip(executor))]
1391async fn connection_lost(
1392    session: Session,
1393    mut teardown: watch::Receiver<bool>,
1394    executor: Executor,
1395) -> Result<()> {
1396    session.peer.disconnect(session.connection_id);
1397    session
1398        .connection_pool()
1399        .await
1400        .remove_connection(session.connection_id)?;
1401
1402    session
1403        .db()
1404        .await
1405        .connection_lost(session.connection_id)
1406        .await
1407        .trace_err();
1408
1409    futures::select_biased! {
1410        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1411            match &session.principal {
1412                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1413                    let session = session.for_user().unwrap();
1414
1415                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1416                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1417                    leave_channel_buffers_for_session(&session)
1418                        .await
1419                        .trace_err();
1420
1421                    if !session
1422                        .connection_pool()
1423                        .await
1424                        .is_user_online(session.user_id())
1425                    {
1426                        let db = session.db().await;
1427                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1428                            room_updated(&room, &session.peer);
1429                        }
1430                    }
1431
1432                    update_user_contacts(session.user_id(), &session).await?;
1433                },
1434            Principal::DevServer(_) => {
1435                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1436            },
1437        }
1438        },
1439        _ = teardown.changed().fuse() => {}
1440    }
1441
1442    Ok(())
1443}
1444
1445/// Acknowledges a ping from a client, used to keep the connection alive.
1446async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1447    response.send(proto::Ack {})?;
1448    Ok(())
1449}
1450
1451/// Creates a new room for calling (outside of channels)
1452async fn create_room(
1453    _request: proto::CreateRoom,
1454    response: Response<proto::CreateRoom>,
1455    session: UserSession,
1456) -> Result<()> {
1457    let live_kit_room = nanoid::nanoid!(30);
1458
1459    let live_kit_connection_info = util::maybe!(async {
1460        let live_kit = session.live_kit_client.as_ref();
1461        let live_kit = live_kit?;
1462        let user_id = session.user_id().to_string();
1463
1464        let token = live_kit
1465            .room_token(&live_kit_room, &user_id.to_string())
1466            .trace_err()?;
1467
1468        Some(proto::LiveKitConnectionInfo {
1469            server_url: live_kit.url().into(),
1470            token,
1471            can_publish: true,
1472        })
1473    })
1474    .await;
1475
1476    let room = session
1477        .db()
1478        .await
1479        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1480        .await?;
1481
1482    response.send(proto::CreateRoomResponse {
1483        room: Some(room.clone()),
1484        live_kit_connection_info,
1485    })?;
1486
1487    update_user_contacts(session.user_id(), &session).await?;
1488    Ok(())
1489}
1490
1491/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1492async fn join_room(
1493    request: proto::JoinRoom,
1494    response: Response<proto::JoinRoom>,
1495    session: UserSession,
1496) -> Result<()> {
1497    let room_id = RoomId::from_proto(request.id);
1498
1499    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1500
1501    if let Some(channel_id) = channel_id {
1502        return join_channel_internal(channel_id, Box::new(response), session).await;
1503    }
1504
1505    let joined_room = {
1506        let room = session
1507            .db()
1508            .await
1509            .join_room(room_id, session.user_id(), session.connection_id)
1510            .await?;
1511        room_updated(&room.room, &session.peer);
1512        room.into_inner()
1513    };
1514
1515    for connection_id in session
1516        .connection_pool()
1517        .await
1518        .user_connection_ids(session.user_id())
1519    {
1520        session
1521            .peer
1522            .send(
1523                connection_id,
1524                proto::CallCanceled {
1525                    room_id: room_id.to_proto(),
1526                },
1527            )
1528            .trace_err();
1529    }
1530
1531    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1532        if let Some(token) = live_kit
1533            .room_token(
1534                &joined_room.room.live_kit_room,
1535                &session.user_id().to_string(),
1536            )
1537            .trace_err()
1538        {
1539            Some(proto::LiveKitConnectionInfo {
1540                server_url: live_kit.url().into(),
1541                token,
1542                can_publish: true,
1543            })
1544        } else {
1545            None
1546        }
1547    } else {
1548        None
1549    };
1550
1551    response.send(proto::JoinRoomResponse {
1552        room: Some(joined_room.room),
1553        channel_id: None,
1554        live_kit_connection_info,
1555    })?;
1556
1557    update_user_contacts(session.user_id(), &session).await?;
1558    Ok(())
1559}
1560
1561/// Rejoin room is used to reconnect to a room after connection errors.
1562async fn rejoin_room(
1563    request: proto::RejoinRoom,
1564    response: Response<proto::RejoinRoom>,
1565    session: UserSession,
1566) -> Result<()> {
1567    let room;
1568    let channel;
1569    {
1570        let mut rejoined_room = session
1571            .db()
1572            .await
1573            .rejoin_room(request, session.user_id(), session.connection_id)
1574            .await?;
1575
1576        response.send(proto::RejoinRoomResponse {
1577            room: Some(rejoined_room.room.clone()),
1578            reshared_projects: rejoined_room
1579                .reshared_projects
1580                .iter()
1581                .map(|project| proto::ResharedProject {
1582                    id: project.id.to_proto(),
1583                    collaborators: project
1584                        .collaborators
1585                        .iter()
1586                        .map(|collaborator| collaborator.to_proto())
1587                        .collect(),
1588                })
1589                .collect(),
1590            rejoined_projects: rejoined_room
1591                .rejoined_projects
1592                .iter()
1593                .map(|rejoined_project| rejoined_project.to_proto())
1594                .collect(),
1595        })?;
1596        room_updated(&rejoined_room.room, &session.peer);
1597
1598        for project in &rejoined_room.reshared_projects {
1599            for collaborator in &project.collaborators {
1600                session
1601                    .peer
1602                    .send(
1603                        collaborator.connection_id,
1604                        proto::UpdateProjectCollaborator {
1605                            project_id: project.id.to_proto(),
1606                            old_peer_id: Some(project.old_connection_id.into()),
1607                            new_peer_id: Some(session.connection_id.into()),
1608                        },
1609                    )
1610                    .trace_err();
1611            }
1612
1613            broadcast(
1614                Some(session.connection_id),
1615                project
1616                    .collaborators
1617                    .iter()
1618                    .map(|collaborator| collaborator.connection_id),
1619                |connection_id| {
1620                    session.peer.forward_send(
1621                        session.connection_id,
1622                        connection_id,
1623                        proto::UpdateProject {
1624                            project_id: project.id.to_proto(),
1625                            worktrees: project.worktrees.clone(),
1626                        },
1627                    )
1628                },
1629            );
1630        }
1631
1632        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1633
1634        let rejoined_room = rejoined_room.into_inner();
1635
1636        room = rejoined_room.room;
1637        channel = rejoined_room.channel;
1638    }
1639
1640    if let Some(channel) = channel {
1641        channel_updated(
1642            &channel,
1643            &room,
1644            &session.peer,
1645            &*session.connection_pool().await,
1646        );
1647    }
1648
1649    update_user_contacts(session.user_id(), &session).await?;
1650    Ok(())
1651}
1652
1653fn notify_rejoined_projects(
1654    rejoined_projects: &mut Vec<RejoinedProject>,
1655    session: &UserSession,
1656) -> Result<()> {
1657    for project in rejoined_projects.iter() {
1658        for collaborator in &project.collaborators {
1659            session
1660                .peer
1661                .send(
1662                    collaborator.connection_id,
1663                    proto::UpdateProjectCollaborator {
1664                        project_id: project.id.to_proto(),
1665                        old_peer_id: Some(project.old_connection_id.into()),
1666                        new_peer_id: Some(session.connection_id.into()),
1667                    },
1668                )
1669                .trace_err();
1670        }
1671    }
1672
1673    for project in rejoined_projects {
1674        for worktree in mem::take(&mut project.worktrees) {
1675            #[cfg(any(test, feature = "test-support"))]
1676            const MAX_CHUNK_SIZE: usize = 2;
1677            #[cfg(not(any(test, feature = "test-support")))]
1678            const MAX_CHUNK_SIZE: usize = 256;
1679
1680            // Stream this worktree's entries.
1681            let message = proto::UpdateWorktree {
1682                project_id: project.id.to_proto(),
1683                worktree_id: worktree.id,
1684                abs_path: worktree.abs_path.clone(),
1685                root_name: worktree.root_name,
1686                updated_entries: worktree.updated_entries,
1687                removed_entries: worktree.removed_entries,
1688                scan_id: worktree.scan_id,
1689                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1690                updated_repositories: worktree.updated_repositories,
1691                removed_repositories: worktree.removed_repositories,
1692            };
1693            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1694                session.peer.send(session.connection_id, update.clone())?;
1695            }
1696
1697            // Stream this worktree's diagnostics.
1698            for summary in worktree.diagnostic_summaries {
1699                session.peer.send(
1700                    session.connection_id,
1701                    proto::UpdateDiagnosticSummary {
1702                        project_id: project.id.to_proto(),
1703                        worktree_id: worktree.id,
1704                        summary: Some(summary),
1705                    },
1706                )?;
1707            }
1708
1709            for settings_file in worktree.settings_files {
1710                session.peer.send(
1711                    session.connection_id,
1712                    proto::UpdateWorktreeSettings {
1713                        project_id: project.id.to_proto(),
1714                        worktree_id: worktree.id,
1715                        path: settings_file.path,
1716                        content: Some(settings_file.content),
1717                    },
1718                )?;
1719            }
1720        }
1721
1722        for language_server in &project.language_servers {
1723            session.peer.send(
1724                session.connection_id,
1725                proto::UpdateLanguageServer {
1726                    project_id: project.id.to_proto(),
1727                    language_server_id: language_server.id,
1728                    variant: Some(
1729                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1730                            proto::LspDiskBasedDiagnosticsUpdated {},
1731                        ),
1732                    ),
1733                },
1734            )?;
1735        }
1736    }
1737    Ok(())
1738}
1739
1740/// leave room disconnects from the room.
1741async fn leave_room(
1742    _: proto::LeaveRoom,
1743    response: Response<proto::LeaveRoom>,
1744    session: UserSession,
1745) -> Result<()> {
1746    leave_room_for_session(&session, session.connection_id).await?;
1747    response.send(proto::Ack {})?;
1748    Ok(())
1749}
1750
1751/// Updates the permissions of someone else in the room.
1752async fn set_room_participant_role(
1753    request: proto::SetRoomParticipantRole,
1754    response: Response<proto::SetRoomParticipantRole>,
1755    session: UserSession,
1756) -> Result<()> {
1757    let user_id = UserId::from_proto(request.user_id);
1758    let role = ChannelRole::from(request.role());
1759
1760    let (live_kit_room, can_publish) = {
1761        let room = session
1762            .db()
1763            .await
1764            .set_room_participant_role(
1765                session.user_id(),
1766                RoomId::from_proto(request.room_id),
1767                user_id,
1768                role,
1769            )
1770            .await?;
1771
1772        let live_kit_room = room.live_kit_room.clone();
1773        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1774        room_updated(&room, &session.peer);
1775        (live_kit_room, can_publish)
1776    };
1777
1778    if let Some(live_kit) = session.live_kit_client.as_ref() {
1779        live_kit
1780            .update_participant(
1781                live_kit_room.clone(),
1782                request.user_id.to_string(),
1783                live_kit_server::proto::ParticipantPermission {
1784                    can_subscribe: true,
1785                    can_publish,
1786                    can_publish_data: can_publish,
1787                    hidden: false,
1788                    recorder: false,
1789                },
1790            )
1791            .await
1792            .trace_err();
1793    }
1794
1795    response.send(proto::Ack {})?;
1796    Ok(())
1797}
1798
1799/// Call someone else into the current room
1800async fn call(
1801    request: proto::Call,
1802    response: Response<proto::Call>,
1803    session: UserSession,
1804) -> Result<()> {
1805    let room_id = RoomId::from_proto(request.room_id);
1806    let calling_user_id = session.user_id();
1807    let calling_connection_id = session.connection_id;
1808    let called_user_id = UserId::from_proto(request.called_user_id);
1809    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1810    if !session
1811        .db()
1812        .await
1813        .has_contact(calling_user_id, called_user_id)
1814        .await?
1815    {
1816        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1817    }
1818
1819    let incoming_call = {
1820        let (room, incoming_call) = &mut *session
1821            .db()
1822            .await
1823            .call(
1824                room_id,
1825                calling_user_id,
1826                calling_connection_id,
1827                called_user_id,
1828                initial_project_id,
1829            )
1830            .await?;
1831        room_updated(&room, &session.peer);
1832        mem::take(incoming_call)
1833    };
1834    update_user_contacts(called_user_id, &session).await?;
1835
1836    let mut calls = session
1837        .connection_pool()
1838        .await
1839        .user_connection_ids(called_user_id)
1840        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1841        .collect::<FuturesUnordered<_>>();
1842
1843    while let Some(call_response) = calls.next().await {
1844        match call_response.as_ref() {
1845            Ok(_) => {
1846                response.send(proto::Ack {})?;
1847                return Ok(());
1848            }
1849            Err(_) => {
1850                call_response.trace_err();
1851            }
1852        }
1853    }
1854
1855    {
1856        let room = session
1857            .db()
1858            .await
1859            .call_failed(room_id, called_user_id)
1860            .await?;
1861        room_updated(&room, &session.peer);
1862    }
1863    update_user_contacts(called_user_id, &session).await?;
1864
1865    Err(anyhow!("failed to ring user"))?
1866}
1867
1868/// Cancel an outgoing call.
1869async fn cancel_call(
1870    request: proto::CancelCall,
1871    response: Response<proto::CancelCall>,
1872    session: UserSession,
1873) -> Result<()> {
1874    let called_user_id = UserId::from_proto(request.called_user_id);
1875    let room_id = RoomId::from_proto(request.room_id);
1876    {
1877        let room = session
1878            .db()
1879            .await
1880            .cancel_call(room_id, session.connection_id, called_user_id)
1881            .await?;
1882        room_updated(&room, &session.peer);
1883    }
1884
1885    for connection_id in session
1886        .connection_pool()
1887        .await
1888        .user_connection_ids(called_user_id)
1889    {
1890        session
1891            .peer
1892            .send(
1893                connection_id,
1894                proto::CallCanceled {
1895                    room_id: room_id.to_proto(),
1896                },
1897            )
1898            .trace_err();
1899    }
1900    response.send(proto::Ack {})?;
1901
1902    update_user_contacts(called_user_id, &session).await?;
1903    Ok(())
1904}
1905
1906/// Decline an incoming call.
1907async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1908    let room_id = RoomId::from_proto(message.room_id);
1909    {
1910        let room = session
1911            .db()
1912            .await
1913            .decline_call(Some(room_id), session.user_id())
1914            .await?
1915            .ok_or_else(|| anyhow!("failed to decline call"))?;
1916        room_updated(&room, &session.peer);
1917    }
1918
1919    for connection_id in session
1920        .connection_pool()
1921        .await
1922        .user_connection_ids(session.user_id())
1923    {
1924        session
1925            .peer
1926            .send(
1927                connection_id,
1928                proto::CallCanceled {
1929                    room_id: room_id.to_proto(),
1930                },
1931            )
1932            .trace_err();
1933    }
1934    update_user_contacts(session.user_id(), &session).await?;
1935    Ok(())
1936}
1937
1938/// Updates other participants in the room with your current location.
1939async fn update_participant_location(
1940    request: proto::UpdateParticipantLocation,
1941    response: Response<proto::UpdateParticipantLocation>,
1942    session: UserSession,
1943) -> Result<()> {
1944    let room_id = RoomId::from_proto(request.room_id);
1945    let location = request
1946        .location
1947        .ok_or_else(|| anyhow!("invalid location"))?;
1948
1949    let db = session.db().await;
1950    let room = db
1951        .update_room_participant_location(room_id, session.connection_id, location)
1952        .await?;
1953
1954    room_updated(&room, &session.peer);
1955    response.send(proto::Ack {})?;
1956    Ok(())
1957}
1958
1959/// Share a project into the room.
1960async fn share_project(
1961    request: proto::ShareProject,
1962    response: Response<proto::ShareProject>,
1963    session: UserSession,
1964) -> Result<()> {
1965    let (project_id, room) = &*session
1966        .db()
1967        .await
1968        .share_project(
1969            RoomId::from_proto(request.room_id),
1970            session.connection_id,
1971            &request.worktrees,
1972            request
1973                .remote_project_id
1974                .map(|id| RemoteProjectId::from_proto(id)),
1975        )
1976        .await?;
1977    response.send(proto::ShareProjectResponse {
1978        project_id: project_id.to_proto(),
1979    })?;
1980    room_updated(&room, &session.peer);
1981
1982    Ok(())
1983}
1984
1985/// Unshare a project from the room.
1986async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1987    let project_id = ProjectId::from_proto(message.project_id);
1988    unshare_project_internal(
1989        project_id,
1990        session.connection_id,
1991        session.user_id(),
1992        &session,
1993    )
1994    .await
1995}
1996
1997async fn unshare_project_internal(
1998    project_id: ProjectId,
1999    connection_id: ConnectionId,
2000    user_id: Option<UserId>,
2001    session: &Session,
2002) -> Result<()> {
2003    let (room, guest_connection_ids) = &*session
2004        .db()
2005        .await
2006        .unshare_project(project_id, connection_id, user_id)
2007        .await?;
2008
2009    let message = proto::UnshareProject {
2010        project_id: project_id.to_proto(),
2011    };
2012
2013    broadcast(
2014        Some(connection_id),
2015        guest_connection_ids.iter().copied(),
2016        |conn_id| session.peer.send(conn_id, message.clone()),
2017    );
2018    if let Some(room) = room {
2019        room_updated(room, &session.peer);
2020    }
2021
2022    Ok(())
2023}
2024
2025/// DevServer makes a project available online
2026async fn share_remote_project(
2027    request: proto::ShareRemoteProject,
2028    response: Response<proto::ShareRemoteProject>,
2029    session: DevServerSession,
2030) -> Result<()> {
2031    let (remote_project, user_id, status) = session
2032        .db()
2033        .await
2034        .share_remote_project(
2035            RemoteProjectId::from_proto(request.remote_project_id),
2036            session.dev_server_id(),
2037            session.connection_id,
2038            &request.worktrees,
2039        )
2040        .await?;
2041    let Some(project_id) = remote_project.project_id else {
2042        return Err(anyhow!("failed to share remote project"))?;
2043    };
2044
2045    send_remote_projects_update(user_id, status, &session).await;
2046
2047    response.send(proto::ShareProjectResponse { project_id })?;
2048
2049    Ok(())
2050}
2051
2052/// Join someone elses shared project.
2053async fn join_project(
2054    request: proto::JoinProject,
2055    response: Response<proto::JoinProject>,
2056    session: UserSession,
2057) -> Result<()> {
2058    let project_id = ProjectId::from_proto(request.project_id);
2059
2060    tracing::info!(%project_id, "join project");
2061
2062    let db = session.db().await;
2063    let (project, replica_id) = &mut *db
2064        .join_project(project_id, session.connection_id, session.user_id())
2065        .await?;
2066    drop(db);
2067    tracing::info!(%project_id, "join remote project");
2068    join_project_internal(response, session, project, replica_id)
2069}
2070
2071trait JoinProjectInternalResponse {
2072    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2073}
2074impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2075    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2076        Response::<proto::JoinProject>::send(self, result)
2077    }
2078}
2079impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2080    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2081        Response::<proto::JoinHostedProject>::send(self, result)
2082    }
2083}
2084
2085fn join_project_internal(
2086    response: impl JoinProjectInternalResponse,
2087    session: UserSession,
2088    project: &mut Project,
2089    replica_id: &ReplicaId,
2090) -> Result<()> {
2091    let collaborators = project
2092        .collaborators
2093        .iter()
2094        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2095        .map(|collaborator| collaborator.to_proto())
2096        .collect::<Vec<_>>();
2097    let project_id = project.id;
2098    let guest_user_id = session.user_id();
2099
2100    let worktrees = project
2101        .worktrees
2102        .iter()
2103        .map(|(id, worktree)| proto::WorktreeMetadata {
2104            id: *id,
2105            root_name: worktree.root_name.clone(),
2106            visible: worktree.visible,
2107            abs_path: worktree.abs_path.clone(),
2108        })
2109        .collect::<Vec<_>>();
2110
2111    let add_project_collaborator = proto::AddProjectCollaborator {
2112        project_id: project_id.to_proto(),
2113        collaborator: Some(proto::Collaborator {
2114            peer_id: Some(session.connection_id.into()),
2115            replica_id: replica_id.0 as u32,
2116            user_id: guest_user_id.to_proto(),
2117        }),
2118    };
2119
2120    for collaborator in &collaborators {
2121        session
2122            .peer
2123            .send(
2124                collaborator.peer_id.unwrap().into(),
2125                add_project_collaborator.clone(),
2126            )
2127            .trace_err();
2128    }
2129
2130    // First, we send the metadata associated with each worktree.
2131    response.send(proto::JoinProjectResponse {
2132        project_id: project.id.0 as u64,
2133        worktrees: worktrees.clone(),
2134        replica_id: replica_id.0 as u32,
2135        collaborators: collaborators.clone(),
2136        language_servers: project.language_servers.clone(),
2137        role: project.role.into(),
2138        remote_project_id: project
2139            .remote_project_id
2140            .map(|remote_project_id| remote_project_id.0 as u64),
2141    })?;
2142
2143    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2144        #[cfg(any(test, feature = "test-support"))]
2145        const MAX_CHUNK_SIZE: usize = 2;
2146        #[cfg(not(any(test, feature = "test-support")))]
2147        const MAX_CHUNK_SIZE: usize = 256;
2148
2149        // Stream this worktree's entries.
2150        let message = proto::UpdateWorktree {
2151            project_id: project_id.to_proto(),
2152            worktree_id,
2153            abs_path: worktree.abs_path.clone(),
2154            root_name: worktree.root_name,
2155            updated_entries: worktree.entries,
2156            removed_entries: Default::default(),
2157            scan_id: worktree.scan_id,
2158            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2159            updated_repositories: worktree.repository_entries.into_values().collect(),
2160            removed_repositories: Default::default(),
2161        };
2162        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2163            session.peer.send(session.connection_id, update.clone())?;
2164        }
2165
2166        // Stream this worktree's diagnostics.
2167        for summary in worktree.diagnostic_summaries {
2168            session.peer.send(
2169                session.connection_id,
2170                proto::UpdateDiagnosticSummary {
2171                    project_id: project_id.to_proto(),
2172                    worktree_id: worktree.id,
2173                    summary: Some(summary),
2174                },
2175            )?;
2176        }
2177
2178        for settings_file in worktree.settings_files {
2179            session.peer.send(
2180                session.connection_id,
2181                proto::UpdateWorktreeSettings {
2182                    project_id: project_id.to_proto(),
2183                    worktree_id: worktree.id,
2184                    path: settings_file.path,
2185                    content: Some(settings_file.content),
2186                },
2187            )?;
2188        }
2189    }
2190
2191    for language_server in &project.language_servers {
2192        session.peer.send(
2193            session.connection_id,
2194            proto::UpdateLanguageServer {
2195                project_id: project_id.to_proto(),
2196                language_server_id: language_server.id,
2197                variant: Some(
2198                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2199                        proto::LspDiskBasedDiagnosticsUpdated {},
2200                    ),
2201                ),
2202            },
2203        )?;
2204    }
2205
2206    Ok(())
2207}
2208
2209/// Leave someone elses shared project.
2210async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2211    let sender_id = session.connection_id;
2212    let project_id = ProjectId::from_proto(request.project_id);
2213    let db = session.db().await;
2214    if db.is_hosted_project(project_id).await? {
2215        let project = db.leave_hosted_project(project_id, sender_id).await?;
2216        project_left(&project, &session);
2217        return Ok(());
2218    }
2219
2220    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2221    tracing::info!(
2222        %project_id,
2223        "leave project"
2224    );
2225
2226    project_left(&project, &session);
2227    if let Some(room) = room {
2228        room_updated(&room, &session.peer);
2229    }
2230
2231    Ok(())
2232}
2233
2234async fn join_hosted_project(
2235    request: proto::JoinHostedProject,
2236    response: Response<proto::JoinHostedProject>,
2237    session: UserSession,
2238) -> Result<()> {
2239    let (mut project, replica_id) = session
2240        .db()
2241        .await
2242        .join_hosted_project(
2243            ProjectId(request.project_id as i32),
2244            session.user_id(),
2245            session.connection_id,
2246        )
2247        .await?;
2248
2249    join_project_internal(response, session, &mut project, &replica_id)
2250}
2251
2252async fn create_remote_project(
2253    request: proto::CreateRemoteProject,
2254    response: Response<proto::CreateRemoteProject>,
2255    session: UserSession,
2256) -> Result<()> {
2257    let dev_server_id = DevServerId(request.dev_server_id as i32);
2258    let dev_server_connection_id = session
2259        .connection_pool()
2260        .await
2261        .dev_server_connection_id(dev_server_id);
2262    let Some(dev_server_connection_id) = dev_server_connection_id else {
2263        Err(ErrorCode::DevServerOffline
2264            .message("Cannot create a remote project when the dev server is offline".to_string())
2265            .anyhow())?
2266    };
2267
2268    let path = request.path.clone();
2269    //Check that the path exists on the dev server
2270    session
2271        .peer
2272        .forward_request(
2273            session.connection_id,
2274            dev_server_connection_id,
2275            proto::ValidateRemoteProjectRequest { path: path.clone() },
2276        )
2277        .await?;
2278
2279    let (remote_project, update) = session
2280        .db()
2281        .await
2282        .create_remote_project(
2283            DevServerId(request.dev_server_id as i32),
2284            &request.path,
2285            session.user_id(),
2286        )
2287        .await?;
2288
2289    let projects = session
2290        .db()
2291        .await
2292        .get_remote_projects_for_dev_server(remote_project.dev_server_id)
2293        .await?;
2294
2295    session.peer.send(
2296        dev_server_connection_id,
2297        proto::DevServerInstructions { projects },
2298    )?;
2299
2300    send_remote_projects_update(session.user_id(), update, &session).await;
2301
2302    response.send(proto::CreateRemoteProjectResponse {
2303        remote_project: Some(remote_project.to_proto(None)),
2304    })?;
2305    Ok(())
2306}
2307
2308async fn create_dev_server(
2309    request: proto::CreateDevServer,
2310    response: Response<proto::CreateDevServer>,
2311    session: UserSession,
2312) -> Result<()> {
2313    let access_token = auth::random_token();
2314    let hashed_access_token = auth::hash_access_token(&access_token);
2315
2316    let (dev_server, status) = session
2317        .db()
2318        .await
2319        .create_dev_server(&request.name, &hashed_access_token, session.user_id())
2320        .await?;
2321
2322    send_remote_projects_update(session.user_id(), status, &session).await;
2323
2324    response.send(proto::CreateDevServerResponse {
2325        dev_server_id: dev_server.id.0 as u64,
2326        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2327        name: request.name.clone(),
2328    })?;
2329    Ok(())
2330}
2331
2332async fn delete_dev_server(
2333    request: proto::DeleteDevServer,
2334    response: Response<proto::DeleteDevServer>,
2335    session: UserSession,
2336) -> Result<()> {
2337    let dev_server_id = DevServerId(request.dev_server_id as i32);
2338    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2339    if dev_server.user_id != session.user_id() {
2340        return Err(anyhow!(ErrorCode::Forbidden))?;
2341    }
2342
2343    let connection_id = session
2344        .connection_pool()
2345        .await
2346        .dev_server_connection_id(dev_server_id);
2347    if let Some(connection_id) = connection_id {
2348        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2349        session
2350            .peer
2351            .send(connection_id, proto::ShutdownDevServer {})?;
2352    }
2353
2354    let status = session
2355        .db()
2356        .await
2357        .delete_dev_server(dev_server_id, session.user_id())
2358        .await?;
2359
2360    send_remote_projects_update(session.user_id(), status, &session).await;
2361
2362    response.send(proto::Ack {})?;
2363    Ok(())
2364}
2365
2366async fn rejoin_remote_projects(
2367    request: proto::RejoinRemoteProjects,
2368    response: Response<proto::RejoinRemoteProjects>,
2369    session: UserSession,
2370) -> Result<()> {
2371    let mut rejoined_projects = {
2372        let db = session.db().await;
2373        db.rejoin_remote_projects(
2374            &request.rejoined_projects,
2375            session.user_id(),
2376            session.0.connection_id,
2377        )
2378        .await?
2379    };
2380    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2381
2382    response.send(proto::RejoinRemoteProjectsResponse {
2383        rejoined_projects: rejoined_projects
2384            .into_iter()
2385            .map(|project| project.to_proto())
2386            .collect(),
2387    })
2388}
2389
2390async fn reconnect_dev_server(
2391    request: proto::ReconnectDevServer,
2392    response: Response<proto::ReconnectDevServer>,
2393    session: DevServerSession,
2394) -> Result<()> {
2395    let reshared_projects = {
2396        let db = session.db().await;
2397        db.reshare_remote_projects(
2398            &request.reshared_projects,
2399            session.dev_server_id(),
2400            session.0.connection_id,
2401        )
2402        .await?
2403    };
2404
2405    for project in &reshared_projects {
2406        for collaborator in &project.collaborators {
2407            session
2408                .peer
2409                .send(
2410                    collaborator.connection_id,
2411                    proto::UpdateProjectCollaborator {
2412                        project_id: project.id.to_proto(),
2413                        old_peer_id: Some(project.old_connection_id.into()),
2414                        new_peer_id: Some(session.connection_id.into()),
2415                    },
2416                )
2417                .trace_err();
2418        }
2419
2420        broadcast(
2421            Some(session.connection_id),
2422            project
2423                .collaborators
2424                .iter()
2425                .map(|collaborator| collaborator.connection_id),
2426            |connection_id| {
2427                session.peer.forward_send(
2428                    session.connection_id,
2429                    connection_id,
2430                    proto::UpdateProject {
2431                        project_id: project.id.to_proto(),
2432                        worktrees: project.worktrees.clone(),
2433                    },
2434                )
2435            },
2436        );
2437    }
2438
2439    response.send(proto::ReconnectDevServerResponse {
2440        reshared_projects: reshared_projects
2441            .iter()
2442            .map(|project| proto::ResharedProject {
2443                id: project.id.to_proto(),
2444                collaborators: project
2445                    .collaborators
2446                    .iter()
2447                    .map(|collaborator| collaborator.to_proto())
2448                    .collect(),
2449            })
2450            .collect(),
2451    })?;
2452
2453    Ok(())
2454}
2455
2456async fn shutdown_dev_server(
2457    _: proto::ShutdownDevServer,
2458    response: Response<proto::ShutdownDevServer>,
2459    session: DevServerSession,
2460) -> Result<()> {
2461    response.send(proto::Ack {})?;
2462    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await
2463}
2464
2465async fn shutdown_dev_server_internal(
2466    dev_server_id: DevServerId,
2467    connection_id: ConnectionId,
2468    session: &Session,
2469) -> Result<()> {
2470    let (remote_projects, dev_server) = {
2471        let db = session.db().await;
2472        let remote_projects = db.get_remote_projects_for_dev_server(dev_server_id).await?;
2473        let dev_server = db.get_dev_server(dev_server_id).await?;
2474        (remote_projects, dev_server)
2475    };
2476
2477    for project_id in remote_projects.iter().filter_map(|p| p.project_id) {
2478        unshare_project_internal(
2479            ProjectId::from_proto(project_id),
2480            connection_id,
2481            None,
2482            session,
2483        )
2484        .await?;
2485    }
2486
2487    session
2488        .connection_pool()
2489        .await
2490        .set_dev_server_offline(dev_server_id);
2491
2492    let status = session
2493        .db()
2494        .await
2495        .remote_projects_update(dev_server.user_id)
2496        .await?;
2497    send_remote_projects_update(dev_server.user_id, status, &session).await;
2498
2499    Ok(())
2500}
2501
2502/// Updates other participants with changes to the project
2503async fn update_project(
2504    request: proto::UpdateProject,
2505    response: Response<proto::UpdateProject>,
2506    session: Session,
2507) -> Result<()> {
2508    let project_id = ProjectId::from_proto(request.project_id);
2509    let (room, guest_connection_ids) = &*session
2510        .db()
2511        .await
2512        .update_project(project_id, session.connection_id, &request.worktrees)
2513        .await?;
2514    broadcast(
2515        Some(session.connection_id),
2516        guest_connection_ids.iter().copied(),
2517        |connection_id| {
2518            session
2519                .peer
2520                .forward_send(session.connection_id, connection_id, request.clone())
2521        },
2522    );
2523    if let Some(room) = room {
2524        room_updated(&room, &session.peer);
2525    }
2526    response.send(proto::Ack {})?;
2527
2528    Ok(())
2529}
2530
2531/// Updates other participants with changes to the worktree
2532async fn update_worktree(
2533    request: proto::UpdateWorktree,
2534    response: Response<proto::UpdateWorktree>,
2535    session: Session,
2536) -> Result<()> {
2537    let guest_connection_ids = session
2538        .db()
2539        .await
2540        .update_worktree(&request, session.connection_id)
2541        .await?;
2542
2543    broadcast(
2544        Some(session.connection_id),
2545        guest_connection_ids.iter().copied(),
2546        |connection_id| {
2547            session
2548                .peer
2549                .forward_send(session.connection_id, connection_id, request.clone())
2550        },
2551    );
2552    response.send(proto::Ack {})?;
2553    Ok(())
2554}
2555
2556/// Updates other participants with changes to the diagnostics
2557async fn update_diagnostic_summary(
2558    message: proto::UpdateDiagnosticSummary,
2559    session: Session,
2560) -> Result<()> {
2561    let guest_connection_ids = session
2562        .db()
2563        .await
2564        .update_diagnostic_summary(&message, session.connection_id)
2565        .await?;
2566
2567    broadcast(
2568        Some(session.connection_id),
2569        guest_connection_ids.iter().copied(),
2570        |connection_id| {
2571            session
2572                .peer
2573                .forward_send(session.connection_id, connection_id, message.clone())
2574        },
2575    );
2576
2577    Ok(())
2578}
2579
2580/// Updates other participants with changes to the worktree settings
2581async fn update_worktree_settings(
2582    message: proto::UpdateWorktreeSettings,
2583    session: Session,
2584) -> Result<()> {
2585    let guest_connection_ids = session
2586        .db()
2587        .await
2588        .update_worktree_settings(&message, session.connection_id)
2589        .await?;
2590
2591    broadcast(
2592        Some(session.connection_id),
2593        guest_connection_ids.iter().copied(),
2594        |connection_id| {
2595            session
2596                .peer
2597                .forward_send(session.connection_id, connection_id, message.clone())
2598        },
2599    );
2600
2601    Ok(())
2602}
2603
2604/// Notify other participants that a  language server has started.
2605async fn start_language_server(
2606    request: proto::StartLanguageServer,
2607    session: Session,
2608) -> Result<()> {
2609    let guest_connection_ids = session
2610        .db()
2611        .await
2612        .start_language_server(&request, session.connection_id)
2613        .await?;
2614
2615    broadcast(
2616        Some(session.connection_id),
2617        guest_connection_ids.iter().copied(),
2618        |connection_id| {
2619            session
2620                .peer
2621                .forward_send(session.connection_id, connection_id, request.clone())
2622        },
2623    );
2624    Ok(())
2625}
2626
2627/// Notify other participants that a language server has changed.
2628async fn update_language_server(
2629    request: proto::UpdateLanguageServer,
2630    session: Session,
2631) -> Result<()> {
2632    let project_id = ProjectId::from_proto(request.project_id);
2633    let project_connection_ids = session
2634        .db()
2635        .await
2636        .project_connection_ids(project_id, session.connection_id, true)
2637        .await?;
2638    broadcast(
2639        Some(session.connection_id),
2640        project_connection_ids.iter().copied(),
2641        |connection_id| {
2642            session
2643                .peer
2644                .forward_send(session.connection_id, connection_id, request.clone())
2645        },
2646    );
2647    Ok(())
2648}
2649
2650/// forward a project request to the host. These requests should be read only
2651/// as guests are allowed to send them.
2652async fn forward_read_only_project_request<T>(
2653    request: T,
2654    response: Response<T>,
2655    session: UserSession,
2656) -> Result<()>
2657where
2658    T: EntityMessage + RequestMessage,
2659{
2660    let project_id = ProjectId::from_proto(request.remote_entity_id());
2661    let host_connection_id = session
2662        .db()
2663        .await
2664        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2665        .await?;
2666    let payload = session
2667        .peer
2668        .forward_request(session.connection_id, host_connection_id, request)
2669        .await?;
2670    response.send(payload)?;
2671    Ok(())
2672}
2673
2674/// forward a project request to the host. These requests are disallowed
2675/// for guests.
2676async fn forward_mutating_project_request<T>(
2677    request: T,
2678    response: Response<T>,
2679    session: UserSession,
2680) -> Result<()>
2681where
2682    T: EntityMessage + RequestMessage,
2683{
2684    let project_id = ProjectId::from_proto(request.remote_entity_id());
2685
2686    let host_connection_id = session
2687        .db()
2688        .await
2689        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2690        .await?;
2691    let payload = session
2692        .peer
2693        .forward_request(session.connection_id, host_connection_id, request)
2694        .await?;
2695    response.send(payload)?;
2696    Ok(())
2697}
2698
2699/// forward a project request to the host. These requests are disallowed
2700/// for guests.
2701async fn forward_versioned_mutating_project_request<T>(
2702    request: T,
2703    response: Response<T>,
2704    session: UserSession,
2705) -> Result<()>
2706where
2707    T: EntityMessage + RequestMessage + VersionedMessage,
2708{
2709    let project_id = ProjectId::from_proto(request.remote_entity_id());
2710
2711    let host_connection_id = session
2712        .db()
2713        .await
2714        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2715        .await?;
2716    if let Some(host_version) = session
2717        .connection_pool()
2718        .await
2719        .connection(host_connection_id)
2720        .map(|c| c.zed_version)
2721    {
2722        if let Some(min_required_version) = request.required_host_version() {
2723            if min_required_version > host_version {
2724                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2725                    .with_tag("required", &min_required_version.to_string())))?;
2726            }
2727        }
2728    }
2729
2730    let payload = session
2731        .peer
2732        .forward_request(session.connection_id, host_connection_id, request)
2733        .await?;
2734    response.send(payload)?;
2735    Ok(())
2736}
2737
2738/// Notify other participants that a new buffer has been created
2739async fn create_buffer_for_peer(
2740    request: proto::CreateBufferForPeer,
2741    session: Session,
2742) -> Result<()> {
2743    session
2744        .db()
2745        .await
2746        .check_user_is_project_host(
2747            ProjectId::from_proto(request.project_id),
2748            session.connection_id,
2749        )
2750        .await?;
2751    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2752    session
2753        .peer
2754        .forward_send(session.connection_id, peer_id.into(), request)?;
2755    Ok(())
2756}
2757
2758/// Notify other participants that a buffer has been updated. This is
2759/// allowed for guests as long as the update is limited to selections.
2760async fn update_buffer(
2761    request: proto::UpdateBuffer,
2762    response: Response<proto::UpdateBuffer>,
2763    session: Session,
2764) -> Result<()> {
2765    let project_id = ProjectId::from_proto(request.project_id);
2766    let mut capability = Capability::ReadOnly;
2767
2768    for op in request.operations.iter() {
2769        match op.variant {
2770            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2771            Some(_) => capability = Capability::ReadWrite,
2772        }
2773    }
2774
2775    let host = {
2776        let guard = session
2777            .db()
2778            .await
2779            .connections_for_buffer_update(
2780                project_id,
2781                session.principal_id(),
2782                session.connection_id,
2783                capability,
2784            )
2785            .await?;
2786
2787        let (host, guests) = &*guard;
2788
2789        broadcast(
2790            Some(session.connection_id),
2791            guests.clone(),
2792            |connection_id| {
2793                session
2794                    .peer
2795                    .forward_send(session.connection_id, connection_id, request.clone())
2796            },
2797        );
2798
2799        *host
2800    };
2801
2802    if host != session.connection_id {
2803        session
2804            .peer
2805            .forward_request(session.connection_id, host, request.clone())
2806            .await?;
2807    }
2808
2809    response.send(proto::Ack {})?;
2810    Ok(())
2811}
2812
2813/// Notify other participants that a project has been updated.
2814async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2815    request: T,
2816    session: Session,
2817) -> Result<()> {
2818    let project_id = ProjectId::from_proto(request.remote_entity_id());
2819    let project_connection_ids = session
2820        .db()
2821        .await
2822        .project_connection_ids(project_id, session.connection_id, false)
2823        .await?;
2824
2825    broadcast(
2826        Some(session.connection_id),
2827        project_connection_ids.iter().copied(),
2828        |connection_id| {
2829            session
2830                .peer
2831                .forward_send(session.connection_id, connection_id, request.clone())
2832        },
2833    );
2834    Ok(())
2835}
2836
2837/// Start following another user in a call.
2838async fn follow(
2839    request: proto::Follow,
2840    response: Response<proto::Follow>,
2841    session: UserSession,
2842) -> Result<()> {
2843    let room_id = RoomId::from_proto(request.room_id);
2844    let project_id = request.project_id.map(ProjectId::from_proto);
2845    let leader_id = request
2846        .leader_id
2847        .ok_or_else(|| anyhow!("invalid leader id"))?
2848        .into();
2849    let follower_id = session.connection_id;
2850
2851    session
2852        .db()
2853        .await
2854        .check_room_participants(room_id, leader_id, session.connection_id)
2855        .await?;
2856
2857    let response_payload = session
2858        .peer
2859        .forward_request(session.connection_id, leader_id, request)
2860        .await?;
2861    response.send(response_payload)?;
2862
2863    if let Some(project_id) = project_id {
2864        let room = session
2865            .db()
2866            .await
2867            .follow(room_id, project_id, leader_id, follower_id)
2868            .await?;
2869        room_updated(&room, &session.peer);
2870    }
2871
2872    Ok(())
2873}
2874
2875/// Stop following another user in a call.
2876async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2877    let room_id = RoomId::from_proto(request.room_id);
2878    let project_id = request.project_id.map(ProjectId::from_proto);
2879    let leader_id = request
2880        .leader_id
2881        .ok_or_else(|| anyhow!("invalid leader id"))?
2882        .into();
2883    let follower_id = session.connection_id;
2884
2885    session
2886        .db()
2887        .await
2888        .check_room_participants(room_id, leader_id, session.connection_id)
2889        .await?;
2890
2891    session
2892        .peer
2893        .forward_send(session.connection_id, leader_id, request)?;
2894
2895    if let Some(project_id) = project_id {
2896        let room = session
2897            .db()
2898            .await
2899            .unfollow(room_id, project_id, leader_id, follower_id)
2900            .await?;
2901        room_updated(&room, &session.peer);
2902    }
2903
2904    Ok(())
2905}
2906
2907/// Notify everyone following you of your current location.
2908async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2909    let room_id = RoomId::from_proto(request.room_id);
2910    let database = session.db.lock().await;
2911
2912    let connection_ids = if let Some(project_id) = request.project_id {
2913        let project_id = ProjectId::from_proto(project_id);
2914        database
2915            .project_connection_ids(project_id, session.connection_id, true)
2916            .await?
2917    } else {
2918        database
2919            .room_connection_ids(room_id, session.connection_id)
2920            .await?
2921    };
2922
2923    // For now, don't send view update messages back to that view's current leader.
2924    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2925        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2926        _ => None,
2927    });
2928
2929    for connection_id in connection_ids.iter().cloned() {
2930        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2931            session
2932                .peer
2933                .forward_send(session.connection_id, connection_id, request.clone())?;
2934        }
2935    }
2936    Ok(())
2937}
2938
2939/// Get public data about users.
2940async fn get_users(
2941    request: proto::GetUsers,
2942    response: Response<proto::GetUsers>,
2943    session: Session,
2944) -> Result<()> {
2945    let user_ids = request
2946        .user_ids
2947        .into_iter()
2948        .map(UserId::from_proto)
2949        .collect();
2950    let users = session
2951        .db()
2952        .await
2953        .get_users_by_ids(user_ids)
2954        .await?
2955        .into_iter()
2956        .map(|user| proto::User {
2957            id: user.id.to_proto(),
2958            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2959            github_login: user.github_login,
2960        })
2961        .collect();
2962    response.send(proto::UsersResponse { users })?;
2963    Ok(())
2964}
2965
2966/// Search for users (to invite) buy Github login
2967async fn fuzzy_search_users(
2968    request: proto::FuzzySearchUsers,
2969    response: Response<proto::FuzzySearchUsers>,
2970    session: UserSession,
2971) -> Result<()> {
2972    let query = request.query;
2973    let users = match query.len() {
2974        0 => vec![],
2975        1 | 2 => session
2976            .db()
2977            .await
2978            .get_user_by_github_login(&query)
2979            .await?
2980            .into_iter()
2981            .collect(),
2982        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2983    };
2984    let users = users
2985        .into_iter()
2986        .filter(|user| user.id != session.user_id())
2987        .map(|user| proto::User {
2988            id: user.id.to_proto(),
2989            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2990            github_login: user.github_login,
2991        })
2992        .collect();
2993    response.send(proto::UsersResponse { users })?;
2994    Ok(())
2995}
2996
2997/// Send a contact request to another user.
2998async fn request_contact(
2999    request: proto::RequestContact,
3000    response: Response<proto::RequestContact>,
3001    session: UserSession,
3002) -> Result<()> {
3003    let requester_id = session.user_id();
3004    let responder_id = UserId::from_proto(request.responder_id);
3005    if requester_id == responder_id {
3006        return Err(anyhow!("cannot add yourself as a contact"))?;
3007    }
3008
3009    let notifications = session
3010        .db()
3011        .await
3012        .send_contact_request(requester_id, responder_id)
3013        .await?;
3014
3015    // Update outgoing contact requests of requester
3016    let mut update = proto::UpdateContacts::default();
3017    update.outgoing_requests.push(responder_id.to_proto());
3018    for connection_id in session
3019        .connection_pool()
3020        .await
3021        .user_connection_ids(requester_id)
3022    {
3023        session.peer.send(connection_id, update.clone())?;
3024    }
3025
3026    // Update incoming contact requests of responder
3027    let mut update = proto::UpdateContacts::default();
3028    update
3029        .incoming_requests
3030        .push(proto::IncomingContactRequest {
3031            requester_id: requester_id.to_proto(),
3032        });
3033    let connection_pool = session.connection_pool().await;
3034    for connection_id in connection_pool.user_connection_ids(responder_id) {
3035        session.peer.send(connection_id, update.clone())?;
3036    }
3037
3038    send_notifications(&connection_pool, &session.peer, notifications);
3039
3040    response.send(proto::Ack {})?;
3041    Ok(())
3042}
3043
3044/// Accept or decline a contact request
3045async fn respond_to_contact_request(
3046    request: proto::RespondToContactRequest,
3047    response: Response<proto::RespondToContactRequest>,
3048    session: UserSession,
3049) -> Result<()> {
3050    let responder_id = session.user_id();
3051    let requester_id = UserId::from_proto(request.requester_id);
3052    let db = session.db().await;
3053    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3054        db.dismiss_contact_notification(responder_id, requester_id)
3055            .await?;
3056    } else {
3057        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3058
3059        let notifications = db
3060            .respond_to_contact_request(responder_id, requester_id, accept)
3061            .await?;
3062        let requester_busy = db.is_user_busy(requester_id).await?;
3063        let responder_busy = db.is_user_busy(responder_id).await?;
3064
3065        let pool = session.connection_pool().await;
3066        // Update responder with new contact
3067        let mut update = proto::UpdateContacts::default();
3068        if accept {
3069            update
3070                .contacts
3071                .push(contact_for_user(requester_id, requester_busy, &pool));
3072        }
3073        update
3074            .remove_incoming_requests
3075            .push(requester_id.to_proto());
3076        for connection_id in pool.user_connection_ids(responder_id) {
3077            session.peer.send(connection_id, update.clone())?;
3078        }
3079
3080        // Update requester with new contact
3081        let mut update = proto::UpdateContacts::default();
3082        if accept {
3083            update
3084                .contacts
3085                .push(contact_for_user(responder_id, responder_busy, &pool));
3086        }
3087        update
3088            .remove_outgoing_requests
3089            .push(responder_id.to_proto());
3090
3091        for connection_id in pool.user_connection_ids(requester_id) {
3092            session.peer.send(connection_id, update.clone())?;
3093        }
3094
3095        send_notifications(&pool, &session.peer, notifications);
3096    }
3097
3098    response.send(proto::Ack {})?;
3099    Ok(())
3100}
3101
3102/// Remove a contact.
3103async fn remove_contact(
3104    request: proto::RemoveContact,
3105    response: Response<proto::RemoveContact>,
3106    session: UserSession,
3107) -> Result<()> {
3108    let requester_id = session.user_id();
3109    let responder_id = UserId::from_proto(request.user_id);
3110    let db = session.db().await;
3111    let (contact_accepted, deleted_notification_id) =
3112        db.remove_contact(requester_id, responder_id).await?;
3113
3114    let pool = session.connection_pool().await;
3115    // Update outgoing contact requests of requester
3116    let mut update = proto::UpdateContacts::default();
3117    if contact_accepted {
3118        update.remove_contacts.push(responder_id.to_proto());
3119    } else {
3120        update
3121            .remove_outgoing_requests
3122            .push(responder_id.to_proto());
3123    }
3124    for connection_id in pool.user_connection_ids(requester_id) {
3125        session.peer.send(connection_id, update.clone())?;
3126    }
3127
3128    // Update incoming contact requests of responder
3129    let mut update = proto::UpdateContacts::default();
3130    if contact_accepted {
3131        update.remove_contacts.push(requester_id.to_proto());
3132    } else {
3133        update
3134            .remove_incoming_requests
3135            .push(requester_id.to_proto());
3136    }
3137    for connection_id in pool.user_connection_ids(responder_id) {
3138        session.peer.send(connection_id, update.clone())?;
3139        if let Some(notification_id) = deleted_notification_id {
3140            session.peer.send(
3141                connection_id,
3142                proto::DeleteNotification {
3143                    notification_id: notification_id.to_proto(),
3144                },
3145            )?;
3146        }
3147    }
3148
3149    response.send(proto::Ack {})?;
3150    Ok(())
3151}
3152
3153/// Creates a new channel.
3154async fn create_channel(
3155    request: proto::CreateChannel,
3156    response: Response<proto::CreateChannel>,
3157    session: UserSession,
3158) -> Result<()> {
3159    let db = session.db().await;
3160
3161    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3162    let (channel, membership) = db
3163        .create_channel(&request.name, parent_id, session.user_id())
3164        .await?;
3165
3166    let root_id = channel.root_id();
3167    let channel = Channel::from_model(channel);
3168
3169    response.send(proto::CreateChannelResponse {
3170        channel: Some(channel.to_proto()),
3171        parent_id: request.parent_id,
3172    })?;
3173
3174    let mut connection_pool = session.connection_pool().await;
3175    if let Some(membership) = membership {
3176        connection_pool.subscribe_to_channel(
3177            membership.user_id,
3178            membership.channel_id,
3179            membership.role,
3180        );
3181        let update = proto::UpdateUserChannels {
3182            channel_memberships: vec![proto::ChannelMembership {
3183                channel_id: membership.channel_id.to_proto(),
3184                role: membership.role.into(),
3185            }],
3186            ..Default::default()
3187        };
3188        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3189            session.peer.send(connection_id, update.clone())?;
3190        }
3191    }
3192
3193    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3194        if !role.can_see_channel(channel.visibility) {
3195            continue;
3196        }
3197
3198        let update = proto::UpdateChannels {
3199            channels: vec![channel.to_proto()],
3200            ..Default::default()
3201        };
3202        session.peer.send(connection_id, update.clone())?;
3203    }
3204
3205    Ok(())
3206}
3207
3208/// Delete a channel
3209async fn delete_channel(
3210    request: proto::DeleteChannel,
3211    response: Response<proto::DeleteChannel>,
3212    session: UserSession,
3213) -> Result<()> {
3214    let db = session.db().await;
3215
3216    let channel_id = request.channel_id;
3217    let (root_channel, removed_channels) = db
3218        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3219        .await?;
3220    response.send(proto::Ack {})?;
3221
3222    // Notify members of removed channels
3223    let mut update = proto::UpdateChannels::default();
3224    update
3225        .delete_channels
3226        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3227
3228    let connection_pool = session.connection_pool().await;
3229    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3230        session.peer.send(connection_id, update.clone())?;
3231    }
3232
3233    Ok(())
3234}
3235
3236/// Invite someone to join a channel.
3237async fn invite_channel_member(
3238    request: proto::InviteChannelMember,
3239    response: Response<proto::InviteChannelMember>,
3240    session: UserSession,
3241) -> Result<()> {
3242    let db = session.db().await;
3243    let channel_id = ChannelId::from_proto(request.channel_id);
3244    let invitee_id = UserId::from_proto(request.user_id);
3245    let InviteMemberResult {
3246        channel,
3247        notifications,
3248    } = db
3249        .invite_channel_member(
3250            channel_id,
3251            invitee_id,
3252            session.user_id(),
3253            request.role().into(),
3254        )
3255        .await?;
3256
3257    let update = proto::UpdateChannels {
3258        channel_invitations: vec![channel.to_proto()],
3259        ..Default::default()
3260    };
3261
3262    let connection_pool = session.connection_pool().await;
3263    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3264        session.peer.send(connection_id, update.clone())?;
3265    }
3266
3267    send_notifications(&connection_pool, &session.peer, notifications);
3268
3269    response.send(proto::Ack {})?;
3270    Ok(())
3271}
3272
3273/// remove someone from a channel
3274async fn remove_channel_member(
3275    request: proto::RemoveChannelMember,
3276    response: Response<proto::RemoveChannelMember>,
3277    session: UserSession,
3278) -> Result<()> {
3279    let db = session.db().await;
3280    let channel_id = ChannelId::from_proto(request.channel_id);
3281    let member_id = UserId::from_proto(request.user_id);
3282
3283    let RemoveChannelMemberResult {
3284        membership_update,
3285        notification_id,
3286    } = db
3287        .remove_channel_member(channel_id, member_id, session.user_id())
3288        .await?;
3289
3290    let mut connection_pool = session.connection_pool().await;
3291    notify_membership_updated(
3292        &mut connection_pool,
3293        membership_update,
3294        member_id,
3295        &session.peer,
3296    );
3297    for connection_id in connection_pool.user_connection_ids(member_id) {
3298        if let Some(notification_id) = notification_id {
3299            session
3300                .peer
3301                .send(
3302                    connection_id,
3303                    proto::DeleteNotification {
3304                        notification_id: notification_id.to_proto(),
3305                    },
3306                )
3307                .trace_err();
3308        }
3309    }
3310
3311    response.send(proto::Ack {})?;
3312    Ok(())
3313}
3314
3315/// Toggle the channel between public and private.
3316/// Care is taken to maintain the invariant that public channels only descend from public channels,
3317/// (though members-only channels can appear at any point in the hierarchy).
3318async fn set_channel_visibility(
3319    request: proto::SetChannelVisibility,
3320    response: Response<proto::SetChannelVisibility>,
3321    session: UserSession,
3322) -> Result<()> {
3323    let db = session.db().await;
3324    let channel_id = ChannelId::from_proto(request.channel_id);
3325    let visibility = request.visibility().into();
3326
3327    let channel_model = db
3328        .set_channel_visibility(channel_id, visibility, session.user_id())
3329        .await?;
3330    let root_id = channel_model.root_id();
3331    let channel = Channel::from_model(channel_model);
3332
3333    let mut connection_pool = session.connection_pool().await;
3334    for (user_id, role) in connection_pool
3335        .channel_user_ids(root_id)
3336        .collect::<Vec<_>>()
3337        .into_iter()
3338    {
3339        let update = if role.can_see_channel(channel.visibility) {
3340            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3341            proto::UpdateChannels {
3342                channels: vec![channel.to_proto()],
3343                ..Default::default()
3344            }
3345        } else {
3346            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3347            proto::UpdateChannels {
3348                delete_channels: vec![channel.id.to_proto()],
3349                ..Default::default()
3350            }
3351        };
3352
3353        for connection_id in connection_pool.user_connection_ids(user_id) {
3354            session.peer.send(connection_id, update.clone())?;
3355        }
3356    }
3357
3358    response.send(proto::Ack {})?;
3359    Ok(())
3360}
3361
3362/// Alter the role for a user in the channel.
3363async fn set_channel_member_role(
3364    request: proto::SetChannelMemberRole,
3365    response: Response<proto::SetChannelMemberRole>,
3366    session: UserSession,
3367) -> Result<()> {
3368    let db = session.db().await;
3369    let channel_id = ChannelId::from_proto(request.channel_id);
3370    let member_id = UserId::from_proto(request.user_id);
3371    let result = db
3372        .set_channel_member_role(
3373            channel_id,
3374            session.user_id(),
3375            member_id,
3376            request.role().into(),
3377        )
3378        .await?;
3379
3380    match result {
3381        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3382            let mut connection_pool = session.connection_pool().await;
3383            notify_membership_updated(
3384                &mut connection_pool,
3385                membership_update,
3386                member_id,
3387                &session.peer,
3388            )
3389        }
3390        db::SetMemberRoleResult::InviteUpdated(channel) => {
3391            let update = proto::UpdateChannels {
3392                channel_invitations: vec![channel.to_proto()],
3393                ..Default::default()
3394            };
3395
3396            for connection_id in session
3397                .connection_pool()
3398                .await
3399                .user_connection_ids(member_id)
3400            {
3401                session.peer.send(connection_id, update.clone())?;
3402            }
3403        }
3404    }
3405
3406    response.send(proto::Ack {})?;
3407    Ok(())
3408}
3409
3410/// Change the name of a channel
3411async fn rename_channel(
3412    request: proto::RenameChannel,
3413    response: Response<proto::RenameChannel>,
3414    session: UserSession,
3415) -> Result<()> {
3416    let db = session.db().await;
3417    let channel_id = ChannelId::from_proto(request.channel_id);
3418    let channel_model = db
3419        .rename_channel(channel_id, session.user_id(), &request.name)
3420        .await?;
3421    let root_id = channel_model.root_id();
3422    let channel = Channel::from_model(channel_model);
3423
3424    response.send(proto::RenameChannelResponse {
3425        channel: Some(channel.to_proto()),
3426    })?;
3427
3428    let connection_pool = session.connection_pool().await;
3429    let update = proto::UpdateChannels {
3430        channels: vec![channel.to_proto()],
3431        ..Default::default()
3432    };
3433    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3434        if role.can_see_channel(channel.visibility) {
3435            session.peer.send(connection_id, update.clone())?;
3436        }
3437    }
3438
3439    Ok(())
3440}
3441
3442/// Move a channel to a new parent.
3443async fn move_channel(
3444    request: proto::MoveChannel,
3445    response: Response<proto::MoveChannel>,
3446    session: UserSession,
3447) -> Result<()> {
3448    let channel_id = ChannelId::from_proto(request.channel_id);
3449    let to = ChannelId::from_proto(request.to);
3450
3451    let (root_id, channels) = session
3452        .db()
3453        .await
3454        .move_channel(channel_id, to, session.user_id())
3455        .await?;
3456
3457    let connection_pool = session.connection_pool().await;
3458    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3459        let channels = channels
3460            .iter()
3461            .filter_map(|channel| {
3462                if role.can_see_channel(channel.visibility) {
3463                    Some(channel.to_proto())
3464                } else {
3465                    None
3466                }
3467            })
3468            .collect::<Vec<_>>();
3469        if channels.is_empty() {
3470            continue;
3471        }
3472
3473        let update = proto::UpdateChannels {
3474            channels,
3475            ..Default::default()
3476        };
3477
3478        session.peer.send(connection_id, update.clone())?;
3479    }
3480
3481    response.send(Ack {})?;
3482    Ok(())
3483}
3484
3485/// Get the list of channel members
3486async fn get_channel_members(
3487    request: proto::GetChannelMembers,
3488    response: Response<proto::GetChannelMembers>,
3489    session: UserSession,
3490) -> Result<()> {
3491    let db = session.db().await;
3492    let channel_id = ChannelId::from_proto(request.channel_id);
3493    let members = db
3494        .get_channel_participant_details(channel_id, session.user_id())
3495        .await?;
3496    response.send(proto::GetChannelMembersResponse { members })?;
3497    Ok(())
3498}
3499
3500/// Accept or decline a channel invitation.
3501async fn respond_to_channel_invite(
3502    request: proto::RespondToChannelInvite,
3503    response: Response<proto::RespondToChannelInvite>,
3504    session: UserSession,
3505) -> Result<()> {
3506    let db = session.db().await;
3507    let channel_id = ChannelId::from_proto(request.channel_id);
3508    let RespondToChannelInvite {
3509        membership_update,
3510        notifications,
3511    } = db
3512        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3513        .await?;
3514
3515    let mut connection_pool = session.connection_pool().await;
3516    if let Some(membership_update) = membership_update {
3517        notify_membership_updated(
3518            &mut connection_pool,
3519            membership_update,
3520            session.user_id(),
3521            &session.peer,
3522        );
3523    } else {
3524        let update = proto::UpdateChannels {
3525            remove_channel_invitations: vec![channel_id.to_proto()],
3526            ..Default::default()
3527        };
3528
3529        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3530            session.peer.send(connection_id, update.clone())?;
3531        }
3532    };
3533
3534    send_notifications(&connection_pool, &session.peer, notifications);
3535
3536    response.send(proto::Ack {})?;
3537
3538    Ok(())
3539}
3540
3541/// Join the channels' room
3542async fn join_channel(
3543    request: proto::JoinChannel,
3544    response: Response<proto::JoinChannel>,
3545    session: UserSession,
3546) -> Result<()> {
3547    let channel_id = ChannelId::from_proto(request.channel_id);
3548    join_channel_internal(channel_id, Box::new(response), session).await
3549}
3550
3551trait JoinChannelInternalResponse {
3552    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3553}
3554impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3555    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3556        Response::<proto::JoinChannel>::send(self, result)
3557    }
3558}
3559impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3560    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3561        Response::<proto::JoinRoom>::send(self, result)
3562    }
3563}
3564
3565async fn join_channel_internal(
3566    channel_id: ChannelId,
3567    response: Box<impl JoinChannelInternalResponse>,
3568    session: UserSession,
3569) -> Result<()> {
3570    let joined_room = {
3571        let mut db = session.db().await;
3572        // If zed quits without leaving the room, and the user re-opens zed before the
3573        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3574        // room they were in.
3575        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3576            tracing::info!(
3577                stale_connection_id = %connection,
3578                "cleaning up stale connection",
3579            );
3580            drop(db);
3581            leave_room_for_session(&session, connection).await?;
3582            db = session.db().await;
3583        }
3584
3585        let (joined_room, membership_updated, role) = db
3586            .join_channel(channel_id, session.user_id(), session.connection_id)
3587            .await?;
3588
3589        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3590            let (can_publish, token) = if role == ChannelRole::Guest {
3591                (
3592                    false,
3593                    live_kit
3594                        .guest_token(
3595                            &joined_room.room.live_kit_room,
3596                            &session.user_id().to_string(),
3597                        )
3598                        .trace_err()?,
3599                )
3600            } else {
3601                (
3602                    true,
3603                    live_kit
3604                        .room_token(
3605                            &joined_room.room.live_kit_room,
3606                            &session.user_id().to_string(),
3607                        )
3608                        .trace_err()?,
3609                )
3610            };
3611
3612            Some(LiveKitConnectionInfo {
3613                server_url: live_kit.url().into(),
3614                token,
3615                can_publish,
3616            })
3617        });
3618
3619        response.send(proto::JoinRoomResponse {
3620            room: Some(joined_room.room.clone()),
3621            channel_id: joined_room
3622                .channel
3623                .as_ref()
3624                .map(|channel| channel.id.to_proto()),
3625            live_kit_connection_info,
3626        })?;
3627
3628        let mut connection_pool = session.connection_pool().await;
3629        if let Some(membership_updated) = membership_updated {
3630            notify_membership_updated(
3631                &mut connection_pool,
3632                membership_updated,
3633                session.user_id(),
3634                &session.peer,
3635            );
3636        }
3637
3638        room_updated(&joined_room.room, &session.peer);
3639
3640        joined_room
3641    };
3642
3643    channel_updated(
3644        &joined_room
3645            .channel
3646            .ok_or_else(|| anyhow!("channel not returned"))?,
3647        &joined_room.room,
3648        &session.peer,
3649        &*session.connection_pool().await,
3650    );
3651
3652    update_user_contacts(session.user_id(), &session).await?;
3653    Ok(())
3654}
3655
3656/// Start editing the channel notes
3657async fn join_channel_buffer(
3658    request: proto::JoinChannelBuffer,
3659    response: Response<proto::JoinChannelBuffer>,
3660    session: UserSession,
3661) -> Result<()> {
3662    let db = session.db().await;
3663    let channel_id = ChannelId::from_proto(request.channel_id);
3664
3665    let open_response = db
3666        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3667        .await?;
3668
3669    let collaborators = open_response.collaborators.clone();
3670    response.send(open_response)?;
3671
3672    let update = UpdateChannelBufferCollaborators {
3673        channel_id: channel_id.to_proto(),
3674        collaborators: collaborators.clone(),
3675    };
3676    channel_buffer_updated(
3677        session.connection_id,
3678        collaborators
3679            .iter()
3680            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3681        &update,
3682        &session.peer,
3683    );
3684
3685    Ok(())
3686}
3687
3688/// Edit the channel notes
3689async fn update_channel_buffer(
3690    request: proto::UpdateChannelBuffer,
3691    session: UserSession,
3692) -> Result<()> {
3693    let db = session.db().await;
3694    let channel_id = ChannelId::from_proto(request.channel_id);
3695
3696    let (collaborators, non_collaborators, epoch, version) = db
3697        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3698        .await?;
3699
3700    channel_buffer_updated(
3701        session.connection_id,
3702        collaborators,
3703        &proto::UpdateChannelBuffer {
3704            channel_id: channel_id.to_proto(),
3705            operations: request.operations,
3706        },
3707        &session.peer,
3708    );
3709
3710    let pool = &*session.connection_pool().await;
3711
3712    broadcast(
3713        None,
3714        non_collaborators
3715            .iter()
3716            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3717        |peer_id| {
3718            session.peer.send(
3719                peer_id,
3720                proto::UpdateChannels {
3721                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3722                        channel_id: channel_id.to_proto(),
3723                        epoch: epoch as u64,
3724                        version: version.clone(),
3725                    }],
3726                    ..Default::default()
3727                },
3728            )
3729        },
3730    );
3731
3732    Ok(())
3733}
3734
3735/// Rejoin the channel notes after a connection blip
3736async fn rejoin_channel_buffers(
3737    request: proto::RejoinChannelBuffers,
3738    response: Response<proto::RejoinChannelBuffers>,
3739    session: UserSession,
3740) -> Result<()> {
3741    let db = session.db().await;
3742    let buffers = db
3743        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3744        .await?;
3745
3746    for rejoined_buffer in &buffers {
3747        let collaborators_to_notify = rejoined_buffer
3748            .buffer
3749            .collaborators
3750            .iter()
3751            .filter_map(|c| Some(c.peer_id?.into()));
3752        channel_buffer_updated(
3753            session.connection_id,
3754            collaborators_to_notify,
3755            &proto::UpdateChannelBufferCollaborators {
3756                channel_id: rejoined_buffer.buffer.channel_id,
3757                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3758            },
3759            &session.peer,
3760        );
3761    }
3762
3763    response.send(proto::RejoinChannelBuffersResponse {
3764        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3765    })?;
3766
3767    Ok(())
3768}
3769
3770/// Stop editing the channel notes
3771async fn leave_channel_buffer(
3772    request: proto::LeaveChannelBuffer,
3773    response: Response<proto::LeaveChannelBuffer>,
3774    session: UserSession,
3775) -> Result<()> {
3776    let db = session.db().await;
3777    let channel_id = ChannelId::from_proto(request.channel_id);
3778
3779    let left_buffer = db
3780        .leave_channel_buffer(channel_id, session.connection_id)
3781        .await?;
3782
3783    response.send(Ack {})?;
3784
3785    channel_buffer_updated(
3786        session.connection_id,
3787        left_buffer.connections,
3788        &proto::UpdateChannelBufferCollaborators {
3789            channel_id: channel_id.to_proto(),
3790            collaborators: left_buffer.collaborators,
3791        },
3792        &session.peer,
3793    );
3794
3795    Ok(())
3796}
3797
3798fn channel_buffer_updated<T: EnvelopedMessage>(
3799    sender_id: ConnectionId,
3800    collaborators: impl IntoIterator<Item = ConnectionId>,
3801    message: &T,
3802    peer: &Peer,
3803) {
3804    broadcast(Some(sender_id), collaborators, |peer_id| {
3805        peer.send(peer_id, message.clone())
3806    });
3807}
3808
3809fn send_notifications(
3810    connection_pool: &ConnectionPool,
3811    peer: &Peer,
3812    notifications: db::NotificationBatch,
3813) {
3814    for (user_id, notification) in notifications {
3815        for connection_id in connection_pool.user_connection_ids(user_id) {
3816            if let Err(error) = peer.send(
3817                connection_id,
3818                proto::AddNotification {
3819                    notification: Some(notification.clone()),
3820                },
3821            ) {
3822                tracing::error!(
3823                    "failed to send notification to {:?} {}",
3824                    connection_id,
3825                    error
3826                );
3827            }
3828        }
3829    }
3830}
3831
3832/// Send a message to the channel
3833async fn send_channel_message(
3834    request: proto::SendChannelMessage,
3835    response: Response<proto::SendChannelMessage>,
3836    session: UserSession,
3837) -> Result<()> {
3838    // Validate the message body.
3839    let body = request.body.trim().to_string();
3840    if body.len() > MAX_MESSAGE_LEN {
3841        return Err(anyhow!("message is too long"))?;
3842    }
3843    if body.is_empty() {
3844        return Err(anyhow!("message can't be blank"))?;
3845    }
3846
3847    // TODO: adjust mentions if body is trimmed
3848
3849    let timestamp = OffsetDateTime::now_utc();
3850    let nonce = request
3851        .nonce
3852        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3853
3854    let channel_id = ChannelId::from_proto(request.channel_id);
3855    let CreatedChannelMessage {
3856        message_id,
3857        participant_connection_ids,
3858        channel_members,
3859        notifications,
3860    } = session
3861        .db()
3862        .await
3863        .create_channel_message(
3864            channel_id,
3865            session.user_id(),
3866            &body,
3867            &request.mentions,
3868            timestamp,
3869            nonce.clone().into(),
3870            match request.reply_to_message_id {
3871                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3872                None => None,
3873            },
3874        )
3875        .await?;
3876
3877    let message = proto::ChannelMessage {
3878        sender_id: session.user_id().to_proto(),
3879        id: message_id.to_proto(),
3880        body,
3881        mentions: request.mentions,
3882        timestamp: timestamp.unix_timestamp() as u64,
3883        nonce: Some(nonce),
3884        reply_to_message_id: request.reply_to_message_id,
3885        edited_at: None,
3886    };
3887    broadcast(
3888        Some(session.connection_id),
3889        participant_connection_ids,
3890        |connection| {
3891            session.peer.send(
3892                connection,
3893                proto::ChannelMessageSent {
3894                    channel_id: channel_id.to_proto(),
3895                    message: Some(message.clone()),
3896                },
3897            )
3898        },
3899    );
3900    response.send(proto::SendChannelMessageResponse {
3901        message: Some(message),
3902    })?;
3903
3904    let pool = &*session.connection_pool().await;
3905    broadcast(
3906        None,
3907        channel_members
3908            .iter()
3909            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3910        |peer_id| {
3911            session.peer.send(
3912                peer_id,
3913                proto::UpdateChannels {
3914                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3915                        channel_id: channel_id.to_proto(),
3916                        message_id: message_id.to_proto(),
3917                    }],
3918                    ..Default::default()
3919                },
3920            )
3921        },
3922    );
3923    send_notifications(pool, &session.peer, notifications);
3924
3925    Ok(())
3926}
3927
3928/// Delete a channel message
3929async fn remove_channel_message(
3930    request: proto::RemoveChannelMessage,
3931    response: Response<proto::RemoveChannelMessage>,
3932    session: UserSession,
3933) -> Result<()> {
3934    let channel_id = ChannelId::from_proto(request.channel_id);
3935    let message_id = MessageId::from_proto(request.message_id);
3936    let (connection_ids, existing_notification_ids) = session
3937        .db()
3938        .await
3939        .remove_channel_message(channel_id, message_id, session.user_id())
3940        .await?;
3941
3942    broadcast(
3943        Some(session.connection_id),
3944        connection_ids,
3945        move |connection| {
3946            session.peer.send(connection, request.clone())?;
3947
3948            for notification_id in &existing_notification_ids {
3949                session.peer.send(
3950                    connection,
3951                    proto::DeleteNotification {
3952                        notification_id: (*notification_id).to_proto(),
3953                    },
3954                )?;
3955            }
3956
3957            Ok(())
3958        },
3959    );
3960    response.send(proto::Ack {})?;
3961    Ok(())
3962}
3963
3964async fn update_channel_message(
3965    request: proto::UpdateChannelMessage,
3966    response: Response<proto::UpdateChannelMessage>,
3967    session: UserSession,
3968) -> Result<()> {
3969    let channel_id = ChannelId::from_proto(request.channel_id);
3970    let message_id = MessageId::from_proto(request.message_id);
3971    let updated_at = OffsetDateTime::now_utc();
3972    let UpdatedChannelMessage {
3973        message_id,
3974        participant_connection_ids,
3975        notifications,
3976        reply_to_message_id,
3977        timestamp,
3978        deleted_mention_notification_ids,
3979        updated_mention_notifications,
3980    } = session
3981        .db()
3982        .await
3983        .update_channel_message(
3984            channel_id,
3985            message_id,
3986            session.user_id(),
3987            request.body.as_str(),
3988            &request.mentions,
3989            updated_at,
3990        )
3991        .await?;
3992
3993    let nonce = request
3994        .nonce
3995        .clone()
3996        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3997
3998    let message = proto::ChannelMessage {
3999        sender_id: session.user_id().to_proto(),
4000        id: message_id.to_proto(),
4001        body: request.body.clone(),
4002        mentions: request.mentions.clone(),
4003        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4004        nonce: Some(nonce),
4005        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4006        edited_at: Some(updated_at.unix_timestamp() as u64),
4007    };
4008
4009    response.send(proto::Ack {})?;
4010
4011    let pool = &*session.connection_pool().await;
4012    broadcast(
4013        Some(session.connection_id),
4014        participant_connection_ids,
4015        |connection| {
4016            session.peer.send(
4017                connection,
4018                proto::ChannelMessageUpdate {
4019                    channel_id: channel_id.to_proto(),
4020                    message: Some(message.clone()),
4021                },
4022            )?;
4023
4024            for notification_id in &deleted_mention_notification_ids {
4025                session.peer.send(
4026                    connection,
4027                    proto::DeleteNotification {
4028                        notification_id: (*notification_id).to_proto(),
4029                    },
4030                )?;
4031            }
4032
4033            for notification in &updated_mention_notifications {
4034                session.peer.send(
4035                    connection,
4036                    proto::UpdateNotification {
4037                        notification: Some(notification.clone()),
4038                    },
4039                )?;
4040            }
4041
4042            Ok(())
4043        },
4044    );
4045
4046    send_notifications(pool, &session.peer, notifications);
4047
4048    Ok(())
4049}
4050
4051/// Mark a channel message as read
4052async fn acknowledge_channel_message(
4053    request: proto::AckChannelMessage,
4054    session: UserSession,
4055) -> Result<()> {
4056    let channel_id = ChannelId::from_proto(request.channel_id);
4057    let message_id = MessageId::from_proto(request.message_id);
4058    let notifications = session
4059        .db()
4060        .await
4061        .observe_channel_message(channel_id, session.user_id(), message_id)
4062        .await?;
4063    send_notifications(
4064        &*session.connection_pool().await,
4065        &session.peer,
4066        notifications,
4067    );
4068    Ok(())
4069}
4070
4071/// Mark a buffer version as synced
4072async fn acknowledge_buffer_version(
4073    request: proto::AckBufferOperation,
4074    session: UserSession,
4075) -> Result<()> {
4076    let buffer_id = BufferId::from_proto(request.buffer_id);
4077    session
4078        .db()
4079        .await
4080        .observe_buffer_version(
4081            buffer_id,
4082            session.user_id(),
4083            request.epoch as i32,
4084            &request.version,
4085        )
4086        .await?;
4087    Ok(())
4088}
4089
4090struct CompleteWithLanguageModelRateLimit;
4091
4092impl RateLimit for CompleteWithLanguageModelRateLimit {
4093    fn capacity() -> usize {
4094        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4095            .ok()
4096            .and_then(|v| v.parse().ok())
4097            .unwrap_or(120) // Picked arbitrarily
4098    }
4099
4100    fn refill_duration() -> chrono::Duration {
4101        chrono::Duration::hours(1)
4102    }
4103
4104    fn db_name() -> &'static str {
4105        "complete-with-language-model"
4106    }
4107}
4108
4109async fn complete_with_language_model(
4110    request: proto::CompleteWithLanguageModel,
4111    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4112    session: Session,
4113    open_ai_api_key: Option<Arc<str>>,
4114    google_ai_api_key: Option<Arc<str>>,
4115    anthropic_api_key: Option<Arc<str>>,
4116) -> Result<()> {
4117    let Some(session) = session.for_user() else {
4118        return Err(anyhow!("user not found"))?;
4119    };
4120    authorize_access_to_language_models(&session).await?;
4121    session
4122        .rate_limiter
4123        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4124        .await?;
4125
4126    if request.model.starts_with("gpt") {
4127        let api_key =
4128            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4129        complete_with_open_ai(request, response, session, api_key).await?;
4130    } else if request.model.starts_with("gemini") {
4131        let api_key = google_ai_api_key
4132            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4133        complete_with_google_ai(request, response, session, api_key).await?;
4134    } else if request.model.starts_with("claude") {
4135        let api_key = anthropic_api_key
4136            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4137        complete_with_anthropic(request, response, session, api_key).await?;
4138    }
4139
4140    Ok(())
4141}
4142
4143async fn complete_with_open_ai(
4144    request: proto::CompleteWithLanguageModel,
4145    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4146    session: UserSession,
4147    api_key: Arc<str>,
4148) -> Result<()> {
4149    let mut completion_stream = open_ai::stream_completion(
4150        &session.http_client,
4151        OPEN_AI_API_URL,
4152        &api_key,
4153        crate::ai::language_model_request_to_open_ai(request)?,
4154    )
4155    .await
4156    .context("open_ai::stream_completion request failed within collab")?;
4157
4158    while let Some(event) = completion_stream.next().await {
4159        let event = event?;
4160        response.send(proto::LanguageModelResponse {
4161            choices: event
4162                .choices
4163                .into_iter()
4164                .map(|choice| proto::LanguageModelChoiceDelta {
4165                    index: choice.index,
4166                    delta: Some(proto::LanguageModelResponseMessage {
4167                        role: choice.delta.role.map(|role| match role {
4168                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4169                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4170                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4171                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4172                        } as i32),
4173                        content: choice.delta.content,
4174                        tool_calls: choice
4175                            .delta
4176                            .tool_calls
4177                            .into_iter()
4178                            .map(|delta| proto::ToolCallDelta {
4179                                index: delta.index as u32,
4180                                id: delta.id,
4181                                variant: match delta.function {
4182                                    Some(function) => {
4183                                        let name = function.name;
4184                                        let arguments = function.arguments;
4185
4186                                        Some(proto::tool_call_delta::Variant::Function(
4187                                            proto::tool_call_delta::FunctionCallDelta {
4188                                                name,
4189                                                arguments,
4190                                            },
4191                                        ))
4192                                    }
4193                                    None => None,
4194                                },
4195                            })
4196                            .collect(),
4197                    }),
4198                    finish_reason: choice.finish_reason,
4199                })
4200                .collect(),
4201        })?;
4202    }
4203
4204    Ok(())
4205}
4206
4207async fn complete_with_google_ai(
4208    request: proto::CompleteWithLanguageModel,
4209    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4210    session: UserSession,
4211    api_key: Arc<str>,
4212) -> Result<()> {
4213    let mut stream = google_ai::stream_generate_content(
4214        &session.http_client,
4215        google_ai::API_URL,
4216        api_key.as_ref(),
4217        crate::ai::language_model_request_to_google_ai(request)?,
4218    )
4219    .await
4220    .context("google_ai::stream_generate_content request failed")?;
4221
4222    while let Some(event) = stream.next().await {
4223        let event = event?;
4224        response.send(proto::LanguageModelResponse {
4225            choices: event
4226                .candidates
4227                .unwrap_or_default()
4228                .into_iter()
4229                .map(|candidate| proto::LanguageModelChoiceDelta {
4230                    index: candidate.index as u32,
4231                    delta: Some(proto::LanguageModelResponseMessage {
4232                        role: Some(match candidate.content.role {
4233                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4234                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4235                        } as i32),
4236                        content: Some(
4237                            candidate
4238                                .content
4239                                .parts
4240                                .into_iter()
4241                                .filter_map(|part| match part {
4242                                    google_ai::Part::TextPart(part) => Some(part.text),
4243                                    google_ai::Part::InlineDataPart(_) => None,
4244                                })
4245                                .collect(),
4246                        ),
4247                        // Tool calls are not supported for Google
4248                        tool_calls: Vec::new(),
4249                    }),
4250                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4251                })
4252                .collect(),
4253        })?;
4254    }
4255
4256    Ok(())
4257}
4258
4259async fn complete_with_anthropic(
4260    request: proto::CompleteWithLanguageModel,
4261    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4262    session: UserSession,
4263    api_key: Arc<str>,
4264) -> Result<()> {
4265    let model = anthropic::Model::from_id(&request.model)?;
4266
4267    let mut system_message = String::new();
4268    let messages = request
4269        .messages
4270        .into_iter()
4271        .filter_map(|message| {
4272            match message.role() {
4273                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4274                    role: anthropic::Role::User,
4275                    content: message.content,
4276                }),
4277                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4278                    role: anthropic::Role::Assistant,
4279                    content: message.content,
4280                }),
4281                // Anthropic's API breaks system instructions out as a separate field rather
4282                // than having a system message role.
4283                LanguageModelRole::LanguageModelSystem => {
4284                    if !system_message.is_empty() {
4285                        system_message.push_str("\n\n");
4286                    }
4287                    system_message.push_str(&message.content);
4288
4289                    None
4290                }
4291                // We don't yet support tool calls for Anthropic
4292                LanguageModelRole::LanguageModelTool => None,
4293            }
4294        })
4295        .collect();
4296
4297    let mut stream = anthropic::stream_completion(
4298        &session.http_client,
4299        "https://api.anthropic.com",
4300        &api_key,
4301        anthropic::Request {
4302            model,
4303            messages,
4304            stream: true,
4305            system: system_message,
4306            max_tokens: 4092,
4307        },
4308    )
4309    .await?;
4310
4311    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4312
4313    while let Some(event) = stream.next().await {
4314        let event = event?;
4315
4316        match event {
4317            anthropic::ResponseEvent::MessageStart { message } => {
4318                if let Some(role) = message.role {
4319                    if role == "assistant" {
4320                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4321                    } else if role == "user" {
4322                        current_role = proto::LanguageModelRole::LanguageModelUser;
4323                    }
4324                }
4325            }
4326            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4327                match content_block {
4328                    anthropic::ContentBlock::Text { text } => {
4329                        if !text.is_empty() {
4330                            response.send(proto::LanguageModelResponse {
4331                                choices: vec![proto::LanguageModelChoiceDelta {
4332                                    index: 0,
4333                                    delta: Some(proto::LanguageModelResponseMessage {
4334                                        role: Some(current_role as i32),
4335                                        content: Some(text),
4336                                        tool_calls: Vec::new(),
4337                                    }),
4338                                    finish_reason: None,
4339                                }],
4340                            })?;
4341                        }
4342                    }
4343                }
4344            }
4345            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4346                anthropic::TextDelta::TextDelta { text } => {
4347                    response.send(proto::LanguageModelResponse {
4348                        choices: vec![proto::LanguageModelChoiceDelta {
4349                            index: 0,
4350                            delta: Some(proto::LanguageModelResponseMessage {
4351                                role: Some(current_role as i32),
4352                                content: Some(text),
4353                                tool_calls: Vec::new(),
4354                            }),
4355                            finish_reason: None,
4356                        }],
4357                    })?;
4358                }
4359            },
4360            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4361                if let Some(stop_reason) = delta.stop_reason {
4362                    response.send(proto::LanguageModelResponse {
4363                        choices: vec![proto::LanguageModelChoiceDelta {
4364                            index: 0,
4365                            delta: None,
4366                            finish_reason: Some(stop_reason),
4367                        }],
4368                    })?;
4369                }
4370            }
4371            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4372            anthropic::ResponseEvent::MessageStop {} => {}
4373            anthropic::ResponseEvent::Ping {} => {}
4374        }
4375    }
4376
4377    Ok(())
4378}
4379
4380struct CountTokensWithLanguageModelRateLimit;
4381
4382impl RateLimit for CountTokensWithLanguageModelRateLimit {
4383    fn capacity() -> usize {
4384        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4385            .ok()
4386            .and_then(|v| v.parse().ok())
4387            .unwrap_or(600) // Picked arbitrarily
4388    }
4389
4390    fn refill_duration() -> chrono::Duration {
4391        chrono::Duration::hours(1)
4392    }
4393
4394    fn db_name() -> &'static str {
4395        "count-tokens-with-language-model"
4396    }
4397}
4398
4399async fn count_tokens_with_language_model(
4400    request: proto::CountTokensWithLanguageModel,
4401    response: Response<proto::CountTokensWithLanguageModel>,
4402    session: UserSession,
4403    google_ai_api_key: Option<Arc<str>>,
4404) -> Result<()> {
4405    authorize_access_to_language_models(&session).await?;
4406
4407    if !request.model.starts_with("gemini") {
4408        return Err(anyhow!(
4409            "counting tokens for model: {:?} is not supported",
4410            request.model
4411        ))?;
4412    }
4413
4414    session
4415        .rate_limiter
4416        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4417        .await?;
4418
4419    let api_key = google_ai_api_key
4420        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4421    let tokens_response = google_ai::count_tokens(
4422        &session.http_client,
4423        google_ai::API_URL,
4424        &api_key,
4425        crate::ai::count_tokens_request_to_google_ai(request)?,
4426    )
4427    .await?;
4428    response.send(proto::CountTokensResponse {
4429        token_count: tokens_response.total_tokens as u32,
4430    })?;
4431    Ok(())
4432}
4433
4434struct ComputeEmbeddingsRateLimit;
4435
4436impl RateLimit for ComputeEmbeddingsRateLimit {
4437    fn capacity() -> usize {
4438        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4439            .ok()
4440            .and_then(|v| v.parse().ok())
4441            .unwrap_or(120) // Picked arbitrarily
4442    }
4443
4444    fn refill_duration() -> chrono::Duration {
4445        chrono::Duration::hours(1)
4446    }
4447
4448    fn db_name() -> &'static str {
4449        "compute-embeddings"
4450    }
4451}
4452
4453async fn compute_embeddings(
4454    request: proto::ComputeEmbeddings,
4455    response: Response<proto::ComputeEmbeddings>,
4456    session: UserSession,
4457    api_key: Option<Arc<str>>,
4458) -> Result<()> {
4459    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4460    authorize_access_to_language_models(&session).await?;
4461
4462    session
4463        .rate_limiter
4464        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4465        .await?;
4466
4467    let embeddings = match request.model.as_str() {
4468        "openai/text-embedding-3-small" => {
4469            open_ai::embed(
4470                &session.http_client,
4471                OPEN_AI_API_URL,
4472                &api_key,
4473                OpenAiEmbeddingModel::TextEmbedding3Small,
4474                request.texts.iter().map(|text| text.as_str()),
4475            )
4476            .await?
4477        }
4478        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4479    };
4480
4481    let embeddings = request
4482        .texts
4483        .iter()
4484        .map(|text| {
4485            let mut hasher = sha2::Sha256::new();
4486            hasher.update(text.as_bytes());
4487            let result = hasher.finalize();
4488            result.to_vec()
4489        })
4490        .zip(
4491            embeddings
4492                .data
4493                .into_iter()
4494                .map(|embedding| embedding.embedding),
4495        )
4496        .collect::<HashMap<_, _>>();
4497
4498    let db = session.db().await;
4499    db.save_embeddings(&request.model, &embeddings)
4500        .await
4501        .context("failed to save embeddings")
4502        .trace_err();
4503
4504    response.send(proto::ComputeEmbeddingsResponse {
4505        embeddings: embeddings
4506            .into_iter()
4507            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4508            .collect(),
4509    })?;
4510    Ok(())
4511}
4512
4513struct GetCachedEmbeddingsRateLimit;
4514
4515impl RateLimit for GetCachedEmbeddingsRateLimit {
4516    fn capacity() -> usize {
4517        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4518            .ok()
4519            .and_then(|v| v.parse().ok())
4520            .unwrap_or(120) // Picked arbitrarily
4521    }
4522
4523    fn refill_duration() -> chrono::Duration {
4524        chrono::Duration::hours(1)
4525    }
4526
4527    fn db_name() -> &'static str {
4528        "get-cached-embeddings"
4529    }
4530}
4531
4532async fn get_cached_embeddings(
4533    request: proto::GetCachedEmbeddings,
4534    response: Response<proto::GetCachedEmbeddings>,
4535    session: UserSession,
4536) -> Result<()> {
4537    authorize_access_to_language_models(&session).await?;
4538
4539    session
4540        .rate_limiter
4541        .check::<GetCachedEmbeddingsRateLimit>(session.user_id())
4542        .await?;
4543
4544    let db = session.db().await;
4545    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4546
4547    response.send(proto::GetCachedEmbeddingsResponse {
4548        embeddings: embeddings
4549            .into_iter()
4550            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4551            .collect(),
4552    })?;
4553    Ok(())
4554}
4555
4556async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4557    let db = session.db().await;
4558    let flags = db.get_user_flags(session.user_id()).await?;
4559    if flags.iter().any(|flag| flag == "language-models") {
4560        Ok(())
4561    } else {
4562        Err(anyhow!("permission denied"))?
4563    }
4564}
4565
4566/// Start receiving chat updates for a channel
4567async fn join_channel_chat(
4568    request: proto::JoinChannelChat,
4569    response: Response<proto::JoinChannelChat>,
4570    session: UserSession,
4571) -> Result<()> {
4572    let channel_id = ChannelId::from_proto(request.channel_id);
4573
4574    let db = session.db().await;
4575    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4576        .await?;
4577    let messages = db
4578        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4579        .await?;
4580    response.send(proto::JoinChannelChatResponse {
4581        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4582        messages,
4583    })?;
4584    Ok(())
4585}
4586
4587/// Stop receiving chat updates for a channel
4588async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4589    let channel_id = ChannelId::from_proto(request.channel_id);
4590    session
4591        .db()
4592        .await
4593        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4594        .await?;
4595    Ok(())
4596}
4597
4598/// Retrieve the chat history for a channel
4599async fn get_channel_messages(
4600    request: proto::GetChannelMessages,
4601    response: Response<proto::GetChannelMessages>,
4602    session: UserSession,
4603) -> Result<()> {
4604    let channel_id = ChannelId::from_proto(request.channel_id);
4605    let messages = session
4606        .db()
4607        .await
4608        .get_channel_messages(
4609            channel_id,
4610            session.user_id(),
4611            MESSAGE_COUNT_PER_PAGE,
4612            Some(MessageId::from_proto(request.before_message_id)),
4613        )
4614        .await?;
4615    response.send(proto::GetChannelMessagesResponse {
4616        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4617        messages,
4618    })?;
4619    Ok(())
4620}
4621
4622/// Retrieve specific chat messages
4623async fn get_channel_messages_by_id(
4624    request: proto::GetChannelMessagesById,
4625    response: Response<proto::GetChannelMessagesById>,
4626    session: UserSession,
4627) -> Result<()> {
4628    let message_ids = request
4629        .message_ids
4630        .iter()
4631        .map(|id| MessageId::from_proto(*id))
4632        .collect::<Vec<_>>();
4633    let messages = session
4634        .db()
4635        .await
4636        .get_channel_messages_by_id(session.user_id(), &message_ids)
4637        .await?;
4638    response.send(proto::GetChannelMessagesResponse {
4639        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4640        messages,
4641    })?;
4642    Ok(())
4643}
4644
4645/// Retrieve the current users notifications
4646async fn get_notifications(
4647    request: proto::GetNotifications,
4648    response: Response<proto::GetNotifications>,
4649    session: UserSession,
4650) -> Result<()> {
4651    let notifications = session
4652        .db()
4653        .await
4654        .get_notifications(
4655            session.user_id(),
4656            NOTIFICATION_COUNT_PER_PAGE,
4657            request
4658                .before_id
4659                .map(|id| db::NotificationId::from_proto(id)),
4660        )
4661        .await?;
4662    response.send(proto::GetNotificationsResponse {
4663        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4664        notifications,
4665    })?;
4666    Ok(())
4667}
4668
4669/// Mark notifications as read
4670async fn mark_notification_as_read(
4671    request: proto::MarkNotificationRead,
4672    response: Response<proto::MarkNotificationRead>,
4673    session: UserSession,
4674) -> Result<()> {
4675    let database = &session.db().await;
4676    let notifications = database
4677        .mark_notification_as_read_by_id(
4678            session.user_id(),
4679            NotificationId::from_proto(request.notification_id),
4680        )
4681        .await?;
4682    send_notifications(
4683        &*session.connection_pool().await,
4684        &session.peer,
4685        notifications,
4686    );
4687    response.send(proto::Ack {})?;
4688    Ok(())
4689}
4690
4691/// Get the current users information
4692async fn get_private_user_info(
4693    _request: proto::GetPrivateUserInfo,
4694    response: Response<proto::GetPrivateUserInfo>,
4695    session: UserSession,
4696) -> Result<()> {
4697    let db = session.db().await;
4698
4699    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4700    let user = db
4701        .get_user_by_id(session.user_id())
4702        .await?
4703        .ok_or_else(|| anyhow!("user not found"))?;
4704    let flags = db.get_user_flags(session.user_id()).await?;
4705
4706    response.send(proto::GetPrivateUserInfoResponse {
4707        metrics_id,
4708        staff: user.admin,
4709        flags,
4710    })?;
4711    Ok(())
4712}
4713
4714fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4715    match message {
4716        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4717        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4718        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4719        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4720        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4721            code: frame.code.into(),
4722            reason: frame.reason,
4723        })),
4724    }
4725}
4726
4727fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4728    match message {
4729        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4730        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4731        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4732        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4733        AxumMessage::Close(frame) => {
4734            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4735                code: frame.code.into(),
4736                reason: frame.reason,
4737            }))
4738        }
4739    }
4740}
4741
4742fn notify_membership_updated(
4743    connection_pool: &mut ConnectionPool,
4744    result: MembershipUpdated,
4745    user_id: UserId,
4746    peer: &Peer,
4747) {
4748    for membership in &result.new_channels.channel_memberships {
4749        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4750    }
4751    for channel_id in &result.removed_channels {
4752        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4753    }
4754
4755    let user_channels_update = proto::UpdateUserChannels {
4756        channel_memberships: result
4757            .new_channels
4758            .channel_memberships
4759            .iter()
4760            .map(|cm| proto::ChannelMembership {
4761                channel_id: cm.channel_id.to_proto(),
4762                role: cm.role.into(),
4763            })
4764            .collect(),
4765        ..Default::default()
4766    };
4767
4768    let mut update = build_channels_update(result.new_channels, vec![]);
4769    update.delete_channels = result
4770        .removed_channels
4771        .into_iter()
4772        .map(|id| id.to_proto())
4773        .collect();
4774    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4775
4776    for connection_id in connection_pool.user_connection_ids(user_id) {
4777        peer.send(connection_id, user_channels_update.clone())
4778            .trace_err();
4779        peer.send(connection_id, update.clone()).trace_err();
4780    }
4781}
4782
4783fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4784    proto::UpdateUserChannels {
4785        channel_memberships: channels
4786            .channel_memberships
4787            .iter()
4788            .map(|m| proto::ChannelMembership {
4789                channel_id: m.channel_id.to_proto(),
4790                role: m.role.into(),
4791            })
4792            .collect(),
4793        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4794        observed_channel_message_id: channels.observed_channel_messages.clone(),
4795    }
4796}
4797
4798fn build_channels_update(
4799    channels: ChannelsForUser,
4800    channel_invites: Vec<db::Channel>,
4801) -> proto::UpdateChannels {
4802    let mut update = proto::UpdateChannels::default();
4803
4804    for channel in channels.channels {
4805        update.channels.push(channel.to_proto());
4806    }
4807
4808    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4809    update.latest_channel_message_ids = channels.latest_channel_messages;
4810
4811    for (channel_id, participants) in channels.channel_participants {
4812        update
4813            .channel_participants
4814            .push(proto::ChannelParticipants {
4815                channel_id: channel_id.to_proto(),
4816                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4817            });
4818    }
4819
4820    for channel in channel_invites {
4821        update.channel_invitations.push(channel.to_proto());
4822    }
4823
4824    update.hosted_projects = channels.hosted_projects;
4825    update
4826}
4827
4828fn build_initial_contacts_update(
4829    contacts: Vec<db::Contact>,
4830    pool: &ConnectionPool,
4831) -> proto::UpdateContacts {
4832    let mut update = proto::UpdateContacts::default();
4833
4834    for contact in contacts {
4835        match contact {
4836            db::Contact::Accepted { user_id, busy } => {
4837                update.contacts.push(contact_for_user(user_id, busy, &pool));
4838            }
4839            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4840            db::Contact::Incoming { user_id } => {
4841                update
4842                    .incoming_requests
4843                    .push(proto::IncomingContactRequest {
4844                        requester_id: user_id.to_proto(),
4845                    })
4846            }
4847        }
4848    }
4849
4850    update
4851}
4852
4853fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4854    proto::Contact {
4855        user_id: user_id.to_proto(),
4856        online: pool.is_user_online(user_id),
4857        busy,
4858    }
4859}
4860
4861fn room_updated(room: &proto::Room, peer: &Peer) {
4862    broadcast(
4863        None,
4864        room.participants
4865            .iter()
4866            .filter_map(|participant| Some(participant.peer_id?.into())),
4867        |peer_id| {
4868            peer.send(
4869                peer_id,
4870                proto::RoomUpdated {
4871                    room: Some(room.clone()),
4872                },
4873            )
4874        },
4875    );
4876}
4877
4878fn channel_updated(
4879    channel: &db::channel::Model,
4880    room: &proto::Room,
4881    peer: &Peer,
4882    pool: &ConnectionPool,
4883) {
4884    let participants = room
4885        .participants
4886        .iter()
4887        .map(|p| p.user_id)
4888        .collect::<Vec<_>>();
4889
4890    broadcast(
4891        None,
4892        pool.channel_connection_ids(channel.root_id())
4893            .filter_map(|(channel_id, role)| {
4894                role.can_see_channel(channel.visibility).then(|| channel_id)
4895            }),
4896        |peer_id| {
4897            peer.send(
4898                peer_id,
4899                proto::UpdateChannels {
4900                    channel_participants: vec![proto::ChannelParticipants {
4901                        channel_id: channel.id.to_proto(),
4902                        participant_user_ids: participants.clone(),
4903                    }],
4904                    ..Default::default()
4905                },
4906            )
4907        },
4908    );
4909}
4910
4911async fn send_remote_projects_update(
4912    user_id: UserId,
4913    mut status: proto::RemoteProjectsUpdate,
4914    session: &Session,
4915) {
4916    let pool = session.connection_pool().await;
4917    for dev_server in &mut status.dev_servers {
4918        dev_server.status =
4919            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
4920    }
4921    let connections = pool.user_connection_ids(user_id);
4922    for connection_id in connections {
4923        session.peer.send(connection_id, status.clone()).trace_err();
4924    }
4925}
4926
4927async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4928    let db = session.db().await;
4929
4930    let contacts = db.get_contacts(user_id).await?;
4931    let busy = db.is_user_busy(user_id).await?;
4932
4933    let pool = session.connection_pool().await;
4934    let updated_contact = contact_for_user(user_id, busy, &pool);
4935    for contact in contacts {
4936        if let db::Contact::Accepted {
4937            user_id: contact_user_id,
4938            ..
4939        } = contact
4940        {
4941            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4942                session
4943                    .peer
4944                    .send(
4945                        contact_conn_id,
4946                        proto::UpdateContacts {
4947                            contacts: vec![updated_contact.clone()],
4948                            remove_contacts: Default::default(),
4949                            incoming_requests: Default::default(),
4950                            remove_incoming_requests: Default::default(),
4951                            outgoing_requests: Default::default(),
4952                            remove_outgoing_requests: Default::default(),
4953                        },
4954                    )
4955                    .trace_err();
4956            }
4957        }
4958    }
4959    Ok(())
4960}
4961
4962async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
4963    log::info!("lost dev server connection, unsharing projects");
4964    let project_ids = session
4965        .db()
4966        .await
4967        .get_stale_dev_server_projects(session.connection_id)
4968        .await?;
4969
4970    for project_id in project_ids {
4971        // not unshare re-checks the connection ids match, so we get away with no transaction
4972        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
4973    }
4974
4975    let user_id = session.dev_server().user_id;
4976    let update = session.db().await.remote_projects_update(user_id).await?;
4977
4978    send_remote_projects_update(user_id, update, session).await;
4979
4980    Ok(())
4981}
4982
4983async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4984    let mut contacts_to_update = HashSet::default();
4985
4986    let room_id;
4987    let canceled_calls_to_user_ids;
4988    let live_kit_room;
4989    let delete_live_kit_room;
4990    let room;
4991    let channel;
4992
4993    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4994        contacts_to_update.insert(session.user_id());
4995
4996        for project in left_room.left_projects.values() {
4997            project_left(project, session);
4998        }
4999
5000        room_id = RoomId::from_proto(left_room.room.id);
5001        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5002        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5003        delete_live_kit_room = left_room.deleted;
5004        room = mem::take(&mut left_room.room);
5005        channel = mem::take(&mut left_room.channel);
5006
5007        room_updated(&room, &session.peer);
5008    } else {
5009        return Ok(());
5010    }
5011
5012    if let Some(channel) = channel {
5013        channel_updated(
5014            &channel,
5015            &room,
5016            &session.peer,
5017            &*session.connection_pool().await,
5018        );
5019    }
5020
5021    {
5022        let pool = session.connection_pool().await;
5023        for canceled_user_id in canceled_calls_to_user_ids {
5024            for connection_id in pool.user_connection_ids(canceled_user_id) {
5025                session
5026                    .peer
5027                    .send(
5028                        connection_id,
5029                        proto::CallCanceled {
5030                            room_id: room_id.to_proto(),
5031                        },
5032                    )
5033                    .trace_err();
5034            }
5035            contacts_to_update.insert(canceled_user_id);
5036        }
5037    }
5038
5039    for contact_user_id in contacts_to_update {
5040        update_user_contacts(contact_user_id, &session).await?;
5041    }
5042
5043    if let Some(live_kit) = session.live_kit_client.as_ref() {
5044        live_kit
5045            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5046            .await
5047            .trace_err();
5048
5049        if delete_live_kit_room {
5050            live_kit.delete_room(live_kit_room).await.trace_err();
5051        }
5052    }
5053
5054    Ok(())
5055}
5056
5057async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5058    let left_channel_buffers = session
5059        .db()
5060        .await
5061        .leave_channel_buffers(session.connection_id)
5062        .await?;
5063
5064    for left_buffer in left_channel_buffers {
5065        channel_buffer_updated(
5066            session.connection_id,
5067            left_buffer.connections,
5068            &proto::UpdateChannelBufferCollaborators {
5069                channel_id: left_buffer.channel_id.to_proto(),
5070                collaborators: left_buffer.collaborators,
5071            },
5072            &session.peer,
5073        );
5074    }
5075
5076    Ok(())
5077}
5078
5079fn project_left(project: &db::LeftProject, session: &UserSession) {
5080    for connection_id in &project.connection_ids {
5081        if project.should_unshare {
5082            session
5083                .peer
5084                .send(
5085                    *connection_id,
5086                    proto::UnshareProject {
5087                        project_id: project.id.to_proto(),
5088                    },
5089                )
5090                .trace_err();
5091        } else {
5092            session
5093                .peer
5094                .send(
5095                    *connection_id,
5096                    proto::RemoveProjectCollaborator {
5097                        project_id: project.id.to_proto(),
5098                        peer_id: Some(session.connection_id.into()),
5099                    },
5100                )
5101                .trace_err();
5102        }
5103    }
5104}
5105
5106pub trait ResultExt {
5107    type Ok;
5108
5109    fn trace_err(self) -> Option<Self::Ok>;
5110}
5111
5112impl<T, E> ResultExt for Result<T, E>
5113where
5114    E: std::fmt::Debug,
5115{
5116    type Ok = T;
5117
5118    #[track_caller]
5119    fn trace_err(self) -> Option<T> {
5120        match self {
5121            Ok(value) => Some(value),
5122            Err(error) => {
5123                tracing::error!("{:?}", error);
5124                None
5125            }
5126        }
5127    }
5128}