rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
   8        MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
   9        RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37
  38use futures::{
  39    channel::oneshot,
  40    future::{self, BoxFuture},
  41    stream::FuturesUnordered,
  42    FutureExt, SinkExt, StreamExt, TryStreamExt,
  43};
  44use prometheus::{register_int_gauge, IntGauge};
  45use rpc::{
  46    proto::{
  47        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  48        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  49    },
  50    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  51};
  52use semantic_version::SemanticVersion;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        atomic::{AtomicBool, Ordering::SeqCst},
  64        Arc, OnceLock,
  65    },
  66    time::{Duration, Instant},
  67};
  68use time::OffsetDateTime;
  69use tokio::sync::{watch, Semaphore};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    field::{self},
  73    info_span, instrument, Instrument,
  74};
  75use util::http::IsahcHttpClient;
  76
  77use self::connection_pool::VersionedMessage;
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105struct StreamingResponse<R: RequestMessage> {
 106    peer: Arc<Peer>,
 107    receipt: Receipt<R>,
 108}
 109
 110impl<R: RequestMessage> StreamingResponse<R> {
 111    fn send(&self, payload: R::Response) -> Result<()> {
 112        self.peer.respond(self.receipt, payload)?;
 113        Ok(())
 114    }
 115}
 116
 117#[derive(Clone, Debug)]
 118pub enum Principal {
 119    User(User),
 120    Impersonated { user: User, admin: User },
 121    DevServer(dev_server::Model),
 122}
 123
 124impl Principal {
 125    fn update_span(&self, span: &tracing::Span) {
 126        match &self {
 127            Principal::User(user) => {
 128                span.record("user_id", &user.id.0);
 129                span.record("login", &user.github_login);
 130            }
 131            Principal::Impersonated { user, admin } => {
 132                span.record("user_id", &user.id.0);
 133                span.record("login", &user.github_login);
 134                span.record("impersonator", &admin.github_login);
 135            }
 136            Principal::DevServer(dev_server) => {
 137                span.record("dev_server_id", &dev_server.id.0);
 138            }
 139        }
 140    }
 141}
 142
 143#[derive(Clone)]
 144struct Session {
 145    principal: Principal,
 146    connection_id: ConnectionId,
 147    db: Arc<tokio::sync::Mutex<DbHandle>>,
 148    peer: Arc<Peer>,
 149    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 150    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 151    http_client: IsahcHttpClient,
 152    rate_limiter: Arc<RateLimiter>,
 153    _executor: Executor,
 154}
 155
 156impl Session {
 157    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 158        #[cfg(test)]
 159        tokio::task::yield_now().await;
 160        let guard = self.db.lock().await;
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        guard
 164    }
 165
 166    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 167        #[cfg(test)]
 168        tokio::task::yield_now().await;
 169        let guard = self.connection_pool.lock();
 170        ConnectionPoolGuard {
 171            guard,
 172            _not_send: PhantomData,
 173        }
 174    }
 175
 176    fn for_user(self) -> Option<UserSession> {
 177        UserSession::new(self)
 178    }
 179
 180    fn for_dev_server(self) -> Option<DevServerSession> {
 181        DevServerSession::new(self)
 182    }
 183
 184    fn user_id(&self) -> Option<UserId> {
 185        match &self.principal {
 186            Principal::User(user) => Some(user.id),
 187            Principal::Impersonated { user, .. } => Some(user.id),
 188            Principal::DevServer(_) => None,
 189        }
 190    }
 191
 192    fn dev_server_id(&self) -> Option<DevServerId> {
 193        match &self.principal {
 194            Principal::User(_) | Principal::Impersonated { .. } => None,
 195            Principal::DevServer(dev_server) => Some(dev_server.id),
 196        }
 197    }
 198
 199    fn principal_id(&self) -> PrincipalId {
 200        match &self.principal {
 201            Principal::User(user) => PrincipalId::UserId(user.id),
 202            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 203            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 204        }
 205    }
 206}
 207
 208impl Debug for Session {
 209    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 210        let mut result = f.debug_struct("Session");
 211        match &self.principal {
 212            Principal::User(user) => {
 213                result.field("user", &user.github_login);
 214            }
 215            Principal::Impersonated { user, admin } => {
 216                result.field("user", &user.github_login);
 217                result.field("impersonator", &admin.github_login);
 218            }
 219            Principal::DevServer(dev_server) => {
 220                result.field("dev_server", &dev_server.id);
 221            }
 222        }
 223        result.field("connection_id", &self.connection_id).finish()
 224    }
 225}
 226
 227struct UserSession(Session);
 228
 229impl UserSession {
 230    pub fn new(s: Session) -> Option<Self> {
 231        s.user_id().map(|_| UserSession(s))
 232    }
 233    pub fn user_id(&self) -> UserId {
 234        self.0.user_id().unwrap()
 235    }
 236}
 237
 238impl Deref for UserSession {
 239    type Target = Session;
 240
 241    fn deref(&self) -> &Self::Target {
 242        &self.0
 243    }
 244}
 245impl DerefMut for UserSession {
 246    fn deref_mut(&mut self) -> &mut Self::Target {
 247        &mut self.0
 248    }
 249}
 250
 251struct DevServerSession(Session);
 252
 253impl DevServerSession {
 254    pub fn new(s: Session) -> Option<Self> {
 255        s.dev_server_id().map(|_| DevServerSession(s))
 256    }
 257    pub fn dev_server_id(&self) -> DevServerId {
 258        self.0.dev_server_id().unwrap()
 259    }
 260
 261    fn dev_server(&self) -> &dev_server::Model {
 262        match &self.0.principal {
 263            Principal::DevServer(dev_server) => dev_server,
 264            _ => unreachable!(),
 265        }
 266    }
 267}
 268
 269impl Deref for DevServerSession {
 270    type Target = Session;
 271
 272    fn deref(&self) -> &Self::Target {
 273        &self.0
 274    }
 275}
 276impl DerefMut for DevServerSession {
 277    fn deref_mut(&mut self) -> &mut Self::Target {
 278        &mut self.0
 279    }
 280}
 281
 282fn user_handler<M: RequestMessage, Fut>(
 283    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 284) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 285where
 286    Fut: Send + Future<Output = Result<()>>,
 287{
 288    let handler = Arc::new(handler);
 289    move |message, response, session| {
 290        let handler = handler.clone();
 291        Box::pin(async move {
 292            if let Some(user_session) = session.for_user() {
 293                Ok(handler(message, response, user_session).await?)
 294            } else {
 295                Err(Error::Internal(anyhow!(
 296                    "must be a user to call {}",
 297                    M::NAME
 298                )))
 299            }
 300        })
 301    }
 302}
 303
 304fn dev_server_handler<M: RequestMessage, Fut>(
 305    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 306) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 307where
 308    Fut: Send + Future<Output = Result<()>>,
 309{
 310    let handler = Arc::new(handler);
 311    move |message, response, session| {
 312        let handler = handler.clone();
 313        Box::pin(async move {
 314            if let Some(dev_server_session) = session.for_dev_server() {
 315                Ok(handler(message, response, dev_server_session).await?)
 316            } else {
 317                Err(Error::Internal(anyhow!(
 318                    "must be a dev server to call {}",
 319                    M::NAME
 320                )))
 321            }
 322        })
 323    }
 324}
 325
 326fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 327    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 328) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 329where
 330    InnertRetFut: Send + Future<Output = Result<()>>,
 331{
 332    let handler = Arc::new(handler);
 333    move |message, session| {
 334        let handler = handler.clone();
 335        Box::pin(async move {
 336            if let Some(user_session) = session.for_user() {
 337                Ok(handler(message, user_session).await?)
 338            } else {
 339                Err(Error::Internal(anyhow!(
 340                    "must be a user to call {}",
 341                    M::NAME
 342                )))
 343            }
 344        })
 345    }
 346}
 347
 348struct DbHandle(Arc<Database>);
 349
 350impl Deref for DbHandle {
 351    type Target = Database;
 352
 353    fn deref(&self) -> &Self::Target {
 354        self.0.as_ref()
 355    }
 356}
 357
 358pub struct Server {
 359    id: parking_lot::Mutex<ServerId>,
 360    peer: Arc<Peer>,
 361    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 362    app_state: Arc<AppState>,
 363    handlers: HashMap<TypeId, MessageHandler>,
 364    teardown: watch::Sender<bool>,
 365}
 366
 367pub(crate) struct ConnectionPoolGuard<'a> {
 368    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 369    _not_send: PhantomData<Rc<()>>,
 370}
 371
 372#[derive(Serialize)]
 373pub struct ServerSnapshot<'a> {
 374    peer: &'a Peer,
 375    #[serde(serialize_with = "serialize_deref")]
 376    connection_pool: ConnectionPoolGuard<'a>,
 377}
 378
 379pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 380where
 381    S: Serializer,
 382    T: Deref<Target = U>,
 383    U: Serialize,
 384{
 385    Serialize::serialize(value.deref(), serializer)
 386}
 387
 388impl Server {
 389    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 390        let mut server = Self {
 391            id: parking_lot::Mutex::new(id),
 392            peer: Peer::new(id.0 as u32),
 393            app_state: app_state.clone(),
 394            connection_pool: Default::default(),
 395            handlers: Default::default(),
 396            teardown: watch::channel(false).0,
 397        };
 398
 399        server
 400            .add_request_handler(ping)
 401            .add_request_handler(user_handler(create_room))
 402            .add_request_handler(user_handler(join_room))
 403            .add_request_handler(user_handler(rejoin_room))
 404            .add_request_handler(user_handler(leave_room))
 405            .add_request_handler(user_handler(set_room_participant_role))
 406            .add_request_handler(user_handler(call))
 407            .add_request_handler(user_handler(cancel_call))
 408            .add_message_handler(user_message_handler(decline_call))
 409            .add_request_handler(user_handler(update_participant_location))
 410            .add_request_handler(user_handler(share_project))
 411            .add_message_handler(unshare_project)
 412            .add_request_handler(user_handler(join_project))
 413            .add_request_handler(user_handler(join_hosted_project))
 414            .add_request_handler(user_handler(rejoin_dev_server_projects))
 415            .add_request_handler(user_handler(create_dev_server_project))
 416            .add_request_handler(user_handler(delete_dev_server_project))
 417            .add_request_handler(user_handler(create_dev_server))
 418            .add_request_handler(user_handler(delete_dev_server))
 419            .add_request_handler(dev_server_handler(share_dev_server_project))
 420            .add_request_handler(dev_server_handler(shutdown_dev_server))
 421            .add_request_handler(dev_server_handler(reconnect_dev_server))
 422            .add_message_handler(user_message_handler(leave_project))
 423            .add_request_handler(update_project)
 424            .add_request_handler(update_worktree)
 425            .add_message_handler(start_language_server)
 426            .add_message_handler(update_language_server)
 427            .add_message_handler(update_diagnostic_summary)
 428            .add_message_handler(update_worktree_settings)
 429            .add_request_handler(user_handler(
 430                forward_read_only_project_request::<proto::GetHover>,
 431            ))
 432            .add_request_handler(user_handler(
 433                forward_read_only_project_request::<proto::GetDefinition>,
 434            ))
 435            .add_request_handler(user_handler(
 436                forward_read_only_project_request::<proto::GetTypeDefinition>,
 437            ))
 438            .add_request_handler(user_handler(
 439                forward_read_only_project_request::<proto::GetReferences>,
 440            ))
 441            .add_request_handler(user_handler(
 442                forward_read_only_project_request::<proto::SearchProject>,
 443            ))
 444            .add_request_handler(user_handler(
 445                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 446            ))
 447            .add_request_handler(user_handler(
 448                forward_read_only_project_request::<proto::GetProjectSymbols>,
 449            ))
 450            .add_request_handler(user_handler(
 451                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 452            ))
 453            .add_request_handler(user_handler(
 454                forward_read_only_project_request::<proto::OpenBufferById>,
 455            ))
 456            .add_request_handler(user_handler(
 457                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 458            ))
 459            .add_request_handler(user_handler(
 460                forward_read_only_project_request::<proto::InlayHints>,
 461            ))
 462            .add_request_handler(user_handler(
 463                forward_read_only_project_request::<proto::OpenBufferByPath>,
 464            ))
 465            .add_request_handler(user_handler(
 466                forward_mutating_project_request::<proto::GetCompletions>,
 467            ))
 468            .add_request_handler(user_handler(
 469                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 470            ))
 471            .add_request_handler(user_handler(
 472                forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
 473            ))
 474            .add_request_handler(user_handler(
 475                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 476            ))
 477            .add_request_handler(user_handler(
 478                forward_mutating_project_request::<proto::GetCodeActions>,
 479            ))
 480            .add_request_handler(user_handler(
 481                forward_mutating_project_request::<proto::ApplyCodeAction>,
 482            ))
 483            .add_request_handler(user_handler(
 484                forward_mutating_project_request::<proto::PrepareRename>,
 485            ))
 486            .add_request_handler(user_handler(
 487                forward_mutating_project_request::<proto::PerformRename>,
 488            ))
 489            .add_request_handler(user_handler(
 490                forward_mutating_project_request::<proto::ReloadBuffers>,
 491            ))
 492            .add_request_handler(user_handler(
 493                forward_mutating_project_request::<proto::FormatBuffers>,
 494            ))
 495            .add_request_handler(user_handler(
 496                forward_mutating_project_request::<proto::CreateProjectEntry>,
 497            ))
 498            .add_request_handler(user_handler(
 499                forward_mutating_project_request::<proto::RenameProjectEntry>,
 500            ))
 501            .add_request_handler(user_handler(
 502                forward_mutating_project_request::<proto::CopyProjectEntry>,
 503            ))
 504            .add_request_handler(user_handler(
 505                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 506            ))
 507            .add_request_handler(user_handler(
 508                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 509            ))
 510            .add_request_handler(user_handler(
 511                forward_mutating_project_request::<proto::OnTypeFormatting>,
 512            ))
 513            .add_request_handler(user_handler(
 514                forward_versioned_mutating_project_request::<proto::SaveBuffer>,
 515            ))
 516            .add_request_handler(user_handler(
 517                forward_mutating_project_request::<proto::BlameBuffer>,
 518            ))
 519            .add_request_handler(user_handler(
 520                forward_mutating_project_request::<proto::MultiLspQuery>,
 521            ))
 522            .add_message_handler(create_buffer_for_peer)
 523            .add_request_handler(update_buffer)
 524            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 525            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 526            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 527            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 528            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 529            .add_request_handler(get_users)
 530            .add_request_handler(user_handler(fuzzy_search_users))
 531            .add_request_handler(user_handler(request_contact))
 532            .add_request_handler(user_handler(remove_contact))
 533            .add_request_handler(user_handler(respond_to_contact_request))
 534            .add_request_handler(user_handler(create_channel))
 535            .add_request_handler(user_handler(delete_channel))
 536            .add_request_handler(user_handler(invite_channel_member))
 537            .add_request_handler(user_handler(remove_channel_member))
 538            .add_request_handler(user_handler(set_channel_member_role))
 539            .add_request_handler(user_handler(set_channel_visibility))
 540            .add_request_handler(user_handler(rename_channel))
 541            .add_request_handler(user_handler(join_channel_buffer))
 542            .add_request_handler(user_handler(leave_channel_buffer))
 543            .add_message_handler(user_message_handler(update_channel_buffer))
 544            .add_request_handler(user_handler(rejoin_channel_buffers))
 545            .add_request_handler(user_handler(get_channel_members))
 546            .add_request_handler(user_handler(respond_to_channel_invite))
 547            .add_request_handler(user_handler(join_channel))
 548            .add_request_handler(user_handler(join_channel_chat))
 549            .add_message_handler(user_message_handler(leave_channel_chat))
 550            .add_request_handler(user_handler(send_channel_message))
 551            .add_request_handler(user_handler(remove_channel_message))
 552            .add_request_handler(user_handler(update_channel_message))
 553            .add_request_handler(user_handler(get_channel_messages))
 554            .add_request_handler(user_handler(get_channel_messages_by_id))
 555            .add_request_handler(user_handler(get_notifications))
 556            .add_request_handler(user_handler(mark_notification_as_read))
 557            .add_request_handler(user_handler(move_channel))
 558            .add_request_handler(user_handler(follow))
 559            .add_message_handler(user_message_handler(unfollow))
 560            .add_message_handler(user_message_handler(update_followers))
 561            .add_request_handler(user_handler(get_private_user_info))
 562            .add_message_handler(user_message_handler(acknowledge_channel_message))
 563            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 564            .add_streaming_request_handler({
 565                let app_state = app_state.clone();
 566                move |request, response, session| {
 567                    complete_with_language_model(
 568                        request,
 569                        response,
 570                        session,
 571                        app_state.config.openai_api_key.clone(),
 572                        app_state.config.google_ai_api_key.clone(),
 573                        app_state.config.anthropic_api_key.clone(),
 574                    )
 575                }
 576            })
 577            .add_request_handler({
 578                let app_state = app_state.clone();
 579                user_handler(move |request, response, session| {
 580                    count_tokens_with_language_model(
 581                        request,
 582                        response,
 583                        session,
 584                        app_state.config.google_ai_api_key.clone(),
 585                    )
 586                })
 587            })
 588            .add_request_handler({
 589                user_handler(move |request, response, session| {
 590                    get_cached_embeddings(request, response, session)
 591                })
 592            })
 593            .add_request_handler({
 594                let app_state = app_state.clone();
 595                user_handler(move |request, response, session| {
 596                    compute_embeddings(
 597                        request,
 598                        response,
 599                        session,
 600                        app_state.config.openai_api_key.clone(),
 601                    )
 602                })
 603            });
 604
 605        Arc::new(server)
 606    }
 607
 608    pub async fn start(&self) -> Result<()> {
 609        let server_id = *self.id.lock();
 610        let app_state = self.app_state.clone();
 611        let peer = self.peer.clone();
 612        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 613        let pool = self.connection_pool.clone();
 614        let live_kit_client = self.app_state.live_kit_client.clone();
 615
 616        let span = info_span!("start server");
 617        self.app_state.executor.spawn_detached(
 618            async move {
 619                tracing::info!("waiting for cleanup timeout");
 620                timeout.await;
 621                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 622                if let Some((room_ids, channel_ids)) = app_state
 623                    .db
 624                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 625                    .await
 626                    .trace_err()
 627                {
 628                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 629                    tracing::info!(
 630                        stale_channel_buffer_count = channel_ids.len(),
 631                        "retrieved stale channel buffers"
 632                    );
 633
 634                    for channel_id in channel_ids {
 635                        if let Some(refreshed_channel_buffer) = app_state
 636                            .db
 637                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 638                            .await
 639                            .trace_err()
 640                        {
 641                            for connection_id in refreshed_channel_buffer.connection_ids {
 642                                peer.send(
 643                                    connection_id,
 644                                    proto::UpdateChannelBufferCollaborators {
 645                                        channel_id: channel_id.to_proto(),
 646                                        collaborators: refreshed_channel_buffer
 647                                            .collaborators
 648                                            .clone(),
 649                                    },
 650                                )
 651                                .trace_err();
 652                            }
 653                        }
 654                    }
 655
 656                    for room_id in room_ids {
 657                        let mut contacts_to_update = HashSet::default();
 658                        let mut canceled_calls_to_user_ids = Vec::new();
 659                        let mut live_kit_room = String::new();
 660                        let mut delete_live_kit_room = false;
 661
 662                        if let Some(mut refreshed_room) = app_state
 663                            .db
 664                            .clear_stale_room_participants(room_id, server_id)
 665                            .await
 666                            .trace_err()
 667                        {
 668                            tracing::info!(
 669                                room_id = room_id.0,
 670                                new_participant_count = refreshed_room.room.participants.len(),
 671                                "refreshed room"
 672                            );
 673                            room_updated(&refreshed_room.room, &peer);
 674                            if let Some(channel) = refreshed_room.channel.as_ref() {
 675                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 676                            }
 677                            contacts_to_update
 678                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 679                            contacts_to_update
 680                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 681                            canceled_calls_to_user_ids =
 682                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 683                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 684                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 685                        }
 686
 687                        {
 688                            let pool = pool.lock();
 689                            for canceled_user_id in canceled_calls_to_user_ids {
 690                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 691                                    peer.send(
 692                                        connection_id,
 693                                        proto::CallCanceled {
 694                                            room_id: room_id.to_proto(),
 695                                        },
 696                                    )
 697                                    .trace_err();
 698                                }
 699                            }
 700                        }
 701
 702                        for user_id in contacts_to_update {
 703                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 704                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 705                            if let Some((busy, contacts)) = busy.zip(contacts) {
 706                                let pool = pool.lock();
 707                                let updated_contact = contact_for_user(user_id, busy, &pool);
 708                                for contact in contacts {
 709                                    if let db::Contact::Accepted {
 710                                        user_id: contact_user_id,
 711                                        ..
 712                                    } = contact
 713                                    {
 714                                        for contact_conn_id in
 715                                            pool.user_connection_ids(contact_user_id)
 716                                        {
 717                                            peer.send(
 718                                                contact_conn_id,
 719                                                proto::UpdateContacts {
 720                                                    contacts: vec![updated_contact.clone()],
 721                                                    remove_contacts: Default::default(),
 722                                                    incoming_requests: Default::default(),
 723                                                    remove_incoming_requests: Default::default(),
 724                                                    outgoing_requests: Default::default(),
 725                                                    remove_outgoing_requests: Default::default(),
 726                                                },
 727                                            )
 728                                            .trace_err();
 729                                        }
 730                                    }
 731                                }
 732                            }
 733                        }
 734
 735                        if let Some(live_kit) = live_kit_client.as_ref() {
 736                            if delete_live_kit_room {
 737                                live_kit.delete_room(live_kit_room).await.trace_err();
 738                            }
 739                        }
 740                    }
 741                }
 742
 743                app_state
 744                    .db
 745                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 746                    .await
 747                    .trace_err();
 748            }
 749            .instrument(span),
 750        );
 751        Ok(())
 752    }
 753
 754    pub fn teardown(&self) {
 755        self.peer.teardown();
 756        self.connection_pool.lock().reset();
 757        let _ = self.teardown.send(true);
 758    }
 759
 760    #[cfg(test)]
 761    pub fn reset(&self, id: ServerId) {
 762        self.teardown();
 763        *self.id.lock() = id;
 764        self.peer.reset(id.0 as u32);
 765        let _ = self.teardown.send(false);
 766    }
 767
 768    #[cfg(test)]
 769    pub fn id(&self) -> ServerId {
 770        *self.id.lock()
 771    }
 772
 773    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 774    where
 775        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 776        Fut: 'static + Send + Future<Output = Result<()>>,
 777        M: EnvelopedMessage,
 778    {
 779        let prev_handler = self.handlers.insert(
 780            TypeId::of::<M>(),
 781            Box::new(move |envelope, session| {
 782                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 783                let received_at = envelope.received_at;
 784                tracing::info!("message received");
 785                let start_time = Instant::now();
 786                let future = (handler)(*envelope, session);
 787                async move {
 788                    let result = future.await;
 789                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 790                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 791                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 792                    let payload_type = M::NAME;
 793
 794                    match result {
 795                        Err(error) => {
 796                            tracing::error!(
 797                                ?error,
 798                                total_duration_ms,
 799                                processing_duration_ms,
 800                                queue_duration_ms,
 801                                payload_type,
 802                                "error handling message"
 803                            )
 804                        }
 805                        Ok(()) => tracing::info!(
 806                            total_duration_ms,
 807                            processing_duration_ms,
 808                            queue_duration_ms,
 809                            "finished handling message"
 810                        ),
 811                    }
 812                }
 813                .boxed()
 814            }),
 815        );
 816        if prev_handler.is_some() {
 817            panic!("registered a handler for the same message twice");
 818        }
 819        self
 820    }
 821
 822    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 823    where
 824        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 825        Fut: 'static + Send + Future<Output = Result<()>>,
 826        M: EnvelopedMessage,
 827    {
 828        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 829        self
 830    }
 831
 832    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 833    where
 834        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 835        Fut: Send + Future<Output = Result<()>>,
 836        M: RequestMessage,
 837    {
 838        let handler = Arc::new(handler);
 839        self.add_handler(move |envelope, session| {
 840            let receipt = envelope.receipt();
 841            let handler = handler.clone();
 842            async move {
 843                let peer = session.peer.clone();
 844                let responded = Arc::new(AtomicBool::default());
 845                let response = Response {
 846                    peer: peer.clone(),
 847                    responded: responded.clone(),
 848                    receipt,
 849                };
 850                match (handler)(envelope.payload, response, session).await {
 851                    Ok(()) => {
 852                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 853                            Ok(())
 854                        } else {
 855                            Err(anyhow!("handler did not send a response"))?
 856                        }
 857                    }
 858                    Err(error) => {
 859                        let proto_err = match &error {
 860                            Error::Internal(err) => err.to_proto(),
 861                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 862                        };
 863                        peer.respond_with_error(receipt, proto_err)?;
 864                        Err(error)
 865                    }
 866                }
 867            }
 868        })
 869    }
 870
 871    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 872    where
 873        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 874        Fut: Send + Future<Output = Result<()>>,
 875        M: RequestMessage,
 876    {
 877        let handler = Arc::new(handler);
 878        self.add_handler(move |envelope, session| {
 879            let receipt = envelope.receipt();
 880            let handler = handler.clone();
 881            async move {
 882                let peer = session.peer.clone();
 883                let response = StreamingResponse {
 884                    peer: peer.clone(),
 885                    receipt,
 886                };
 887                match (handler)(envelope.payload, response, session).await {
 888                    Ok(()) => {
 889                        peer.end_stream(receipt)?;
 890                        Ok(())
 891                    }
 892                    Err(error) => {
 893                        let proto_err = match &error {
 894                            Error::Internal(err) => err.to_proto(),
 895                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 896                        };
 897                        peer.respond_with_error(receipt, proto_err)?;
 898                        Err(error)
 899                    }
 900                }
 901            }
 902        })
 903    }
 904
 905    #[allow(clippy::too_many_arguments)]
 906    pub fn handle_connection(
 907        self: &Arc<Self>,
 908        connection: Connection,
 909        address: String,
 910        principal: Principal,
 911        zed_version: ZedVersion,
 912        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 913        executor: Executor,
 914    ) -> impl Future<Output = ()> {
 915        let this = self.clone();
 916        let span = info_span!("handle connection", %address,
 917            connection_id=field::Empty,
 918            user_id=field::Empty,
 919            login=field::Empty,
 920            impersonator=field::Empty,
 921            dev_server_id=field::Empty
 922        );
 923        principal.update_span(&span);
 924
 925        let mut teardown = self.teardown.subscribe();
 926        async move {
 927            if *teardown.borrow() {
 928                tracing::error!("server is tearing down");
 929                return
 930            }
 931            let (connection_id, handle_io, mut incoming_rx) = this
 932                .peer
 933                .add_connection(connection, {
 934                    let executor = executor.clone();
 935                    move |duration| executor.sleep(duration)
 936                });
 937            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 938            tracing::info!("connection opened");
 939
 940            let http_client = match IsahcHttpClient::new() {
 941                Ok(http_client) => http_client,
 942                Err(error) => {
 943                    tracing::error!(?error, "failed to create HTTP client");
 944                    return;
 945                }
 946            };
 947
 948            let session = Session {
 949                principal: principal.clone(),
 950                connection_id,
 951                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 952                peer: this.peer.clone(),
 953                connection_pool: this.connection_pool.clone(),
 954                live_kit_client: this.app_state.live_kit_client.clone(),
 955                http_client,
 956                rate_limiter: this.app_state.rate_limiter.clone(),
 957                _executor: executor.clone(),
 958            };
 959
 960            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 961                tracing::error!(?error, "failed to send initial client update");
 962                return;
 963            }
 964
 965            let handle_io = handle_io.fuse();
 966            futures::pin_mut!(handle_io);
 967
 968            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 969            // This prevents deadlocks when e.g., client A performs a request to client B and
 970            // client B performs a request to client A. If both clients stop processing further
 971            // messages until their respective request completes, they won't have a chance to
 972            // respond to the other client's request and cause a deadlock.
 973            //
 974            // This arrangement ensures we will attempt to process earlier messages first, but fall
 975            // back to processing messages arrived later in the spirit of making progress.
 976            let mut foreground_message_handlers = FuturesUnordered::new();
 977            let concurrent_handlers = Arc::new(Semaphore::new(256));
 978            loop {
 979                let next_message = async {
 980                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 981                    let message = incoming_rx.next().await;
 982                    (permit, message)
 983                }.fuse();
 984                futures::pin_mut!(next_message);
 985                futures::select_biased! {
 986                    _ = teardown.changed().fuse() => return,
 987                    result = handle_io => {
 988                        if let Err(error) = result {
 989                            tracing::error!(?error, "error handling I/O");
 990                        }
 991                        break;
 992                    }
 993                    _ = foreground_message_handlers.next() => {}
 994                    next_message = next_message => {
 995                        let (permit, message) = next_message;
 996                        if let Some(message) = message {
 997                            let type_name = message.payload_type_name();
 998                            // note: we copy all the fields from the parent span so we can query them in the logs.
 999                            // (https://github.com/tokio-rs/tracing/issues/2670).
1000                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1001                                user_id=field::Empty,
1002                                login=field::Empty,
1003                                impersonator=field::Empty,
1004                                dev_server_id=field::Empty
1005                            );
1006                            principal.update_span(&span);
1007                            let span_enter = span.enter();
1008                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1009                                let is_background = message.is_background();
1010                                let handle_message = (handler)(message, session.clone());
1011                                drop(span_enter);
1012
1013                                let handle_message = async move {
1014                                    handle_message.await;
1015                                    drop(permit);
1016                                }.instrument(span);
1017                                if is_background {
1018                                    executor.spawn_detached(handle_message);
1019                                } else {
1020                                    foreground_message_handlers.push(handle_message);
1021                                }
1022                            } else {
1023                                tracing::error!("no message handler");
1024                            }
1025                        } else {
1026                            tracing::info!("connection closed");
1027                            break;
1028                        }
1029                    }
1030                }
1031            }
1032
1033            drop(foreground_message_handlers);
1034            tracing::info!("signing out");
1035            if let Err(error) = connection_lost(session, teardown, executor).await {
1036                tracing::error!(?error, "error signing out");
1037            }
1038
1039        }.instrument(span)
1040    }
1041
1042    async fn send_initial_client_update(
1043        &self,
1044        connection_id: ConnectionId,
1045        principal: &Principal,
1046        zed_version: ZedVersion,
1047        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1048        session: &Session,
1049    ) -> Result<()> {
1050        self.peer.send(
1051            connection_id,
1052            proto::Hello {
1053                peer_id: Some(connection_id.into()),
1054            },
1055        )?;
1056        tracing::info!("sent hello message");
1057        if let Some(send_connection_id) = send_connection_id.take() {
1058            let _ = send_connection_id.send(connection_id);
1059        }
1060
1061        match principal {
1062            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1063                if !user.connected_once {
1064                    self.peer.send(connection_id, proto::ShowContacts {})?;
1065                    self.app_state
1066                        .db
1067                        .set_user_connected_once(user.id, true)
1068                        .await?;
1069                }
1070
1071                let (contacts, channels_for_user, channel_invites, dev_server_projects) =
1072                    future::try_join4(
1073                        self.app_state.db.get_contacts(user.id),
1074                        self.app_state.db.get_channels_for_user(user.id),
1075                        self.app_state.db.get_channel_invites_for_user(user.id),
1076                        self.app_state.db.dev_server_projects_update(user.id),
1077                    )
1078                    .await?;
1079
1080                {
1081                    let mut pool = self.connection_pool.lock();
1082                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1083                    for membership in &channels_for_user.channel_memberships {
1084                        pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1085                    }
1086                    self.peer.send(
1087                        connection_id,
1088                        build_initial_contacts_update(contacts, &pool),
1089                    )?;
1090                    self.peer.send(
1091                        connection_id,
1092                        build_update_user_channels(&channels_for_user),
1093                    )?;
1094                    self.peer.send(
1095                        connection_id,
1096                        build_channels_update(channels_for_user, channel_invites),
1097                    )?;
1098                }
1099                send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1100
1101                if let Some(incoming_call) =
1102                    self.app_state.db.incoming_call_for_user(user.id).await?
1103                {
1104                    self.peer.send(connection_id, incoming_call)?;
1105                }
1106
1107                update_user_contacts(user.id, &session).await?;
1108            }
1109            Principal::DevServer(dev_server) => {
1110                {
1111                    let mut pool = self.connection_pool.lock();
1112                    if pool.dev_server_connection_id(dev_server.id).is_some() {
1113                        return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1114                    };
1115                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1116                }
1117
1118                let projects = self
1119                    .app_state
1120                    .db
1121                    .get_projects_for_dev_server(dev_server.id)
1122                    .await?;
1123                self.peer
1124                    .send(connection_id, proto::DevServerInstructions { projects })?;
1125
1126                let status = self
1127                    .app_state
1128                    .db
1129                    .dev_server_projects_update(dev_server.user_id)
1130                    .await?;
1131                send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1132            }
1133        }
1134
1135        Ok(())
1136    }
1137
1138    pub async fn invite_code_redeemed(
1139        self: &Arc<Self>,
1140        inviter_id: UserId,
1141        invitee_id: UserId,
1142    ) -> Result<()> {
1143        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1144            if let Some(code) = &user.invite_code {
1145                let pool = self.connection_pool.lock();
1146                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1147                for connection_id in pool.user_connection_ids(inviter_id) {
1148                    self.peer.send(
1149                        connection_id,
1150                        proto::UpdateContacts {
1151                            contacts: vec![invitee_contact.clone()],
1152                            ..Default::default()
1153                        },
1154                    )?;
1155                    self.peer.send(
1156                        connection_id,
1157                        proto::UpdateInviteInfo {
1158                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1159                            count: user.invite_count as u32,
1160                        },
1161                    )?;
1162                }
1163            }
1164        }
1165        Ok(())
1166    }
1167
1168    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1169        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1170            if let Some(invite_code) = &user.invite_code {
1171                let pool = self.connection_pool.lock();
1172                for connection_id in pool.user_connection_ids(user_id) {
1173                    self.peer.send(
1174                        connection_id,
1175                        proto::UpdateInviteInfo {
1176                            url: format!(
1177                                "{}{}",
1178                                self.app_state.config.invite_link_prefix, invite_code
1179                            ),
1180                            count: user.invite_count as u32,
1181                        },
1182                    )?;
1183                }
1184            }
1185        }
1186        Ok(())
1187    }
1188
1189    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1190        ServerSnapshot {
1191            connection_pool: ConnectionPoolGuard {
1192                guard: self.connection_pool.lock(),
1193                _not_send: PhantomData,
1194            },
1195            peer: &self.peer,
1196        }
1197    }
1198}
1199
1200impl<'a> Deref for ConnectionPoolGuard<'a> {
1201    type Target = ConnectionPool;
1202
1203    fn deref(&self) -> &Self::Target {
1204        &self.guard
1205    }
1206}
1207
1208impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1209    fn deref_mut(&mut self) -> &mut Self::Target {
1210        &mut self.guard
1211    }
1212}
1213
1214impl<'a> Drop for ConnectionPoolGuard<'a> {
1215    fn drop(&mut self) {
1216        #[cfg(test)]
1217        self.check_invariants();
1218    }
1219}
1220
1221fn broadcast<F>(
1222    sender_id: Option<ConnectionId>,
1223    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1224    mut f: F,
1225) where
1226    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1227{
1228    for receiver_id in receiver_ids {
1229        if Some(receiver_id) != sender_id {
1230            if let Err(error) = f(receiver_id) {
1231                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1232            }
1233        }
1234    }
1235}
1236
1237pub struct ProtocolVersion(u32);
1238
1239impl Header for ProtocolVersion {
1240    fn name() -> &'static HeaderName {
1241        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1242        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1243    }
1244
1245    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1246    where
1247        Self: Sized,
1248        I: Iterator<Item = &'i axum::http::HeaderValue>,
1249    {
1250        let version = values
1251            .next()
1252            .ok_or_else(axum::headers::Error::invalid)?
1253            .to_str()
1254            .map_err(|_| axum::headers::Error::invalid())?
1255            .parse()
1256            .map_err(|_| axum::headers::Error::invalid())?;
1257        Ok(Self(version))
1258    }
1259
1260    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1261        values.extend([self.0.to_string().parse().unwrap()]);
1262    }
1263}
1264
1265pub struct AppVersionHeader(SemanticVersion);
1266impl Header for AppVersionHeader {
1267    fn name() -> &'static HeaderName {
1268        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1269        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1270    }
1271
1272    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1273    where
1274        Self: Sized,
1275        I: Iterator<Item = &'i axum::http::HeaderValue>,
1276    {
1277        let version = values
1278            .next()
1279            .ok_or_else(axum::headers::Error::invalid)?
1280            .to_str()
1281            .map_err(|_| axum::headers::Error::invalid())?
1282            .parse()
1283            .map_err(|_| axum::headers::Error::invalid())?;
1284        Ok(Self(version))
1285    }
1286
1287    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1288        values.extend([self.0.to_string().parse().unwrap()]);
1289    }
1290}
1291
1292pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1293    Router::new()
1294        .route("/rpc", get(handle_websocket_request))
1295        .layer(
1296            ServiceBuilder::new()
1297                .layer(Extension(server.app_state.clone()))
1298                .layer(middleware::from_fn(auth::validate_header)),
1299        )
1300        .route("/metrics", get(handle_metrics))
1301        .layer(Extension(server))
1302}
1303
1304pub async fn handle_websocket_request(
1305    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1306    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1307    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1308    Extension(server): Extension<Arc<Server>>,
1309    Extension(principal): Extension<Principal>,
1310    ws: WebSocketUpgrade,
1311) -> axum::response::Response {
1312    if protocol_version != rpc::PROTOCOL_VERSION {
1313        return (
1314            StatusCode::UPGRADE_REQUIRED,
1315            "client must be upgraded".to_string(),
1316        )
1317            .into_response();
1318    }
1319
1320    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1321        return (
1322            StatusCode::UPGRADE_REQUIRED,
1323            "no version header found".to_string(),
1324        )
1325            .into_response();
1326    };
1327
1328    if !version.can_collaborate() {
1329        return (
1330            StatusCode::UPGRADE_REQUIRED,
1331            "client must be upgraded".to_string(),
1332        )
1333            .into_response();
1334    }
1335
1336    let socket_address = socket_address.to_string();
1337    ws.on_upgrade(move |socket| {
1338        let socket = socket
1339            .map_ok(to_tungstenite_message)
1340            .err_into()
1341            .with(|message| async move { Ok(to_axum_message(message)) });
1342        let connection = Connection::new(Box::pin(socket));
1343        async move {
1344            server
1345                .handle_connection(
1346                    connection,
1347                    socket_address,
1348                    principal,
1349                    version,
1350                    None,
1351                    Executor::Production,
1352                )
1353                .await;
1354        }
1355    })
1356}
1357
1358pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1359    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1360    let connections_metric = CONNECTIONS_METRIC
1361        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1362
1363    let connections = server
1364        .connection_pool
1365        .lock()
1366        .connections()
1367        .filter(|connection| !connection.admin)
1368        .count();
1369    connections_metric.set(connections as _);
1370
1371    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1372    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1373        register_int_gauge!(
1374            "shared_projects",
1375            "number of open projects with one or more guests"
1376        )
1377        .unwrap()
1378    });
1379
1380    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1381    shared_projects_metric.set(shared_projects as _);
1382
1383    let encoder = prometheus::TextEncoder::new();
1384    let metric_families = prometheus::gather();
1385    let encoded_metrics = encoder
1386        .encode_to_string(&metric_families)
1387        .map_err(|err| anyhow!("{}", err))?;
1388    Ok(encoded_metrics)
1389}
1390
1391#[instrument(err, skip(executor))]
1392async fn connection_lost(
1393    session: Session,
1394    mut teardown: watch::Receiver<bool>,
1395    executor: Executor,
1396) -> Result<()> {
1397    session.peer.disconnect(session.connection_id);
1398    session
1399        .connection_pool()
1400        .await
1401        .remove_connection(session.connection_id)?;
1402
1403    session
1404        .db()
1405        .await
1406        .connection_lost(session.connection_id)
1407        .await
1408        .trace_err();
1409
1410    futures::select_biased! {
1411        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1412            match &session.principal {
1413                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1414                    let session = session.for_user().unwrap();
1415
1416                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1417                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1418                    leave_channel_buffers_for_session(&session)
1419                        .await
1420                        .trace_err();
1421
1422                    if !session
1423                        .connection_pool()
1424                        .await
1425                        .is_user_online(session.user_id())
1426                    {
1427                        let db = session.db().await;
1428                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1429                            room_updated(&room, &session.peer);
1430                        }
1431                    }
1432
1433                    update_user_contacts(session.user_id(), &session).await?;
1434                },
1435            Principal::DevServer(_) => {
1436                lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1437            },
1438        }
1439        },
1440        _ = teardown.changed().fuse() => {}
1441    }
1442
1443    Ok(())
1444}
1445
1446/// Acknowledges a ping from a client, used to keep the connection alive.
1447async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1448    response.send(proto::Ack {})?;
1449    Ok(())
1450}
1451
1452/// Creates a new room for calling (outside of channels)
1453async fn create_room(
1454    _request: proto::CreateRoom,
1455    response: Response<proto::CreateRoom>,
1456    session: UserSession,
1457) -> Result<()> {
1458    let live_kit_room = nanoid::nanoid!(30);
1459
1460    let live_kit_connection_info = util::maybe!(async {
1461        let live_kit = session.live_kit_client.as_ref();
1462        let live_kit = live_kit?;
1463        let user_id = session.user_id().to_string();
1464
1465        let token = live_kit
1466            .room_token(&live_kit_room, &user_id.to_string())
1467            .trace_err()?;
1468
1469        Some(proto::LiveKitConnectionInfo {
1470            server_url: live_kit.url().into(),
1471            token,
1472            can_publish: true,
1473        })
1474    })
1475    .await;
1476
1477    let room = session
1478        .db()
1479        .await
1480        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1481        .await?;
1482
1483    response.send(proto::CreateRoomResponse {
1484        room: Some(room.clone()),
1485        live_kit_connection_info,
1486    })?;
1487
1488    update_user_contacts(session.user_id(), &session).await?;
1489    Ok(())
1490}
1491
1492/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1493async fn join_room(
1494    request: proto::JoinRoom,
1495    response: Response<proto::JoinRoom>,
1496    session: UserSession,
1497) -> Result<()> {
1498    let room_id = RoomId::from_proto(request.id);
1499
1500    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1501
1502    if let Some(channel_id) = channel_id {
1503        return join_channel_internal(channel_id, Box::new(response), session).await;
1504    }
1505
1506    let joined_room = {
1507        let room = session
1508            .db()
1509            .await
1510            .join_room(room_id, session.user_id(), session.connection_id)
1511            .await?;
1512        room_updated(&room.room, &session.peer);
1513        room.into_inner()
1514    };
1515
1516    for connection_id in session
1517        .connection_pool()
1518        .await
1519        .user_connection_ids(session.user_id())
1520    {
1521        session
1522            .peer
1523            .send(
1524                connection_id,
1525                proto::CallCanceled {
1526                    room_id: room_id.to_proto(),
1527                },
1528            )
1529            .trace_err();
1530    }
1531
1532    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1533        if let Some(token) = live_kit
1534            .room_token(
1535                &joined_room.room.live_kit_room,
1536                &session.user_id().to_string(),
1537            )
1538            .trace_err()
1539        {
1540            Some(proto::LiveKitConnectionInfo {
1541                server_url: live_kit.url().into(),
1542                token,
1543                can_publish: true,
1544            })
1545        } else {
1546            None
1547        }
1548    } else {
1549        None
1550    };
1551
1552    response.send(proto::JoinRoomResponse {
1553        room: Some(joined_room.room),
1554        channel_id: None,
1555        live_kit_connection_info,
1556    })?;
1557
1558    update_user_contacts(session.user_id(), &session).await?;
1559    Ok(())
1560}
1561
1562/// Rejoin room is used to reconnect to a room after connection errors.
1563async fn rejoin_room(
1564    request: proto::RejoinRoom,
1565    response: Response<proto::RejoinRoom>,
1566    session: UserSession,
1567) -> Result<()> {
1568    let room;
1569    let channel;
1570    {
1571        let mut rejoined_room = session
1572            .db()
1573            .await
1574            .rejoin_room(request, session.user_id(), session.connection_id)
1575            .await?;
1576
1577        response.send(proto::RejoinRoomResponse {
1578            room: Some(rejoined_room.room.clone()),
1579            reshared_projects: rejoined_room
1580                .reshared_projects
1581                .iter()
1582                .map(|project| proto::ResharedProject {
1583                    id: project.id.to_proto(),
1584                    collaborators: project
1585                        .collaborators
1586                        .iter()
1587                        .map(|collaborator| collaborator.to_proto())
1588                        .collect(),
1589                })
1590                .collect(),
1591            rejoined_projects: rejoined_room
1592                .rejoined_projects
1593                .iter()
1594                .map(|rejoined_project| rejoined_project.to_proto())
1595                .collect(),
1596        })?;
1597        room_updated(&rejoined_room.room, &session.peer);
1598
1599        for project in &rejoined_room.reshared_projects {
1600            for collaborator in &project.collaborators {
1601                session
1602                    .peer
1603                    .send(
1604                        collaborator.connection_id,
1605                        proto::UpdateProjectCollaborator {
1606                            project_id: project.id.to_proto(),
1607                            old_peer_id: Some(project.old_connection_id.into()),
1608                            new_peer_id: Some(session.connection_id.into()),
1609                        },
1610                    )
1611                    .trace_err();
1612            }
1613
1614            broadcast(
1615                Some(session.connection_id),
1616                project
1617                    .collaborators
1618                    .iter()
1619                    .map(|collaborator| collaborator.connection_id),
1620                |connection_id| {
1621                    session.peer.forward_send(
1622                        session.connection_id,
1623                        connection_id,
1624                        proto::UpdateProject {
1625                            project_id: project.id.to_proto(),
1626                            worktrees: project.worktrees.clone(),
1627                        },
1628                    )
1629                },
1630            );
1631        }
1632
1633        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1634
1635        let rejoined_room = rejoined_room.into_inner();
1636
1637        room = rejoined_room.room;
1638        channel = rejoined_room.channel;
1639    }
1640
1641    if let Some(channel) = channel {
1642        channel_updated(
1643            &channel,
1644            &room,
1645            &session.peer,
1646            &*session.connection_pool().await,
1647        );
1648    }
1649
1650    update_user_contacts(session.user_id(), &session).await?;
1651    Ok(())
1652}
1653
1654fn notify_rejoined_projects(
1655    rejoined_projects: &mut Vec<RejoinedProject>,
1656    session: &UserSession,
1657) -> Result<()> {
1658    for project in rejoined_projects.iter() {
1659        for collaborator in &project.collaborators {
1660            session
1661                .peer
1662                .send(
1663                    collaborator.connection_id,
1664                    proto::UpdateProjectCollaborator {
1665                        project_id: project.id.to_proto(),
1666                        old_peer_id: Some(project.old_connection_id.into()),
1667                        new_peer_id: Some(session.connection_id.into()),
1668                    },
1669                )
1670                .trace_err();
1671        }
1672    }
1673
1674    for project in rejoined_projects {
1675        for worktree in mem::take(&mut project.worktrees) {
1676            #[cfg(any(test, feature = "test-support"))]
1677            const MAX_CHUNK_SIZE: usize = 2;
1678            #[cfg(not(any(test, feature = "test-support")))]
1679            const MAX_CHUNK_SIZE: usize = 256;
1680
1681            // Stream this worktree's entries.
1682            let message = proto::UpdateWorktree {
1683                project_id: project.id.to_proto(),
1684                worktree_id: worktree.id,
1685                abs_path: worktree.abs_path.clone(),
1686                root_name: worktree.root_name,
1687                updated_entries: worktree.updated_entries,
1688                removed_entries: worktree.removed_entries,
1689                scan_id: worktree.scan_id,
1690                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1691                updated_repositories: worktree.updated_repositories,
1692                removed_repositories: worktree.removed_repositories,
1693            };
1694            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1695                session.peer.send(session.connection_id, update.clone())?;
1696            }
1697
1698            // Stream this worktree's diagnostics.
1699            for summary in worktree.diagnostic_summaries {
1700                session.peer.send(
1701                    session.connection_id,
1702                    proto::UpdateDiagnosticSummary {
1703                        project_id: project.id.to_proto(),
1704                        worktree_id: worktree.id,
1705                        summary: Some(summary),
1706                    },
1707                )?;
1708            }
1709
1710            for settings_file in worktree.settings_files {
1711                session.peer.send(
1712                    session.connection_id,
1713                    proto::UpdateWorktreeSettings {
1714                        project_id: project.id.to_proto(),
1715                        worktree_id: worktree.id,
1716                        path: settings_file.path,
1717                        content: Some(settings_file.content),
1718                    },
1719                )?;
1720            }
1721        }
1722
1723        for language_server in &project.language_servers {
1724            session.peer.send(
1725                session.connection_id,
1726                proto::UpdateLanguageServer {
1727                    project_id: project.id.to_proto(),
1728                    language_server_id: language_server.id,
1729                    variant: Some(
1730                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1731                            proto::LspDiskBasedDiagnosticsUpdated {},
1732                        ),
1733                    ),
1734                },
1735            )?;
1736        }
1737    }
1738    Ok(())
1739}
1740
1741/// leave room disconnects from the room.
1742async fn leave_room(
1743    _: proto::LeaveRoom,
1744    response: Response<proto::LeaveRoom>,
1745    session: UserSession,
1746) -> Result<()> {
1747    leave_room_for_session(&session, session.connection_id).await?;
1748    response.send(proto::Ack {})?;
1749    Ok(())
1750}
1751
1752/// Updates the permissions of someone else in the room.
1753async fn set_room_participant_role(
1754    request: proto::SetRoomParticipantRole,
1755    response: Response<proto::SetRoomParticipantRole>,
1756    session: UserSession,
1757) -> Result<()> {
1758    let user_id = UserId::from_proto(request.user_id);
1759    let role = ChannelRole::from(request.role());
1760
1761    let (live_kit_room, can_publish) = {
1762        let room = session
1763            .db()
1764            .await
1765            .set_room_participant_role(
1766                session.user_id(),
1767                RoomId::from_proto(request.room_id),
1768                user_id,
1769                role,
1770            )
1771            .await?;
1772
1773        let live_kit_room = room.live_kit_room.clone();
1774        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1775        room_updated(&room, &session.peer);
1776        (live_kit_room, can_publish)
1777    };
1778
1779    if let Some(live_kit) = session.live_kit_client.as_ref() {
1780        live_kit
1781            .update_participant(
1782                live_kit_room.clone(),
1783                request.user_id.to_string(),
1784                live_kit_server::proto::ParticipantPermission {
1785                    can_subscribe: true,
1786                    can_publish,
1787                    can_publish_data: can_publish,
1788                    hidden: false,
1789                    recorder: false,
1790                },
1791            )
1792            .await
1793            .trace_err();
1794    }
1795
1796    response.send(proto::Ack {})?;
1797    Ok(())
1798}
1799
1800/// Call someone else into the current room
1801async fn call(
1802    request: proto::Call,
1803    response: Response<proto::Call>,
1804    session: UserSession,
1805) -> Result<()> {
1806    let room_id = RoomId::from_proto(request.room_id);
1807    let calling_user_id = session.user_id();
1808    let calling_connection_id = session.connection_id;
1809    let called_user_id = UserId::from_proto(request.called_user_id);
1810    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1811    if !session
1812        .db()
1813        .await
1814        .has_contact(calling_user_id, called_user_id)
1815        .await?
1816    {
1817        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1818    }
1819
1820    let incoming_call = {
1821        let (room, incoming_call) = &mut *session
1822            .db()
1823            .await
1824            .call(
1825                room_id,
1826                calling_user_id,
1827                calling_connection_id,
1828                called_user_id,
1829                initial_project_id,
1830            )
1831            .await?;
1832        room_updated(&room, &session.peer);
1833        mem::take(incoming_call)
1834    };
1835    update_user_contacts(called_user_id, &session).await?;
1836
1837    let mut calls = session
1838        .connection_pool()
1839        .await
1840        .user_connection_ids(called_user_id)
1841        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1842        .collect::<FuturesUnordered<_>>();
1843
1844    while let Some(call_response) = calls.next().await {
1845        match call_response.as_ref() {
1846            Ok(_) => {
1847                response.send(proto::Ack {})?;
1848                return Ok(());
1849            }
1850            Err(_) => {
1851                call_response.trace_err();
1852            }
1853        }
1854    }
1855
1856    {
1857        let room = session
1858            .db()
1859            .await
1860            .call_failed(room_id, called_user_id)
1861            .await?;
1862        room_updated(&room, &session.peer);
1863    }
1864    update_user_contacts(called_user_id, &session).await?;
1865
1866    Err(anyhow!("failed to ring user"))?
1867}
1868
1869/// Cancel an outgoing call.
1870async fn cancel_call(
1871    request: proto::CancelCall,
1872    response: Response<proto::CancelCall>,
1873    session: UserSession,
1874) -> Result<()> {
1875    let called_user_id = UserId::from_proto(request.called_user_id);
1876    let room_id = RoomId::from_proto(request.room_id);
1877    {
1878        let room = session
1879            .db()
1880            .await
1881            .cancel_call(room_id, session.connection_id, called_user_id)
1882            .await?;
1883        room_updated(&room, &session.peer);
1884    }
1885
1886    for connection_id in session
1887        .connection_pool()
1888        .await
1889        .user_connection_ids(called_user_id)
1890    {
1891        session
1892            .peer
1893            .send(
1894                connection_id,
1895                proto::CallCanceled {
1896                    room_id: room_id.to_proto(),
1897                },
1898            )
1899            .trace_err();
1900    }
1901    response.send(proto::Ack {})?;
1902
1903    update_user_contacts(called_user_id, &session).await?;
1904    Ok(())
1905}
1906
1907/// Decline an incoming call.
1908async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1909    let room_id = RoomId::from_proto(message.room_id);
1910    {
1911        let room = session
1912            .db()
1913            .await
1914            .decline_call(Some(room_id), session.user_id())
1915            .await?
1916            .ok_or_else(|| anyhow!("failed to decline call"))?;
1917        room_updated(&room, &session.peer);
1918    }
1919
1920    for connection_id in session
1921        .connection_pool()
1922        .await
1923        .user_connection_ids(session.user_id())
1924    {
1925        session
1926            .peer
1927            .send(
1928                connection_id,
1929                proto::CallCanceled {
1930                    room_id: room_id.to_proto(),
1931                },
1932            )
1933            .trace_err();
1934    }
1935    update_user_contacts(session.user_id(), &session).await?;
1936    Ok(())
1937}
1938
1939/// Updates other participants in the room with your current location.
1940async fn update_participant_location(
1941    request: proto::UpdateParticipantLocation,
1942    response: Response<proto::UpdateParticipantLocation>,
1943    session: UserSession,
1944) -> Result<()> {
1945    let room_id = RoomId::from_proto(request.room_id);
1946    let location = request
1947        .location
1948        .ok_or_else(|| anyhow!("invalid location"))?;
1949
1950    let db = session.db().await;
1951    let room = db
1952        .update_room_participant_location(room_id, session.connection_id, location)
1953        .await?;
1954
1955    room_updated(&room, &session.peer);
1956    response.send(proto::Ack {})?;
1957    Ok(())
1958}
1959
1960/// Share a project into the room.
1961async fn share_project(
1962    request: proto::ShareProject,
1963    response: Response<proto::ShareProject>,
1964    session: UserSession,
1965) -> Result<()> {
1966    let (project_id, room) = &*session
1967        .db()
1968        .await
1969        .share_project(
1970            RoomId::from_proto(request.room_id),
1971            session.connection_id,
1972            &request.worktrees,
1973            request
1974                .dev_server_project_id
1975                .map(|id| DevServerProjectId::from_proto(id)),
1976        )
1977        .await?;
1978    response.send(proto::ShareProjectResponse {
1979        project_id: project_id.to_proto(),
1980    })?;
1981    room_updated(&room, &session.peer);
1982
1983    Ok(())
1984}
1985
1986/// Unshare a project from the room.
1987async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1988    let project_id = ProjectId::from_proto(message.project_id);
1989    unshare_project_internal(
1990        project_id,
1991        session.connection_id,
1992        session.user_id(),
1993        &session,
1994    )
1995    .await
1996}
1997
1998async fn unshare_project_internal(
1999    project_id: ProjectId,
2000    connection_id: ConnectionId,
2001    user_id: Option<UserId>,
2002    session: &Session,
2003) -> Result<()> {
2004    let (room, guest_connection_ids) = &*session
2005        .db()
2006        .await
2007        .unshare_project(project_id, connection_id, user_id)
2008        .await?;
2009
2010    let message = proto::UnshareProject {
2011        project_id: project_id.to_proto(),
2012    };
2013
2014    broadcast(
2015        Some(connection_id),
2016        guest_connection_ids.iter().copied(),
2017        |conn_id| session.peer.send(conn_id, message.clone()),
2018    );
2019    if let Some(room) = room {
2020        room_updated(room, &session.peer);
2021    }
2022
2023    Ok(())
2024}
2025
2026/// DevServer makes a project available online
2027async fn share_dev_server_project(
2028    request: proto::ShareDevServerProject,
2029    response: Response<proto::ShareDevServerProject>,
2030    session: DevServerSession,
2031) -> Result<()> {
2032    let (dev_server_project, user_id, status) = session
2033        .db()
2034        .await
2035        .share_dev_server_project(
2036            DevServerProjectId::from_proto(request.dev_server_project_id),
2037            session.dev_server_id(),
2038            session.connection_id,
2039            &request.worktrees,
2040        )
2041        .await?;
2042    let Some(project_id) = dev_server_project.project_id else {
2043        return Err(anyhow!("failed to share remote project"))?;
2044    };
2045
2046    send_dev_server_projects_update(user_id, status, &session).await;
2047
2048    response.send(proto::ShareProjectResponse { project_id })?;
2049
2050    Ok(())
2051}
2052
2053/// Join someone elses shared project.
2054async fn join_project(
2055    request: proto::JoinProject,
2056    response: Response<proto::JoinProject>,
2057    session: UserSession,
2058) -> Result<()> {
2059    let project_id = ProjectId::from_proto(request.project_id);
2060
2061    tracing::info!(%project_id, "join project");
2062
2063    let db = session.db().await;
2064    let (project, replica_id) = &mut *db
2065        .join_project(project_id, session.connection_id, session.user_id())
2066        .await?;
2067    drop(db);
2068    tracing::info!(%project_id, "join remote project");
2069    join_project_internal(response, session, project, replica_id)
2070}
2071
2072trait JoinProjectInternalResponse {
2073    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2074}
2075impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2076    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2077        Response::<proto::JoinProject>::send(self, result)
2078    }
2079}
2080impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2081    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2082        Response::<proto::JoinHostedProject>::send(self, result)
2083    }
2084}
2085
2086fn join_project_internal(
2087    response: impl JoinProjectInternalResponse,
2088    session: UserSession,
2089    project: &mut Project,
2090    replica_id: &ReplicaId,
2091) -> Result<()> {
2092    let collaborators = project
2093        .collaborators
2094        .iter()
2095        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2096        .map(|collaborator| collaborator.to_proto())
2097        .collect::<Vec<_>>();
2098    let project_id = project.id;
2099    let guest_user_id = session.user_id();
2100
2101    let worktrees = project
2102        .worktrees
2103        .iter()
2104        .map(|(id, worktree)| proto::WorktreeMetadata {
2105            id: *id,
2106            root_name: worktree.root_name.clone(),
2107            visible: worktree.visible,
2108            abs_path: worktree.abs_path.clone(),
2109        })
2110        .collect::<Vec<_>>();
2111
2112    let add_project_collaborator = proto::AddProjectCollaborator {
2113        project_id: project_id.to_proto(),
2114        collaborator: Some(proto::Collaborator {
2115            peer_id: Some(session.connection_id.into()),
2116            replica_id: replica_id.0 as u32,
2117            user_id: guest_user_id.to_proto(),
2118        }),
2119    };
2120
2121    for collaborator in &collaborators {
2122        session
2123            .peer
2124            .send(
2125                collaborator.peer_id.unwrap().into(),
2126                add_project_collaborator.clone(),
2127            )
2128            .trace_err();
2129    }
2130
2131    // First, we send the metadata associated with each worktree.
2132    response.send(proto::JoinProjectResponse {
2133        project_id: project.id.0 as u64,
2134        worktrees: worktrees.clone(),
2135        replica_id: replica_id.0 as u32,
2136        collaborators: collaborators.clone(),
2137        language_servers: project.language_servers.clone(),
2138        role: project.role.into(),
2139        dev_server_project_id: project
2140            .dev_server_project_id
2141            .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2142    })?;
2143
2144    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2145        #[cfg(any(test, feature = "test-support"))]
2146        const MAX_CHUNK_SIZE: usize = 2;
2147        #[cfg(not(any(test, feature = "test-support")))]
2148        const MAX_CHUNK_SIZE: usize = 256;
2149
2150        // Stream this worktree's entries.
2151        let message = proto::UpdateWorktree {
2152            project_id: project_id.to_proto(),
2153            worktree_id,
2154            abs_path: worktree.abs_path.clone(),
2155            root_name: worktree.root_name,
2156            updated_entries: worktree.entries,
2157            removed_entries: Default::default(),
2158            scan_id: worktree.scan_id,
2159            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2160            updated_repositories: worktree.repository_entries.into_values().collect(),
2161            removed_repositories: Default::default(),
2162        };
2163        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2164            session.peer.send(session.connection_id, update.clone())?;
2165        }
2166
2167        // Stream this worktree's diagnostics.
2168        for summary in worktree.diagnostic_summaries {
2169            session.peer.send(
2170                session.connection_id,
2171                proto::UpdateDiagnosticSummary {
2172                    project_id: project_id.to_proto(),
2173                    worktree_id: worktree.id,
2174                    summary: Some(summary),
2175                },
2176            )?;
2177        }
2178
2179        for settings_file in worktree.settings_files {
2180            session.peer.send(
2181                session.connection_id,
2182                proto::UpdateWorktreeSettings {
2183                    project_id: project_id.to_proto(),
2184                    worktree_id: worktree.id,
2185                    path: settings_file.path,
2186                    content: Some(settings_file.content),
2187                },
2188            )?;
2189        }
2190    }
2191
2192    for language_server in &project.language_servers {
2193        session.peer.send(
2194            session.connection_id,
2195            proto::UpdateLanguageServer {
2196                project_id: project_id.to_proto(),
2197                language_server_id: language_server.id,
2198                variant: Some(
2199                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2200                        proto::LspDiskBasedDiagnosticsUpdated {},
2201                    ),
2202                ),
2203            },
2204        )?;
2205    }
2206
2207    Ok(())
2208}
2209
2210/// Leave someone elses shared project.
2211async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2212    let sender_id = session.connection_id;
2213    let project_id = ProjectId::from_proto(request.project_id);
2214    let db = session.db().await;
2215    if db.is_hosted_project(project_id).await? {
2216        let project = db.leave_hosted_project(project_id, sender_id).await?;
2217        project_left(&project, &session);
2218        return Ok(());
2219    }
2220
2221    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2222    tracing::info!(
2223        %project_id,
2224        "leave project"
2225    );
2226
2227    project_left(&project, &session);
2228    if let Some(room) = room {
2229        room_updated(&room, &session.peer);
2230    }
2231
2232    Ok(())
2233}
2234
2235async fn join_hosted_project(
2236    request: proto::JoinHostedProject,
2237    response: Response<proto::JoinHostedProject>,
2238    session: UserSession,
2239) -> Result<()> {
2240    let (mut project, replica_id) = session
2241        .db()
2242        .await
2243        .join_hosted_project(
2244            ProjectId(request.project_id as i32),
2245            session.user_id(),
2246            session.connection_id,
2247        )
2248        .await?;
2249
2250    join_project_internal(response, session, &mut project, &replica_id)
2251}
2252
2253async fn create_dev_server_project(
2254    request: proto::CreateDevServerProject,
2255    response: Response<proto::CreateDevServerProject>,
2256    session: UserSession,
2257) -> Result<()> {
2258    let dev_server_id = DevServerId(request.dev_server_id as i32);
2259    let dev_server_connection_id = session
2260        .connection_pool()
2261        .await
2262        .dev_server_connection_id(dev_server_id);
2263    let Some(dev_server_connection_id) = dev_server_connection_id else {
2264        Err(ErrorCode::DevServerOffline
2265            .message("Cannot create a remote project when the dev server is offline".to_string())
2266            .anyhow())?
2267    };
2268
2269    let path = request.path.clone();
2270    //Check that the path exists on the dev server
2271    session
2272        .peer
2273        .forward_request(
2274            session.connection_id,
2275            dev_server_connection_id,
2276            proto::ValidateDevServerProjectRequest { path: path.clone() },
2277        )
2278        .await?;
2279
2280    let (dev_server_project, update) = session
2281        .db()
2282        .await
2283        .create_dev_server_project(
2284            DevServerId(request.dev_server_id as i32),
2285            &request.path,
2286            session.user_id(),
2287        )
2288        .await?;
2289
2290    let projects = session
2291        .db()
2292        .await
2293        .get_projects_for_dev_server(dev_server_project.dev_server_id)
2294        .await?;
2295
2296    session.peer.send(
2297        dev_server_connection_id,
2298        proto::DevServerInstructions { projects },
2299    )?;
2300
2301    send_dev_server_projects_update(session.user_id(), update, &session).await;
2302
2303    response.send(proto::CreateDevServerProjectResponse {
2304        dev_server_project: Some(dev_server_project.to_proto(None)),
2305    })?;
2306    Ok(())
2307}
2308
2309async fn create_dev_server(
2310    request: proto::CreateDevServer,
2311    response: Response<proto::CreateDevServer>,
2312    session: UserSession,
2313) -> Result<()> {
2314    let access_token = auth::random_token();
2315    let hashed_access_token = auth::hash_access_token(&access_token);
2316
2317    let (dev_server, status) = session
2318        .db()
2319        .await
2320        .create_dev_server(&request.name, &hashed_access_token, session.user_id())
2321        .await?;
2322
2323    send_dev_server_projects_update(session.user_id(), status, &session).await;
2324
2325    response.send(proto::CreateDevServerResponse {
2326        dev_server_id: dev_server.id.0 as u64,
2327        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2328        name: request.name.clone(),
2329    })?;
2330    Ok(())
2331}
2332
2333async fn delete_dev_server(
2334    request: proto::DeleteDevServer,
2335    response: Response<proto::DeleteDevServer>,
2336    session: UserSession,
2337) -> Result<()> {
2338    let dev_server_id = DevServerId(request.dev_server_id as i32);
2339    let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2340    if dev_server.user_id != session.user_id() {
2341        return Err(anyhow!(ErrorCode::Forbidden))?;
2342    }
2343
2344    let connection_id = session
2345        .connection_pool()
2346        .await
2347        .dev_server_connection_id(dev_server_id);
2348    if let Some(connection_id) = connection_id {
2349        shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2350        session
2351            .peer
2352            .send(connection_id, proto::ShutdownDevServer {})?;
2353    }
2354
2355    let status = session
2356        .db()
2357        .await
2358        .delete_dev_server(dev_server_id, session.user_id())
2359        .await?;
2360
2361    send_dev_server_projects_update(session.user_id(), status, &session).await;
2362
2363    response.send(proto::Ack {})?;
2364    Ok(())
2365}
2366
2367async fn delete_dev_server_project(
2368    request: proto::DeleteDevServerProject,
2369    response: Response<proto::DeleteDevServerProject>,
2370    session: UserSession,
2371) -> Result<()> {
2372    let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2373    let dev_server_project = session
2374        .db()
2375        .await
2376        .get_dev_server_project(dev_server_project_id)
2377        .await?;
2378
2379    let dev_server = session
2380        .db()
2381        .await
2382        .get_dev_server(dev_server_project.dev_server_id)
2383        .await?;
2384    if dev_server.user_id != session.user_id() {
2385        return Err(anyhow!(ErrorCode::Forbidden))?;
2386    }
2387
2388    let dev_server_connection_id = session
2389        .connection_pool()
2390        .await
2391        .dev_server_connection_id(dev_server.id);
2392
2393    if let Some(dev_server_connection_id) = dev_server_connection_id {
2394        let project = session
2395            .db()
2396            .await
2397            .find_dev_server_project(dev_server_project_id)
2398            .await;
2399        if let Ok(project) = project {
2400            unshare_project_internal(
2401                project.id,
2402                dev_server_connection_id,
2403                Some(session.user_id()),
2404                &session,
2405            )
2406            .await?;
2407        }
2408    }
2409
2410    let (projects, status) = session
2411        .db()
2412        .await
2413        .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2414        .await?;
2415
2416    if let Some(dev_server_connection_id) = dev_server_connection_id {
2417        session.peer.send(
2418            dev_server_connection_id,
2419            proto::DevServerInstructions { projects },
2420        )?;
2421    }
2422
2423    send_dev_server_projects_update(session.user_id(), status, &session).await;
2424
2425    response.send(proto::Ack {})?;
2426    Ok(())
2427}
2428
2429async fn rejoin_dev_server_projects(
2430    request: proto::RejoinRemoteProjects,
2431    response: Response<proto::RejoinRemoteProjects>,
2432    session: UserSession,
2433) -> Result<()> {
2434    let mut rejoined_projects = {
2435        let db = session.db().await;
2436        db.rejoin_dev_server_projects(
2437            &request.rejoined_projects,
2438            session.user_id(),
2439            session.0.connection_id,
2440        )
2441        .await?
2442    };
2443    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2444
2445    response.send(proto::RejoinRemoteProjectsResponse {
2446        rejoined_projects: rejoined_projects
2447            .into_iter()
2448            .map(|project| project.to_proto())
2449            .collect(),
2450    })
2451}
2452
2453async fn reconnect_dev_server(
2454    request: proto::ReconnectDevServer,
2455    response: Response<proto::ReconnectDevServer>,
2456    session: DevServerSession,
2457) -> Result<()> {
2458    let reshared_projects = {
2459        let db = session.db().await;
2460        db.reshare_dev_server_projects(
2461            &request.reshared_projects,
2462            session.dev_server_id(),
2463            session.0.connection_id,
2464        )
2465        .await?
2466    };
2467
2468    for project in &reshared_projects {
2469        for collaborator in &project.collaborators {
2470            session
2471                .peer
2472                .send(
2473                    collaborator.connection_id,
2474                    proto::UpdateProjectCollaborator {
2475                        project_id: project.id.to_proto(),
2476                        old_peer_id: Some(project.old_connection_id.into()),
2477                        new_peer_id: Some(session.connection_id.into()),
2478                    },
2479                )
2480                .trace_err();
2481        }
2482
2483        broadcast(
2484            Some(session.connection_id),
2485            project
2486                .collaborators
2487                .iter()
2488                .map(|collaborator| collaborator.connection_id),
2489            |connection_id| {
2490                session.peer.forward_send(
2491                    session.connection_id,
2492                    connection_id,
2493                    proto::UpdateProject {
2494                        project_id: project.id.to_proto(),
2495                        worktrees: project.worktrees.clone(),
2496                    },
2497                )
2498            },
2499        );
2500    }
2501
2502    response.send(proto::ReconnectDevServerResponse {
2503        reshared_projects: reshared_projects
2504            .iter()
2505            .map(|project| proto::ResharedProject {
2506                id: project.id.to_proto(),
2507                collaborators: project
2508                    .collaborators
2509                    .iter()
2510                    .map(|collaborator| collaborator.to_proto())
2511                    .collect(),
2512            })
2513            .collect(),
2514    })?;
2515
2516    Ok(())
2517}
2518
2519async fn shutdown_dev_server(
2520    _: proto::ShutdownDevServer,
2521    response: Response<proto::ShutdownDevServer>,
2522    session: DevServerSession,
2523) -> Result<()> {
2524    response.send(proto::Ack {})?;
2525    shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await
2526}
2527
2528async fn shutdown_dev_server_internal(
2529    dev_server_id: DevServerId,
2530    connection_id: ConnectionId,
2531    session: &Session,
2532) -> Result<()> {
2533    let (dev_server_projects, dev_server) = {
2534        let db = session.db().await;
2535        let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2536        let dev_server = db.get_dev_server(dev_server_id).await?;
2537        (dev_server_projects, dev_server)
2538    };
2539
2540    for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2541        unshare_project_internal(
2542            ProjectId::from_proto(project_id),
2543            connection_id,
2544            None,
2545            session,
2546        )
2547        .await?;
2548    }
2549
2550    session
2551        .connection_pool()
2552        .await
2553        .set_dev_server_offline(dev_server_id);
2554
2555    let status = session
2556        .db()
2557        .await
2558        .dev_server_projects_update(dev_server.user_id)
2559        .await?;
2560    send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2561
2562    Ok(())
2563}
2564
2565/// Updates other participants with changes to the project
2566async fn update_project(
2567    request: proto::UpdateProject,
2568    response: Response<proto::UpdateProject>,
2569    session: Session,
2570) -> Result<()> {
2571    let project_id = ProjectId::from_proto(request.project_id);
2572    let (room, guest_connection_ids) = &*session
2573        .db()
2574        .await
2575        .update_project(project_id, session.connection_id, &request.worktrees)
2576        .await?;
2577    broadcast(
2578        Some(session.connection_id),
2579        guest_connection_ids.iter().copied(),
2580        |connection_id| {
2581            session
2582                .peer
2583                .forward_send(session.connection_id, connection_id, request.clone())
2584        },
2585    );
2586    if let Some(room) = room {
2587        room_updated(&room, &session.peer);
2588    }
2589    response.send(proto::Ack {})?;
2590
2591    Ok(())
2592}
2593
2594/// Updates other participants with changes to the worktree
2595async fn update_worktree(
2596    request: proto::UpdateWorktree,
2597    response: Response<proto::UpdateWorktree>,
2598    session: Session,
2599) -> Result<()> {
2600    let guest_connection_ids = session
2601        .db()
2602        .await
2603        .update_worktree(&request, session.connection_id)
2604        .await?;
2605
2606    broadcast(
2607        Some(session.connection_id),
2608        guest_connection_ids.iter().copied(),
2609        |connection_id| {
2610            session
2611                .peer
2612                .forward_send(session.connection_id, connection_id, request.clone())
2613        },
2614    );
2615    response.send(proto::Ack {})?;
2616    Ok(())
2617}
2618
2619/// Updates other participants with changes to the diagnostics
2620async fn update_diagnostic_summary(
2621    message: proto::UpdateDiagnosticSummary,
2622    session: Session,
2623) -> Result<()> {
2624    let guest_connection_ids = session
2625        .db()
2626        .await
2627        .update_diagnostic_summary(&message, session.connection_id)
2628        .await?;
2629
2630    broadcast(
2631        Some(session.connection_id),
2632        guest_connection_ids.iter().copied(),
2633        |connection_id| {
2634            session
2635                .peer
2636                .forward_send(session.connection_id, connection_id, message.clone())
2637        },
2638    );
2639
2640    Ok(())
2641}
2642
2643/// Updates other participants with changes to the worktree settings
2644async fn update_worktree_settings(
2645    message: proto::UpdateWorktreeSettings,
2646    session: Session,
2647) -> Result<()> {
2648    let guest_connection_ids = session
2649        .db()
2650        .await
2651        .update_worktree_settings(&message, session.connection_id)
2652        .await?;
2653
2654    broadcast(
2655        Some(session.connection_id),
2656        guest_connection_ids.iter().copied(),
2657        |connection_id| {
2658            session
2659                .peer
2660                .forward_send(session.connection_id, connection_id, message.clone())
2661        },
2662    );
2663
2664    Ok(())
2665}
2666
2667/// Notify other participants that a  language server has started.
2668async fn start_language_server(
2669    request: proto::StartLanguageServer,
2670    session: Session,
2671) -> Result<()> {
2672    let guest_connection_ids = session
2673        .db()
2674        .await
2675        .start_language_server(&request, session.connection_id)
2676        .await?;
2677
2678    broadcast(
2679        Some(session.connection_id),
2680        guest_connection_ids.iter().copied(),
2681        |connection_id| {
2682            session
2683                .peer
2684                .forward_send(session.connection_id, connection_id, request.clone())
2685        },
2686    );
2687    Ok(())
2688}
2689
2690/// Notify other participants that a language server has changed.
2691async fn update_language_server(
2692    request: proto::UpdateLanguageServer,
2693    session: Session,
2694) -> Result<()> {
2695    let project_id = ProjectId::from_proto(request.project_id);
2696    let project_connection_ids = session
2697        .db()
2698        .await
2699        .project_connection_ids(project_id, session.connection_id, true)
2700        .await?;
2701    broadcast(
2702        Some(session.connection_id),
2703        project_connection_ids.iter().copied(),
2704        |connection_id| {
2705            session
2706                .peer
2707                .forward_send(session.connection_id, connection_id, request.clone())
2708        },
2709    );
2710    Ok(())
2711}
2712
2713/// forward a project request to the host. These requests should be read only
2714/// as guests are allowed to send them.
2715async fn forward_read_only_project_request<T>(
2716    request: T,
2717    response: Response<T>,
2718    session: UserSession,
2719) -> Result<()>
2720where
2721    T: EntityMessage + RequestMessage,
2722{
2723    let project_id = ProjectId::from_proto(request.remote_entity_id());
2724    let host_connection_id = session
2725        .db()
2726        .await
2727        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2728        .await?;
2729    let payload = session
2730        .peer
2731        .forward_request(session.connection_id, host_connection_id, request)
2732        .await?;
2733    response.send(payload)?;
2734    Ok(())
2735}
2736
2737/// forward a project request to the host. These requests are disallowed
2738/// for guests.
2739async fn forward_mutating_project_request<T>(
2740    request: T,
2741    response: Response<T>,
2742    session: UserSession,
2743) -> Result<()>
2744where
2745    T: EntityMessage + RequestMessage,
2746{
2747    let project_id = ProjectId::from_proto(request.remote_entity_id());
2748
2749    let host_connection_id = session
2750        .db()
2751        .await
2752        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2753        .await?;
2754    let payload = session
2755        .peer
2756        .forward_request(session.connection_id, host_connection_id, request)
2757        .await?;
2758    response.send(payload)?;
2759    Ok(())
2760}
2761
2762/// forward a project request to the host. These requests are disallowed
2763/// for guests.
2764async fn forward_versioned_mutating_project_request<T>(
2765    request: T,
2766    response: Response<T>,
2767    session: UserSession,
2768) -> Result<()>
2769where
2770    T: EntityMessage + RequestMessage + VersionedMessage,
2771{
2772    let project_id = ProjectId::from_proto(request.remote_entity_id());
2773
2774    let host_connection_id = session
2775        .db()
2776        .await
2777        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2778        .await?;
2779    if let Some(host_version) = session
2780        .connection_pool()
2781        .await
2782        .connection(host_connection_id)
2783        .map(|c| c.zed_version)
2784    {
2785        if let Some(min_required_version) = request.required_host_version() {
2786            if min_required_version > host_version {
2787                return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2788                    .with_tag("required", &min_required_version.to_string())))?;
2789            }
2790        }
2791    }
2792
2793    let payload = session
2794        .peer
2795        .forward_request(session.connection_id, host_connection_id, request)
2796        .await?;
2797    response.send(payload)?;
2798    Ok(())
2799}
2800
2801/// Notify other participants that a new buffer has been created
2802async fn create_buffer_for_peer(
2803    request: proto::CreateBufferForPeer,
2804    session: Session,
2805) -> Result<()> {
2806    session
2807        .db()
2808        .await
2809        .check_user_is_project_host(
2810            ProjectId::from_proto(request.project_id),
2811            session.connection_id,
2812        )
2813        .await?;
2814    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2815    session
2816        .peer
2817        .forward_send(session.connection_id, peer_id.into(), request)?;
2818    Ok(())
2819}
2820
2821/// Notify other participants that a buffer has been updated. This is
2822/// allowed for guests as long as the update is limited to selections.
2823async fn update_buffer(
2824    request: proto::UpdateBuffer,
2825    response: Response<proto::UpdateBuffer>,
2826    session: Session,
2827) -> Result<()> {
2828    let project_id = ProjectId::from_proto(request.project_id);
2829    let mut capability = Capability::ReadOnly;
2830
2831    for op in request.operations.iter() {
2832        match op.variant {
2833            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2834            Some(_) => capability = Capability::ReadWrite,
2835        }
2836    }
2837
2838    let host = {
2839        let guard = session
2840            .db()
2841            .await
2842            .connections_for_buffer_update(
2843                project_id,
2844                session.principal_id(),
2845                session.connection_id,
2846                capability,
2847            )
2848            .await?;
2849
2850        let (host, guests) = &*guard;
2851
2852        broadcast(
2853            Some(session.connection_id),
2854            guests.clone(),
2855            |connection_id| {
2856                session
2857                    .peer
2858                    .forward_send(session.connection_id, connection_id, request.clone())
2859            },
2860        );
2861
2862        *host
2863    };
2864
2865    if host != session.connection_id {
2866        session
2867            .peer
2868            .forward_request(session.connection_id, host, request.clone())
2869            .await?;
2870    }
2871
2872    response.send(proto::Ack {})?;
2873    Ok(())
2874}
2875
2876/// Notify other participants that a project has been updated.
2877async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2878    request: T,
2879    session: Session,
2880) -> Result<()> {
2881    let project_id = ProjectId::from_proto(request.remote_entity_id());
2882    let project_connection_ids = session
2883        .db()
2884        .await
2885        .project_connection_ids(project_id, session.connection_id, false)
2886        .await?;
2887
2888    broadcast(
2889        Some(session.connection_id),
2890        project_connection_ids.iter().copied(),
2891        |connection_id| {
2892            session
2893                .peer
2894                .forward_send(session.connection_id, connection_id, request.clone())
2895        },
2896    );
2897    Ok(())
2898}
2899
2900/// Start following another user in a call.
2901async fn follow(
2902    request: proto::Follow,
2903    response: Response<proto::Follow>,
2904    session: UserSession,
2905) -> Result<()> {
2906    let room_id = RoomId::from_proto(request.room_id);
2907    let project_id = request.project_id.map(ProjectId::from_proto);
2908    let leader_id = request
2909        .leader_id
2910        .ok_or_else(|| anyhow!("invalid leader id"))?
2911        .into();
2912    let follower_id = session.connection_id;
2913
2914    session
2915        .db()
2916        .await
2917        .check_room_participants(room_id, leader_id, session.connection_id)
2918        .await?;
2919
2920    let response_payload = session
2921        .peer
2922        .forward_request(session.connection_id, leader_id, request)
2923        .await?;
2924    response.send(response_payload)?;
2925
2926    if let Some(project_id) = project_id {
2927        let room = session
2928            .db()
2929            .await
2930            .follow(room_id, project_id, leader_id, follower_id)
2931            .await?;
2932        room_updated(&room, &session.peer);
2933    }
2934
2935    Ok(())
2936}
2937
2938/// Stop following another user in a call.
2939async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2940    let room_id = RoomId::from_proto(request.room_id);
2941    let project_id = request.project_id.map(ProjectId::from_proto);
2942    let leader_id = request
2943        .leader_id
2944        .ok_or_else(|| anyhow!("invalid leader id"))?
2945        .into();
2946    let follower_id = session.connection_id;
2947
2948    session
2949        .db()
2950        .await
2951        .check_room_participants(room_id, leader_id, session.connection_id)
2952        .await?;
2953
2954    session
2955        .peer
2956        .forward_send(session.connection_id, leader_id, request)?;
2957
2958    if let Some(project_id) = project_id {
2959        let room = session
2960            .db()
2961            .await
2962            .unfollow(room_id, project_id, leader_id, follower_id)
2963            .await?;
2964        room_updated(&room, &session.peer);
2965    }
2966
2967    Ok(())
2968}
2969
2970/// Notify everyone following you of your current location.
2971async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2972    let room_id = RoomId::from_proto(request.room_id);
2973    let database = session.db.lock().await;
2974
2975    let connection_ids = if let Some(project_id) = request.project_id {
2976        let project_id = ProjectId::from_proto(project_id);
2977        database
2978            .project_connection_ids(project_id, session.connection_id, true)
2979            .await?
2980    } else {
2981        database
2982            .room_connection_ids(room_id, session.connection_id)
2983            .await?
2984    };
2985
2986    // For now, don't send view update messages back to that view's current leader.
2987    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2988        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2989        _ => None,
2990    });
2991
2992    for connection_id in connection_ids.iter().cloned() {
2993        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2994            session
2995                .peer
2996                .forward_send(session.connection_id, connection_id, request.clone())?;
2997        }
2998    }
2999    Ok(())
3000}
3001
3002/// Get public data about users.
3003async fn get_users(
3004    request: proto::GetUsers,
3005    response: Response<proto::GetUsers>,
3006    session: Session,
3007) -> Result<()> {
3008    let user_ids = request
3009        .user_ids
3010        .into_iter()
3011        .map(UserId::from_proto)
3012        .collect();
3013    let users = session
3014        .db()
3015        .await
3016        .get_users_by_ids(user_ids)
3017        .await?
3018        .into_iter()
3019        .map(|user| proto::User {
3020            id: user.id.to_proto(),
3021            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3022            github_login: user.github_login,
3023        })
3024        .collect();
3025    response.send(proto::UsersResponse { users })?;
3026    Ok(())
3027}
3028
3029/// Search for users (to invite) buy Github login
3030async fn fuzzy_search_users(
3031    request: proto::FuzzySearchUsers,
3032    response: Response<proto::FuzzySearchUsers>,
3033    session: UserSession,
3034) -> Result<()> {
3035    let query = request.query;
3036    let users = match query.len() {
3037        0 => vec![],
3038        1 | 2 => session
3039            .db()
3040            .await
3041            .get_user_by_github_login(&query)
3042            .await?
3043            .into_iter()
3044            .collect(),
3045        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3046    };
3047    let users = users
3048        .into_iter()
3049        .filter(|user| user.id != session.user_id())
3050        .map(|user| proto::User {
3051            id: user.id.to_proto(),
3052            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3053            github_login: user.github_login,
3054        })
3055        .collect();
3056    response.send(proto::UsersResponse { users })?;
3057    Ok(())
3058}
3059
3060/// Send a contact request to another user.
3061async fn request_contact(
3062    request: proto::RequestContact,
3063    response: Response<proto::RequestContact>,
3064    session: UserSession,
3065) -> Result<()> {
3066    let requester_id = session.user_id();
3067    let responder_id = UserId::from_proto(request.responder_id);
3068    if requester_id == responder_id {
3069        return Err(anyhow!("cannot add yourself as a contact"))?;
3070    }
3071
3072    let notifications = session
3073        .db()
3074        .await
3075        .send_contact_request(requester_id, responder_id)
3076        .await?;
3077
3078    // Update outgoing contact requests of requester
3079    let mut update = proto::UpdateContacts::default();
3080    update.outgoing_requests.push(responder_id.to_proto());
3081    for connection_id in session
3082        .connection_pool()
3083        .await
3084        .user_connection_ids(requester_id)
3085    {
3086        session.peer.send(connection_id, update.clone())?;
3087    }
3088
3089    // Update incoming contact requests of responder
3090    let mut update = proto::UpdateContacts::default();
3091    update
3092        .incoming_requests
3093        .push(proto::IncomingContactRequest {
3094            requester_id: requester_id.to_proto(),
3095        });
3096    let connection_pool = session.connection_pool().await;
3097    for connection_id in connection_pool.user_connection_ids(responder_id) {
3098        session.peer.send(connection_id, update.clone())?;
3099    }
3100
3101    send_notifications(&connection_pool, &session.peer, notifications);
3102
3103    response.send(proto::Ack {})?;
3104    Ok(())
3105}
3106
3107/// Accept or decline a contact request
3108async fn respond_to_contact_request(
3109    request: proto::RespondToContactRequest,
3110    response: Response<proto::RespondToContactRequest>,
3111    session: UserSession,
3112) -> Result<()> {
3113    let responder_id = session.user_id();
3114    let requester_id = UserId::from_proto(request.requester_id);
3115    let db = session.db().await;
3116    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3117        db.dismiss_contact_notification(responder_id, requester_id)
3118            .await?;
3119    } else {
3120        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3121
3122        let notifications = db
3123            .respond_to_contact_request(responder_id, requester_id, accept)
3124            .await?;
3125        let requester_busy = db.is_user_busy(requester_id).await?;
3126        let responder_busy = db.is_user_busy(responder_id).await?;
3127
3128        let pool = session.connection_pool().await;
3129        // Update responder with new contact
3130        let mut update = proto::UpdateContacts::default();
3131        if accept {
3132            update
3133                .contacts
3134                .push(contact_for_user(requester_id, requester_busy, &pool));
3135        }
3136        update
3137            .remove_incoming_requests
3138            .push(requester_id.to_proto());
3139        for connection_id in pool.user_connection_ids(responder_id) {
3140            session.peer.send(connection_id, update.clone())?;
3141        }
3142
3143        // Update requester with new contact
3144        let mut update = proto::UpdateContacts::default();
3145        if accept {
3146            update
3147                .contacts
3148                .push(contact_for_user(responder_id, responder_busy, &pool));
3149        }
3150        update
3151            .remove_outgoing_requests
3152            .push(responder_id.to_proto());
3153
3154        for connection_id in pool.user_connection_ids(requester_id) {
3155            session.peer.send(connection_id, update.clone())?;
3156        }
3157
3158        send_notifications(&pool, &session.peer, notifications);
3159    }
3160
3161    response.send(proto::Ack {})?;
3162    Ok(())
3163}
3164
3165/// Remove a contact.
3166async fn remove_contact(
3167    request: proto::RemoveContact,
3168    response: Response<proto::RemoveContact>,
3169    session: UserSession,
3170) -> Result<()> {
3171    let requester_id = session.user_id();
3172    let responder_id = UserId::from_proto(request.user_id);
3173    let db = session.db().await;
3174    let (contact_accepted, deleted_notification_id) =
3175        db.remove_contact(requester_id, responder_id).await?;
3176
3177    let pool = session.connection_pool().await;
3178    // Update outgoing contact requests of requester
3179    let mut update = proto::UpdateContacts::default();
3180    if contact_accepted {
3181        update.remove_contacts.push(responder_id.to_proto());
3182    } else {
3183        update
3184            .remove_outgoing_requests
3185            .push(responder_id.to_proto());
3186    }
3187    for connection_id in pool.user_connection_ids(requester_id) {
3188        session.peer.send(connection_id, update.clone())?;
3189    }
3190
3191    // Update incoming contact requests of responder
3192    let mut update = proto::UpdateContacts::default();
3193    if contact_accepted {
3194        update.remove_contacts.push(requester_id.to_proto());
3195    } else {
3196        update
3197            .remove_incoming_requests
3198            .push(requester_id.to_proto());
3199    }
3200    for connection_id in pool.user_connection_ids(responder_id) {
3201        session.peer.send(connection_id, update.clone())?;
3202        if let Some(notification_id) = deleted_notification_id {
3203            session.peer.send(
3204                connection_id,
3205                proto::DeleteNotification {
3206                    notification_id: notification_id.to_proto(),
3207                },
3208            )?;
3209        }
3210    }
3211
3212    response.send(proto::Ack {})?;
3213    Ok(())
3214}
3215
3216/// Creates a new channel.
3217async fn create_channel(
3218    request: proto::CreateChannel,
3219    response: Response<proto::CreateChannel>,
3220    session: UserSession,
3221) -> Result<()> {
3222    let db = session.db().await;
3223
3224    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3225    let (channel, membership) = db
3226        .create_channel(&request.name, parent_id, session.user_id())
3227        .await?;
3228
3229    let root_id = channel.root_id();
3230    let channel = Channel::from_model(channel);
3231
3232    response.send(proto::CreateChannelResponse {
3233        channel: Some(channel.to_proto()),
3234        parent_id: request.parent_id,
3235    })?;
3236
3237    let mut connection_pool = session.connection_pool().await;
3238    if let Some(membership) = membership {
3239        connection_pool.subscribe_to_channel(
3240            membership.user_id,
3241            membership.channel_id,
3242            membership.role,
3243        );
3244        let update = proto::UpdateUserChannels {
3245            channel_memberships: vec![proto::ChannelMembership {
3246                channel_id: membership.channel_id.to_proto(),
3247                role: membership.role.into(),
3248            }],
3249            ..Default::default()
3250        };
3251        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3252            session.peer.send(connection_id, update.clone())?;
3253        }
3254    }
3255
3256    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3257        if !role.can_see_channel(channel.visibility) {
3258            continue;
3259        }
3260
3261        let update = proto::UpdateChannels {
3262            channels: vec![channel.to_proto()],
3263            ..Default::default()
3264        };
3265        session.peer.send(connection_id, update.clone())?;
3266    }
3267
3268    Ok(())
3269}
3270
3271/// Delete a channel
3272async fn delete_channel(
3273    request: proto::DeleteChannel,
3274    response: Response<proto::DeleteChannel>,
3275    session: UserSession,
3276) -> Result<()> {
3277    let db = session.db().await;
3278
3279    let channel_id = request.channel_id;
3280    let (root_channel, removed_channels) = db
3281        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3282        .await?;
3283    response.send(proto::Ack {})?;
3284
3285    // Notify members of removed channels
3286    let mut update = proto::UpdateChannels::default();
3287    update
3288        .delete_channels
3289        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3290
3291    let connection_pool = session.connection_pool().await;
3292    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3293        session.peer.send(connection_id, update.clone())?;
3294    }
3295
3296    Ok(())
3297}
3298
3299/// Invite someone to join a channel.
3300async fn invite_channel_member(
3301    request: proto::InviteChannelMember,
3302    response: Response<proto::InviteChannelMember>,
3303    session: UserSession,
3304) -> Result<()> {
3305    let db = session.db().await;
3306    let channel_id = ChannelId::from_proto(request.channel_id);
3307    let invitee_id = UserId::from_proto(request.user_id);
3308    let InviteMemberResult {
3309        channel,
3310        notifications,
3311    } = db
3312        .invite_channel_member(
3313            channel_id,
3314            invitee_id,
3315            session.user_id(),
3316            request.role().into(),
3317        )
3318        .await?;
3319
3320    let update = proto::UpdateChannels {
3321        channel_invitations: vec![channel.to_proto()],
3322        ..Default::default()
3323    };
3324
3325    let connection_pool = session.connection_pool().await;
3326    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3327        session.peer.send(connection_id, update.clone())?;
3328    }
3329
3330    send_notifications(&connection_pool, &session.peer, notifications);
3331
3332    response.send(proto::Ack {})?;
3333    Ok(())
3334}
3335
3336/// remove someone from a channel
3337async fn remove_channel_member(
3338    request: proto::RemoveChannelMember,
3339    response: Response<proto::RemoveChannelMember>,
3340    session: UserSession,
3341) -> Result<()> {
3342    let db = session.db().await;
3343    let channel_id = ChannelId::from_proto(request.channel_id);
3344    let member_id = UserId::from_proto(request.user_id);
3345
3346    let RemoveChannelMemberResult {
3347        membership_update,
3348        notification_id,
3349    } = db
3350        .remove_channel_member(channel_id, member_id, session.user_id())
3351        .await?;
3352
3353    let mut connection_pool = session.connection_pool().await;
3354    notify_membership_updated(
3355        &mut connection_pool,
3356        membership_update,
3357        member_id,
3358        &session.peer,
3359    );
3360    for connection_id in connection_pool.user_connection_ids(member_id) {
3361        if let Some(notification_id) = notification_id {
3362            session
3363                .peer
3364                .send(
3365                    connection_id,
3366                    proto::DeleteNotification {
3367                        notification_id: notification_id.to_proto(),
3368                    },
3369                )
3370                .trace_err();
3371        }
3372    }
3373
3374    response.send(proto::Ack {})?;
3375    Ok(())
3376}
3377
3378/// Toggle the channel between public and private.
3379/// Care is taken to maintain the invariant that public channels only descend from public channels,
3380/// (though members-only channels can appear at any point in the hierarchy).
3381async fn set_channel_visibility(
3382    request: proto::SetChannelVisibility,
3383    response: Response<proto::SetChannelVisibility>,
3384    session: UserSession,
3385) -> Result<()> {
3386    let db = session.db().await;
3387    let channel_id = ChannelId::from_proto(request.channel_id);
3388    let visibility = request.visibility().into();
3389
3390    let channel_model = db
3391        .set_channel_visibility(channel_id, visibility, session.user_id())
3392        .await?;
3393    let root_id = channel_model.root_id();
3394    let channel = Channel::from_model(channel_model);
3395
3396    let mut connection_pool = session.connection_pool().await;
3397    for (user_id, role) in connection_pool
3398        .channel_user_ids(root_id)
3399        .collect::<Vec<_>>()
3400        .into_iter()
3401    {
3402        let update = if role.can_see_channel(channel.visibility) {
3403            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3404            proto::UpdateChannels {
3405                channels: vec![channel.to_proto()],
3406                ..Default::default()
3407            }
3408        } else {
3409            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3410            proto::UpdateChannels {
3411                delete_channels: vec![channel.id.to_proto()],
3412                ..Default::default()
3413            }
3414        };
3415
3416        for connection_id in connection_pool.user_connection_ids(user_id) {
3417            session.peer.send(connection_id, update.clone())?;
3418        }
3419    }
3420
3421    response.send(proto::Ack {})?;
3422    Ok(())
3423}
3424
3425/// Alter the role for a user in the channel.
3426async fn set_channel_member_role(
3427    request: proto::SetChannelMemberRole,
3428    response: Response<proto::SetChannelMemberRole>,
3429    session: UserSession,
3430) -> Result<()> {
3431    let db = session.db().await;
3432    let channel_id = ChannelId::from_proto(request.channel_id);
3433    let member_id = UserId::from_proto(request.user_id);
3434    let result = db
3435        .set_channel_member_role(
3436            channel_id,
3437            session.user_id(),
3438            member_id,
3439            request.role().into(),
3440        )
3441        .await?;
3442
3443    match result {
3444        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3445            let mut connection_pool = session.connection_pool().await;
3446            notify_membership_updated(
3447                &mut connection_pool,
3448                membership_update,
3449                member_id,
3450                &session.peer,
3451            )
3452        }
3453        db::SetMemberRoleResult::InviteUpdated(channel) => {
3454            let update = proto::UpdateChannels {
3455                channel_invitations: vec![channel.to_proto()],
3456                ..Default::default()
3457            };
3458
3459            for connection_id in session
3460                .connection_pool()
3461                .await
3462                .user_connection_ids(member_id)
3463            {
3464                session.peer.send(connection_id, update.clone())?;
3465            }
3466        }
3467    }
3468
3469    response.send(proto::Ack {})?;
3470    Ok(())
3471}
3472
3473/// Change the name of a channel
3474async fn rename_channel(
3475    request: proto::RenameChannel,
3476    response: Response<proto::RenameChannel>,
3477    session: UserSession,
3478) -> Result<()> {
3479    let db = session.db().await;
3480    let channel_id = ChannelId::from_proto(request.channel_id);
3481    let channel_model = db
3482        .rename_channel(channel_id, session.user_id(), &request.name)
3483        .await?;
3484    let root_id = channel_model.root_id();
3485    let channel = Channel::from_model(channel_model);
3486
3487    response.send(proto::RenameChannelResponse {
3488        channel: Some(channel.to_proto()),
3489    })?;
3490
3491    let connection_pool = session.connection_pool().await;
3492    let update = proto::UpdateChannels {
3493        channels: vec![channel.to_proto()],
3494        ..Default::default()
3495    };
3496    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3497        if role.can_see_channel(channel.visibility) {
3498            session.peer.send(connection_id, update.clone())?;
3499        }
3500    }
3501
3502    Ok(())
3503}
3504
3505/// Move a channel to a new parent.
3506async fn move_channel(
3507    request: proto::MoveChannel,
3508    response: Response<proto::MoveChannel>,
3509    session: UserSession,
3510) -> Result<()> {
3511    let channel_id = ChannelId::from_proto(request.channel_id);
3512    let to = ChannelId::from_proto(request.to);
3513
3514    let (root_id, channels) = session
3515        .db()
3516        .await
3517        .move_channel(channel_id, to, session.user_id())
3518        .await?;
3519
3520    let connection_pool = session.connection_pool().await;
3521    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3522        let channels = channels
3523            .iter()
3524            .filter_map(|channel| {
3525                if role.can_see_channel(channel.visibility) {
3526                    Some(channel.to_proto())
3527                } else {
3528                    None
3529                }
3530            })
3531            .collect::<Vec<_>>();
3532        if channels.is_empty() {
3533            continue;
3534        }
3535
3536        let update = proto::UpdateChannels {
3537            channels,
3538            ..Default::default()
3539        };
3540
3541        session.peer.send(connection_id, update.clone())?;
3542    }
3543
3544    response.send(Ack {})?;
3545    Ok(())
3546}
3547
3548/// Get the list of channel members
3549async fn get_channel_members(
3550    request: proto::GetChannelMembers,
3551    response: Response<proto::GetChannelMembers>,
3552    session: UserSession,
3553) -> Result<()> {
3554    let db = session.db().await;
3555    let channel_id = ChannelId::from_proto(request.channel_id);
3556    let members = db
3557        .get_channel_participant_details(channel_id, session.user_id())
3558        .await?;
3559    response.send(proto::GetChannelMembersResponse { members })?;
3560    Ok(())
3561}
3562
3563/// Accept or decline a channel invitation.
3564async fn respond_to_channel_invite(
3565    request: proto::RespondToChannelInvite,
3566    response: Response<proto::RespondToChannelInvite>,
3567    session: UserSession,
3568) -> Result<()> {
3569    let db = session.db().await;
3570    let channel_id = ChannelId::from_proto(request.channel_id);
3571    let RespondToChannelInvite {
3572        membership_update,
3573        notifications,
3574    } = db
3575        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3576        .await?;
3577
3578    let mut connection_pool = session.connection_pool().await;
3579    if let Some(membership_update) = membership_update {
3580        notify_membership_updated(
3581            &mut connection_pool,
3582            membership_update,
3583            session.user_id(),
3584            &session.peer,
3585        );
3586    } else {
3587        let update = proto::UpdateChannels {
3588            remove_channel_invitations: vec![channel_id.to_proto()],
3589            ..Default::default()
3590        };
3591
3592        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3593            session.peer.send(connection_id, update.clone())?;
3594        }
3595    };
3596
3597    send_notifications(&connection_pool, &session.peer, notifications);
3598
3599    response.send(proto::Ack {})?;
3600
3601    Ok(())
3602}
3603
3604/// Join the channels' room
3605async fn join_channel(
3606    request: proto::JoinChannel,
3607    response: Response<proto::JoinChannel>,
3608    session: UserSession,
3609) -> Result<()> {
3610    let channel_id = ChannelId::from_proto(request.channel_id);
3611    join_channel_internal(channel_id, Box::new(response), session).await
3612}
3613
3614trait JoinChannelInternalResponse {
3615    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3616}
3617impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3618    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3619        Response::<proto::JoinChannel>::send(self, result)
3620    }
3621}
3622impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3623    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3624        Response::<proto::JoinRoom>::send(self, result)
3625    }
3626}
3627
3628async fn join_channel_internal(
3629    channel_id: ChannelId,
3630    response: Box<impl JoinChannelInternalResponse>,
3631    session: UserSession,
3632) -> Result<()> {
3633    let joined_room = {
3634        let mut db = session.db().await;
3635        // If zed quits without leaving the room, and the user re-opens zed before the
3636        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3637        // room they were in.
3638        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3639            tracing::info!(
3640                stale_connection_id = %connection,
3641                "cleaning up stale connection",
3642            );
3643            drop(db);
3644            leave_room_for_session(&session, connection).await?;
3645            db = session.db().await;
3646        }
3647
3648        let (joined_room, membership_updated, role) = db
3649            .join_channel(channel_id, session.user_id(), session.connection_id)
3650            .await?;
3651
3652        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3653            let (can_publish, token) = if role == ChannelRole::Guest {
3654                (
3655                    false,
3656                    live_kit
3657                        .guest_token(
3658                            &joined_room.room.live_kit_room,
3659                            &session.user_id().to_string(),
3660                        )
3661                        .trace_err()?,
3662                )
3663            } else {
3664                (
3665                    true,
3666                    live_kit
3667                        .room_token(
3668                            &joined_room.room.live_kit_room,
3669                            &session.user_id().to_string(),
3670                        )
3671                        .trace_err()?,
3672                )
3673            };
3674
3675            Some(LiveKitConnectionInfo {
3676                server_url: live_kit.url().into(),
3677                token,
3678                can_publish,
3679            })
3680        });
3681
3682        response.send(proto::JoinRoomResponse {
3683            room: Some(joined_room.room.clone()),
3684            channel_id: joined_room
3685                .channel
3686                .as_ref()
3687                .map(|channel| channel.id.to_proto()),
3688            live_kit_connection_info,
3689        })?;
3690
3691        let mut connection_pool = session.connection_pool().await;
3692        if let Some(membership_updated) = membership_updated {
3693            notify_membership_updated(
3694                &mut connection_pool,
3695                membership_updated,
3696                session.user_id(),
3697                &session.peer,
3698            );
3699        }
3700
3701        room_updated(&joined_room.room, &session.peer);
3702
3703        joined_room
3704    };
3705
3706    channel_updated(
3707        &joined_room
3708            .channel
3709            .ok_or_else(|| anyhow!("channel not returned"))?,
3710        &joined_room.room,
3711        &session.peer,
3712        &*session.connection_pool().await,
3713    );
3714
3715    update_user_contacts(session.user_id(), &session).await?;
3716    Ok(())
3717}
3718
3719/// Start editing the channel notes
3720async fn join_channel_buffer(
3721    request: proto::JoinChannelBuffer,
3722    response: Response<proto::JoinChannelBuffer>,
3723    session: UserSession,
3724) -> Result<()> {
3725    let db = session.db().await;
3726    let channel_id = ChannelId::from_proto(request.channel_id);
3727
3728    let open_response = db
3729        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3730        .await?;
3731
3732    let collaborators = open_response.collaborators.clone();
3733    response.send(open_response)?;
3734
3735    let update = UpdateChannelBufferCollaborators {
3736        channel_id: channel_id.to_proto(),
3737        collaborators: collaborators.clone(),
3738    };
3739    channel_buffer_updated(
3740        session.connection_id,
3741        collaborators
3742            .iter()
3743            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3744        &update,
3745        &session.peer,
3746    );
3747
3748    Ok(())
3749}
3750
3751/// Edit the channel notes
3752async fn update_channel_buffer(
3753    request: proto::UpdateChannelBuffer,
3754    session: UserSession,
3755) -> Result<()> {
3756    let db = session.db().await;
3757    let channel_id = ChannelId::from_proto(request.channel_id);
3758
3759    let (collaborators, non_collaborators, epoch, version) = db
3760        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3761        .await?;
3762
3763    channel_buffer_updated(
3764        session.connection_id,
3765        collaborators,
3766        &proto::UpdateChannelBuffer {
3767            channel_id: channel_id.to_proto(),
3768            operations: request.operations,
3769        },
3770        &session.peer,
3771    );
3772
3773    let pool = &*session.connection_pool().await;
3774
3775    broadcast(
3776        None,
3777        non_collaborators
3778            .iter()
3779            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3780        |peer_id| {
3781            session.peer.send(
3782                peer_id,
3783                proto::UpdateChannels {
3784                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3785                        channel_id: channel_id.to_proto(),
3786                        epoch: epoch as u64,
3787                        version: version.clone(),
3788                    }],
3789                    ..Default::default()
3790                },
3791            )
3792        },
3793    );
3794
3795    Ok(())
3796}
3797
3798/// Rejoin the channel notes after a connection blip
3799async fn rejoin_channel_buffers(
3800    request: proto::RejoinChannelBuffers,
3801    response: Response<proto::RejoinChannelBuffers>,
3802    session: UserSession,
3803) -> Result<()> {
3804    let db = session.db().await;
3805    let buffers = db
3806        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3807        .await?;
3808
3809    for rejoined_buffer in &buffers {
3810        let collaborators_to_notify = rejoined_buffer
3811            .buffer
3812            .collaborators
3813            .iter()
3814            .filter_map(|c| Some(c.peer_id?.into()));
3815        channel_buffer_updated(
3816            session.connection_id,
3817            collaborators_to_notify,
3818            &proto::UpdateChannelBufferCollaborators {
3819                channel_id: rejoined_buffer.buffer.channel_id,
3820                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3821            },
3822            &session.peer,
3823        );
3824    }
3825
3826    response.send(proto::RejoinChannelBuffersResponse {
3827        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3828    })?;
3829
3830    Ok(())
3831}
3832
3833/// Stop editing the channel notes
3834async fn leave_channel_buffer(
3835    request: proto::LeaveChannelBuffer,
3836    response: Response<proto::LeaveChannelBuffer>,
3837    session: UserSession,
3838) -> Result<()> {
3839    let db = session.db().await;
3840    let channel_id = ChannelId::from_proto(request.channel_id);
3841
3842    let left_buffer = db
3843        .leave_channel_buffer(channel_id, session.connection_id)
3844        .await?;
3845
3846    response.send(Ack {})?;
3847
3848    channel_buffer_updated(
3849        session.connection_id,
3850        left_buffer.connections,
3851        &proto::UpdateChannelBufferCollaborators {
3852            channel_id: channel_id.to_proto(),
3853            collaborators: left_buffer.collaborators,
3854        },
3855        &session.peer,
3856    );
3857
3858    Ok(())
3859}
3860
3861fn channel_buffer_updated<T: EnvelopedMessage>(
3862    sender_id: ConnectionId,
3863    collaborators: impl IntoIterator<Item = ConnectionId>,
3864    message: &T,
3865    peer: &Peer,
3866) {
3867    broadcast(Some(sender_id), collaborators, |peer_id| {
3868        peer.send(peer_id, message.clone())
3869    });
3870}
3871
3872fn send_notifications(
3873    connection_pool: &ConnectionPool,
3874    peer: &Peer,
3875    notifications: db::NotificationBatch,
3876) {
3877    for (user_id, notification) in notifications {
3878        for connection_id in connection_pool.user_connection_ids(user_id) {
3879            if let Err(error) = peer.send(
3880                connection_id,
3881                proto::AddNotification {
3882                    notification: Some(notification.clone()),
3883                },
3884            ) {
3885                tracing::error!(
3886                    "failed to send notification to {:?} {}",
3887                    connection_id,
3888                    error
3889                );
3890            }
3891        }
3892    }
3893}
3894
3895/// Send a message to the channel
3896async fn send_channel_message(
3897    request: proto::SendChannelMessage,
3898    response: Response<proto::SendChannelMessage>,
3899    session: UserSession,
3900) -> Result<()> {
3901    // Validate the message body.
3902    let body = request.body.trim().to_string();
3903    if body.len() > MAX_MESSAGE_LEN {
3904        return Err(anyhow!("message is too long"))?;
3905    }
3906    if body.is_empty() {
3907        return Err(anyhow!("message can't be blank"))?;
3908    }
3909
3910    // TODO: adjust mentions if body is trimmed
3911
3912    let timestamp = OffsetDateTime::now_utc();
3913    let nonce = request
3914        .nonce
3915        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3916
3917    let channel_id = ChannelId::from_proto(request.channel_id);
3918    let CreatedChannelMessage {
3919        message_id,
3920        participant_connection_ids,
3921        channel_members,
3922        notifications,
3923    } = session
3924        .db()
3925        .await
3926        .create_channel_message(
3927            channel_id,
3928            session.user_id(),
3929            &body,
3930            &request.mentions,
3931            timestamp,
3932            nonce.clone().into(),
3933            match request.reply_to_message_id {
3934                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3935                None => None,
3936            },
3937        )
3938        .await?;
3939
3940    let message = proto::ChannelMessage {
3941        sender_id: session.user_id().to_proto(),
3942        id: message_id.to_proto(),
3943        body,
3944        mentions: request.mentions,
3945        timestamp: timestamp.unix_timestamp() as u64,
3946        nonce: Some(nonce),
3947        reply_to_message_id: request.reply_to_message_id,
3948        edited_at: None,
3949    };
3950    broadcast(
3951        Some(session.connection_id),
3952        participant_connection_ids,
3953        |connection| {
3954            session.peer.send(
3955                connection,
3956                proto::ChannelMessageSent {
3957                    channel_id: channel_id.to_proto(),
3958                    message: Some(message.clone()),
3959                },
3960            )
3961        },
3962    );
3963    response.send(proto::SendChannelMessageResponse {
3964        message: Some(message),
3965    })?;
3966
3967    let pool = &*session.connection_pool().await;
3968    broadcast(
3969        None,
3970        channel_members
3971            .iter()
3972            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3973        |peer_id| {
3974            session.peer.send(
3975                peer_id,
3976                proto::UpdateChannels {
3977                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3978                        channel_id: channel_id.to_proto(),
3979                        message_id: message_id.to_proto(),
3980                    }],
3981                    ..Default::default()
3982                },
3983            )
3984        },
3985    );
3986    send_notifications(pool, &session.peer, notifications);
3987
3988    Ok(())
3989}
3990
3991/// Delete a channel message
3992async fn remove_channel_message(
3993    request: proto::RemoveChannelMessage,
3994    response: Response<proto::RemoveChannelMessage>,
3995    session: UserSession,
3996) -> Result<()> {
3997    let channel_id = ChannelId::from_proto(request.channel_id);
3998    let message_id = MessageId::from_proto(request.message_id);
3999    let (connection_ids, existing_notification_ids) = session
4000        .db()
4001        .await
4002        .remove_channel_message(channel_id, message_id, session.user_id())
4003        .await?;
4004
4005    broadcast(
4006        Some(session.connection_id),
4007        connection_ids,
4008        move |connection| {
4009            session.peer.send(connection, request.clone())?;
4010
4011            for notification_id in &existing_notification_ids {
4012                session.peer.send(
4013                    connection,
4014                    proto::DeleteNotification {
4015                        notification_id: (*notification_id).to_proto(),
4016                    },
4017                )?;
4018            }
4019
4020            Ok(())
4021        },
4022    );
4023    response.send(proto::Ack {})?;
4024    Ok(())
4025}
4026
4027async fn update_channel_message(
4028    request: proto::UpdateChannelMessage,
4029    response: Response<proto::UpdateChannelMessage>,
4030    session: UserSession,
4031) -> Result<()> {
4032    let channel_id = ChannelId::from_proto(request.channel_id);
4033    let message_id = MessageId::from_proto(request.message_id);
4034    let updated_at = OffsetDateTime::now_utc();
4035    let UpdatedChannelMessage {
4036        message_id,
4037        participant_connection_ids,
4038        notifications,
4039        reply_to_message_id,
4040        timestamp,
4041        deleted_mention_notification_ids,
4042        updated_mention_notifications,
4043    } = session
4044        .db()
4045        .await
4046        .update_channel_message(
4047            channel_id,
4048            message_id,
4049            session.user_id(),
4050            request.body.as_str(),
4051            &request.mentions,
4052            updated_at,
4053        )
4054        .await?;
4055
4056    let nonce = request
4057        .nonce
4058        .clone()
4059        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4060
4061    let message = proto::ChannelMessage {
4062        sender_id: session.user_id().to_proto(),
4063        id: message_id.to_proto(),
4064        body: request.body.clone(),
4065        mentions: request.mentions.clone(),
4066        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4067        nonce: Some(nonce),
4068        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4069        edited_at: Some(updated_at.unix_timestamp() as u64),
4070    };
4071
4072    response.send(proto::Ack {})?;
4073
4074    let pool = &*session.connection_pool().await;
4075    broadcast(
4076        Some(session.connection_id),
4077        participant_connection_ids,
4078        |connection| {
4079            session.peer.send(
4080                connection,
4081                proto::ChannelMessageUpdate {
4082                    channel_id: channel_id.to_proto(),
4083                    message: Some(message.clone()),
4084                },
4085            )?;
4086
4087            for notification_id in &deleted_mention_notification_ids {
4088                session.peer.send(
4089                    connection,
4090                    proto::DeleteNotification {
4091                        notification_id: (*notification_id).to_proto(),
4092                    },
4093                )?;
4094            }
4095
4096            for notification in &updated_mention_notifications {
4097                session.peer.send(
4098                    connection,
4099                    proto::UpdateNotification {
4100                        notification: Some(notification.clone()),
4101                    },
4102                )?;
4103            }
4104
4105            Ok(())
4106        },
4107    );
4108
4109    send_notifications(pool, &session.peer, notifications);
4110
4111    Ok(())
4112}
4113
4114/// Mark a channel message as read
4115async fn acknowledge_channel_message(
4116    request: proto::AckChannelMessage,
4117    session: UserSession,
4118) -> Result<()> {
4119    let channel_id = ChannelId::from_proto(request.channel_id);
4120    let message_id = MessageId::from_proto(request.message_id);
4121    let notifications = session
4122        .db()
4123        .await
4124        .observe_channel_message(channel_id, session.user_id(), message_id)
4125        .await?;
4126    send_notifications(
4127        &*session.connection_pool().await,
4128        &session.peer,
4129        notifications,
4130    );
4131    Ok(())
4132}
4133
4134/// Mark a buffer version as synced
4135async fn acknowledge_buffer_version(
4136    request: proto::AckBufferOperation,
4137    session: UserSession,
4138) -> Result<()> {
4139    let buffer_id = BufferId::from_proto(request.buffer_id);
4140    session
4141        .db()
4142        .await
4143        .observe_buffer_version(
4144            buffer_id,
4145            session.user_id(),
4146            request.epoch as i32,
4147            &request.version,
4148        )
4149        .await?;
4150    Ok(())
4151}
4152
4153struct CompleteWithLanguageModelRateLimit;
4154
4155impl RateLimit for CompleteWithLanguageModelRateLimit {
4156    fn capacity() -> usize {
4157        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4158            .ok()
4159            .and_then(|v| v.parse().ok())
4160            .unwrap_or(120) // Picked arbitrarily
4161    }
4162
4163    fn refill_duration() -> chrono::Duration {
4164        chrono::Duration::hours(1)
4165    }
4166
4167    fn db_name() -> &'static str {
4168        "complete-with-language-model"
4169    }
4170}
4171
4172async fn complete_with_language_model(
4173    request: proto::CompleteWithLanguageModel,
4174    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4175    session: Session,
4176    open_ai_api_key: Option<Arc<str>>,
4177    google_ai_api_key: Option<Arc<str>>,
4178    anthropic_api_key: Option<Arc<str>>,
4179) -> Result<()> {
4180    let Some(session) = session.for_user() else {
4181        return Err(anyhow!("user not found"))?;
4182    };
4183    authorize_access_to_language_models(&session).await?;
4184    session
4185        .rate_limiter
4186        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4187        .await?;
4188
4189    if request.model.starts_with("gpt") {
4190        let api_key =
4191            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4192        complete_with_open_ai(request, response, session, api_key).await?;
4193    } else if request.model.starts_with("gemini") {
4194        let api_key = google_ai_api_key
4195            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4196        complete_with_google_ai(request, response, session, api_key).await?;
4197    } else if request.model.starts_with("claude") {
4198        let api_key = anthropic_api_key
4199            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4200        complete_with_anthropic(request, response, session, api_key).await?;
4201    }
4202
4203    Ok(())
4204}
4205
4206async fn complete_with_open_ai(
4207    request: proto::CompleteWithLanguageModel,
4208    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4209    session: UserSession,
4210    api_key: Arc<str>,
4211) -> Result<()> {
4212    let mut completion_stream = open_ai::stream_completion(
4213        &session.http_client,
4214        OPEN_AI_API_URL,
4215        &api_key,
4216        crate::ai::language_model_request_to_open_ai(request)?,
4217    )
4218    .await
4219    .context("open_ai::stream_completion request failed within collab")?;
4220
4221    while let Some(event) = completion_stream.next().await {
4222        let event = event?;
4223        response.send(proto::LanguageModelResponse {
4224            choices: event
4225                .choices
4226                .into_iter()
4227                .map(|choice| proto::LanguageModelChoiceDelta {
4228                    index: choice.index,
4229                    delta: Some(proto::LanguageModelResponseMessage {
4230                        role: choice.delta.role.map(|role| match role {
4231                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4232                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4233                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4234                            open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4235                        } as i32),
4236                        content: choice.delta.content,
4237                        tool_calls: choice
4238                            .delta
4239                            .tool_calls
4240                            .into_iter()
4241                            .map(|delta| proto::ToolCallDelta {
4242                                index: delta.index as u32,
4243                                id: delta.id,
4244                                variant: match delta.function {
4245                                    Some(function) => {
4246                                        let name = function.name;
4247                                        let arguments = function.arguments;
4248
4249                                        Some(proto::tool_call_delta::Variant::Function(
4250                                            proto::tool_call_delta::FunctionCallDelta {
4251                                                name,
4252                                                arguments,
4253                                            },
4254                                        ))
4255                                    }
4256                                    None => None,
4257                                },
4258                            })
4259                            .collect(),
4260                    }),
4261                    finish_reason: choice.finish_reason,
4262                })
4263                .collect(),
4264        })?;
4265    }
4266
4267    Ok(())
4268}
4269
4270async fn complete_with_google_ai(
4271    request: proto::CompleteWithLanguageModel,
4272    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4273    session: UserSession,
4274    api_key: Arc<str>,
4275) -> Result<()> {
4276    let mut stream = google_ai::stream_generate_content(
4277        &session.http_client,
4278        google_ai::API_URL,
4279        api_key.as_ref(),
4280        crate::ai::language_model_request_to_google_ai(request)?,
4281    )
4282    .await
4283    .context("google_ai::stream_generate_content request failed")?;
4284
4285    while let Some(event) = stream.next().await {
4286        let event = event?;
4287        response.send(proto::LanguageModelResponse {
4288            choices: event
4289                .candidates
4290                .unwrap_or_default()
4291                .into_iter()
4292                .map(|candidate| proto::LanguageModelChoiceDelta {
4293                    index: candidate.index as u32,
4294                    delta: Some(proto::LanguageModelResponseMessage {
4295                        role: Some(match candidate.content.role {
4296                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4297                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4298                        } as i32),
4299                        content: Some(
4300                            candidate
4301                                .content
4302                                .parts
4303                                .into_iter()
4304                                .filter_map(|part| match part {
4305                                    google_ai::Part::TextPart(part) => Some(part.text),
4306                                    google_ai::Part::InlineDataPart(_) => None,
4307                                })
4308                                .collect(),
4309                        ),
4310                        // Tool calls are not supported for Google
4311                        tool_calls: Vec::new(),
4312                    }),
4313                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4314                })
4315                .collect(),
4316        })?;
4317    }
4318
4319    Ok(())
4320}
4321
4322async fn complete_with_anthropic(
4323    request: proto::CompleteWithLanguageModel,
4324    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4325    session: UserSession,
4326    api_key: Arc<str>,
4327) -> Result<()> {
4328    let model = anthropic::Model::from_id(&request.model)?;
4329
4330    let mut system_message = String::new();
4331    let messages = request
4332        .messages
4333        .into_iter()
4334        .filter_map(|message| {
4335            match message.role() {
4336                LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4337                    role: anthropic::Role::User,
4338                    content: message.content,
4339                }),
4340                LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4341                    role: anthropic::Role::Assistant,
4342                    content: message.content,
4343                }),
4344                // Anthropic's API breaks system instructions out as a separate field rather
4345                // than having a system message role.
4346                LanguageModelRole::LanguageModelSystem => {
4347                    if !system_message.is_empty() {
4348                        system_message.push_str("\n\n");
4349                    }
4350                    system_message.push_str(&message.content);
4351
4352                    None
4353                }
4354                // We don't yet support tool calls for Anthropic
4355                LanguageModelRole::LanguageModelTool => None,
4356            }
4357        })
4358        .collect();
4359
4360    let mut stream = anthropic::stream_completion(
4361        &session.http_client,
4362        "https://api.anthropic.com",
4363        &api_key,
4364        anthropic::Request {
4365            model,
4366            messages,
4367            stream: true,
4368            system: system_message,
4369            max_tokens: 4092,
4370        },
4371    )
4372    .await?;
4373
4374    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4375
4376    while let Some(event) = stream.next().await {
4377        let event = event?;
4378
4379        match event {
4380            anthropic::ResponseEvent::MessageStart { message } => {
4381                if let Some(role) = message.role {
4382                    if role == "assistant" {
4383                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4384                    } else if role == "user" {
4385                        current_role = proto::LanguageModelRole::LanguageModelUser;
4386                    }
4387                }
4388            }
4389            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4390                match content_block {
4391                    anthropic::ContentBlock::Text { text } => {
4392                        if !text.is_empty() {
4393                            response.send(proto::LanguageModelResponse {
4394                                choices: vec![proto::LanguageModelChoiceDelta {
4395                                    index: 0,
4396                                    delta: Some(proto::LanguageModelResponseMessage {
4397                                        role: Some(current_role as i32),
4398                                        content: Some(text),
4399                                        tool_calls: Vec::new(),
4400                                    }),
4401                                    finish_reason: None,
4402                                }],
4403                            })?;
4404                        }
4405                    }
4406                }
4407            }
4408            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4409                anthropic::TextDelta::TextDelta { text } => {
4410                    response.send(proto::LanguageModelResponse {
4411                        choices: vec![proto::LanguageModelChoiceDelta {
4412                            index: 0,
4413                            delta: Some(proto::LanguageModelResponseMessage {
4414                                role: Some(current_role as i32),
4415                                content: Some(text),
4416                                tool_calls: Vec::new(),
4417                            }),
4418                            finish_reason: None,
4419                        }],
4420                    })?;
4421                }
4422            },
4423            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4424                if let Some(stop_reason) = delta.stop_reason {
4425                    response.send(proto::LanguageModelResponse {
4426                        choices: vec![proto::LanguageModelChoiceDelta {
4427                            index: 0,
4428                            delta: None,
4429                            finish_reason: Some(stop_reason),
4430                        }],
4431                    })?;
4432                }
4433            }
4434            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4435            anthropic::ResponseEvent::MessageStop {} => {}
4436            anthropic::ResponseEvent::Ping {} => {}
4437        }
4438    }
4439
4440    Ok(())
4441}
4442
4443struct CountTokensWithLanguageModelRateLimit;
4444
4445impl RateLimit for CountTokensWithLanguageModelRateLimit {
4446    fn capacity() -> usize {
4447        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4448            .ok()
4449            .and_then(|v| v.parse().ok())
4450            .unwrap_or(600) // Picked arbitrarily
4451    }
4452
4453    fn refill_duration() -> chrono::Duration {
4454        chrono::Duration::hours(1)
4455    }
4456
4457    fn db_name() -> &'static str {
4458        "count-tokens-with-language-model"
4459    }
4460}
4461
4462async fn count_tokens_with_language_model(
4463    request: proto::CountTokensWithLanguageModel,
4464    response: Response<proto::CountTokensWithLanguageModel>,
4465    session: UserSession,
4466    google_ai_api_key: Option<Arc<str>>,
4467) -> Result<()> {
4468    authorize_access_to_language_models(&session).await?;
4469
4470    if !request.model.starts_with("gemini") {
4471        return Err(anyhow!(
4472            "counting tokens for model: {:?} is not supported",
4473            request.model
4474        ))?;
4475    }
4476
4477    session
4478        .rate_limiter
4479        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4480        .await?;
4481
4482    let api_key = google_ai_api_key
4483        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4484    let tokens_response = google_ai::count_tokens(
4485        &session.http_client,
4486        google_ai::API_URL,
4487        &api_key,
4488        crate::ai::count_tokens_request_to_google_ai(request)?,
4489    )
4490    .await?;
4491    response.send(proto::CountTokensResponse {
4492        token_count: tokens_response.total_tokens as u32,
4493    })?;
4494    Ok(())
4495}
4496
4497struct ComputeEmbeddingsRateLimit;
4498
4499impl RateLimit for ComputeEmbeddingsRateLimit {
4500    fn capacity() -> usize {
4501        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4502            .ok()
4503            .and_then(|v| v.parse().ok())
4504            .unwrap_or(5000) // Picked arbitrarily
4505    }
4506
4507    fn refill_duration() -> chrono::Duration {
4508        chrono::Duration::hours(1)
4509    }
4510
4511    fn db_name() -> &'static str {
4512        "compute-embeddings"
4513    }
4514}
4515
4516async fn compute_embeddings(
4517    request: proto::ComputeEmbeddings,
4518    response: Response<proto::ComputeEmbeddings>,
4519    session: UserSession,
4520    api_key: Option<Arc<str>>,
4521) -> Result<()> {
4522    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4523    authorize_access_to_language_models(&session).await?;
4524
4525    session
4526        .rate_limiter
4527        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4528        .await?;
4529
4530    let embeddings = match request.model.as_str() {
4531        "openai/text-embedding-3-small" => {
4532            open_ai::embed(
4533                &session.http_client,
4534                OPEN_AI_API_URL,
4535                &api_key,
4536                OpenAiEmbeddingModel::TextEmbedding3Small,
4537                request.texts.iter().map(|text| text.as_str()),
4538            )
4539            .await?
4540        }
4541        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4542    };
4543
4544    let embeddings = request
4545        .texts
4546        .iter()
4547        .map(|text| {
4548            let mut hasher = sha2::Sha256::new();
4549            hasher.update(text.as_bytes());
4550            let result = hasher.finalize();
4551            result.to_vec()
4552        })
4553        .zip(
4554            embeddings
4555                .data
4556                .into_iter()
4557                .map(|embedding| embedding.embedding),
4558        )
4559        .collect::<HashMap<_, _>>();
4560
4561    let db = session.db().await;
4562    db.save_embeddings(&request.model, &embeddings)
4563        .await
4564        .context("failed to save embeddings")
4565        .trace_err();
4566
4567    response.send(proto::ComputeEmbeddingsResponse {
4568        embeddings: embeddings
4569            .into_iter()
4570            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4571            .collect(),
4572    })?;
4573    Ok(())
4574}
4575
4576async fn get_cached_embeddings(
4577    request: proto::GetCachedEmbeddings,
4578    response: Response<proto::GetCachedEmbeddings>,
4579    session: UserSession,
4580) -> Result<()> {
4581    authorize_access_to_language_models(&session).await?;
4582
4583    let db = session.db().await;
4584    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4585
4586    response.send(proto::GetCachedEmbeddingsResponse {
4587        embeddings: embeddings
4588            .into_iter()
4589            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4590            .collect(),
4591    })?;
4592    Ok(())
4593}
4594
4595async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4596    let db = session.db().await;
4597    let flags = db.get_user_flags(session.user_id()).await?;
4598    if flags.iter().any(|flag| flag == "language-models") {
4599        Ok(())
4600    } else {
4601        Err(anyhow!("permission denied"))?
4602    }
4603}
4604
4605/// Start receiving chat updates for a channel
4606async fn join_channel_chat(
4607    request: proto::JoinChannelChat,
4608    response: Response<proto::JoinChannelChat>,
4609    session: UserSession,
4610) -> Result<()> {
4611    let channel_id = ChannelId::from_proto(request.channel_id);
4612
4613    let db = session.db().await;
4614    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4615        .await?;
4616    let messages = db
4617        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4618        .await?;
4619    response.send(proto::JoinChannelChatResponse {
4620        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4621        messages,
4622    })?;
4623    Ok(())
4624}
4625
4626/// Stop receiving chat updates for a channel
4627async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4628    let channel_id = ChannelId::from_proto(request.channel_id);
4629    session
4630        .db()
4631        .await
4632        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4633        .await?;
4634    Ok(())
4635}
4636
4637/// Retrieve the chat history for a channel
4638async fn get_channel_messages(
4639    request: proto::GetChannelMessages,
4640    response: Response<proto::GetChannelMessages>,
4641    session: UserSession,
4642) -> Result<()> {
4643    let channel_id = ChannelId::from_proto(request.channel_id);
4644    let messages = session
4645        .db()
4646        .await
4647        .get_channel_messages(
4648            channel_id,
4649            session.user_id(),
4650            MESSAGE_COUNT_PER_PAGE,
4651            Some(MessageId::from_proto(request.before_message_id)),
4652        )
4653        .await?;
4654    response.send(proto::GetChannelMessagesResponse {
4655        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4656        messages,
4657    })?;
4658    Ok(())
4659}
4660
4661/// Retrieve specific chat messages
4662async fn get_channel_messages_by_id(
4663    request: proto::GetChannelMessagesById,
4664    response: Response<proto::GetChannelMessagesById>,
4665    session: UserSession,
4666) -> Result<()> {
4667    let message_ids = request
4668        .message_ids
4669        .iter()
4670        .map(|id| MessageId::from_proto(*id))
4671        .collect::<Vec<_>>();
4672    let messages = session
4673        .db()
4674        .await
4675        .get_channel_messages_by_id(session.user_id(), &message_ids)
4676        .await?;
4677    response.send(proto::GetChannelMessagesResponse {
4678        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4679        messages,
4680    })?;
4681    Ok(())
4682}
4683
4684/// Retrieve the current users notifications
4685async fn get_notifications(
4686    request: proto::GetNotifications,
4687    response: Response<proto::GetNotifications>,
4688    session: UserSession,
4689) -> Result<()> {
4690    let notifications = session
4691        .db()
4692        .await
4693        .get_notifications(
4694            session.user_id(),
4695            NOTIFICATION_COUNT_PER_PAGE,
4696            request
4697                .before_id
4698                .map(|id| db::NotificationId::from_proto(id)),
4699        )
4700        .await?;
4701    response.send(proto::GetNotificationsResponse {
4702        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4703        notifications,
4704    })?;
4705    Ok(())
4706}
4707
4708/// Mark notifications as read
4709async fn mark_notification_as_read(
4710    request: proto::MarkNotificationRead,
4711    response: Response<proto::MarkNotificationRead>,
4712    session: UserSession,
4713) -> Result<()> {
4714    let database = &session.db().await;
4715    let notifications = database
4716        .mark_notification_as_read_by_id(
4717            session.user_id(),
4718            NotificationId::from_proto(request.notification_id),
4719        )
4720        .await?;
4721    send_notifications(
4722        &*session.connection_pool().await,
4723        &session.peer,
4724        notifications,
4725    );
4726    response.send(proto::Ack {})?;
4727    Ok(())
4728}
4729
4730/// Get the current users information
4731async fn get_private_user_info(
4732    _request: proto::GetPrivateUserInfo,
4733    response: Response<proto::GetPrivateUserInfo>,
4734    session: UserSession,
4735) -> Result<()> {
4736    let db = session.db().await;
4737
4738    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4739    let user = db
4740        .get_user_by_id(session.user_id())
4741        .await?
4742        .ok_or_else(|| anyhow!("user not found"))?;
4743    let flags = db.get_user_flags(session.user_id()).await?;
4744
4745    response.send(proto::GetPrivateUserInfoResponse {
4746        metrics_id,
4747        staff: user.admin,
4748        flags,
4749    })?;
4750    Ok(())
4751}
4752
4753fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4754    match message {
4755        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4756        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4757        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4758        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4759        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4760            code: frame.code.into(),
4761            reason: frame.reason,
4762        })),
4763    }
4764}
4765
4766fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4767    match message {
4768        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4769        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4770        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4771        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4772        AxumMessage::Close(frame) => {
4773            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4774                code: frame.code.into(),
4775                reason: frame.reason,
4776            }))
4777        }
4778    }
4779}
4780
4781fn notify_membership_updated(
4782    connection_pool: &mut ConnectionPool,
4783    result: MembershipUpdated,
4784    user_id: UserId,
4785    peer: &Peer,
4786) {
4787    for membership in &result.new_channels.channel_memberships {
4788        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4789    }
4790    for channel_id in &result.removed_channels {
4791        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4792    }
4793
4794    let user_channels_update = proto::UpdateUserChannels {
4795        channel_memberships: result
4796            .new_channels
4797            .channel_memberships
4798            .iter()
4799            .map(|cm| proto::ChannelMembership {
4800                channel_id: cm.channel_id.to_proto(),
4801                role: cm.role.into(),
4802            })
4803            .collect(),
4804        ..Default::default()
4805    };
4806
4807    let mut update = build_channels_update(result.new_channels, vec![]);
4808    update.delete_channels = result
4809        .removed_channels
4810        .into_iter()
4811        .map(|id| id.to_proto())
4812        .collect();
4813    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4814
4815    for connection_id in connection_pool.user_connection_ids(user_id) {
4816        peer.send(connection_id, user_channels_update.clone())
4817            .trace_err();
4818        peer.send(connection_id, update.clone()).trace_err();
4819    }
4820}
4821
4822fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4823    proto::UpdateUserChannels {
4824        channel_memberships: channels
4825            .channel_memberships
4826            .iter()
4827            .map(|m| proto::ChannelMembership {
4828                channel_id: m.channel_id.to_proto(),
4829                role: m.role.into(),
4830            })
4831            .collect(),
4832        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4833        observed_channel_message_id: channels.observed_channel_messages.clone(),
4834    }
4835}
4836
4837fn build_channels_update(
4838    channels: ChannelsForUser,
4839    channel_invites: Vec<db::Channel>,
4840) -> proto::UpdateChannels {
4841    let mut update = proto::UpdateChannels::default();
4842
4843    for channel in channels.channels {
4844        update.channels.push(channel.to_proto());
4845    }
4846
4847    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4848    update.latest_channel_message_ids = channels.latest_channel_messages;
4849
4850    for (channel_id, participants) in channels.channel_participants {
4851        update
4852            .channel_participants
4853            .push(proto::ChannelParticipants {
4854                channel_id: channel_id.to_proto(),
4855                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4856            });
4857    }
4858
4859    for channel in channel_invites {
4860        update.channel_invitations.push(channel.to_proto());
4861    }
4862
4863    update.hosted_projects = channels.hosted_projects;
4864    update
4865}
4866
4867fn build_initial_contacts_update(
4868    contacts: Vec<db::Contact>,
4869    pool: &ConnectionPool,
4870) -> proto::UpdateContacts {
4871    let mut update = proto::UpdateContacts::default();
4872
4873    for contact in contacts {
4874        match contact {
4875            db::Contact::Accepted { user_id, busy } => {
4876                update.contacts.push(contact_for_user(user_id, busy, &pool));
4877            }
4878            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4879            db::Contact::Incoming { user_id } => {
4880                update
4881                    .incoming_requests
4882                    .push(proto::IncomingContactRequest {
4883                        requester_id: user_id.to_proto(),
4884                    })
4885            }
4886        }
4887    }
4888
4889    update
4890}
4891
4892fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4893    proto::Contact {
4894        user_id: user_id.to_proto(),
4895        online: pool.is_user_online(user_id),
4896        busy,
4897    }
4898}
4899
4900fn room_updated(room: &proto::Room, peer: &Peer) {
4901    broadcast(
4902        None,
4903        room.participants
4904            .iter()
4905            .filter_map(|participant| Some(participant.peer_id?.into())),
4906        |peer_id| {
4907            peer.send(
4908                peer_id,
4909                proto::RoomUpdated {
4910                    room: Some(room.clone()),
4911                },
4912            )
4913        },
4914    );
4915}
4916
4917fn channel_updated(
4918    channel: &db::channel::Model,
4919    room: &proto::Room,
4920    peer: &Peer,
4921    pool: &ConnectionPool,
4922) {
4923    let participants = room
4924        .participants
4925        .iter()
4926        .map(|p| p.user_id)
4927        .collect::<Vec<_>>();
4928
4929    broadcast(
4930        None,
4931        pool.channel_connection_ids(channel.root_id())
4932            .filter_map(|(channel_id, role)| {
4933                role.can_see_channel(channel.visibility).then(|| channel_id)
4934            }),
4935        |peer_id| {
4936            peer.send(
4937                peer_id,
4938                proto::UpdateChannels {
4939                    channel_participants: vec![proto::ChannelParticipants {
4940                        channel_id: channel.id.to_proto(),
4941                        participant_user_ids: participants.clone(),
4942                    }],
4943                    ..Default::default()
4944                },
4945            )
4946        },
4947    );
4948}
4949
4950async fn send_dev_server_projects_update(
4951    user_id: UserId,
4952    mut status: proto::DevServerProjectsUpdate,
4953    session: &Session,
4954) {
4955    let pool = session.connection_pool().await;
4956    for dev_server in &mut status.dev_servers {
4957        dev_server.status =
4958            pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
4959    }
4960    let connections = pool.user_connection_ids(user_id);
4961    for connection_id in connections {
4962        session.peer.send(connection_id, status.clone()).trace_err();
4963    }
4964}
4965
4966async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4967    let db = session.db().await;
4968
4969    let contacts = db.get_contacts(user_id).await?;
4970    let busy = db.is_user_busy(user_id).await?;
4971
4972    let pool = session.connection_pool().await;
4973    let updated_contact = contact_for_user(user_id, busy, &pool);
4974    for contact in contacts {
4975        if let db::Contact::Accepted {
4976            user_id: contact_user_id,
4977            ..
4978        } = contact
4979        {
4980            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4981                session
4982                    .peer
4983                    .send(
4984                        contact_conn_id,
4985                        proto::UpdateContacts {
4986                            contacts: vec![updated_contact.clone()],
4987                            remove_contacts: Default::default(),
4988                            incoming_requests: Default::default(),
4989                            remove_incoming_requests: Default::default(),
4990                            outgoing_requests: Default::default(),
4991                            remove_outgoing_requests: Default::default(),
4992                        },
4993                    )
4994                    .trace_err();
4995            }
4996        }
4997    }
4998    Ok(())
4999}
5000
5001async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5002    log::info!("lost dev server connection, unsharing projects");
5003    let project_ids = session
5004        .db()
5005        .await
5006        .get_stale_dev_server_projects(session.connection_id)
5007        .await?;
5008
5009    for project_id in project_ids {
5010        // not unshare re-checks the connection ids match, so we get away with no transaction
5011        unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5012    }
5013
5014    let user_id = session.dev_server().user_id;
5015    let update = session
5016        .db()
5017        .await
5018        .dev_server_projects_update(user_id)
5019        .await?;
5020
5021    send_dev_server_projects_update(user_id, update, session).await;
5022
5023    Ok(())
5024}
5025
5026async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5027    let mut contacts_to_update = HashSet::default();
5028
5029    let room_id;
5030    let canceled_calls_to_user_ids;
5031    let live_kit_room;
5032    let delete_live_kit_room;
5033    let room;
5034    let channel;
5035
5036    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5037        contacts_to_update.insert(session.user_id());
5038
5039        for project in left_room.left_projects.values() {
5040            project_left(project, session);
5041        }
5042
5043        room_id = RoomId::from_proto(left_room.room.id);
5044        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5045        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5046        delete_live_kit_room = left_room.deleted;
5047        room = mem::take(&mut left_room.room);
5048        channel = mem::take(&mut left_room.channel);
5049
5050        room_updated(&room, &session.peer);
5051    } else {
5052        return Ok(());
5053    }
5054
5055    if let Some(channel) = channel {
5056        channel_updated(
5057            &channel,
5058            &room,
5059            &session.peer,
5060            &*session.connection_pool().await,
5061        );
5062    }
5063
5064    {
5065        let pool = session.connection_pool().await;
5066        for canceled_user_id in canceled_calls_to_user_ids {
5067            for connection_id in pool.user_connection_ids(canceled_user_id) {
5068                session
5069                    .peer
5070                    .send(
5071                        connection_id,
5072                        proto::CallCanceled {
5073                            room_id: room_id.to_proto(),
5074                        },
5075                    )
5076                    .trace_err();
5077            }
5078            contacts_to_update.insert(canceled_user_id);
5079        }
5080    }
5081
5082    for contact_user_id in contacts_to_update {
5083        update_user_contacts(contact_user_id, &session).await?;
5084    }
5085
5086    if let Some(live_kit) = session.live_kit_client.as_ref() {
5087        live_kit
5088            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5089            .await
5090            .trace_err();
5091
5092        if delete_live_kit_room {
5093            live_kit.delete_room(live_kit_room).await.trace_err();
5094        }
5095    }
5096
5097    Ok(())
5098}
5099
5100async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5101    let left_channel_buffers = session
5102        .db()
5103        .await
5104        .leave_channel_buffers(session.connection_id)
5105        .await?;
5106
5107    for left_buffer in left_channel_buffers {
5108        channel_buffer_updated(
5109            session.connection_id,
5110            left_buffer.connections,
5111            &proto::UpdateChannelBufferCollaborators {
5112                channel_id: left_buffer.channel_id.to_proto(),
5113                collaborators: left_buffer.collaborators,
5114            },
5115            &session.peer,
5116        );
5117    }
5118
5119    Ok(())
5120}
5121
5122fn project_left(project: &db::LeftProject, session: &UserSession) {
5123    for connection_id in &project.connection_ids {
5124        if project.should_unshare {
5125            session
5126                .peer
5127                .send(
5128                    *connection_id,
5129                    proto::UnshareProject {
5130                        project_id: project.id.to_proto(),
5131                    },
5132                )
5133                .trace_err();
5134        } else {
5135            session
5136                .peer
5137                .send(
5138                    *connection_id,
5139                    proto::RemoveProjectCollaborator {
5140                        project_id: project.id.to_proto(),
5141                        peer_id: Some(session.connection_id.into()),
5142                    },
5143                )
5144                .trace_err();
5145        }
5146    }
5147}
5148
5149pub trait ResultExt {
5150    type Ok;
5151
5152    fn trace_err(self) -> Option<Self::Ok>;
5153}
5154
5155impl<T, E> ResultExt for Result<T, E>
5156where
5157    E: std::fmt::Debug,
5158{
5159    type Ok = T;
5160
5161    #[track_caller]
5162    fn trace_err(self) -> Option<T> {
5163        match self {
5164            Ok(value) => Some(value),
5165            Err(error) => {
5166                tracing::error!("{:?}", error);
5167                None
5168            }
5169        }
5170    }
5171}