rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth,
   5    db::{
   6        self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, DevServerId, InviteMemberResult, MembershipUpdated,
   8        MessageId, NotificationId, PrincipalId, Project, ProjectId, RejoinedProject,
   9        RemoteProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
  10        ServerId, UpdatedChannelMessage, User, UserId,
  11    },
  12    executor::Executor,
  13    AppState, Error, RateLimit, RateLimiter, Result,
  14};
  15use anyhow::{anyhow, Context as _};
  16use async_tungstenite::tungstenite::{
  17    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  18};
  19use axum::{
  20    body::Body,
  21    extract::{
  22        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  23        ConnectInfo, WebSocketUpgrade,
  24    },
  25    headers::{Header, HeaderName},
  26    http::StatusCode,
  27    middleware,
  28    response::IntoResponse,
  29    routing::get,
  30    Extension, Router, TypedHeader,
  31};
  32use collections::{HashMap, HashSet};
  33pub use connection_pool::{ConnectionPool, ZedVersion};
  34use core::fmt::{self, Debug, Formatter};
  35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  36use sha2::Digest;
  37
  38use futures::{
  39    channel::oneshot,
  40    future::{self, BoxFuture},
  41    stream::FuturesUnordered,
  42    FutureExt, SinkExt, StreamExt, TryStreamExt,
  43};
  44use prometheus::{register_int_gauge, IntGauge};
  45use rpc::{
  46    proto::{
  47        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  48        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  49    },
  50    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  51};
  52use semantic_version::SemanticVersion;
  53use serde::{Serialize, Serializer};
  54use std::{
  55    any::TypeId,
  56    future::Future,
  57    marker::PhantomData,
  58    mem,
  59    net::SocketAddr,
  60    ops::{Deref, DerefMut},
  61    rc::Rc,
  62    sync::{
  63        atomic::{AtomicBool, Ordering::SeqCst},
  64        Arc, OnceLock,
  65    },
  66    time::{Duration, Instant},
  67};
  68use time::OffsetDateTime;
  69use tokio::sync::{watch, Semaphore};
  70use tower::ServiceBuilder;
  71use tracing::{
  72    field::{self},
  73    info_span, instrument, Instrument,
  74};
  75use util::http::IsahcHttpClient;
  76
  77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  78
  79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  81
  82const MESSAGE_COUNT_PER_PAGE: usize = 100;
  83const MAX_MESSAGE_LEN: usize = 1024;
  84const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  85
  86type MessageHandler =
  87    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  88
  89struct Response<R> {
  90    peer: Arc<Peer>,
  91    receipt: Receipt<R>,
  92    responded: Arc<AtomicBool>,
  93}
  94
  95impl<R: RequestMessage> Response<R> {
  96    fn send(self, payload: R::Response) -> Result<()> {
  97        self.responded.store(true, SeqCst);
  98        self.peer.respond(self.receipt, payload)?;
  99        Ok(())
 100    }
 101}
 102
 103struct StreamingResponse<R: RequestMessage> {
 104    peer: Arc<Peer>,
 105    receipt: Receipt<R>,
 106}
 107
 108impl<R: RequestMessage> StreamingResponse<R> {
 109    fn send(&self, payload: R::Response) -> Result<()> {
 110        self.peer.respond(self.receipt, payload)?;
 111        Ok(())
 112    }
 113}
 114
 115#[derive(Clone, Debug)]
 116pub enum Principal {
 117    User(User),
 118    Impersonated { user: User, admin: User },
 119    DevServer(dev_server::Model),
 120}
 121
 122impl Principal {
 123    fn update_span(&self, span: &tracing::Span) {
 124        match &self {
 125            Principal::User(user) => {
 126                span.record("user_id", &user.id.0);
 127                span.record("login", &user.github_login);
 128            }
 129            Principal::Impersonated { user, admin } => {
 130                span.record("user_id", &user.id.0);
 131                span.record("login", &user.github_login);
 132                span.record("impersonator", &admin.github_login);
 133            }
 134            Principal::DevServer(dev_server) => {
 135                span.record("dev_server_id", &dev_server.id.0);
 136            }
 137        }
 138    }
 139}
 140
 141#[derive(Clone)]
 142struct Session {
 143    principal: Principal,
 144    connection_id: ConnectionId,
 145    db: Arc<tokio::sync::Mutex<DbHandle>>,
 146    peer: Arc<Peer>,
 147    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 148    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 149    http_client: IsahcHttpClient,
 150    rate_limiter: Arc<RateLimiter>,
 151    _executor: Executor,
 152}
 153
 154impl Session {
 155    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        let guard = self.db.lock().await;
 159        #[cfg(test)]
 160        tokio::task::yield_now().await;
 161        guard
 162    }
 163
 164    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 165        #[cfg(test)]
 166        tokio::task::yield_now().await;
 167        let guard = self.connection_pool.lock();
 168        ConnectionPoolGuard {
 169            guard,
 170            _not_send: PhantomData,
 171        }
 172    }
 173
 174    fn for_user(self) -> Option<UserSession> {
 175        UserSession::new(self)
 176    }
 177
 178    fn for_dev_server(self) -> Option<DevServerSession> {
 179        DevServerSession::new(self)
 180    }
 181
 182    fn user_id(&self) -> Option<UserId> {
 183        match &self.principal {
 184            Principal::User(user) => Some(user.id),
 185            Principal::Impersonated { user, .. } => Some(user.id),
 186            Principal::DevServer(_) => None,
 187        }
 188    }
 189
 190    fn dev_server_id(&self) -> Option<DevServerId> {
 191        match &self.principal {
 192            Principal::User(_) | Principal::Impersonated { .. } => None,
 193            Principal::DevServer(dev_server) => Some(dev_server.id),
 194        }
 195    }
 196
 197    fn principal_id(&self) -> PrincipalId {
 198        match &self.principal {
 199            Principal::User(user) => PrincipalId::UserId(user.id),
 200            Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
 201            Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
 202        }
 203    }
 204}
 205
 206impl Debug for Session {
 207    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 208        let mut result = f.debug_struct("Session");
 209        match &self.principal {
 210            Principal::User(user) => {
 211                result.field("user", &user.github_login);
 212            }
 213            Principal::Impersonated { user, admin } => {
 214                result.field("user", &user.github_login);
 215                result.field("impersonator", &admin.github_login);
 216            }
 217            Principal::DevServer(dev_server) => {
 218                result.field("dev_server", &dev_server.id);
 219            }
 220        }
 221        result.field("connection_id", &self.connection_id).finish()
 222    }
 223}
 224
 225struct UserSession(Session);
 226
 227impl UserSession {
 228    pub fn new(s: Session) -> Option<Self> {
 229        s.user_id().map(|_| UserSession(s))
 230    }
 231    pub fn user_id(&self) -> UserId {
 232        self.0.user_id().unwrap()
 233    }
 234}
 235
 236impl Deref for UserSession {
 237    type Target = Session;
 238
 239    fn deref(&self) -> &Self::Target {
 240        &self.0
 241    }
 242}
 243impl DerefMut for UserSession {
 244    fn deref_mut(&mut self) -> &mut Self::Target {
 245        &mut self.0
 246    }
 247}
 248
 249struct DevServerSession(Session);
 250
 251impl DevServerSession {
 252    pub fn new(s: Session) -> Option<Self> {
 253        s.dev_server_id().map(|_| DevServerSession(s))
 254    }
 255    pub fn dev_server_id(&self) -> DevServerId {
 256        self.0.dev_server_id().unwrap()
 257    }
 258}
 259
 260impl Deref for DevServerSession {
 261    type Target = Session;
 262
 263    fn deref(&self) -> &Self::Target {
 264        &self.0
 265    }
 266}
 267impl DerefMut for DevServerSession {
 268    fn deref_mut(&mut self) -> &mut Self::Target {
 269        &mut self.0
 270    }
 271}
 272
 273fn user_handler<M: RequestMessage, Fut>(
 274    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 275) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 276where
 277    Fut: Send + Future<Output = Result<()>>,
 278{
 279    let handler = Arc::new(handler);
 280    move |message, response, session| {
 281        let handler = handler.clone();
 282        Box::pin(async move {
 283            if let Some(user_session) = session.for_user() {
 284                Ok(handler(message, response, user_session).await?)
 285            } else {
 286                Err(Error::Internal(anyhow!(
 287                    "must be a user to call {}",
 288                    M::NAME
 289                )))
 290            }
 291        })
 292    }
 293}
 294
 295fn dev_server_handler<M: RequestMessage, Fut>(
 296    handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
 297) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 298where
 299    Fut: Send + Future<Output = Result<()>>,
 300{
 301    let handler = Arc::new(handler);
 302    move |message, response, session| {
 303        let handler = handler.clone();
 304        Box::pin(async move {
 305            if let Some(dev_server_session) = session.for_dev_server() {
 306                Ok(handler(message, response, dev_server_session).await?)
 307            } else {
 308                Err(Error::Internal(anyhow!(
 309                    "must be a dev server to call {}",
 310                    M::NAME
 311                )))
 312            }
 313        })
 314    }
 315}
 316
 317fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 318    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 319) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 320where
 321    InnertRetFut: Send + Future<Output = Result<()>>,
 322{
 323    let handler = Arc::new(handler);
 324    move |message, session| {
 325        let handler = handler.clone();
 326        Box::pin(async move {
 327            if let Some(user_session) = session.for_user() {
 328                Ok(handler(message, user_session).await?)
 329            } else {
 330                Err(Error::Internal(anyhow!(
 331                    "must be a user to call {}",
 332                    M::NAME
 333                )))
 334            }
 335        })
 336    }
 337}
 338
 339struct DbHandle(Arc<Database>);
 340
 341impl Deref for DbHandle {
 342    type Target = Database;
 343
 344    fn deref(&self) -> &Self::Target {
 345        self.0.as_ref()
 346    }
 347}
 348
 349pub struct Server {
 350    id: parking_lot::Mutex<ServerId>,
 351    peer: Arc<Peer>,
 352    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 353    app_state: Arc<AppState>,
 354    handlers: HashMap<TypeId, MessageHandler>,
 355    teardown: watch::Sender<bool>,
 356}
 357
 358pub(crate) struct ConnectionPoolGuard<'a> {
 359    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 360    _not_send: PhantomData<Rc<()>>,
 361}
 362
 363#[derive(Serialize)]
 364pub struct ServerSnapshot<'a> {
 365    peer: &'a Peer,
 366    #[serde(serialize_with = "serialize_deref")]
 367    connection_pool: ConnectionPoolGuard<'a>,
 368}
 369
 370pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 371where
 372    S: Serializer,
 373    T: Deref<Target = U>,
 374    U: Serialize,
 375{
 376    Serialize::serialize(value.deref(), serializer)
 377}
 378
 379impl Server {
 380    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 381        let mut server = Self {
 382            id: parking_lot::Mutex::new(id),
 383            peer: Peer::new(id.0 as u32),
 384            app_state: app_state.clone(),
 385            connection_pool: Default::default(),
 386            handlers: Default::default(),
 387            teardown: watch::channel(false).0,
 388        };
 389
 390        server
 391            .add_request_handler(ping)
 392            .add_request_handler(user_handler(create_room))
 393            .add_request_handler(user_handler(join_room))
 394            .add_request_handler(user_handler(rejoin_room))
 395            .add_request_handler(user_handler(leave_room))
 396            .add_request_handler(user_handler(set_room_participant_role))
 397            .add_request_handler(user_handler(call))
 398            .add_request_handler(user_handler(cancel_call))
 399            .add_message_handler(user_message_handler(decline_call))
 400            .add_request_handler(user_handler(update_participant_location))
 401            .add_request_handler(user_handler(share_project))
 402            .add_message_handler(unshare_project)
 403            .add_request_handler(user_handler(join_project))
 404            .add_request_handler(user_handler(join_hosted_project))
 405            .add_request_handler(user_handler(rejoin_remote_projects))
 406            .add_request_handler(user_handler(create_remote_project))
 407            .add_request_handler(user_handler(create_dev_server))
 408            .add_request_handler(dev_server_handler(share_remote_project))
 409            .add_request_handler(dev_server_handler(shutdown_dev_server))
 410            .add_request_handler(dev_server_handler(reconnect_dev_server))
 411            .add_message_handler(user_message_handler(leave_project))
 412            .add_request_handler(update_project)
 413            .add_request_handler(update_worktree)
 414            .add_message_handler(start_language_server)
 415            .add_message_handler(update_language_server)
 416            .add_message_handler(update_diagnostic_summary)
 417            .add_message_handler(update_worktree_settings)
 418            .add_request_handler(user_handler(
 419                forward_read_only_project_request::<proto::GetHover>,
 420            ))
 421            .add_request_handler(user_handler(
 422                forward_read_only_project_request::<proto::GetDefinition>,
 423            ))
 424            .add_request_handler(user_handler(
 425                forward_read_only_project_request::<proto::GetTypeDefinition>,
 426            ))
 427            .add_request_handler(user_handler(
 428                forward_read_only_project_request::<proto::GetReferences>,
 429            ))
 430            .add_request_handler(user_handler(
 431                forward_read_only_project_request::<proto::SearchProject>,
 432            ))
 433            .add_request_handler(user_handler(
 434                forward_read_only_project_request::<proto::GetDocumentHighlights>,
 435            ))
 436            .add_request_handler(user_handler(
 437                forward_read_only_project_request::<proto::GetProjectSymbols>,
 438            ))
 439            .add_request_handler(user_handler(
 440                forward_read_only_project_request::<proto::OpenBufferForSymbol>,
 441            ))
 442            .add_request_handler(user_handler(
 443                forward_read_only_project_request::<proto::OpenBufferById>,
 444            ))
 445            .add_request_handler(user_handler(
 446                forward_read_only_project_request::<proto::SynchronizeBuffers>,
 447            ))
 448            .add_request_handler(user_handler(
 449                forward_read_only_project_request::<proto::InlayHints>,
 450            ))
 451            .add_request_handler(user_handler(
 452                forward_read_only_project_request::<proto::OpenBufferByPath>,
 453            ))
 454            .add_request_handler(user_handler(
 455                forward_mutating_project_request::<proto::GetCompletions>,
 456            ))
 457            .add_request_handler(user_handler(
 458                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 459            ))
 460            .add_request_handler(user_handler(
 461                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 462            ))
 463            .add_request_handler(user_handler(
 464                forward_mutating_project_request::<proto::GetCodeActions>,
 465            ))
 466            .add_request_handler(user_handler(
 467                forward_mutating_project_request::<proto::ApplyCodeAction>,
 468            ))
 469            .add_request_handler(user_handler(
 470                forward_mutating_project_request::<proto::PrepareRename>,
 471            ))
 472            .add_request_handler(user_handler(
 473                forward_mutating_project_request::<proto::PerformRename>,
 474            ))
 475            .add_request_handler(user_handler(
 476                forward_mutating_project_request::<proto::ReloadBuffers>,
 477            ))
 478            .add_request_handler(user_handler(
 479                forward_mutating_project_request::<proto::FormatBuffers>,
 480            ))
 481            .add_request_handler(user_handler(
 482                forward_mutating_project_request::<proto::CreateProjectEntry>,
 483            ))
 484            .add_request_handler(user_handler(
 485                forward_mutating_project_request::<proto::RenameProjectEntry>,
 486            ))
 487            .add_request_handler(user_handler(
 488                forward_mutating_project_request::<proto::CopyProjectEntry>,
 489            ))
 490            .add_request_handler(user_handler(
 491                forward_mutating_project_request::<proto::DeleteProjectEntry>,
 492            ))
 493            .add_request_handler(user_handler(
 494                forward_mutating_project_request::<proto::ExpandProjectEntry>,
 495            ))
 496            .add_request_handler(user_handler(
 497                forward_mutating_project_request::<proto::OnTypeFormatting>,
 498            ))
 499            .add_request_handler(user_handler(
 500                forward_mutating_project_request::<proto::SaveBuffer>,
 501            ))
 502            .add_request_handler(user_handler(
 503                forward_mutating_project_request::<proto::BlameBuffer>,
 504            ))
 505            .add_request_handler(user_handler(
 506                forward_mutating_project_request::<proto::MultiLspQuery>,
 507            ))
 508            .add_message_handler(create_buffer_for_peer)
 509            .add_request_handler(update_buffer)
 510            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 511            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 512            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 513            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 514            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 515            .add_request_handler(get_users)
 516            .add_request_handler(user_handler(fuzzy_search_users))
 517            .add_request_handler(user_handler(request_contact))
 518            .add_request_handler(user_handler(remove_contact))
 519            .add_request_handler(user_handler(respond_to_contact_request))
 520            .add_request_handler(user_handler(create_channel))
 521            .add_request_handler(user_handler(delete_channel))
 522            .add_request_handler(user_handler(invite_channel_member))
 523            .add_request_handler(user_handler(remove_channel_member))
 524            .add_request_handler(user_handler(set_channel_member_role))
 525            .add_request_handler(user_handler(set_channel_visibility))
 526            .add_request_handler(user_handler(rename_channel))
 527            .add_request_handler(user_handler(join_channel_buffer))
 528            .add_request_handler(user_handler(leave_channel_buffer))
 529            .add_message_handler(user_message_handler(update_channel_buffer))
 530            .add_request_handler(user_handler(rejoin_channel_buffers))
 531            .add_request_handler(user_handler(get_channel_members))
 532            .add_request_handler(user_handler(respond_to_channel_invite))
 533            .add_request_handler(user_handler(join_channel))
 534            .add_request_handler(user_handler(join_channel_chat))
 535            .add_message_handler(user_message_handler(leave_channel_chat))
 536            .add_request_handler(user_handler(send_channel_message))
 537            .add_request_handler(user_handler(remove_channel_message))
 538            .add_request_handler(user_handler(update_channel_message))
 539            .add_request_handler(user_handler(get_channel_messages))
 540            .add_request_handler(user_handler(get_channel_messages_by_id))
 541            .add_request_handler(user_handler(get_notifications))
 542            .add_request_handler(user_handler(mark_notification_as_read))
 543            .add_request_handler(user_handler(move_channel))
 544            .add_request_handler(user_handler(follow))
 545            .add_message_handler(user_message_handler(unfollow))
 546            .add_message_handler(user_message_handler(update_followers))
 547            .add_request_handler(user_handler(get_private_user_info))
 548            .add_message_handler(user_message_handler(acknowledge_channel_message))
 549            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 550            .add_streaming_request_handler({
 551                let app_state = app_state.clone();
 552                move |request, response, session| {
 553                    complete_with_language_model(
 554                        request,
 555                        response,
 556                        session,
 557                        app_state.config.openai_api_key.clone(),
 558                        app_state.config.google_ai_api_key.clone(),
 559                        app_state.config.anthropic_api_key.clone(),
 560                    )
 561                }
 562            })
 563            .add_request_handler({
 564                let app_state = app_state.clone();
 565                user_handler(move |request, response, session| {
 566                    count_tokens_with_language_model(
 567                        request,
 568                        response,
 569                        session,
 570                        app_state.config.google_ai_api_key.clone(),
 571                    )
 572                })
 573            })
 574            .add_request_handler({
 575                user_handler(move |request, response, session| {
 576                    get_cached_embeddings(request, response, session)
 577                })
 578            })
 579            .add_request_handler({
 580                let app_state = app_state.clone();
 581                user_handler(move |request, response, session| {
 582                    compute_embeddings(
 583                        request,
 584                        response,
 585                        session,
 586                        app_state.config.openai_api_key.clone(),
 587                    )
 588                })
 589            });
 590
 591        Arc::new(server)
 592    }
 593
 594    pub async fn start(&self) -> Result<()> {
 595        let server_id = *self.id.lock();
 596        let app_state = self.app_state.clone();
 597        let peer = self.peer.clone();
 598        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 599        let pool = self.connection_pool.clone();
 600        let live_kit_client = self.app_state.live_kit_client.clone();
 601
 602        let span = info_span!("start server");
 603        self.app_state.executor.spawn_detached(
 604            async move {
 605                tracing::info!("waiting for cleanup timeout");
 606                timeout.await;
 607                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 608                if let Some((room_ids, channel_ids)) = app_state
 609                    .db
 610                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 611                    .await
 612                    .trace_err()
 613                {
 614                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 615                    tracing::info!(
 616                        stale_channel_buffer_count = channel_ids.len(),
 617                        "retrieved stale channel buffers"
 618                    );
 619
 620                    for channel_id in channel_ids {
 621                        if let Some(refreshed_channel_buffer) = app_state
 622                            .db
 623                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 624                            .await
 625                            .trace_err()
 626                        {
 627                            for connection_id in refreshed_channel_buffer.connection_ids {
 628                                peer.send(
 629                                    connection_id,
 630                                    proto::UpdateChannelBufferCollaborators {
 631                                        channel_id: channel_id.to_proto(),
 632                                        collaborators: refreshed_channel_buffer
 633                                            .collaborators
 634                                            .clone(),
 635                                    },
 636                                )
 637                                .trace_err();
 638                            }
 639                        }
 640                    }
 641
 642                    for room_id in room_ids {
 643                        let mut contacts_to_update = HashSet::default();
 644                        let mut canceled_calls_to_user_ids = Vec::new();
 645                        let mut live_kit_room = String::new();
 646                        let mut delete_live_kit_room = false;
 647
 648                        if let Some(mut refreshed_room) = app_state
 649                            .db
 650                            .clear_stale_room_participants(room_id, server_id)
 651                            .await
 652                            .trace_err()
 653                        {
 654                            tracing::info!(
 655                                room_id = room_id.0,
 656                                new_participant_count = refreshed_room.room.participants.len(),
 657                                "refreshed room"
 658                            );
 659                            room_updated(&refreshed_room.room, &peer);
 660                            if let Some(channel) = refreshed_room.channel.as_ref() {
 661                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 662                            }
 663                            contacts_to_update
 664                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 665                            contacts_to_update
 666                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 667                            canceled_calls_to_user_ids =
 668                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 669                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 670                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 671                        }
 672
 673                        {
 674                            let pool = pool.lock();
 675                            for canceled_user_id in canceled_calls_to_user_ids {
 676                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 677                                    peer.send(
 678                                        connection_id,
 679                                        proto::CallCanceled {
 680                                            room_id: room_id.to_proto(),
 681                                        },
 682                                    )
 683                                    .trace_err();
 684                                }
 685                            }
 686                        }
 687
 688                        for user_id in contacts_to_update {
 689                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 690                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 691                            if let Some((busy, contacts)) = busy.zip(contacts) {
 692                                let pool = pool.lock();
 693                                let updated_contact = contact_for_user(user_id, busy, &pool);
 694                                for contact in contacts {
 695                                    if let db::Contact::Accepted {
 696                                        user_id: contact_user_id,
 697                                        ..
 698                                    } = contact
 699                                    {
 700                                        for contact_conn_id in
 701                                            pool.user_connection_ids(contact_user_id)
 702                                        {
 703                                            peer.send(
 704                                                contact_conn_id,
 705                                                proto::UpdateContacts {
 706                                                    contacts: vec![updated_contact.clone()],
 707                                                    remove_contacts: Default::default(),
 708                                                    incoming_requests: Default::default(),
 709                                                    remove_incoming_requests: Default::default(),
 710                                                    outgoing_requests: Default::default(),
 711                                                    remove_outgoing_requests: Default::default(),
 712                                                },
 713                                            )
 714                                            .trace_err();
 715                                        }
 716                                    }
 717                                }
 718                            }
 719                        }
 720
 721                        if let Some(live_kit) = live_kit_client.as_ref() {
 722                            if delete_live_kit_room {
 723                                live_kit.delete_room(live_kit_room).await.trace_err();
 724                            }
 725                        }
 726                    }
 727                }
 728
 729                app_state
 730                    .db
 731                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 732                    .await
 733                    .trace_err();
 734            }
 735            .instrument(span),
 736        );
 737        Ok(())
 738    }
 739
 740    pub fn teardown(&self) {
 741        self.peer.teardown();
 742        self.connection_pool.lock().reset();
 743        let _ = self.teardown.send(true);
 744    }
 745
 746    #[cfg(test)]
 747    pub fn reset(&self, id: ServerId) {
 748        self.teardown();
 749        *self.id.lock() = id;
 750        self.peer.reset(id.0 as u32);
 751        let _ = self.teardown.send(false);
 752    }
 753
 754    #[cfg(test)]
 755    pub fn id(&self) -> ServerId {
 756        *self.id.lock()
 757    }
 758
 759    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 760    where
 761        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 762        Fut: 'static + Send + Future<Output = Result<()>>,
 763        M: EnvelopedMessage,
 764    {
 765        let prev_handler = self.handlers.insert(
 766            TypeId::of::<M>(),
 767            Box::new(move |envelope, session| {
 768                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 769                let received_at = envelope.received_at;
 770                    tracing::info!(
 771                        "message received"
 772                    );
 773                let start_time = Instant::now();
 774                let future = (handler)(*envelope, session);
 775                async move {
 776                    let result = future.await;
 777                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 778                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 779                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 780                    let payload_type = M::NAME;
 781                    match result {
 782                        Err(error) => {
 783                            // todo!(), why isn't this logged inside the span?
 784                            tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, payload_type, "error handling message")
 785                        }
 786                        Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
 787                    }
 788                }
 789                .boxed()
 790            }),
 791        );
 792        if prev_handler.is_some() {
 793            panic!("registered a handler for the same message twice");
 794        }
 795        self
 796    }
 797
 798    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 799    where
 800        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 801        Fut: 'static + Send + Future<Output = Result<()>>,
 802        M: EnvelopedMessage,
 803    {
 804        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 805        self
 806    }
 807
 808    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 809    where
 810        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 811        Fut: Send + Future<Output = Result<()>>,
 812        M: RequestMessage,
 813    {
 814        let handler = Arc::new(handler);
 815        self.add_handler(move |envelope, session| {
 816            let receipt = envelope.receipt();
 817            let handler = handler.clone();
 818            async move {
 819                let peer = session.peer.clone();
 820                let responded = Arc::new(AtomicBool::default());
 821                let response = Response {
 822                    peer: peer.clone(),
 823                    responded: responded.clone(),
 824                    receipt,
 825                };
 826                match (handler)(envelope.payload, response, session).await {
 827                    Ok(()) => {
 828                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 829                            Ok(())
 830                        } else {
 831                            Err(anyhow!("handler did not send a response"))?
 832                        }
 833                    }
 834                    Err(error) => {
 835                        let proto_err = match &error {
 836                            Error::Internal(err) => err.to_proto(),
 837                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 838                        };
 839                        peer.respond_with_error(receipt, proto_err)?;
 840                        Err(error)
 841                    }
 842                }
 843            }
 844        })
 845    }
 846
 847    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 848    where
 849        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 850        Fut: Send + Future<Output = Result<()>>,
 851        M: RequestMessage,
 852    {
 853        let handler = Arc::new(handler);
 854        self.add_handler(move |envelope, session| {
 855            let receipt = envelope.receipt();
 856            let handler = handler.clone();
 857            async move {
 858                let peer = session.peer.clone();
 859                let response = StreamingResponse {
 860                    peer: peer.clone(),
 861                    receipt,
 862                };
 863                match (handler)(envelope.payload, response, session).await {
 864                    Ok(()) => {
 865                        peer.end_stream(receipt)?;
 866                        Ok(())
 867                    }
 868                    Err(error) => {
 869                        let proto_err = match &error {
 870                            Error::Internal(err) => err.to_proto(),
 871                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 872                        };
 873                        peer.respond_with_error(receipt, proto_err)?;
 874                        Err(error)
 875                    }
 876                }
 877            }
 878        })
 879    }
 880
 881    #[allow(clippy::too_many_arguments)]
 882    pub fn handle_connection(
 883        self: &Arc<Self>,
 884        connection: Connection,
 885        address: String,
 886        principal: Principal,
 887        zed_version: ZedVersion,
 888        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 889        executor: Executor,
 890    ) -> impl Future<Output = ()> {
 891        let this = self.clone();
 892        let span = info_span!("handle connection", %address,
 893            connection_id=field::Empty,
 894            user_id=field::Empty,
 895            login=field::Empty,
 896            impersonator=field::Empty,
 897            dev_server_id=field::Empty
 898        );
 899        principal.update_span(&span);
 900
 901        let mut teardown = self.teardown.subscribe();
 902        async move {
 903            if *teardown.borrow() {
 904                tracing::error!("server is tearing down");
 905                return
 906            }
 907            let (connection_id, handle_io, mut incoming_rx) = this
 908                .peer
 909                .add_connection(connection, {
 910                    let executor = executor.clone();
 911                    move |duration| executor.sleep(duration)
 912                });
 913            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 914            tracing::info!("connection opened");
 915
 916            let http_client = match IsahcHttpClient::new() {
 917                Ok(http_client) => http_client,
 918                Err(error) => {
 919                    tracing::error!(?error, "failed to create HTTP client");
 920                    return;
 921                }
 922            };
 923
 924            let session = Session {
 925                principal: principal.clone(),
 926                connection_id,
 927                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 928                peer: this.peer.clone(),
 929                connection_pool: this.connection_pool.clone(),
 930                live_kit_client: this.app_state.live_kit_client.clone(),
 931                http_client,
 932                rate_limiter: this.app_state.rate_limiter.clone(),
 933                _executor: executor.clone(),
 934            };
 935
 936            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 937                tracing::error!(?error, "failed to send initial client update");
 938                return;
 939            }
 940
 941            let handle_io = handle_io.fuse();
 942            futures::pin_mut!(handle_io);
 943
 944            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 945            // This prevents deadlocks when e.g., client A performs a request to client B and
 946            // client B performs a request to client A. If both clients stop processing further
 947            // messages until their respective request completes, they won't have a chance to
 948            // respond to the other client's request and cause a deadlock.
 949            //
 950            // This arrangement ensures we will attempt to process earlier messages first, but fall
 951            // back to processing messages arrived later in the spirit of making progress.
 952            let mut foreground_message_handlers = FuturesUnordered::new();
 953            let concurrent_handlers = Arc::new(Semaphore::new(256));
 954            loop {
 955                let next_message = async {
 956                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 957                    let message = incoming_rx.next().await;
 958                    (permit, message)
 959                }.fuse();
 960                futures::pin_mut!(next_message);
 961                futures::select_biased! {
 962                    _ = teardown.changed().fuse() => return,
 963                    result = handle_io => {
 964                        if let Err(error) = result {
 965                            tracing::error!(?error, "error handling I/O");
 966                        }
 967                        break;
 968                    }
 969                    _ = foreground_message_handlers.next() => {}
 970                    next_message = next_message => {
 971                        let (permit, message) = next_message;
 972                        if let Some(message) = message {
 973                            let type_name = message.payload_type_name();
 974                            // note: we copy all the fields from the parent span so we can query them in the logs.
 975                            // (https://github.com/tokio-rs/tracing/issues/2670).
 976                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 977                                user_id=field::Empty,
 978                                login=field::Empty,
 979                                impersonator=field::Empty,
 980                                dev_server_id=field::Empty
 981                            );
 982                            principal.update_span(&span);
 983                            let span_enter = span.enter();
 984                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 985                                let is_background = message.is_background();
 986                                let handle_message = (handler)(message, session.clone());
 987                                drop(span_enter);
 988
 989                                let handle_message = async move {
 990                                    handle_message.await;
 991                                    drop(permit);
 992                                }.instrument(span);
 993                                if is_background {
 994                                    executor.spawn_detached(handle_message);
 995                                } else {
 996                                    foreground_message_handlers.push(handle_message);
 997                                }
 998                            } else {
 999                                tracing::error!("no message handler");
1000                            }
1001                        } else {
1002                            tracing::info!("connection closed");
1003                            break;
1004                        }
1005                    }
1006                }
1007            }
1008
1009            drop(foreground_message_handlers);
1010            tracing::info!("signing out");
1011            if let Err(error) = connection_lost(session, teardown, executor).await {
1012                tracing::error!(?error, "error signing out");
1013            }
1014
1015        }.instrument(span)
1016    }
1017
1018    async fn send_initial_client_update(
1019        &self,
1020        connection_id: ConnectionId,
1021        principal: &Principal,
1022        zed_version: ZedVersion,
1023        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1024        session: &Session,
1025    ) -> Result<()> {
1026        self.peer.send(
1027            connection_id,
1028            proto::Hello {
1029                peer_id: Some(connection_id.into()),
1030            },
1031        )?;
1032        tracing::info!("sent hello message");
1033        if let Some(send_connection_id) = send_connection_id.take() {
1034            let _ = send_connection_id.send(connection_id);
1035        }
1036
1037        match principal {
1038            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1039                if !user.connected_once {
1040                    self.peer.send(connection_id, proto::ShowContacts {})?;
1041                    self.app_state
1042                        .db
1043                        .set_user_connected_once(user.id, true)
1044                        .await?;
1045                }
1046
1047                let (contacts, channels_for_user, channel_invites) = future::try_join3(
1048                    self.app_state.db.get_contacts(user.id),
1049                    self.app_state.db.get_channels_for_user(user.id),
1050                    self.app_state.db.get_channel_invites_for_user(user.id),
1051                )
1052                .await?;
1053
1054                {
1055                    let mut pool = self.connection_pool.lock();
1056                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
1057                    for membership in &channels_for_user.channel_memberships {
1058                        pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1059                    }
1060                    self.peer.send(
1061                        connection_id,
1062                        build_initial_contacts_update(contacts, &pool),
1063                    )?;
1064                    self.peer.send(
1065                        connection_id,
1066                        build_update_user_channels(&channels_for_user),
1067                    )?;
1068                    self.peer.send(
1069                        connection_id,
1070                        build_channels_update(channels_for_user, channel_invites, &pool),
1071                    )?;
1072                }
1073
1074                if let Some(incoming_call) =
1075                    self.app_state.db.incoming_call_for_user(user.id).await?
1076                {
1077                    self.peer.send(connection_id, incoming_call)?;
1078                }
1079
1080                update_user_contacts(user.id, &session).await?;
1081            }
1082            Principal::DevServer(dev_server) => {
1083                {
1084                    let mut pool = self.connection_pool.lock();
1085                    if pool.dev_server_connection_id(dev_server.id).is_some() {
1086                        return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1087                    };
1088                    pool.add_dev_server(connection_id, dev_server.id, zed_version);
1089                }
1090                update_dev_server_status(dev_server, proto::DevServerStatus::Online, &session)
1091                    .await;
1092                // todo!() allow only one connection.
1093
1094                let projects = self
1095                    .app_state
1096                    .db
1097                    .get_remote_projects_for_dev_server(dev_server.id)
1098                    .await?;
1099                self.peer
1100                    .send(connection_id, proto::DevServerInstructions { projects })?;
1101            }
1102        }
1103
1104        Ok(())
1105    }
1106
1107    pub async fn invite_code_redeemed(
1108        self: &Arc<Self>,
1109        inviter_id: UserId,
1110        invitee_id: UserId,
1111    ) -> Result<()> {
1112        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1113            if let Some(code) = &user.invite_code {
1114                let pool = self.connection_pool.lock();
1115                let invitee_contact = contact_for_user(invitee_id, false, &pool);
1116                for connection_id in pool.user_connection_ids(inviter_id) {
1117                    self.peer.send(
1118                        connection_id,
1119                        proto::UpdateContacts {
1120                            contacts: vec![invitee_contact.clone()],
1121                            ..Default::default()
1122                        },
1123                    )?;
1124                    self.peer.send(
1125                        connection_id,
1126                        proto::UpdateInviteInfo {
1127                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1128                            count: user.invite_count as u32,
1129                        },
1130                    )?;
1131                }
1132            }
1133        }
1134        Ok(())
1135    }
1136
1137    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1138        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1139            if let Some(invite_code) = &user.invite_code {
1140                let pool = self.connection_pool.lock();
1141                for connection_id in pool.user_connection_ids(user_id) {
1142                    self.peer.send(
1143                        connection_id,
1144                        proto::UpdateInviteInfo {
1145                            url: format!(
1146                                "{}{}",
1147                                self.app_state.config.invite_link_prefix, invite_code
1148                            ),
1149                            count: user.invite_count as u32,
1150                        },
1151                    )?;
1152                }
1153            }
1154        }
1155        Ok(())
1156    }
1157
1158    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1159        ServerSnapshot {
1160            connection_pool: ConnectionPoolGuard {
1161                guard: self.connection_pool.lock(),
1162                _not_send: PhantomData,
1163            },
1164            peer: &self.peer,
1165        }
1166    }
1167}
1168
1169impl<'a> Deref for ConnectionPoolGuard<'a> {
1170    type Target = ConnectionPool;
1171
1172    fn deref(&self) -> &Self::Target {
1173        &self.guard
1174    }
1175}
1176
1177impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1178    fn deref_mut(&mut self) -> &mut Self::Target {
1179        &mut self.guard
1180    }
1181}
1182
1183impl<'a> Drop for ConnectionPoolGuard<'a> {
1184    fn drop(&mut self) {
1185        #[cfg(test)]
1186        self.check_invariants();
1187    }
1188}
1189
1190fn broadcast<F>(
1191    sender_id: Option<ConnectionId>,
1192    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1193    mut f: F,
1194) where
1195    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1196{
1197    for receiver_id in receiver_ids {
1198        if Some(receiver_id) != sender_id {
1199            if let Err(error) = f(receiver_id) {
1200                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1201            }
1202        }
1203    }
1204}
1205
1206pub struct ProtocolVersion(u32);
1207
1208impl Header for ProtocolVersion {
1209    fn name() -> &'static HeaderName {
1210        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1211        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1212    }
1213
1214    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1215    where
1216        Self: Sized,
1217        I: Iterator<Item = &'i axum::http::HeaderValue>,
1218    {
1219        let version = values
1220            .next()
1221            .ok_or_else(axum::headers::Error::invalid)?
1222            .to_str()
1223            .map_err(|_| axum::headers::Error::invalid())?
1224            .parse()
1225            .map_err(|_| axum::headers::Error::invalid())?;
1226        Ok(Self(version))
1227    }
1228
1229    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1230        values.extend([self.0.to_string().parse().unwrap()]);
1231    }
1232}
1233
1234pub struct AppVersionHeader(SemanticVersion);
1235impl Header for AppVersionHeader {
1236    fn name() -> &'static HeaderName {
1237        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1238        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1239    }
1240
1241    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1242    where
1243        Self: Sized,
1244        I: Iterator<Item = &'i axum::http::HeaderValue>,
1245    {
1246        let version = values
1247            .next()
1248            .ok_or_else(axum::headers::Error::invalid)?
1249            .to_str()
1250            .map_err(|_| axum::headers::Error::invalid())?
1251            .parse()
1252            .map_err(|_| axum::headers::Error::invalid())?;
1253        Ok(Self(version))
1254    }
1255
1256    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1257        values.extend([self.0.to_string().parse().unwrap()]);
1258    }
1259}
1260
1261pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1262    Router::new()
1263        .route("/rpc", get(handle_websocket_request))
1264        .layer(
1265            ServiceBuilder::new()
1266                .layer(Extension(server.app_state.clone()))
1267                .layer(middleware::from_fn(auth::validate_header)),
1268        )
1269        .route("/metrics", get(handle_metrics))
1270        .layer(Extension(server))
1271}
1272
1273pub async fn handle_websocket_request(
1274    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1275    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1276    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1277    Extension(server): Extension<Arc<Server>>,
1278    Extension(principal): Extension<Principal>,
1279    ws: WebSocketUpgrade,
1280) -> axum::response::Response {
1281    if protocol_version != rpc::PROTOCOL_VERSION {
1282        return (
1283            StatusCode::UPGRADE_REQUIRED,
1284            "client must be upgraded".to_string(),
1285        )
1286            .into_response();
1287    }
1288
1289    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1290        return (
1291            StatusCode::UPGRADE_REQUIRED,
1292            "no version header found".to_string(),
1293        )
1294            .into_response();
1295    };
1296
1297    if !version.can_collaborate() {
1298        return (
1299            StatusCode::UPGRADE_REQUIRED,
1300            "client must be upgraded".to_string(),
1301        )
1302            .into_response();
1303    }
1304
1305    let socket_address = socket_address.to_string();
1306    ws.on_upgrade(move |socket| {
1307        let socket = socket
1308            .map_ok(to_tungstenite_message)
1309            .err_into()
1310            .with(|message| async move { Ok(to_axum_message(message)) });
1311        let connection = Connection::new(Box::pin(socket));
1312        async move {
1313            server
1314                .handle_connection(
1315                    connection,
1316                    socket_address,
1317                    principal,
1318                    version,
1319                    None,
1320                    Executor::Production,
1321                )
1322                .await;
1323        }
1324    })
1325}
1326
1327pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1328    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1329    let connections_metric = CONNECTIONS_METRIC
1330        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1331
1332    let connections = server
1333        .connection_pool
1334        .lock()
1335        .connections()
1336        .filter(|connection| !connection.admin)
1337        .count();
1338    connections_metric.set(connections as _);
1339
1340    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1341    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1342        register_int_gauge!(
1343            "shared_projects",
1344            "number of open projects with one or more guests"
1345        )
1346        .unwrap()
1347    });
1348
1349    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1350    shared_projects_metric.set(shared_projects as _);
1351
1352    let encoder = prometheus::TextEncoder::new();
1353    let metric_families = prometheus::gather();
1354    let encoded_metrics = encoder
1355        .encode_to_string(&metric_families)
1356        .map_err(|err| anyhow!("{}", err))?;
1357    Ok(encoded_metrics)
1358}
1359
1360#[instrument(err, skip(executor))]
1361async fn connection_lost(
1362    session: Session,
1363    mut teardown: watch::Receiver<bool>,
1364    executor: Executor,
1365) -> Result<()> {
1366    session.peer.disconnect(session.connection_id);
1367    session
1368        .connection_pool()
1369        .await
1370        .remove_connection(session.connection_id)?;
1371
1372    session
1373        .db()
1374        .await
1375        .connection_lost(session.connection_id)
1376        .await
1377        .trace_err();
1378
1379    futures::select_biased! {
1380        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1381            match &session.principal {
1382                Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1383                    let session = session.for_user().unwrap();
1384
1385                    log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1386                    leave_room_for_session(&session, session.connection_id).await.trace_err();
1387                    leave_channel_buffers_for_session(&session)
1388                        .await
1389                        .trace_err();
1390
1391                    if !session
1392                        .connection_pool()
1393                        .await
1394                        .is_user_online(session.user_id())
1395                    {
1396                        let db = session.db().await;
1397                        if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1398                            room_updated(&room, &session.peer);
1399                        }
1400                    }
1401
1402                    update_user_contacts(session.user_id(), &session).await?;
1403                },
1404            Principal::DevServer(dev_server) => {
1405                lost_dev_server_connection(&session).await?;
1406                update_dev_server_status(&dev_server, proto::DevServerStatus::Offline, &session)
1407                    .await;
1408            },
1409        }
1410        },
1411        _ = teardown.changed().fuse() => {}
1412    }
1413
1414    Ok(())
1415}
1416
1417/// Acknowledges a ping from a client, used to keep the connection alive.
1418async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1419    response.send(proto::Ack {})?;
1420    Ok(())
1421}
1422
1423/// Creates a new room for calling (outside of channels)
1424async fn create_room(
1425    _request: proto::CreateRoom,
1426    response: Response<proto::CreateRoom>,
1427    session: UserSession,
1428) -> Result<()> {
1429    let live_kit_room = nanoid::nanoid!(30);
1430
1431    let live_kit_connection_info = util::maybe!(async {
1432        let live_kit = session.live_kit_client.as_ref();
1433        let live_kit = live_kit?;
1434        let user_id = session.user_id().to_string();
1435
1436        let token = live_kit
1437            .room_token(&live_kit_room, &user_id.to_string())
1438            .trace_err()?;
1439
1440        Some(proto::LiveKitConnectionInfo {
1441            server_url: live_kit.url().into(),
1442            token,
1443            can_publish: true,
1444        })
1445    })
1446    .await;
1447
1448    let room = session
1449        .db()
1450        .await
1451        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1452        .await?;
1453
1454    response.send(proto::CreateRoomResponse {
1455        room: Some(room.clone()),
1456        live_kit_connection_info,
1457    })?;
1458
1459    update_user_contacts(session.user_id(), &session).await?;
1460    Ok(())
1461}
1462
1463/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1464async fn join_room(
1465    request: proto::JoinRoom,
1466    response: Response<proto::JoinRoom>,
1467    session: UserSession,
1468) -> Result<()> {
1469    let room_id = RoomId::from_proto(request.id);
1470
1471    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1472
1473    if let Some(channel_id) = channel_id {
1474        return join_channel_internal(channel_id, Box::new(response), session).await;
1475    }
1476
1477    let joined_room = {
1478        let room = session
1479            .db()
1480            .await
1481            .join_room(room_id, session.user_id(), session.connection_id)
1482            .await?;
1483        room_updated(&room.room, &session.peer);
1484        room.into_inner()
1485    };
1486
1487    for connection_id in session
1488        .connection_pool()
1489        .await
1490        .user_connection_ids(session.user_id())
1491    {
1492        session
1493            .peer
1494            .send(
1495                connection_id,
1496                proto::CallCanceled {
1497                    room_id: room_id.to_proto(),
1498                },
1499            )
1500            .trace_err();
1501    }
1502
1503    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1504        if let Some(token) = live_kit
1505            .room_token(
1506                &joined_room.room.live_kit_room,
1507                &session.user_id().to_string(),
1508            )
1509            .trace_err()
1510        {
1511            Some(proto::LiveKitConnectionInfo {
1512                server_url: live_kit.url().into(),
1513                token,
1514                can_publish: true,
1515            })
1516        } else {
1517            None
1518        }
1519    } else {
1520        None
1521    };
1522
1523    response.send(proto::JoinRoomResponse {
1524        room: Some(joined_room.room),
1525        channel_id: None,
1526        live_kit_connection_info,
1527    })?;
1528
1529    update_user_contacts(session.user_id(), &session).await?;
1530    Ok(())
1531}
1532
1533/// Rejoin room is used to reconnect to a room after connection errors.
1534async fn rejoin_room(
1535    request: proto::RejoinRoom,
1536    response: Response<proto::RejoinRoom>,
1537    session: UserSession,
1538) -> Result<()> {
1539    let room;
1540    let channel;
1541    {
1542        let mut rejoined_room = session
1543            .db()
1544            .await
1545            .rejoin_room(request, session.user_id(), session.connection_id)
1546            .await?;
1547
1548        response.send(proto::RejoinRoomResponse {
1549            room: Some(rejoined_room.room.clone()),
1550            reshared_projects: rejoined_room
1551                .reshared_projects
1552                .iter()
1553                .map(|project| proto::ResharedProject {
1554                    id: project.id.to_proto(),
1555                    collaborators: project
1556                        .collaborators
1557                        .iter()
1558                        .map(|collaborator| collaborator.to_proto())
1559                        .collect(),
1560                })
1561                .collect(),
1562            rejoined_projects: rejoined_room
1563                .rejoined_projects
1564                .iter()
1565                .map(|rejoined_project| rejoined_project.to_proto())
1566                .collect(),
1567        })?;
1568        room_updated(&rejoined_room.room, &session.peer);
1569
1570        for project in &rejoined_room.reshared_projects {
1571            for collaborator in &project.collaborators {
1572                session
1573                    .peer
1574                    .send(
1575                        collaborator.connection_id,
1576                        proto::UpdateProjectCollaborator {
1577                            project_id: project.id.to_proto(),
1578                            old_peer_id: Some(project.old_connection_id.into()),
1579                            new_peer_id: Some(session.connection_id.into()),
1580                        },
1581                    )
1582                    .trace_err();
1583            }
1584
1585            broadcast(
1586                Some(session.connection_id),
1587                project
1588                    .collaborators
1589                    .iter()
1590                    .map(|collaborator| collaborator.connection_id),
1591                |connection_id| {
1592                    session.peer.forward_send(
1593                        session.connection_id,
1594                        connection_id,
1595                        proto::UpdateProject {
1596                            project_id: project.id.to_proto(),
1597                            worktrees: project.worktrees.clone(),
1598                        },
1599                    )
1600                },
1601            );
1602        }
1603
1604        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1605
1606        let rejoined_room = rejoined_room.into_inner();
1607
1608        room = rejoined_room.room;
1609        channel = rejoined_room.channel;
1610    }
1611
1612    if let Some(channel) = channel {
1613        channel_updated(
1614            &channel,
1615            &room,
1616            &session.peer,
1617            &*session.connection_pool().await,
1618        );
1619    }
1620
1621    update_user_contacts(session.user_id(), &session).await?;
1622    Ok(())
1623}
1624
1625fn notify_rejoined_projects(
1626    rejoined_projects: &mut Vec<RejoinedProject>,
1627    session: &UserSession,
1628) -> Result<()> {
1629    for project in rejoined_projects.iter() {
1630        for collaborator in &project.collaborators {
1631            session
1632                .peer
1633                .send(
1634                    collaborator.connection_id,
1635                    proto::UpdateProjectCollaborator {
1636                        project_id: project.id.to_proto(),
1637                        old_peer_id: Some(project.old_connection_id.into()),
1638                        new_peer_id: Some(session.connection_id.into()),
1639                    },
1640                )
1641                .trace_err();
1642        }
1643    }
1644
1645    for project in rejoined_projects {
1646        for worktree in mem::take(&mut project.worktrees) {
1647            #[cfg(any(test, feature = "test-support"))]
1648            const MAX_CHUNK_SIZE: usize = 2;
1649            #[cfg(not(any(test, feature = "test-support")))]
1650            const MAX_CHUNK_SIZE: usize = 256;
1651
1652            // Stream this worktree's entries.
1653            let message = proto::UpdateWorktree {
1654                project_id: project.id.to_proto(),
1655                worktree_id: worktree.id,
1656                abs_path: worktree.abs_path.clone(),
1657                root_name: worktree.root_name,
1658                updated_entries: worktree.updated_entries,
1659                removed_entries: worktree.removed_entries,
1660                scan_id: worktree.scan_id,
1661                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1662                updated_repositories: worktree.updated_repositories,
1663                removed_repositories: worktree.removed_repositories,
1664            };
1665            for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1666                session.peer.send(session.connection_id, update.clone())?;
1667            }
1668
1669            // Stream this worktree's diagnostics.
1670            for summary in worktree.diagnostic_summaries {
1671                session.peer.send(
1672                    session.connection_id,
1673                    proto::UpdateDiagnosticSummary {
1674                        project_id: project.id.to_proto(),
1675                        worktree_id: worktree.id,
1676                        summary: Some(summary),
1677                    },
1678                )?;
1679            }
1680
1681            for settings_file in worktree.settings_files {
1682                session.peer.send(
1683                    session.connection_id,
1684                    proto::UpdateWorktreeSettings {
1685                        project_id: project.id.to_proto(),
1686                        worktree_id: worktree.id,
1687                        path: settings_file.path,
1688                        content: Some(settings_file.content),
1689                    },
1690                )?;
1691            }
1692        }
1693
1694        for language_server in &project.language_servers {
1695            session.peer.send(
1696                session.connection_id,
1697                proto::UpdateLanguageServer {
1698                    project_id: project.id.to_proto(),
1699                    language_server_id: language_server.id,
1700                    variant: Some(
1701                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1702                            proto::LspDiskBasedDiagnosticsUpdated {},
1703                        ),
1704                    ),
1705                },
1706            )?;
1707        }
1708    }
1709    Ok(())
1710}
1711
1712/// leave room disconnects from the room.
1713async fn leave_room(
1714    _: proto::LeaveRoom,
1715    response: Response<proto::LeaveRoom>,
1716    session: UserSession,
1717) -> Result<()> {
1718    leave_room_for_session(&session, session.connection_id).await?;
1719    response.send(proto::Ack {})?;
1720    Ok(())
1721}
1722
1723/// Updates the permissions of someone else in the room.
1724async fn set_room_participant_role(
1725    request: proto::SetRoomParticipantRole,
1726    response: Response<proto::SetRoomParticipantRole>,
1727    session: UserSession,
1728) -> Result<()> {
1729    let user_id = UserId::from_proto(request.user_id);
1730    let role = ChannelRole::from(request.role());
1731
1732    let (live_kit_room, can_publish) = {
1733        let room = session
1734            .db()
1735            .await
1736            .set_room_participant_role(
1737                session.user_id(),
1738                RoomId::from_proto(request.room_id),
1739                user_id,
1740                role,
1741            )
1742            .await?;
1743
1744        let live_kit_room = room.live_kit_room.clone();
1745        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1746        room_updated(&room, &session.peer);
1747        (live_kit_room, can_publish)
1748    };
1749
1750    if let Some(live_kit) = session.live_kit_client.as_ref() {
1751        live_kit
1752            .update_participant(
1753                live_kit_room.clone(),
1754                request.user_id.to_string(),
1755                live_kit_server::proto::ParticipantPermission {
1756                    can_subscribe: true,
1757                    can_publish,
1758                    can_publish_data: can_publish,
1759                    hidden: false,
1760                    recorder: false,
1761                },
1762            )
1763            .await
1764            .trace_err();
1765    }
1766
1767    response.send(proto::Ack {})?;
1768    Ok(())
1769}
1770
1771/// Call someone else into the current room
1772async fn call(
1773    request: proto::Call,
1774    response: Response<proto::Call>,
1775    session: UserSession,
1776) -> Result<()> {
1777    let room_id = RoomId::from_proto(request.room_id);
1778    let calling_user_id = session.user_id();
1779    let calling_connection_id = session.connection_id;
1780    let called_user_id = UserId::from_proto(request.called_user_id);
1781    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1782    if !session
1783        .db()
1784        .await
1785        .has_contact(calling_user_id, called_user_id)
1786        .await?
1787    {
1788        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1789    }
1790
1791    let incoming_call = {
1792        let (room, incoming_call) = &mut *session
1793            .db()
1794            .await
1795            .call(
1796                room_id,
1797                calling_user_id,
1798                calling_connection_id,
1799                called_user_id,
1800                initial_project_id,
1801            )
1802            .await?;
1803        room_updated(&room, &session.peer);
1804        mem::take(incoming_call)
1805    };
1806    update_user_contacts(called_user_id, &session).await?;
1807
1808    let mut calls = session
1809        .connection_pool()
1810        .await
1811        .user_connection_ids(called_user_id)
1812        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1813        .collect::<FuturesUnordered<_>>();
1814
1815    while let Some(call_response) = calls.next().await {
1816        match call_response.as_ref() {
1817            Ok(_) => {
1818                response.send(proto::Ack {})?;
1819                return Ok(());
1820            }
1821            Err(_) => {
1822                call_response.trace_err();
1823            }
1824        }
1825    }
1826
1827    {
1828        let room = session
1829            .db()
1830            .await
1831            .call_failed(room_id, called_user_id)
1832            .await?;
1833        room_updated(&room, &session.peer);
1834    }
1835    update_user_contacts(called_user_id, &session).await?;
1836
1837    Err(anyhow!("failed to ring user"))?
1838}
1839
1840/// Cancel an outgoing call.
1841async fn cancel_call(
1842    request: proto::CancelCall,
1843    response: Response<proto::CancelCall>,
1844    session: UserSession,
1845) -> Result<()> {
1846    let called_user_id = UserId::from_proto(request.called_user_id);
1847    let room_id = RoomId::from_proto(request.room_id);
1848    {
1849        let room = session
1850            .db()
1851            .await
1852            .cancel_call(room_id, session.connection_id, called_user_id)
1853            .await?;
1854        room_updated(&room, &session.peer);
1855    }
1856
1857    for connection_id in session
1858        .connection_pool()
1859        .await
1860        .user_connection_ids(called_user_id)
1861    {
1862        session
1863            .peer
1864            .send(
1865                connection_id,
1866                proto::CallCanceled {
1867                    room_id: room_id.to_proto(),
1868                },
1869            )
1870            .trace_err();
1871    }
1872    response.send(proto::Ack {})?;
1873
1874    update_user_contacts(called_user_id, &session).await?;
1875    Ok(())
1876}
1877
1878/// Decline an incoming call.
1879async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1880    let room_id = RoomId::from_proto(message.room_id);
1881    {
1882        let room = session
1883            .db()
1884            .await
1885            .decline_call(Some(room_id), session.user_id())
1886            .await?
1887            .ok_or_else(|| anyhow!("failed to decline call"))?;
1888        room_updated(&room, &session.peer);
1889    }
1890
1891    for connection_id in session
1892        .connection_pool()
1893        .await
1894        .user_connection_ids(session.user_id())
1895    {
1896        session
1897            .peer
1898            .send(
1899                connection_id,
1900                proto::CallCanceled {
1901                    room_id: room_id.to_proto(),
1902                },
1903            )
1904            .trace_err();
1905    }
1906    update_user_contacts(session.user_id(), &session).await?;
1907    Ok(())
1908}
1909
1910/// Updates other participants in the room with your current location.
1911async fn update_participant_location(
1912    request: proto::UpdateParticipantLocation,
1913    response: Response<proto::UpdateParticipantLocation>,
1914    session: UserSession,
1915) -> Result<()> {
1916    let room_id = RoomId::from_proto(request.room_id);
1917    let location = request
1918        .location
1919        .ok_or_else(|| anyhow!("invalid location"))?;
1920
1921    let db = session.db().await;
1922    let room = db
1923        .update_room_participant_location(room_id, session.connection_id, location)
1924        .await?;
1925
1926    room_updated(&room, &session.peer);
1927    response.send(proto::Ack {})?;
1928    Ok(())
1929}
1930
1931/// Share a project into the room.
1932async fn share_project(
1933    request: proto::ShareProject,
1934    response: Response<proto::ShareProject>,
1935    session: UserSession,
1936) -> Result<()> {
1937    let (project_id, room) = &*session
1938        .db()
1939        .await
1940        .share_project(
1941            RoomId::from_proto(request.room_id),
1942            session.connection_id,
1943            &request.worktrees,
1944        )
1945        .await?;
1946    response.send(proto::ShareProjectResponse {
1947        project_id: project_id.to_proto(),
1948    })?;
1949    room_updated(&room, &session.peer);
1950
1951    Ok(())
1952}
1953
1954/// Unshare a project from the room.
1955async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1956    let project_id = ProjectId::from_proto(message.project_id);
1957    unshare_project_internal(project_id, &session).await
1958}
1959
1960async fn unshare_project_internal(project_id: ProjectId, session: &Session) -> Result<()> {
1961    let (room, guest_connection_ids) = &*session
1962        .db()
1963        .await
1964        .unshare_project(project_id, session.connection_id)
1965        .await?;
1966
1967    let message = proto::UnshareProject {
1968        project_id: project_id.to_proto(),
1969    };
1970
1971    broadcast(
1972        Some(session.connection_id),
1973        guest_connection_ids.iter().copied(),
1974        |conn_id| session.peer.send(conn_id, message.clone()),
1975    );
1976    if let Some(room) = room {
1977        room_updated(room, &session.peer);
1978    }
1979
1980    Ok(())
1981}
1982
1983/// Share a project into the room.
1984async fn share_remote_project(
1985    request: proto::ShareRemoteProject,
1986    response: Response<proto::ShareRemoteProject>,
1987    session: DevServerSession,
1988) -> Result<()> {
1989    let remote_project = session
1990        .db()
1991        .await
1992        .share_remote_project(
1993            RemoteProjectId::from_proto(request.remote_project_id),
1994            session.dev_server_id(),
1995            session.connection_id,
1996            &request.worktrees,
1997        )
1998        .await?;
1999    let Some(project_id) = remote_project.project_id else {
2000        return Err(anyhow!("failed to share remote project"))?;
2001    };
2002
2003    for (connection_id, _) in session
2004        .connection_pool()
2005        .await
2006        .channel_connection_ids(ChannelId::from_proto(remote_project.channel_id))
2007    {
2008        session
2009            .peer
2010            .send(
2011                connection_id,
2012                proto::UpdateChannels {
2013                    remote_projects: vec![remote_project.clone()],
2014                    ..Default::default()
2015                },
2016            )
2017            .trace_err();
2018    }
2019
2020    response.send(proto::ShareProjectResponse { project_id })?;
2021
2022    Ok(())
2023}
2024
2025/// Join someone elses shared project.
2026async fn join_project(
2027    request: proto::JoinProject,
2028    response: Response<proto::JoinProject>,
2029    session: UserSession,
2030) -> Result<()> {
2031    let project_id = ProjectId::from_proto(request.project_id);
2032
2033    tracing::info!(%project_id, "join project");
2034
2035    let db = session.db().await;
2036    let (project, replica_id) = &mut *db
2037        .join_project(project_id, session.connection_id, session.user_id())
2038        .await?;
2039    drop(db);
2040    tracing::info!(%project_id, "join remote project");
2041    join_project_internal(response, session, project, replica_id)
2042}
2043
2044trait JoinProjectInternalResponse {
2045    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2046}
2047impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2048    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2049        Response::<proto::JoinProject>::send(self, result)
2050    }
2051}
2052impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2053    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2054        Response::<proto::JoinHostedProject>::send(self, result)
2055    }
2056}
2057
2058fn join_project_internal(
2059    response: impl JoinProjectInternalResponse,
2060    session: UserSession,
2061    project: &mut Project,
2062    replica_id: &ReplicaId,
2063) -> Result<()> {
2064    let collaborators = project
2065        .collaborators
2066        .iter()
2067        .filter(|collaborator| collaborator.connection_id != session.connection_id)
2068        .map(|collaborator| collaborator.to_proto())
2069        .collect::<Vec<_>>();
2070    let project_id = project.id;
2071    let guest_user_id = session.user_id();
2072
2073    let worktrees = project
2074        .worktrees
2075        .iter()
2076        .map(|(id, worktree)| proto::WorktreeMetadata {
2077            id: *id,
2078            root_name: worktree.root_name.clone(),
2079            visible: worktree.visible,
2080            abs_path: worktree.abs_path.clone(),
2081        })
2082        .collect::<Vec<_>>();
2083
2084    for collaborator in &collaborators {
2085        session
2086            .peer
2087            .send(
2088                collaborator.peer_id.unwrap().into(),
2089                proto::AddProjectCollaborator {
2090                    project_id: project_id.to_proto(),
2091                    collaborator: Some(proto::Collaborator {
2092                        peer_id: Some(session.connection_id.into()),
2093                        replica_id: replica_id.0 as u32,
2094                        user_id: guest_user_id.to_proto(),
2095                    }),
2096                },
2097            )
2098            .trace_err();
2099    }
2100
2101    // First, we send the metadata associated with each worktree.
2102    response.send(proto::JoinProjectResponse {
2103        project_id: project.id.0 as u64,
2104        worktrees: worktrees.clone(),
2105        replica_id: replica_id.0 as u32,
2106        collaborators: collaborators.clone(),
2107        language_servers: project.language_servers.clone(),
2108        role: project.role.into(), // todo
2109    })?;
2110
2111    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2112        #[cfg(any(test, feature = "test-support"))]
2113        const MAX_CHUNK_SIZE: usize = 2;
2114        #[cfg(not(any(test, feature = "test-support")))]
2115        const MAX_CHUNK_SIZE: usize = 256;
2116
2117        // Stream this worktree's entries.
2118        let message = proto::UpdateWorktree {
2119            project_id: project_id.to_proto(),
2120            worktree_id,
2121            abs_path: worktree.abs_path.clone(),
2122            root_name: worktree.root_name,
2123            updated_entries: worktree.entries,
2124            removed_entries: Default::default(),
2125            scan_id: worktree.scan_id,
2126            is_last_update: worktree.scan_id == worktree.completed_scan_id,
2127            updated_repositories: worktree.repository_entries.into_values().collect(),
2128            removed_repositories: Default::default(),
2129        };
2130        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2131            session.peer.send(session.connection_id, update.clone())?;
2132        }
2133
2134        // Stream this worktree's diagnostics.
2135        for summary in worktree.diagnostic_summaries {
2136            session.peer.send(
2137                session.connection_id,
2138                proto::UpdateDiagnosticSummary {
2139                    project_id: project_id.to_proto(),
2140                    worktree_id: worktree.id,
2141                    summary: Some(summary),
2142                },
2143            )?;
2144        }
2145
2146        for settings_file in worktree.settings_files {
2147            session.peer.send(
2148                session.connection_id,
2149                proto::UpdateWorktreeSettings {
2150                    project_id: project_id.to_proto(),
2151                    worktree_id: worktree.id,
2152                    path: settings_file.path,
2153                    content: Some(settings_file.content),
2154                },
2155            )?;
2156        }
2157    }
2158
2159    for language_server in &project.language_servers {
2160        session.peer.send(
2161            session.connection_id,
2162            proto::UpdateLanguageServer {
2163                project_id: project_id.to_proto(),
2164                language_server_id: language_server.id,
2165                variant: Some(
2166                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2167                        proto::LspDiskBasedDiagnosticsUpdated {},
2168                    ),
2169                ),
2170            },
2171        )?;
2172    }
2173
2174    Ok(())
2175}
2176
2177/// Leave someone elses shared project.
2178async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2179    let sender_id = session.connection_id;
2180    let project_id = ProjectId::from_proto(request.project_id);
2181    let db = session.db().await;
2182    if db.is_hosted_project(project_id).await? {
2183        let project = db.leave_hosted_project(project_id, sender_id).await?;
2184        project_left(&project, &session);
2185        return Ok(());
2186    }
2187
2188    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2189    tracing::info!(
2190        %project_id,
2191        host_user_id = ?project.host_user_id,
2192        host_connection_id = ?project.host_connection_id,
2193        "leave project"
2194    );
2195
2196    project_left(&project, &session);
2197    if let Some(room) = room {
2198        room_updated(&room, &session.peer);
2199    }
2200
2201    Ok(())
2202}
2203
2204async fn join_hosted_project(
2205    request: proto::JoinHostedProject,
2206    response: Response<proto::JoinHostedProject>,
2207    session: UserSession,
2208) -> Result<()> {
2209    let (mut project, replica_id) = session
2210        .db()
2211        .await
2212        .join_hosted_project(
2213            ProjectId(request.project_id as i32),
2214            session.user_id(),
2215            session.connection_id,
2216        )
2217        .await?;
2218
2219    join_project_internal(response, session, &mut project, &replica_id)
2220}
2221
2222async fn create_remote_project(
2223    request: proto::CreateRemoteProject,
2224    response: Response<proto::CreateRemoteProject>,
2225    session: UserSession,
2226) -> Result<()> {
2227    let (channel, remote_project) = session
2228        .db()
2229        .await
2230        .create_remote_project(
2231            ChannelId(request.channel_id as i32),
2232            DevServerId(request.dev_server_id as i32),
2233            &request.name,
2234            &request.path,
2235            session.user_id(),
2236        )
2237        .await?;
2238
2239    let projects = session
2240        .db()
2241        .await
2242        .get_remote_projects_for_dev_server(remote_project.dev_server_id)
2243        .await?;
2244
2245    let update = proto::UpdateChannels {
2246        remote_projects: vec![remote_project.to_proto(None)],
2247        ..Default::default()
2248    };
2249    let connection_pool = session.connection_pool().await;
2250    for (connection_id, role) in connection_pool.channel_connection_ids(channel.root_id()) {
2251        if role.can_see_all_descendants() {
2252            session.peer.send(connection_id, update.clone())?;
2253        }
2254    }
2255
2256    let dev_server_id = remote_project.dev_server_id;
2257    let dev_server_connection_id = connection_pool.dev_server_connection_id(dev_server_id);
2258    if let Some(dev_server_connection_id) = dev_server_connection_id {
2259        session.peer.send(
2260            dev_server_connection_id,
2261            proto::DevServerInstructions { projects },
2262        )?;
2263    }
2264
2265    response.send(proto::CreateRemoteProjectResponse {
2266        remote_project: Some(remote_project.to_proto(None)),
2267    })?;
2268    Ok(())
2269}
2270
2271async fn create_dev_server(
2272    request: proto::CreateDevServer,
2273    response: Response<proto::CreateDevServer>,
2274    session: UserSession,
2275) -> Result<()> {
2276    let access_token = auth::random_token();
2277    let hashed_access_token = auth::hash_access_token(&access_token);
2278
2279    let (channel, dev_server) = session
2280        .db()
2281        .await
2282        .create_dev_server(
2283            ChannelId(request.channel_id as i32),
2284            &request.name,
2285            &hashed_access_token,
2286            session.user_id(),
2287        )
2288        .await?;
2289
2290    let update = proto::UpdateChannels {
2291        dev_servers: vec![dev_server.to_proto(proto::DevServerStatus::Offline)],
2292        ..Default::default()
2293    };
2294    let connection_pool = session.connection_pool().await;
2295    for (connection_id, role) in connection_pool.channel_connection_ids(channel.root_id()) {
2296        if role.can_see_channel(channel.visibility) {
2297            session.peer.send(connection_id, update.clone())?;
2298        }
2299    }
2300
2301    response.send(proto::CreateDevServerResponse {
2302        dev_server_id: dev_server.id.0 as u64,
2303        channel_id: request.channel_id,
2304        access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2305        name: request.name.clone(),
2306    })?;
2307    Ok(())
2308}
2309
2310async fn rejoin_remote_projects(
2311    request: proto::RejoinRemoteProjects,
2312    response: Response<proto::RejoinRemoteProjects>,
2313    session: UserSession,
2314) -> Result<()> {
2315    let mut rejoined_projects = {
2316        let db = session.db().await;
2317        db.rejoin_remote_projects(
2318            &request.rejoined_projects,
2319            session.user_id(),
2320            session.0.connection_id,
2321        )
2322        .await?
2323    };
2324    notify_rejoined_projects(&mut rejoined_projects, &session)?;
2325
2326    response.send(proto::RejoinRemoteProjectsResponse {
2327        rejoined_projects: rejoined_projects
2328            .into_iter()
2329            .map(|project| project.to_proto())
2330            .collect(),
2331    })
2332}
2333
2334async fn reconnect_dev_server(
2335    request: proto::ReconnectDevServer,
2336    response: Response<proto::ReconnectDevServer>,
2337    session: DevServerSession,
2338) -> Result<()> {
2339    let reshared_projects = {
2340        let db = session.db().await;
2341        db.reshare_remote_projects(
2342            &request.reshared_projects,
2343            session.dev_server_id(),
2344            session.0.connection_id,
2345        )
2346        .await?
2347    };
2348
2349    for project in &reshared_projects {
2350        for collaborator in &project.collaborators {
2351            session
2352                .peer
2353                .send(
2354                    collaborator.connection_id,
2355                    proto::UpdateProjectCollaborator {
2356                        project_id: project.id.to_proto(),
2357                        old_peer_id: Some(project.old_connection_id.into()),
2358                        new_peer_id: Some(session.connection_id.into()),
2359                    },
2360                )
2361                .trace_err();
2362        }
2363
2364        broadcast(
2365            Some(session.connection_id),
2366            project
2367                .collaborators
2368                .iter()
2369                .map(|collaborator| collaborator.connection_id),
2370            |connection_id| {
2371                session.peer.forward_send(
2372                    session.connection_id,
2373                    connection_id,
2374                    proto::UpdateProject {
2375                        project_id: project.id.to_proto(),
2376                        worktrees: project.worktrees.clone(),
2377                    },
2378                )
2379            },
2380        );
2381    }
2382
2383    response.send(proto::ReconnectDevServerResponse {
2384        reshared_projects: reshared_projects
2385            .iter()
2386            .map(|project| proto::ResharedProject {
2387                id: project.id.to_proto(),
2388                collaborators: project
2389                    .collaborators
2390                    .iter()
2391                    .map(|collaborator| collaborator.to_proto())
2392                    .collect(),
2393            })
2394            .collect(),
2395    })?;
2396
2397    Ok(())
2398}
2399
2400async fn shutdown_dev_server(
2401    _: proto::ShutdownDevServer,
2402    response: Response<proto::ShutdownDevServer>,
2403    session: DevServerSession,
2404) -> Result<()> {
2405    response.send(proto::Ack {})?;
2406    let (remote_projects, dev_server) = {
2407        let dev_server_id = session.dev_server_id();
2408        let db = session.db().await;
2409        let remote_projects = db.get_remote_projects_for_dev_server(dev_server_id).await?;
2410        let dev_server = db.get_dev_server(dev_server_id).await?;
2411        (remote_projects, dev_server)
2412    };
2413
2414    for project_id in remote_projects.iter().filter_map(|p| p.project_id) {
2415        unshare_project_internal(ProjectId::from_proto(project_id), &session.0).await?;
2416    }
2417
2418    let update = proto::UpdateChannels {
2419        remote_projects,
2420        dev_servers: vec![dev_server.to_proto(proto::DevServerStatus::Offline)],
2421        ..Default::default()
2422    };
2423
2424    for (connection_id, _) in session
2425        .connection_pool()
2426        .await
2427        .channel_connection_ids(dev_server.channel_id)
2428    {
2429        session.peer.send(connection_id, update.clone()).trace_err();
2430    }
2431
2432    Ok(())
2433}
2434
2435/// Updates other participants with changes to the project
2436async fn update_project(
2437    request: proto::UpdateProject,
2438    response: Response<proto::UpdateProject>,
2439    session: Session,
2440) -> Result<()> {
2441    let project_id = ProjectId::from_proto(request.project_id);
2442    let (room, guest_connection_ids) = &*session
2443        .db()
2444        .await
2445        .update_project(project_id, session.connection_id, &request.worktrees)
2446        .await?;
2447    broadcast(
2448        Some(session.connection_id),
2449        guest_connection_ids.iter().copied(),
2450        |connection_id| {
2451            session
2452                .peer
2453                .forward_send(session.connection_id, connection_id, request.clone())
2454        },
2455    );
2456    if let Some(room) = room {
2457        room_updated(&room, &session.peer);
2458    }
2459    response.send(proto::Ack {})?;
2460
2461    Ok(())
2462}
2463
2464/// Updates other participants with changes to the worktree
2465async fn update_worktree(
2466    request: proto::UpdateWorktree,
2467    response: Response<proto::UpdateWorktree>,
2468    session: Session,
2469) -> Result<()> {
2470    let guest_connection_ids = session
2471        .db()
2472        .await
2473        .update_worktree(&request, session.connection_id)
2474        .await?;
2475
2476    broadcast(
2477        Some(session.connection_id),
2478        guest_connection_ids.iter().copied(),
2479        |connection_id| {
2480            session
2481                .peer
2482                .forward_send(session.connection_id, connection_id, request.clone())
2483        },
2484    );
2485    response.send(proto::Ack {})?;
2486    Ok(())
2487}
2488
2489/// Updates other participants with changes to the diagnostics
2490async fn update_diagnostic_summary(
2491    message: proto::UpdateDiagnosticSummary,
2492    session: Session,
2493) -> Result<()> {
2494    let guest_connection_ids = session
2495        .db()
2496        .await
2497        .update_diagnostic_summary(&message, session.connection_id)
2498        .await?;
2499
2500    broadcast(
2501        Some(session.connection_id),
2502        guest_connection_ids.iter().copied(),
2503        |connection_id| {
2504            session
2505                .peer
2506                .forward_send(session.connection_id, connection_id, message.clone())
2507        },
2508    );
2509
2510    Ok(())
2511}
2512
2513/// Updates other participants with changes to the worktree settings
2514async fn update_worktree_settings(
2515    message: proto::UpdateWorktreeSettings,
2516    session: Session,
2517) -> Result<()> {
2518    let guest_connection_ids = session
2519        .db()
2520        .await
2521        .update_worktree_settings(&message, session.connection_id)
2522        .await?;
2523
2524    broadcast(
2525        Some(session.connection_id),
2526        guest_connection_ids.iter().copied(),
2527        |connection_id| {
2528            session
2529                .peer
2530                .forward_send(session.connection_id, connection_id, message.clone())
2531        },
2532    );
2533
2534    Ok(())
2535}
2536
2537/// Notify other participants that a  language server has started.
2538async fn start_language_server(
2539    request: proto::StartLanguageServer,
2540    session: Session,
2541) -> Result<()> {
2542    let guest_connection_ids = session
2543        .db()
2544        .await
2545        .start_language_server(&request, session.connection_id)
2546        .await?;
2547
2548    broadcast(
2549        Some(session.connection_id),
2550        guest_connection_ids.iter().copied(),
2551        |connection_id| {
2552            session
2553                .peer
2554                .forward_send(session.connection_id, connection_id, request.clone())
2555        },
2556    );
2557    Ok(())
2558}
2559
2560/// Notify other participants that a language server has changed.
2561async fn update_language_server(
2562    request: proto::UpdateLanguageServer,
2563    session: Session,
2564) -> Result<()> {
2565    let project_id = ProjectId::from_proto(request.project_id);
2566    let project_connection_ids = session
2567        .db()
2568        .await
2569        .project_connection_ids(project_id, session.connection_id, true)
2570        .await?;
2571    broadcast(
2572        Some(session.connection_id),
2573        project_connection_ids.iter().copied(),
2574        |connection_id| {
2575            session
2576                .peer
2577                .forward_send(session.connection_id, connection_id, request.clone())
2578        },
2579    );
2580    Ok(())
2581}
2582
2583/// forward a project request to the host. These requests should be read only
2584/// as guests are allowed to send them.
2585async fn forward_read_only_project_request<T>(
2586    request: T,
2587    response: Response<T>,
2588    session: UserSession,
2589) -> Result<()>
2590where
2591    T: EntityMessage + RequestMessage,
2592{
2593    let project_id = ProjectId::from_proto(request.remote_entity_id());
2594    let host_connection_id = session
2595        .db()
2596        .await
2597        .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2598        .await?;
2599    let payload = session
2600        .peer
2601        .forward_request(session.connection_id, host_connection_id, request)
2602        .await?;
2603    response.send(payload)?;
2604    Ok(())
2605}
2606
2607/// forward a project request to the host. These requests are disallowed
2608/// for guests.
2609async fn forward_mutating_project_request<T>(
2610    request: T,
2611    response: Response<T>,
2612    session: UserSession,
2613) -> Result<()>
2614where
2615    T: EntityMessage + RequestMessage,
2616{
2617    let project_id = ProjectId::from_proto(request.remote_entity_id());
2618    let host_connection_id = session
2619        .db()
2620        .await
2621        .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2622        .await?;
2623    let payload = session
2624        .peer
2625        .forward_request(session.connection_id, host_connection_id, request)
2626        .await?;
2627    response.send(payload)?;
2628    Ok(())
2629}
2630
2631/// Notify other participants that a new buffer has been created
2632async fn create_buffer_for_peer(
2633    request: proto::CreateBufferForPeer,
2634    session: Session,
2635) -> Result<()> {
2636    session
2637        .db()
2638        .await
2639        .check_user_is_project_host(
2640            ProjectId::from_proto(request.project_id),
2641            session.connection_id,
2642        )
2643        .await?;
2644    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2645    session
2646        .peer
2647        .forward_send(session.connection_id, peer_id.into(), request)?;
2648    Ok(())
2649}
2650
2651/// Notify other participants that a buffer has been updated. This is
2652/// allowed for guests as long as the update is limited to selections.
2653async fn update_buffer(
2654    request: proto::UpdateBuffer,
2655    response: Response<proto::UpdateBuffer>,
2656    session: Session,
2657) -> Result<()> {
2658    let project_id = ProjectId::from_proto(request.project_id);
2659    let mut capability = Capability::ReadOnly;
2660
2661    for op in request.operations.iter() {
2662        match op.variant {
2663            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2664            Some(_) => capability = Capability::ReadWrite,
2665        }
2666    }
2667
2668    let host = {
2669        let guard = session
2670            .db()
2671            .await
2672            .connections_for_buffer_update(
2673                project_id,
2674                session.principal_id(),
2675                session.connection_id,
2676                capability,
2677            )
2678            .await?;
2679
2680        let (host, guests) = &*guard;
2681
2682        broadcast(
2683            Some(session.connection_id),
2684            guests.clone(),
2685            |connection_id| {
2686                session
2687                    .peer
2688                    .forward_send(session.connection_id, connection_id, request.clone())
2689            },
2690        );
2691
2692        *host
2693    };
2694
2695    if host != session.connection_id {
2696        session
2697            .peer
2698            .forward_request(session.connection_id, host, request.clone())
2699            .await?;
2700    }
2701
2702    response.send(proto::Ack {})?;
2703    Ok(())
2704}
2705
2706/// Notify other participants that a project has been updated.
2707async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2708    request: T,
2709    session: Session,
2710) -> Result<()> {
2711    let project_id = ProjectId::from_proto(request.remote_entity_id());
2712    let project_connection_ids = session
2713        .db()
2714        .await
2715        .project_connection_ids(project_id, session.connection_id, false)
2716        .await?;
2717
2718    broadcast(
2719        Some(session.connection_id),
2720        project_connection_ids.iter().copied(),
2721        |connection_id| {
2722            session
2723                .peer
2724                .forward_send(session.connection_id, connection_id, request.clone())
2725        },
2726    );
2727    Ok(())
2728}
2729
2730/// Start following another user in a call.
2731async fn follow(
2732    request: proto::Follow,
2733    response: Response<proto::Follow>,
2734    session: UserSession,
2735) -> Result<()> {
2736    let room_id = RoomId::from_proto(request.room_id);
2737    let project_id = request.project_id.map(ProjectId::from_proto);
2738    let leader_id = request
2739        .leader_id
2740        .ok_or_else(|| anyhow!("invalid leader id"))?
2741        .into();
2742    let follower_id = session.connection_id;
2743
2744    session
2745        .db()
2746        .await
2747        .check_room_participants(room_id, leader_id, session.connection_id)
2748        .await?;
2749
2750    let response_payload = session
2751        .peer
2752        .forward_request(session.connection_id, leader_id, request)
2753        .await?;
2754    response.send(response_payload)?;
2755
2756    if let Some(project_id) = project_id {
2757        let room = session
2758            .db()
2759            .await
2760            .follow(room_id, project_id, leader_id, follower_id)
2761            .await?;
2762        room_updated(&room, &session.peer);
2763    }
2764
2765    Ok(())
2766}
2767
2768/// Stop following another user in a call.
2769async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2770    let room_id = RoomId::from_proto(request.room_id);
2771    let project_id = request.project_id.map(ProjectId::from_proto);
2772    let leader_id = request
2773        .leader_id
2774        .ok_or_else(|| anyhow!("invalid leader id"))?
2775        .into();
2776    let follower_id = session.connection_id;
2777
2778    session
2779        .db()
2780        .await
2781        .check_room_participants(room_id, leader_id, session.connection_id)
2782        .await?;
2783
2784    session
2785        .peer
2786        .forward_send(session.connection_id, leader_id, request)?;
2787
2788    if let Some(project_id) = project_id {
2789        let room = session
2790            .db()
2791            .await
2792            .unfollow(room_id, project_id, leader_id, follower_id)
2793            .await?;
2794        room_updated(&room, &session.peer);
2795    }
2796
2797    Ok(())
2798}
2799
2800/// Notify everyone following you of your current location.
2801async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2802    let room_id = RoomId::from_proto(request.room_id);
2803    let database = session.db.lock().await;
2804
2805    let connection_ids = if let Some(project_id) = request.project_id {
2806        let project_id = ProjectId::from_proto(project_id);
2807        database
2808            .project_connection_ids(project_id, session.connection_id, true)
2809            .await?
2810    } else {
2811        database
2812            .room_connection_ids(room_id, session.connection_id)
2813            .await?
2814    };
2815
2816    // For now, don't send view update messages back to that view's current leader.
2817    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2818        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2819        _ => None,
2820    });
2821
2822    for connection_id in connection_ids.iter().cloned() {
2823        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2824            session
2825                .peer
2826                .forward_send(session.connection_id, connection_id, request.clone())?;
2827        }
2828    }
2829    Ok(())
2830}
2831
2832/// Get public data about users.
2833async fn get_users(
2834    request: proto::GetUsers,
2835    response: Response<proto::GetUsers>,
2836    session: Session,
2837) -> Result<()> {
2838    let user_ids = request
2839        .user_ids
2840        .into_iter()
2841        .map(UserId::from_proto)
2842        .collect();
2843    let users = session
2844        .db()
2845        .await
2846        .get_users_by_ids(user_ids)
2847        .await?
2848        .into_iter()
2849        .map(|user| proto::User {
2850            id: user.id.to_proto(),
2851            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2852            github_login: user.github_login,
2853        })
2854        .collect();
2855    response.send(proto::UsersResponse { users })?;
2856    Ok(())
2857}
2858
2859/// Search for users (to invite) buy Github login
2860async fn fuzzy_search_users(
2861    request: proto::FuzzySearchUsers,
2862    response: Response<proto::FuzzySearchUsers>,
2863    session: UserSession,
2864) -> Result<()> {
2865    let query = request.query;
2866    let users = match query.len() {
2867        0 => vec![],
2868        1 | 2 => session
2869            .db()
2870            .await
2871            .get_user_by_github_login(&query)
2872            .await?
2873            .into_iter()
2874            .collect(),
2875        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2876    };
2877    let users = users
2878        .into_iter()
2879        .filter(|user| user.id != session.user_id())
2880        .map(|user| proto::User {
2881            id: user.id.to_proto(),
2882            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2883            github_login: user.github_login,
2884        })
2885        .collect();
2886    response.send(proto::UsersResponse { users })?;
2887    Ok(())
2888}
2889
2890/// Send a contact request to another user.
2891async fn request_contact(
2892    request: proto::RequestContact,
2893    response: Response<proto::RequestContact>,
2894    session: UserSession,
2895) -> Result<()> {
2896    let requester_id = session.user_id();
2897    let responder_id = UserId::from_proto(request.responder_id);
2898    if requester_id == responder_id {
2899        return Err(anyhow!("cannot add yourself as a contact"))?;
2900    }
2901
2902    let notifications = session
2903        .db()
2904        .await
2905        .send_contact_request(requester_id, responder_id)
2906        .await?;
2907
2908    // Update outgoing contact requests of requester
2909    let mut update = proto::UpdateContacts::default();
2910    update.outgoing_requests.push(responder_id.to_proto());
2911    for connection_id in session
2912        .connection_pool()
2913        .await
2914        .user_connection_ids(requester_id)
2915    {
2916        session.peer.send(connection_id, update.clone())?;
2917    }
2918
2919    // Update incoming contact requests of responder
2920    let mut update = proto::UpdateContacts::default();
2921    update
2922        .incoming_requests
2923        .push(proto::IncomingContactRequest {
2924            requester_id: requester_id.to_proto(),
2925        });
2926    let connection_pool = session.connection_pool().await;
2927    for connection_id in connection_pool.user_connection_ids(responder_id) {
2928        session.peer.send(connection_id, update.clone())?;
2929    }
2930
2931    send_notifications(&connection_pool, &session.peer, notifications);
2932
2933    response.send(proto::Ack {})?;
2934    Ok(())
2935}
2936
2937/// Accept or decline a contact request
2938async fn respond_to_contact_request(
2939    request: proto::RespondToContactRequest,
2940    response: Response<proto::RespondToContactRequest>,
2941    session: UserSession,
2942) -> Result<()> {
2943    let responder_id = session.user_id();
2944    let requester_id = UserId::from_proto(request.requester_id);
2945    let db = session.db().await;
2946    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2947        db.dismiss_contact_notification(responder_id, requester_id)
2948            .await?;
2949    } else {
2950        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2951
2952        let notifications = db
2953            .respond_to_contact_request(responder_id, requester_id, accept)
2954            .await?;
2955        let requester_busy = db.is_user_busy(requester_id).await?;
2956        let responder_busy = db.is_user_busy(responder_id).await?;
2957
2958        let pool = session.connection_pool().await;
2959        // Update responder with new contact
2960        let mut update = proto::UpdateContacts::default();
2961        if accept {
2962            update
2963                .contacts
2964                .push(contact_for_user(requester_id, requester_busy, &pool));
2965        }
2966        update
2967            .remove_incoming_requests
2968            .push(requester_id.to_proto());
2969        for connection_id in pool.user_connection_ids(responder_id) {
2970            session.peer.send(connection_id, update.clone())?;
2971        }
2972
2973        // Update requester with new contact
2974        let mut update = proto::UpdateContacts::default();
2975        if accept {
2976            update
2977                .contacts
2978                .push(contact_for_user(responder_id, responder_busy, &pool));
2979        }
2980        update
2981            .remove_outgoing_requests
2982            .push(responder_id.to_proto());
2983
2984        for connection_id in pool.user_connection_ids(requester_id) {
2985            session.peer.send(connection_id, update.clone())?;
2986        }
2987
2988        send_notifications(&pool, &session.peer, notifications);
2989    }
2990
2991    response.send(proto::Ack {})?;
2992    Ok(())
2993}
2994
2995/// Remove a contact.
2996async fn remove_contact(
2997    request: proto::RemoveContact,
2998    response: Response<proto::RemoveContact>,
2999    session: UserSession,
3000) -> Result<()> {
3001    let requester_id = session.user_id();
3002    let responder_id = UserId::from_proto(request.user_id);
3003    let db = session.db().await;
3004    let (contact_accepted, deleted_notification_id) =
3005        db.remove_contact(requester_id, responder_id).await?;
3006
3007    let pool = session.connection_pool().await;
3008    // Update outgoing contact requests of requester
3009    let mut update = proto::UpdateContacts::default();
3010    if contact_accepted {
3011        update.remove_contacts.push(responder_id.to_proto());
3012    } else {
3013        update
3014            .remove_outgoing_requests
3015            .push(responder_id.to_proto());
3016    }
3017    for connection_id in pool.user_connection_ids(requester_id) {
3018        session.peer.send(connection_id, update.clone())?;
3019    }
3020
3021    // Update incoming contact requests of responder
3022    let mut update = proto::UpdateContacts::default();
3023    if contact_accepted {
3024        update.remove_contacts.push(requester_id.to_proto());
3025    } else {
3026        update
3027            .remove_incoming_requests
3028            .push(requester_id.to_proto());
3029    }
3030    for connection_id in pool.user_connection_ids(responder_id) {
3031        session.peer.send(connection_id, update.clone())?;
3032        if let Some(notification_id) = deleted_notification_id {
3033            session.peer.send(
3034                connection_id,
3035                proto::DeleteNotification {
3036                    notification_id: notification_id.to_proto(),
3037                },
3038            )?;
3039        }
3040    }
3041
3042    response.send(proto::Ack {})?;
3043    Ok(())
3044}
3045
3046/// Creates a new channel.
3047async fn create_channel(
3048    request: proto::CreateChannel,
3049    response: Response<proto::CreateChannel>,
3050    session: UserSession,
3051) -> Result<()> {
3052    let db = session.db().await;
3053
3054    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3055    let (channel, membership) = db
3056        .create_channel(&request.name, parent_id, session.user_id())
3057        .await?;
3058
3059    let root_id = channel.root_id();
3060    let channel = Channel::from_model(channel);
3061
3062    response.send(proto::CreateChannelResponse {
3063        channel: Some(channel.to_proto()),
3064        parent_id: request.parent_id,
3065    })?;
3066
3067    let mut connection_pool = session.connection_pool().await;
3068    if let Some(membership) = membership {
3069        connection_pool.subscribe_to_channel(
3070            membership.user_id,
3071            membership.channel_id,
3072            membership.role,
3073        );
3074        let update = proto::UpdateUserChannels {
3075            channel_memberships: vec![proto::ChannelMembership {
3076                channel_id: membership.channel_id.to_proto(),
3077                role: membership.role.into(),
3078            }],
3079            ..Default::default()
3080        };
3081        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3082            session.peer.send(connection_id, update.clone())?;
3083        }
3084    }
3085
3086    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3087        if !role.can_see_channel(channel.visibility) {
3088            continue;
3089        }
3090
3091        let update = proto::UpdateChannels {
3092            channels: vec![channel.to_proto()],
3093            ..Default::default()
3094        };
3095        session.peer.send(connection_id, update.clone())?;
3096    }
3097
3098    Ok(())
3099}
3100
3101/// Delete a channel
3102async fn delete_channel(
3103    request: proto::DeleteChannel,
3104    response: Response<proto::DeleteChannel>,
3105    session: UserSession,
3106) -> Result<()> {
3107    let db = session.db().await;
3108
3109    let channel_id = request.channel_id;
3110    let (root_channel, removed_channels) = db
3111        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3112        .await?;
3113    response.send(proto::Ack {})?;
3114
3115    // Notify members of removed channels
3116    let mut update = proto::UpdateChannels::default();
3117    update
3118        .delete_channels
3119        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3120
3121    let connection_pool = session.connection_pool().await;
3122    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3123        session.peer.send(connection_id, update.clone())?;
3124    }
3125
3126    Ok(())
3127}
3128
3129/// Invite someone to join a channel.
3130async fn invite_channel_member(
3131    request: proto::InviteChannelMember,
3132    response: Response<proto::InviteChannelMember>,
3133    session: UserSession,
3134) -> Result<()> {
3135    let db = session.db().await;
3136    let channel_id = ChannelId::from_proto(request.channel_id);
3137    let invitee_id = UserId::from_proto(request.user_id);
3138    let InviteMemberResult {
3139        channel,
3140        notifications,
3141    } = db
3142        .invite_channel_member(
3143            channel_id,
3144            invitee_id,
3145            session.user_id(),
3146            request.role().into(),
3147        )
3148        .await?;
3149
3150    let update = proto::UpdateChannels {
3151        channel_invitations: vec![channel.to_proto()],
3152        ..Default::default()
3153    };
3154
3155    let connection_pool = session.connection_pool().await;
3156    for connection_id in connection_pool.user_connection_ids(invitee_id) {
3157        session.peer.send(connection_id, update.clone())?;
3158    }
3159
3160    send_notifications(&connection_pool, &session.peer, notifications);
3161
3162    response.send(proto::Ack {})?;
3163    Ok(())
3164}
3165
3166/// remove someone from a channel
3167async fn remove_channel_member(
3168    request: proto::RemoveChannelMember,
3169    response: Response<proto::RemoveChannelMember>,
3170    session: UserSession,
3171) -> Result<()> {
3172    let db = session.db().await;
3173    let channel_id = ChannelId::from_proto(request.channel_id);
3174    let member_id = UserId::from_proto(request.user_id);
3175
3176    let RemoveChannelMemberResult {
3177        membership_update,
3178        notification_id,
3179    } = db
3180        .remove_channel_member(channel_id, member_id, session.user_id())
3181        .await?;
3182
3183    let mut connection_pool = session.connection_pool().await;
3184    notify_membership_updated(
3185        &mut connection_pool,
3186        membership_update,
3187        member_id,
3188        &session.peer,
3189    );
3190    for connection_id in connection_pool.user_connection_ids(member_id) {
3191        if let Some(notification_id) = notification_id {
3192            session
3193                .peer
3194                .send(
3195                    connection_id,
3196                    proto::DeleteNotification {
3197                        notification_id: notification_id.to_proto(),
3198                    },
3199                )
3200                .trace_err();
3201        }
3202    }
3203
3204    response.send(proto::Ack {})?;
3205    Ok(())
3206}
3207
3208/// Toggle the channel between public and private.
3209/// Care is taken to maintain the invariant that public channels only descend from public channels,
3210/// (though members-only channels can appear at any point in the hierarchy).
3211async fn set_channel_visibility(
3212    request: proto::SetChannelVisibility,
3213    response: Response<proto::SetChannelVisibility>,
3214    session: UserSession,
3215) -> Result<()> {
3216    let db = session.db().await;
3217    let channel_id = ChannelId::from_proto(request.channel_id);
3218    let visibility = request.visibility().into();
3219
3220    let channel_model = db
3221        .set_channel_visibility(channel_id, visibility, session.user_id())
3222        .await?;
3223    let root_id = channel_model.root_id();
3224    let channel = Channel::from_model(channel_model);
3225
3226    let mut connection_pool = session.connection_pool().await;
3227    for (user_id, role) in connection_pool
3228        .channel_user_ids(root_id)
3229        .collect::<Vec<_>>()
3230        .into_iter()
3231    {
3232        let update = if role.can_see_channel(channel.visibility) {
3233            connection_pool.subscribe_to_channel(user_id, channel_id, role);
3234            proto::UpdateChannels {
3235                channels: vec![channel.to_proto()],
3236                ..Default::default()
3237            }
3238        } else {
3239            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3240            proto::UpdateChannels {
3241                delete_channels: vec![channel.id.to_proto()],
3242                ..Default::default()
3243            }
3244        };
3245
3246        for connection_id in connection_pool.user_connection_ids(user_id) {
3247            session.peer.send(connection_id, update.clone())?;
3248        }
3249    }
3250
3251    response.send(proto::Ack {})?;
3252    Ok(())
3253}
3254
3255/// Alter the role for a user in the channel.
3256async fn set_channel_member_role(
3257    request: proto::SetChannelMemberRole,
3258    response: Response<proto::SetChannelMemberRole>,
3259    session: UserSession,
3260) -> Result<()> {
3261    let db = session.db().await;
3262    let channel_id = ChannelId::from_proto(request.channel_id);
3263    let member_id = UserId::from_proto(request.user_id);
3264    let result = db
3265        .set_channel_member_role(
3266            channel_id,
3267            session.user_id(),
3268            member_id,
3269            request.role().into(),
3270        )
3271        .await?;
3272
3273    match result {
3274        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3275            let mut connection_pool = session.connection_pool().await;
3276            notify_membership_updated(
3277                &mut connection_pool,
3278                membership_update,
3279                member_id,
3280                &session.peer,
3281            )
3282        }
3283        db::SetMemberRoleResult::InviteUpdated(channel) => {
3284            let update = proto::UpdateChannels {
3285                channel_invitations: vec![channel.to_proto()],
3286                ..Default::default()
3287            };
3288
3289            for connection_id in session
3290                .connection_pool()
3291                .await
3292                .user_connection_ids(member_id)
3293            {
3294                session.peer.send(connection_id, update.clone())?;
3295            }
3296        }
3297    }
3298
3299    response.send(proto::Ack {})?;
3300    Ok(())
3301}
3302
3303/// Change the name of a channel
3304async fn rename_channel(
3305    request: proto::RenameChannel,
3306    response: Response<proto::RenameChannel>,
3307    session: UserSession,
3308) -> Result<()> {
3309    let db = session.db().await;
3310    let channel_id = ChannelId::from_proto(request.channel_id);
3311    let channel_model = db
3312        .rename_channel(channel_id, session.user_id(), &request.name)
3313        .await?;
3314    let root_id = channel_model.root_id();
3315    let channel = Channel::from_model(channel_model);
3316
3317    response.send(proto::RenameChannelResponse {
3318        channel: Some(channel.to_proto()),
3319    })?;
3320
3321    let connection_pool = session.connection_pool().await;
3322    let update = proto::UpdateChannels {
3323        channels: vec![channel.to_proto()],
3324        ..Default::default()
3325    };
3326    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3327        if role.can_see_channel(channel.visibility) {
3328            session.peer.send(connection_id, update.clone())?;
3329        }
3330    }
3331
3332    Ok(())
3333}
3334
3335/// Move a channel to a new parent.
3336async fn move_channel(
3337    request: proto::MoveChannel,
3338    response: Response<proto::MoveChannel>,
3339    session: UserSession,
3340) -> Result<()> {
3341    let channel_id = ChannelId::from_proto(request.channel_id);
3342    let to = ChannelId::from_proto(request.to);
3343
3344    let (root_id, channels) = session
3345        .db()
3346        .await
3347        .move_channel(channel_id, to, session.user_id())
3348        .await?;
3349
3350    let connection_pool = session.connection_pool().await;
3351    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3352        let channels = channels
3353            .iter()
3354            .filter_map(|channel| {
3355                if role.can_see_channel(channel.visibility) {
3356                    Some(channel.to_proto())
3357                } else {
3358                    None
3359                }
3360            })
3361            .collect::<Vec<_>>();
3362        if channels.is_empty() {
3363            continue;
3364        }
3365
3366        let update = proto::UpdateChannels {
3367            channels,
3368            ..Default::default()
3369        };
3370
3371        session.peer.send(connection_id, update.clone())?;
3372    }
3373
3374    response.send(Ack {})?;
3375    Ok(())
3376}
3377
3378/// Get the list of channel members
3379async fn get_channel_members(
3380    request: proto::GetChannelMembers,
3381    response: Response<proto::GetChannelMembers>,
3382    session: UserSession,
3383) -> Result<()> {
3384    let db = session.db().await;
3385    let channel_id = ChannelId::from_proto(request.channel_id);
3386    let members = db
3387        .get_channel_participant_details(channel_id, session.user_id())
3388        .await?;
3389    response.send(proto::GetChannelMembersResponse { members })?;
3390    Ok(())
3391}
3392
3393/// Accept or decline a channel invitation.
3394async fn respond_to_channel_invite(
3395    request: proto::RespondToChannelInvite,
3396    response: Response<proto::RespondToChannelInvite>,
3397    session: UserSession,
3398) -> Result<()> {
3399    let db = session.db().await;
3400    let channel_id = ChannelId::from_proto(request.channel_id);
3401    let RespondToChannelInvite {
3402        membership_update,
3403        notifications,
3404    } = db
3405        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3406        .await?;
3407
3408    let mut connection_pool = session.connection_pool().await;
3409    if let Some(membership_update) = membership_update {
3410        notify_membership_updated(
3411            &mut connection_pool,
3412            membership_update,
3413            session.user_id(),
3414            &session.peer,
3415        );
3416    } else {
3417        let update = proto::UpdateChannels {
3418            remove_channel_invitations: vec![channel_id.to_proto()],
3419            ..Default::default()
3420        };
3421
3422        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3423            session.peer.send(connection_id, update.clone())?;
3424        }
3425    };
3426
3427    send_notifications(&connection_pool, &session.peer, notifications);
3428
3429    response.send(proto::Ack {})?;
3430
3431    Ok(())
3432}
3433
3434/// Join the channels' room
3435async fn join_channel(
3436    request: proto::JoinChannel,
3437    response: Response<proto::JoinChannel>,
3438    session: UserSession,
3439) -> Result<()> {
3440    let channel_id = ChannelId::from_proto(request.channel_id);
3441    join_channel_internal(channel_id, Box::new(response), session).await
3442}
3443
3444trait JoinChannelInternalResponse {
3445    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3446}
3447impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3448    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3449        Response::<proto::JoinChannel>::send(self, result)
3450    }
3451}
3452impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3453    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3454        Response::<proto::JoinRoom>::send(self, result)
3455    }
3456}
3457
3458async fn join_channel_internal(
3459    channel_id: ChannelId,
3460    response: Box<impl JoinChannelInternalResponse>,
3461    session: UserSession,
3462) -> Result<()> {
3463    let joined_room = {
3464        let mut db = session.db().await;
3465        // If zed quits without leaving the room, and the user re-opens zed before the
3466        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3467        // room they were in.
3468        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3469            tracing::info!(
3470                stale_connection_id = %connection,
3471                "cleaning up stale connection",
3472            );
3473            drop(db);
3474            leave_room_for_session(&session, connection).await?;
3475            db = session.db().await;
3476        }
3477
3478        let (joined_room, membership_updated, role) = db
3479            .join_channel(channel_id, session.user_id(), session.connection_id)
3480            .await?;
3481
3482        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3483            let (can_publish, token) = if role == ChannelRole::Guest {
3484                (
3485                    false,
3486                    live_kit
3487                        .guest_token(
3488                            &joined_room.room.live_kit_room,
3489                            &session.user_id().to_string(),
3490                        )
3491                        .trace_err()?,
3492                )
3493            } else {
3494                (
3495                    true,
3496                    live_kit
3497                        .room_token(
3498                            &joined_room.room.live_kit_room,
3499                            &session.user_id().to_string(),
3500                        )
3501                        .trace_err()?,
3502                )
3503            };
3504
3505            Some(LiveKitConnectionInfo {
3506                server_url: live_kit.url().into(),
3507                token,
3508                can_publish,
3509            })
3510        });
3511
3512        response.send(proto::JoinRoomResponse {
3513            room: Some(joined_room.room.clone()),
3514            channel_id: joined_room
3515                .channel
3516                .as_ref()
3517                .map(|channel| channel.id.to_proto()),
3518            live_kit_connection_info,
3519        })?;
3520
3521        let mut connection_pool = session.connection_pool().await;
3522        if let Some(membership_updated) = membership_updated {
3523            notify_membership_updated(
3524                &mut connection_pool,
3525                membership_updated,
3526                session.user_id(),
3527                &session.peer,
3528            );
3529        }
3530
3531        room_updated(&joined_room.room, &session.peer);
3532
3533        joined_room
3534    };
3535
3536    channel_updated(
3537        &joined_room
3538            .channel
3539            .ok_or_else(|| anyhow!("channel not returned"))?,
3540        &joined_room.room,
3541        &session.peer,
3542        &*session.connection_pool().await,
3543    );
3544
3545    update_user_contacts(session.user_id(), &session).await?;
3546    Ok(())
3547}
3548
3549/// Start editing the channel notes
3550async fn join_channel_buffer(
3551    request: proto::JoinChannelBuffer,
3552    response: Response<proto::JoinChannelBuffer>,
3553    session: UserSession,
3554) -> Result<()> {
3555    let db = session.db().await;
3556    let channel_id = ChannelId::from_proto(request.channel_id);
3557
3558    let open_response = db
3559        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3560        .await?;
3561
3562    let collaborators = open_response.collaborators.clone();
3563    response.send(open_response)?;
3564
3565    let update = UpdateChannelBufferCollaborators {
3566        channel_id: channel_id.to_proto(),
3567        collaborators: collaborators.clone(),
3568    };
3569    channel_buffer_updated(
3570        session.connection_id,
3571        collaborators
3572            .iter()
3573            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3574        &update,
3575        &session.peer,
3576    );
3577
3578    Ok(())
3579}
3580
3581/// Edit the channel notes
3582async fn update_channel_buffer(
3583    request: proto::UpdateChannelBuffer,
3584    session: UserSession,
3585) -> Result<()> {
3586    let db = session.db().await;
3587    let channel_id = ChannelId::from_proto(request.channel_id);
3588
3589    let (collaborators, non_collaborators, epoch, version) = db
3590        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3591        .await?;
3592
3593    channel_buffer_updated(
3594        session.connection_id,
3595        collaborators,
3596        &proto::UpdateChannelBuffer {
3597            channel_id: channel_id.to_proto(),
3598            operations: request.operations,
3599        },
3600        &session.peer,
3601    );
3602
3603    let pool = &*session.connection_pool().await;
3604
3605    broadcast(
3606        None,
3607        non_collaborators
3608            .iter()
3609            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3610        |peer_id| {
3611            session.peer.send(
3612                peer_id,
3613                proto::UpdateChannels {
3614                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3615                        channel_id: channel_id.to_proto(),
3616                        epoch: epoch as u64,
3617                        version: version.clone(),
3618                    }],
3619                    ..Default::default()
3620                },
3621            )
3622        },
3623    );
3624
3625    Ok(())
3626}
3627
3628/// Rejoin the channel notes after a connection blip
3629async fn rejoin_channel_buffers(
3630    request: proto::RejoinChannelBuffers,
3631    response: Response<proto::RejoinChannelBuffers>,
3632    session: UserSession,
3633) -> Result<()> {
3634    let db = session.db().await;
3635    let buffers = db
3636        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3637        .await?;
3638
3639    for rejoined_buffer in &buffers {
3640        let collaborators_to_notify = rejoined_buffer
3641            .buffer
3642            .collaborators
3643            .iter()
3644            .filter_map(|c| Some(c.peer_id?.into()));
3645        channel_buffer_updated(
3646            session.connection_id,
3647            collaborators_to_notify,
3648            &proto::UpdateChannelBufferCollaborators {
3649                channel_id: rejoined_buffer.buffer.channel_id,
3650                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3651            },
3652            &session.peer,
3653        );
3654    }
3655
3656    response.send(proto::RejoinChannelBuffersResponse {
3657        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3658    })?;
3659
3660    Ok(())
3661}
3662
3663/// Stop editing the channel notes
3664async fn leave_channel_buffer(
3665    request: proto::LeaveChannelBuffer,
3666    response: Response<proto::LeaveChannelBuffer>,
3667    session: UserSession,
3668) -> Result<()> {
3669    let db = session.db().await;
3670    let channel_id = ChannelId::from_proto(request.channel_id);
3671
3672    let left_buffer = db
3673        .leave_channel_buffer(channel_id, session.connection_id)
3674        .await?;
3675
3676    response.send(Ack {})?;
3677
3678    channel_buffer_updated(
3679        session.connection_id,
3680        left_buffer.connections,
3681        &proto::UpdateChannelBufferCollaborators {
3682            channel_id: channel_id.to_proto(),
3683            collaborators: left_buffer.collaborators,
3684        },
3685        &session.peer,
3686    );
3687
3688    Ok(())
3689}
3690
3691fn channel_buffer_updated<T: EnvelopedMessage>(
3692    sender_id: ConnectionId,
3693    collaborators: impl IntoIterator<Item = ConnectionId>,
3694    message: &T,
3695    peer: &Peer,
3696) {
3697    broadcast(Some(sender_id), collaborators, |peer_id| {
3698        peer.send(peer_id, message.clone())
3699    });
3700}
3701
3702fn send_notifications(
3703    connection_pool: &ConnectionPool,
3704    peer: &Peer,
3705    notifications: db::NotificationBatch,
3706) {
3707    for (user_id, notification) in notifications {
3708        for connection_id in connection_pool.user_connection_ids(user_id) {
3709            if let Err(error) = peer.send(
3710                connection_id,
3711                proto::AddNotification {
3712                    notification: Some(notification.clone()),
3713                },
3714            ) {
3715                tracing::error!(
3716                    "failed to send notification to {:?} {}",
3717                    connection_id,
3718                    error
3719                );
3720            }
3721        }
3722    }
3723}
3724
3725/// Send a message to the channel
3726async fn send_channel_message(
3727    request: proto::SendChannelMessage,
3728    response: Response<proto::SendChannelMessage>,
3729    session: UserSession,
3730) -> Result<()> {
3731    // Validate the message body.
3732    let body = request.body.trim().to_string();
3733    if body.len() > MAX_MESSAGE_LEN {
3734        return Err(anyhow!("message is too long"))?;
3735    }
3736    if body.is_empty() {
3737        return Err(anyhow!("message can't be blank"))?;
3738    }
3739
3740    // TODO: adjust mentions if body is trimmed
3741
3742    let timestamp = OffsetDateTime::now_utc();
3743    let nonce = request
3744        .nonce
3745        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3746
3747    let channel_id = ChannelId::from_proto(request.channel_id);
3748    let CreatedChannelMessage {
3749        message_id,
3750        participant_connection_ids,
3751        channel_members,
3752        notifications,
3753    } = session
3754        .db()
3755        .await
3756        .create_channel_message(
3757            channel_id,
3758            session.user_id(),
3759            &body,
3760            &request.mentions,
3761            timestamp,
3762            nonce.clone().into(),
3763            match request.reply_to_message_id {
3764                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3765                None => None,
3766            },
3767        )
3768        .await?;
3769
3770    let message = proto::ChannelMessage {
3771        sender_id: session.user_id().to_proto(),
3772        id: message_id.to_proto(),
3773        body,
3774        mentions: request.mentions,
3775        timestamp: timestamp.unix_timestamp() as u64,
3776        nonce: Some(nonce),
3777        reply_to_message_id: request.reply_to_message_id,
3778        edited_at: None,
3779    };
3780    broadcast(
3781        Some(session.connection_id),
3782        participant_connection_ids,
3783        |connection| {
3784            session.peer.send(
3785                connection,
3786                proto::ChannelMessageSent {
3787                    channel_id: channel_id.to_proto(),
3788                    message: Some(message.clone()),
3789                },
3790            )
3791        },
3792    );
3793    response.send(proto::SendChannelMessageResponse {
3794        message: Some(message),
3795    })?;
3796
3797    let pool = &*session.connection_pool().await;
3798    broadcast(
3799        None,
3800        channel_members
3801            .iter()
3802            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3803        |peer_id| {
3804            session.peer.send(
3805                peer_id,
3806                proto::UpdateChannels {
3807                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3808                        channel_id: channel_id.to_proto(),
3809                        message_id: message_id.to_proto(),
3810                    }],
3811                    ..Default::default()
3812                },
3813            )
3814        },
3815    );
3816    send_notifications(pool, &session.peer, notifications);
3817
3818    Ok(())
3819}
3820
3821/// Delete a channel message
3822async fn remove_channel_message(
3823    request: proto::RemoveChannelMessage,
3824    response: Response<proto::RemoveChannelMessage>,
3825    session: UserSession,
3826) -> Result<()> {
3827    let channel_id = ChannelId::from_proto(request.channel_id);
3828    let message_id = MessageId::from_proto(request.message_id);
3829    let (connection_ids, existing_notification_ids) = session
3830        .db()
3831        .await
3832        .remove_channel_message(channel_id, message_id, session.user_id())
3833        .await?;
3834
3835    broadcast(
3836        Some(session.connection_id),
3837        connection_ids,
3838        move |connection| {
3839            session.peer.send(connection, request.clone())?;
3840
3841            for notification_id in &existing_notification_ids {
3842                session.peer.send(
3843                    connection,
3844                    proto::DeleteNotification {
3845                        notification_id: (*notification_id).to_proto(),
3846                    },
3847                )?;
3848            }
3849
3850            Ok(())
3851        },
3852    );
3853    response.send(proto::Ack {})?;
3854    Ok(())
3855}
3856
3857async fn update_channel_message(
3858    request: proto::UpdateChannelMessage,
3859    response: Response<proto::UpdateChannelMessage>,
3860    session: UserSession,
3861) -> Result<()> {
3862    let channel_id = ChannelId::from_proto(request.channel_id);
3863    let message_id = MessageId::from_proto(request.message_id);
3864    let updated_at = OffsetDateTime::now_utc();
3865    let UpdatedChannelMessage {
3866        message_id,
3867        participant_connection_ids,
3868        notifications,
3869        reply_to_message_id,
3870        timestamp,
3871        deleted_mention_notification_ids,
3872        updated_mention_notifications,
3873    } = session
3874        .db()
3875        .await
3876        .update_channel_message(
3877            channel_id,
3878            message_id,
3879            session.user_id(),
3880            request.body.as_str(),
3881            &request.mentions,
3882            updated_at,
3883        )
3884        .await?;
3885
3886    let nonce = request
3887        .nonce
3888        .clone()
3889        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3890
3891    let message = proto::ChannelMessage {
3892        sender_id: session.user_id().to_proto(),
3893        id: message_id.to_proto(),
3894        body: request.body.clone(),
3895        mentions: request.mentions.clone(),
3896        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3897        nonce: Some(nonce),
3898        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3899        edited_at: Some(updated_at.unix_timestamp() as u64),
3900    };
3901
3902    response.send(proto::Ack {})?;
3903
3904    let pool = &*session.connection_pool().await;
3905    broadcast(
3906        Some(session.connection_id),
3907        participant_connection_ids,
3908        |connection| {
3909            session.peer.send(
3910                connection,
3911                proto::ChannelMessageUpdate {
3912                    channel_id: channel_id.to_proto(),
3913                    message: Some(message.clone()),
3914                },
3915            )?;
3916
3917            for notification_id in &deleted_mention_notification_ids {
3918                session.peer.send(
3919                    connection,
3920                    proto::DeleteNotification {
3921                        notification_id: (*notification_id).to_proto(),
3922                    },
3923                )?;
3924            }
3925
3926            for notification in &updated_mention_notifications {
3927                session.peer.send(
3928                    connection,
3929                    proto::UpdateNotification {
3930                        notification: Some(notification.clone()),
3931                    },
3932                )?;
3933            }
3934
3935            Ok(())
3936        },
3937    );
3938
3939    send_notifications(pool, &session.peer, notifications);
3940
3941    Ok(())
3942}
3943
3944/// Mark a channel message as read
3945async fn acknowledge_channel_message(
3946    request: proto::AckChannelMessage,
3947    session: UserSession,
3948) -> Result<()> {
3949    let channel_id = ChannelId::from_proto(request.channel_id);
3950    let message_id = MessageId::from_proto(request.message_id);
3951    let notifications = session
3952        .db()
3953        .await
3954        .observe_channel_message(channel_id, session.user_id(), message_id)
3955        .await?;
3956    send_notifications(
3957        &*session.connection_pool().await,
3958        &session.peer,
3959        notifications,
3960    );
3961    Ok(())
3962}
3963
3964/// Mark a buffer version as synced
3965async fn acknowledge_buffer_version(
3966    request: proto::AckBufferOperation,
3967    session: UserSession,
3968) -> Result<()> {
3969    let buffer_id = BufferId::from_proto(request.buffer_id);
3970    session
3971        .db()
3972        .await
3973        .observe_buffer_version(
3974            buffer_id,
3975            session.user_id(),
3976            request.epoch as i32,
3977            &request.version,
3978        )
3979        .await?;
3980    Ok(())
3981}
3982
3983struct CompleteWithLanguageModelRateLimit;
3984
3985impl RateLimit for CompleteWithLanguageModelRateLimit {
3986    fn capacity() -> usize {
3987        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3988            .ok()
3989            .and_then(|v| v.parse().ok())
3990            .unwrap_or(120) // Picked arbitrarily
3991    }
3992
3993    fn refill_duration() -> chrono::Duration {
3994        chrono::Duration::hours(1)
3995    }
3996
3997    fn db_name() -> &'static str {
3998        "complete-with-language-model"
3999    }
4000}
4001
4002async fn complete_with_language_model(
4003    request: proto::CompleteWithLanguageModel,
4004    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4005    session: Session,
4006    open_ai_api_key: Option<Arc<str>>,
4007    google_ai_api_key: Option<Arc<str>>,
4008    anthropic_api_key: Option<Arc<str>>,
4009) -> Result<()> {
4010    let Some(session) = session.for_user() else {
4011        return Err(anyhow!("user not found"))?;
4012    };
4013    authorize_access_to_language_models(&session).await?;
4014    session
4015        .rate_limiter
4016        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4017        .await?;
4018
4019    if request.model.starts_with("gpt") {
4020        let api_key =
4021            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4022        complete_with_open_ai(request, response, session, api_key).await?;
4023    } else if request.model.starts_with("gemini") {
4024        let api_key = google_ai_api_key
4025            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4026        complete_with_google_ai(request, response, session, api_key).await?;
4027    } else if request.model.starts_with("claude") {
4028        let api_key = anthropic_api_key
4029            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4030        complete_with_anthropic(request, response, session, api_key).await?;
4031    }
4032
4033    Ok(())
4034}
4035
4036async fn complete_with_open_ai(
4037    request: proto::CompleteWithLanguageModel,
4038    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4039    session: UserSession,
4040    api_key: Arc<str>,
4041) -> Result<()> {
4042    let mut completion_stream = open_ai::stream_completion(
4043        &session.http_client,
4044        OPEN_AI_API_URL,
4045        &api_key,
4046        crate::ai::language_model_request_to_open_ai(request)?,
4047    )
4048    .await
4049    .context("open_ai::stream_completion request failed")?;
4050
4051    while let Some(event) = completion_stream.next().await {
4052        let event = event?;
4053        response.send(proto::LanguageModelResponse {
4054            choices: event
4055                .choices
4056                .into_iter()
4057                .map(|choice| proto::LanguageModelChoiceDelta {
4058                    index: choice.index,
4059                    delta: Some(proto::LanguageModelResponseMessage {
4060                        role: choice.delta.role.map(|role| match role {
4061                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4062                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4063                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4064                        } as i32),
4065                        content: choice.delta.content,
4066                    }),
4067                    finish_reason: choice.finish_reason,
4068                })
4069                .collect(),
4070        })?;
4071    }
4072
4073    Ok(())
4074}
4075
4076async fn complete_with_google_ai(
4077    request: proto::CompleteWithLanguageModel,
4078    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4079    session: UserSession,
4080    api_key: Arc<str>,
4081) -> Result<()> {
4082    let mut stream = google_ai::stream_generate_content(
4083        &session.http_client,
4084        google_ai::API_URL,
4085        api_key.as_ref(),
4086        crate::ai::language_model_request_to_google_ai(request)?,
4087    )
4088    .await
4089    .context("google_ai::stream_generate_content request failed")?;
4090
4091    while let Some(event) = stream.next().await {
4092        let event = event?;
4093        response.send(proto::LanguageModelResponse {
4094            choices: event
4095                .candidates
4096                .unwrap_or_default()
4097                .into_iter()
4098                .map(|candidate| proto::LanguageModelChoiceDelta {
4099                    index: candidate.index as u32,
4100                    delta: Some(proto::LanguageModelResponseMessage {
4101                        role: Some(match candidate.content.role {
4102                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4103                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4104                        } as i32),
4105                        content: Some(
4106                            candidate
4107                                .content
4108                                .parts
4109                                .into_iter()
4110                                .filter_map(|part| match part {
4111                                    google_ai::Part::TextPart(part) => Some(part.text),
4112                                    google_ai::Part::InlineDataPart(_) => None,
4113                                })
4114                                .collect(),
4115                        ),
4116                    }),
4117                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4118                })
4119                .collect(),
4120        })?;
4121    }
4122
4123    Ok(())
4124}
4125
4126async fn complete_with_anthropic(
4127    request: proto::CompleteWithLanguageModel,
4128    response: StreamingResponse<proto::CompleteWithLanguageModel>,
4129    session: UserSession,
4130    api_key: Arc<str>,
4131) -> Result<()> {
4132    let model = anthropic::Model::from_id(&request.model)?;
4133
4134    let mut system_message = String::new();
4135    let messages = request
4136        .messages
4137        .into_iter()
4138        .filter_map(|message| match message.role() {
4139            LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4140                role: anthropic::Role::User,
4141                content: message.content,
4142            }),
4143            LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4144                role: anthropic::Role::Assistant,
4145                content: message.content,
4146            }),
4147            // Anthropic's API breaks system instructions out as a separate field rather
4148            // than having a system message role.
4149            LanguageModelRole::LanguageModelSystem => {
4150                if !system_message.is_empty() {
4151                    system_message.push_str("\n\n");
4152                }
4153                system_message.push_str(&message.content);
4154
4155                None
4156            }
4157        })
4158        .collect();
4159
4160    let mut stream = anthropic::stream_completion(
4161        &session.http_client,
4162        "https://api.anthropic.com",
4163        &api_key,
4164        anthropic::Request {
4165            model,
4166            messages,
4167            stream: true,
4168            system: system_message,
4169            max_tokens: 4092,
4170        },
4171    )
4172    .await?;
4173
4174    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4175
4176    while let Some(event) = stream.next().await {
4177        let event = event?;
4178
4179        match event {
4180            anthropic::ResponseEvent::MessageStart { message } => {
4181                if let Some(role) = message.role {
4182                    if role == "assistant" {
4183                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
4184                    } else if role == "user" {
4185                        current_role = proto::LanguageModelRole::LanguageModelUser;
4186                    }
4187                }
4188            }
4189            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4190                match content_block {
4191                    anthropic::ContentBlock::Text { text } => {
4192                        if !text.is_empty() {
4193                            response.send(proto::LanguageModelResponse {
4194                                choices: vec![proto::LanguageModelChoiceDelta {
4195                                    index: 0,
4196                                    delta: Some(proto::LanguageModelResponseMessage {
4197                                        role: Some(current_role as i32),
4198                                        content: Some(text),
4199                                    }),
4200                                    finish_reason: None,
4201                                }],
4202                            })?;
4203                        }
4204                    }
4205                }
4206            }
4207            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4208                anthropic::TextDelta::TextDelta { text } => {
4209                    response.send(proto::LanguageModelResponse {
4210                        choices: vec![proto::LanguageModelChoiceDelta {
4211                            index: 0,
4212                            delta: Some(proto::LanguageModelResponseMessage {
4213                                role: Some(current_role as i32),
4214                                content: Some(text),
4215                            }),
4216                            finish_reason: None,
4217                        }],
4218                    })?;
4219                }
4220            },
4221            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4222                if let Some(stop_reason) = delta.stop_reason {
4223                    response.send(proto::LanguageModelResponse {
4224                        choices: vec![proto::LanguageModelChoiceDelta {
4225                            index: 0,
4226                            delta: None,
4227                            finish_reason: Some(stop_reason),
4228                        }],
4229                    })?;
4230                }
4231            }
4232            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4233            anthropic::ResponseEvent::MessageStop {} => {}
4234            anthropic::ResponseEvent::Ping {} => {}
4235        }
4236    }
4237
4238    Ok(())
4239}
4240
4241struct CountTokensWithLanguageModelRateLimit;
4242
4243impl RateLimit for CountTokensWithLanguageModelRateLimit {
4244    fn capacity() -> usize {
4245        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4246            .ok()
4247            .and_then(|v| v.parse().ok())
4248            .unwrap_or(600) // Picked arbitrarily
4249    }
4250
4251    fn refill_duration() -> chrono::Duration {
4252        chrono::Duration::hours(1)
4253    }
4254
4255    fn db_name() -> &'static str {
4256        "count-tokens-with-language-model"
4257    }
4258}
4259
4260async fn count_tokens_with_language_model(
4261    request: proto::CountTokensWithLanguageModel,
4262    response: Response<proto::CountTokensWithLanguageModel>,
4263    session: UserSession,
4264    google_ai_api_key: Option<Arc<str>>,
4265) -> Result<()> {
4266    authorize_access_to_language_models(&session).await?;
4267
4268    if !request.model.starts_with("gemini") {
4269        return Err(anyhow!(
4270            "counting tokens for model: {:?} is not supported",
4271            request.model
4272        ))?;
4273    }
4274
4275    session
4276        .rate_limiter
4277        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4278        .await?;
4279
4280    let api_key = google_ai_api_key
4281        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4282    let tokens_response = google_ai::count_tokens(
4283        &session.http_client,
4284        google_ai::API_URL,
4285        &api_key,
4286        crate::ai::count_tokens_request_to_google_ai(request)?,
4287    )
4288    .await?;
4289    response.send(proto::CountTokensResponse {
4290        token_count: tokens_response.total_tokens as u32,
4291    })?;
4292    Ok(())
4293}
4294
4295struct ComputeEmbeddingsRateLimit;
4296
4297impl RateLimit for ComputeEmbeddingsRateLimit {
4298    fn capacity() -> usize {
4299        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4300            .ok()
4301            .and_then(|v| v.parse().ok())
4302            .unwrap_or(120) // Picked arbitrarily
4303    }
4304
4305    fn refill_duration() -> chrono::Duration {
4306        chrono::Duration::hours(1)
4307    }
4308
4309    fn db_name() -> &'static str {
4310        "compute-embeddings"
4311    }
4312}
4313
4314async fn compute_embeddings(
4315    request: proto::ComputeEmbeddings,
4316    response: Response<proto::ComputeEmbeddings>,
4317    session: UserSession,
4318    api_key: Option<Arc<str>>,
4319) -> Result<()> {
4320    let api_key = api_key.context("no OpenAI API key configured on the server")?;
4321    authorize_access_to_language_models(&session).await?;
4322
4323    session
4324        .rate_limiter
4325        .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4326        .await?;
4327
4328    let embeddings = match request.model.as_str() {
4329        "openai/text-embedding-3-small" => {
4330            open_ai::embed(
4331                &session.http_client,
4332                OPEN_AI_API_URL,
4333                &api_key,
4334                OpenAiEmbeddingModel::TextEmbedding3Small,
4335                request.texts.iter().map(|text| text.as_str()),
4336            )
4337            .await?
4338        }
4339        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4340    };
4341
4342    let embeddings = request
4343        .texts
4344        .iter()
4345        .map(|text| {
4346            let mut hasher = sha2::Sha256::new();
4347            hasher.update(text.as_bytes());
4348            let result = hasher.finalize();
4349            result.to_vec()
4350        })
4351        .zip(
4352            embeddings
4353                .data
4354                .into_iter()
4355                .map(|embedding| embedding.embedding),
4356        )
4357        .collect::<HashMap<_, _>>();
4358
4359    let db = session.db().await;
4360    db.save_embeddings(&request.model, &embeddings)
4361        .await
4362        .context("failed to save embeddings")
4363        .trace_err();
4364
4365    response.send(proto::ComputeEmbeddingsResponse {
4366        embeddings: embeddings
4367            .into_iter()
4368            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4369            .collect(),
4370    })?;
4371    Ok(())
4372}
4373
4374struct GetCachedEmbeddingsRateLimit;
4375
4376impl RateLimit for GetCachedEmbeddingsRateLimit {
4377    fn capacity() -> usize {
4378        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4379            .ok()
4380            .and_then(|v| v.parse().ok())
4381            .unwrap_or(120) // Picked arbitrarily
4382    }
4383
4384    fn refill_duration() -> chrono::Duration {
4385        chrono::Duration::hours(1)
4386    }
4387
4388    fn db_name() -> &'static str {
4389        "get-cached-embeddings"
4390    }
4391}
4392
4393async fn get_cached_embeddings(
4394    request: proto::GetCachedEmbeddings,
4395    response: Response<proto::GetCachedEmbeddings>,
4396    session: UserSession,
4397) -> Result<()> {
4398    authorize_access_to_language_models(&session).await?;
4399
4400    session
4401        .rate_limiter
4402        .check::<GetCachedEmbeddingsRateLimit>(session.user_id())
4403        .await?;
4404
4405    let db = session.db().await;
4406    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4407
4408    response.send(proto::GetCachedEmbeddingsResponse {
4409        embeddings: embeddings
4410            .into_iter()
4411            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4412            .collect(),
4413    })?;
4414    Ok(())
4415}
4416
4417async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4418    let db = session.db().await;
4419    let flags = db.get_user_flags(session.user_id()).await?;
4420    if flags.iter().any(|flag| flag == "language-models") {
4421        Ok(())
4422    } else {
4423        Err(anyhow!("permission denied"))?
4424    }
4425}
4426
4427/// Start receiving chat updates for a channel
4428async fn join_channel_chat(
4429    request: proto::JoinChannelChat,
4430    response: Response<proto::JoinChannelChat>,
4431    session: UserSession,
4432) -> Result<()> {
4433    let channel_id = ChannelId::from_proto(request.channel_id);
4434
4435    let db = session.db().await;
4436    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4437        .await?;
4438    let messages = db
4439        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4440        .await?;
4441    response.send(proto::JoinChannelChatResponse {
4442        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4443        messages,
4444    })?;
4445    Ok(())
4446}
4447
4448/// Stop receiving chat updates for a channel
4449async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4450    let channel_id = ChannelId::from_proto(request.channel_id);
4451    session
4452        .db()
4453        .await
4454        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4455        .await?;
4456    Ok(())
4457}
4458
4459/// Retrieve the chat history for a channel
4460async fn get_channel_messages(
4461    request: proto::GetChannelMessages,
4462    response: Response<proto::GetChannelMessages>,
4463    session: UserSession,
4464) -> Result<()> {
4465    let channel_id = ChannelId::from_proto(request.channel_id);
4466    let messages = session
4467        .db()
4468        .await
4469        .get_channel_messages(
4470            channel_id,
4471            session.user_id(),
4472            MESSAGE_COUNT_PER_PAGE,
4473            Some(MessageId::from_proto(request.before_message_id)),
4474        )
4475        .await?;
4476    response.send(proto::GetChannelMessagesResponse {
4477        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4478        messages,
4479    })?;
4480    Ok(())
4481}
4482
4483/// Retrieve specific chat messages
4484async fn get_channel_messages_by_id(
4485    request: proto::GetChannelMessagesById,
4486    response: Response<proto::GetChannelMessagesById>,
4487    session: UserSession,
4488) -> Result<()> {
4489    let message_ids = request
4490        .message_ids
4491        .iter()
4492        .map(|id| MessageId::from_proto(*id))
4493        .collect::<Vec<_>>();
4494    let messages = session
4495        .db()
4496        .await
4497        .get_channel_messages_by_id(session.user_id(), &message_ids)
4498        .await?;
4499    response.send(proto::GetChannelMessagesResponse {
4500        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4501        messages,
4502    })?;
4503    Ok(())
4504}
4505
4506/// Retrieve the current users notifications
4507async fn get_notifications(
4508    request: proto::GetNotifications,
4509    response: Response<proto::GetNotifications>,
4510    session: UserSession,
4511) -> Result<()> {
4512    let notifications = session
4513        .db()
4514        .await
4515        .get_notifications(
4516            session.user_id(),
4517            NOTIFICATION_COUNT_PER_PAGE,
4518            request
4519                .before_id
4520                .map(|id| db::NotificationId::from_proto(id)),
4521        )
4522        .await?;
4523    response.send(proto::GetNotificationsResponse {
4524        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4525        notifications,
4526    })?;
4527    Ok(())
4528}
4529
4530/// Mark notifications as read
4531async fn mark_notification_as_read(
4532    request: proto::MarkNotificationRead,
4533    response: Response<proto::MarkNotificationRead>,
4534    session: UserSession,
4535) -> Result<()> {
4536    let database = &session.db().await;
4537    let notifications = database
4538        .mark_notification_as_read_by_id(
4539            session.user_id(),
4540            NotificationId::from_proto(request.notification_id),
4541        )
4542        .await?;
4543    send_notifications(
4544        &*session.connection_pool().await,
4545        &session.peer,
4546        notifications,
4547    );
4548    response.send(proto::Ack {})?;
4549    Ok(())
4550}
4551
4552/// Get the current users information
4553async fn get_private_user_info(
4554    _request: proto::GetPrivateUserInfo,
4555    response: Response<proto::GetPrivateUserInfo>,
4556    session: UserSession,
4557) -> Result<()> {
4558    let db = session.db().await;
4559
4560    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4561    let user = db
4562        .get_user_by_id(session.user_id())
4563        .await?
4564        .ok_or_else(|| anyhow!("user not found"))?;
4565    let flags = db.get_user_flags(session.user_id()).await?;
4566
4567    response.send(proto::GetPrivateUserInfoResponse {
4568        metrics_id,
4569        staff: user.admin,
4570        flags,
4571    })?;
4572    Ok(())
4573}
4574
4575fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4576    match message {
4577        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4578        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4579        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4580        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4581        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4582            code: frame.code.into(),
4583            reason: frame.reason,
4584        })),
4585    }
4586}
4587
4588fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4589    match message {
4590        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4591        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4592        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4593        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4594        AxumMessage::Close(frame) => {
4595            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4596                code: frame.code.into(),
4597                reason: frame.reason,
4598            }))
4599        }
4600    }
4601}
4602
4603fn notify_membership_updated(
4604    connection_pool: &mut ConnectionPool,
4605    result: MembershipUpdated,
4606    user_id: UserId,
4607    peer: &Peer,
4608) {
4609    for membership in &result.new_channels.channel_memberships {
4610        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4611    }
4612    for channel_id in &result.removed_channels {
4613        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4614    }
4615
4616    let user_channels_update = proto::UpdateUserChannels {
4617        channel_memberships: result
4618            .new_channels
4619            .channel_memberships
4620            .iter()
4621            .map(|cm| proto::ChannelMembership {
4622                channel_id: cm.channel_id.to_proto(),
4623                role: cm.role.into(),
4624            })
4625            .collect(),
4626        ..Default::default()
4627    };
4628
4629    let mut update = build_channels_update(result.new_channels, vec![], connection_pool);
4630    update.delete_channels = result
4631        .removed_channels
4632        .into_iter()
4633        .map(|id| id.to_proto())
4634        .collect();
4635    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4636
4637    for connection_id in connection_pool.user_connection_ids(user_id) {
4638        peer.send(connection_id, user_channels_update.clone())
4639            .trace_err();
4640        peer.send(connection_id, update.clone()).trace_err();
4641    }
4642}
4643
4644fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4645    proto::UpdateUserChannels {
4646        channel_memberships: channels
4647            .channel_memberships
4648            .iter()
4649            .map(|m| proto::ChannelMembership {
4650                channel_id: m.channel_id.to_proto(),
4651                role: m.role.into(),
4652            })
4653            .collect(),
4654        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4655        observed_channel_message_id: channels.observed_channel_messages.clone(),
4656    }
4657}
4658
4659fn build_channels_update(
4660    channels: ChannelsForUser,
4661    channel_invites: Vec<db::Channel>,
4662    pool: &ConnectionPool,
4663) -> proto::UpdateChannels {
4664    let mut update = proto::UpdateChannels::default();
4665
4666    for channel in channels.channels {
4667        update.channels.push(channel.to_proto());
4668    }
4669
4670    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4671    update.latest_channel_message_ids = channels.latest_channel_messages;
4672
4673    for (channel_id, participants) in channels.channel_participants {
4674        update
4675            .channel_participants
4676            .push(proto::ChannelParticipants {
4677                channel_id: channel_id.to_proto(),
4678                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4679            });
4680    }
4681
4682    for channel in channel_invites {
4683        update.channel_invitations.push(channel.to_proto());
4684    }
4685
4686    update.hosted_projects = channels.hosted_projects;
4687    update.dev_servers = channels
4688        .dev_servers
4689        .into_iter()
4690        .map(|dev_server| dev_server.to_proto(pool.dev_server_status(dev_server.id)))
4691        .collect();
4692    update.remote_projects = channels.remote_projects;
4693
4694    update
4695}
4696
4697fn build_initial_contacts_update(
4698    contacts: Vec<db::Contact>,
4699    pool: &ConnectionPool,
4700) -> proto::UpdateContacts {
4701    let mut update = proto::UpdateContacts::default();
4702
4703    for contact in contacts {
4704        match contact {
4705            db::Contact::Accepted { user_id, busy } => {
4706                update.contacts.push(contact_for_user(user_id, busy, &pool));
4707            }
4708            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4709            db::Contact::Incoming { user_id } => {
4710                update
4711                    .incoming_requests
4712                    .push(proto::IncomingContactRequest {
4713                        requester_id: user_id.to_proto(),
4714                    })
4715            }
4716        }
4717    }
4718
4719    update
4720}
4721
4722fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4723    proto::Contact {
4724        user_id: user_id.to_proto(),
4725        online: pool.is_user_online(user_id),
4726        busy,
4727    }
4728}
4729
4730fn room_updated(room: &proto::Room, peer: &Peer) {
4731    broadcast(
4732        None,
4733        room.participants
4734            .iter()
4735            .filter_map(|participant| Some(participant.peer_id?.into())),
4736        |peer_id| {
4737            peer.send(
4738                peer_id,
4739                proto::RoomUpdated {
4740                    room: Some(room.clone()),
4741                },
4742            )
4743        },
4744    );
4745}
4746
4747fn channel_updated(
4748    channel: &db::channel::Model,
4749    room: &proto::Room,
4750    peer: &Peer,
4751    pool: &ConnectionPool,
4752) {
4753    let participants = room
4754        .participants
4755        .iter()
4756        .map(|p| p.user_id)
4757        .collect::<Vec<_>>();
4758
4759    broadcast(
4760        None,
4761        pool.channel_connection_ids(channel.root_id())
4762            .filter_map(|(channel_id, role)| {
4763                role.can_see_channel(channel.visibility).then(|| channel_id)
4764            }),
4765        |peer_id| {
4766            peer.send(
4767                peer_id,
4768                proto::UpdateChannels {
4769                    channel_participants: vec![proto::ChannelParticipants {
4770                        channel_id: channel.id.to_proto(),
4771                        participant_user_ids: participants.clone(),
4772                    }],
4773                    ..Default::default()
4774                },
4775            )
4776        },
4777    );
4778}
4779
4780async fn update_dev_server_status(
4781    dev_server: &dev_server::Model,
4782    status: proto::DevServerStatus,
4783    session: &Session,
4784) {
4785    let pool = session.connection_pool().await;
4786    let connections = pool.channel_connection_ids(dev_server.channel_id);
4787    for (connection_id, _) in connections {
4788        session
4789            .peer
4790            .send(
4791                connection_id,
4792                proto::UpdateChannels {
4793                    dev_servers: vec![dev_server.to_proto(status)],
4794                    ..Default::default()
4795                },
4796            )
4797            .trace_err();
4798    }
4799}
4800
4801async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4802    let db = session.db().await;
4803
4804    let contacts = db.get_contacts(user_id).await?;
4805    let busy = db.is_user_busy(user_id).await?;
4806
4807    let pool = session.connection_pool().await;
4808    let updated_contact = contact_for_user(user_id, busy, &pool);
4809    for contact in contacts {
4810        if let db::Contact::Accepted {
4811            user_id: contact_user_id,
4812            ..
4813        } = contact
4814        {
4815            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4816                session
4817                    .peer
4818                    .send(
4819                        contact_conn_id,
4820                        proto::UpdateContacts {
4821                            contacts: vec![updated_contact.clone()],
4822                            remove_contacts: Default::default(),
4823                            incoming_requests: Default::default(),
4824                            remove_incoming_requests: Default::default(),
4825                            outgoing_requests: Default::default(),
4826                            remove_outgoing_requests: Default::default(),
4827                        },
4828                    )
4829                    .trace_err();
4830            }
4831        }
4832    }
4833    Ok(())
4834}
4835
4836async fn lost_dev_server_connection(session: &Session) -> Result<()> {
4837    log::info!("lost dev server connection, unsharing projects");
4838    let project_ids = session
4839        .db()
4840        .await
4841        .get_stale_dev_server_projects(session.connection_id)
4842        .await?;
4843
4844    for project_id in project_ids {
4845        // not unshare re-checks the connection ids match, so we get away with no transaction
4846        unshare_project_internal(project_id, &session).await?;
4847    }
4848
4849    Ok(())
4850}
4851
4852async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4853    let mut contacts_to_update = HashSet::default();
4854
4855    let room_id;
4856    let canceled_calls_to_user_ids;
4857    let live_kit_room;
4858    let delete_live_kit_room;
4859    let room;
4860    let channel;
4861
4862    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4863        contacts_to_update.insert(session.user_id());
4864
4865        for project in left_room.left_projects.values() {
4866            project_left(project, session);
4867        }
4868
4869        room_id = RoomId::from_proto(left_room.room.id);
4870        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4871        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4872        delete_live_kit_room = left_room.deleted;
4873        room = mem::take(&mut left_room.room);
4874        channel = mem::take(&mut left_room.channel);
4875
4876        room_updated(&room, &session.peer);
4877    } else {
4878        return Ok(());
4879    }
4880
4881    if let Some(channel) = channel {
4882        channel_updated(
4883            &channel,
4884            &room,
4885            &session.peer,
4886            &*session.connection_pool().await,
4887        );
4888    }
4889
4890    {
4891        let pool = session.connection_pool().await;
4892        for canceled_user_id in canceled_calls_to_user_ids {
4893            for connection_id in pool.user_connection_ids(canceled_user_id) {
4894                session
4895                    .peer
4896                    .send(
4897                        connection_id,
4898                        proto::CallCanceled {
4899                            room_id: room_id.to_proto(),
4900                        },
4901                    )
4902                    .trace_err();
4903            }
4904            contacts_to_update.insert(canceled_user_id);
4905        }
4906    }
4907
4908    for contact_user_id in contacts_to_update {
4909        update_user_contacts(contact_user_id, &session).await?;
4910    }
4911
4912    if let Some(live_kit) = session.live_kit_client.as_ref() {
4913        live_kit
4914            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4915            .await
4916            .trace_err();
4917
4918        if delete_live_kit_room {
4919            live_kit.delete_room(live_kit_room).await.trace_err();
4920        }
4921    }
4922
4923    Ok(())
4924}
4925
4926async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4927    let left_channel_buffers = session
4928        .db()
4929        .await
4930        .leave_channel_buffers(session.connection_id)
4931        .await?;
4932
4933    for left_buffer in left_channel_buffers {
4934        channel_buffer_updated(
4935            session.connection_id,
4936            left_buffer.connections,
4937            &proto::UpdateChannelBufferCollaborators {
4938                channel_id: left_buffer.channel_id.to_proto(),
4939                collaborators: left_buffer.collaborators,
4940            },
4941            &session.peer,
4942        );
4943    }
4944
4945    Ok(())
4946}
4947
4948fn project_left(project: &db::LeftProject, session: &UserSession) {
4949    for connection_id in &project.connection_ids {
4950        if project.host_user_id == Some(session.user_id()) {
4951            session
4952                .peer
4953                .send(
4954                    *connection_id,
4955                    proto::UnshareProject {
4956                        project_id: project.id.to_proto(),
4957                    },
4958                )
4959                .trace_err();
4960        } else {
4961            session
4962                .peer
4963                .send(
4964                    *connection_id,
4965                    proto::RemoveProjectCollaborator {
4966                        project_id: project.id.to_proto(),
4967                        peer_id: Some(session.connection_id.into()),
4968                    },
4969                )
4970                .trace_err();
4971        }
4972    }
4973}
4974
4975pub trait ResultExt {
4976    type Ok;
4977
4978    fn trace_err(self) -> Option<Self::Ok>;
4979}
4980
4981impl<T, E> ResultExt for Result<T, E>
4982where
4983    E: std::fmt::Debug,
4984{
4985    type Ok = T;
4986
4987    #[track_caller]
4988    fn trace_err(self) -> Option<T> {
4989        match self {
4990            Ok(value) => Some(value),
4991            Err(error) => {
4992                tracing::error!("{:?}", error);
4993                None
4994            }
4995        }
4996    }
4997}