rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self},
   5    db::{
   6        self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
   9        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, RateLimit, RateLimiter, Result,
  13};
  14use anyhow::{anyhow, Context as _};
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34
  35use futures::{
  36    channel::oneshot,
  37    future::{self, BoxFuture},
  38    stream::FuturesUnordered,
  39    FutureExt, SinkExt, StreamExt, TryStreamExt,
  40};
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  45        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48};
  49use semantic_version::SemanticVersion;
  50use serde::{Serialize, Serializer};
  51use std::{
  52    any::TypeId,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc, OnceLock,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{
  69    field::{self},
  70    info_span, instrument, Instrument,
  71};
  72use util::http::IsahcHttpClient;
  73
  74pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  75
  76// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  77pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  78
  79const MESSAGE_COUNT_PER_PAGE: usize = 100;
  80const MAX_MESSAGE_LEN: usize = 1024;
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82
  83type MessageHandler =
  84    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  85
  86struct Response<R> {
  87    peer: Arc<Peer>,
  88    receipt: Receipt<R>,
  89    responded: Arc<AtomicBool>,
  90}
  91
  92impl<R: RequestMessage> Response<R> {
  93    fn send(self, payload: R::Response) -> Result<()> {
  94        self.responded.store(true, SeqCst);
  95        self.peer.respond(self.receipt, payload)?;
  96        Ok(())
  97    }
  98}
  99
 100struct StreamingResponse<R: RequestMessage> {
 101    peer: Arc<Peer>,
 102    receipt: Receipt<R>,
 103}
 104
 105impl<R: RequestMessage> StreamingResponse<R> {
 106    fn send(&self, payload: R::Response) -> Result<()> {
 107        self.peer.respond(self.receipt, payload)?;
 108        Ok(())
 109    }
 110}
 111
 112#[derive(Clone, Debug)]
 113pub enum Principal {
 114    User(User),
 115    Impersonated { user: User, admin: User },
 116    DevServer(dev_server::Model),
 117}
 118
 119impl Principal {
 120    fn update_span(&self, span: &tracing::Span) {
 121        match &self {
 122            Principal::User(user) => {
 123                span.record("user_id", &user.id.0);
 124                span.record("login", &user.github_login);
 125            }
 126            Principal::Impersonated { user, admin } => {
 127                span.record("user_id", &user.id.0);
 128                span.record("login", &user.github_login);
 129                span.record("impersonator", &admin.github_login);
 130            }
 131            Principal::DevServer(dev_server) => {
 132                span.record("dev_server_id", &dev_server.id.0);
 133            }
 134        }
 135    }
 136}
 137
 138#[derive(Clone)]
 139struct Session {
 140    principal: Principal,
 141    connection_id: ConnectionId,
 142    db: Arc<tokio::sync::Mutex<DbHandle>>,
 143    peer: Arc<Peer>,
 144    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 145    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 146    http_client: IsahcHttpClient,
 147    rate_limiter: Arc<RateLimiter>,
 148    _executor: Executor,
 149}
 150
 151impl Session {
 152    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.db.lock().await;
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        guard
 159    }
 160
 161    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 162        #[cfg(test)]
 163        tokio::task::yield_now().await;
 164        let guard = self.connection_pool.lock();
 165        ConnectionPoolGuard {
 166            guard,
 167            _not_send: PhantomData,
 168        }
 169    }
 170
 171    fn for_user(self) -> Option<UserSession> {
 172        UserSession::new(self)
 173    }
 174
 175    fn user_id(&self) -> Option<UserId> {
 176        match &self.principal {
 177            Principal::User(user) => Some(user.id),
 178            Principal::Impersonated { user, .. } => Some(user.id),
 179            Principal::DevServer(_) => None,
 180        }
 181    }
 182}
 183
 184impl Debug for Session {
 185    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 186        let mut result = f.debug_struct("Session");
 187        match &self.principal {
 188            Principal::User(user) => {
 189                result.field("user", &user.github_login);
 190            }
 191            Principal::Impersonated { user, admin } => {
 192                result.field("user", &user.github_login);
 193                result.field("impersonator", &admin.github_login);
 194            }
 195            Principal::DevServer(dev_server) => {
 196                result.field("dev_server", &dev_server.id);
 197            }
 198        }
 199        result.field("connection_id", &self.connection_id).finish()
 200    }
 201}
 202
 203struct UserSession(Session);
 204
 205impl UserSession {
 206    pub fn new(s: Session) -> Option<Self> {
 207        s.user_id().map(|_| UserSession(s))
 208    }
 209    pub fn user_id(&self) -> UserId {
 210        self.0.user_id().unwrap()
 211    }
 212}
 213
 214impl Deref for UserSession {
 215    type Target = Session;
 216
 217    fn deref(&self) -> &Self::Target {
 218        &self.0
 219    }
 220}
 221impl DerefMut for UserSession {
 222    fn deref_mut(&mut self) -> &mut Self::Target {
 223        &mut self.0
 224    }
 225}
 226
 227fn user_handler<M: RequestMessage, Fut>(
 228    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 229) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 230where
 231    Fut: Send + Future<Output = Result<()>>,
 232{
 233    let handler = Arc::new(handler);
 234    move |message, response, session| {
 235        let handler = handler.clone();
 236        Box::pin(async move {
 237            if let Some(user_session) = session.for_user() {
 238                Ok(handler(message, response, user_session).await?)
 239            } else {
 240                Err(Error::Internal(anyhow!("must be a user")))
 241            }
 242        })
 243    }
 244}
 245
 246fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 247    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 248) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 249where
 250    InnertRetFut: Send + Future<Output = Result<()>>,
 251{
 252    let handler = Arc::new(handler);
 253    move |message, session| {
 254        let handler = handler.clone();
 255        Box::pin(async move {
 256            if let Some(user_session) = session.for_user() {
 257                Ok(handler(message, user_session).await?)
 258            } else {
 259                Err(Error::Internal(anyhow!("must be a user")))
 260            }
 261        })
 262    }
 263}
 264
 265struct DbHandle(Arc<Database>);
 266
 267impl Deref for DbHandle {
 268    type Target = Database;
 269
 270    fn deref(&self) -> &Self::Target {
 271        self.0.as_ref()
 272    }
 273}
 274
 275pub struct Server {
 276    id: parking_lot::Mutex<ServerId>,
 277    peer: Arc<Peer>,
 278    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 279    app_state: Arc<AppState>,
 280    handlers: HashMap<TypeId, MessageHandler>,
 281    teardown: watch::Sender<bool>,
 282}
 283
 284pub(crate) struct ConnectionPoolGuard<'a> {
 285    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 286    _not_send: PhantomData<Rc<()>>,
 287}
 288
 289#[derive(Serialize)]
 290pub struct ServerSnapshot<'a> {
 291    peer: &'a Peer,
 292    #[serde(serialize_with = "serialize_deref")]
 293    connection_pool: ConnectionPoolGuard<'a>,
 294}
 295
 296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 297where
 298    S: Serializer,
 299    T: Deref<Target = U>,
 300    U: Serialize,
 301{
 302    Serialize::serialize(value.deref(), serializer)
 303}
 304
 305impl Server {
 306    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 307        let mut server = Self {
 308            id: parking_lot::Mutex::new(id),
 309            peer: Peer::new(id.0 as u32),
 310            app_state: app_state.clone(),
 311            connection_pool: Default::default(),
 312            handlers: Default::default(),
 313            teardown: watch::channel(false).0,
 314        };
 315
 316        server
 317            .add_request_handler(ping)
 318            .add_request_handler(user_handler(create_room))
 319            .add_request_handler(user_handler(join_room))
 320            .add_request_handler(user_handler(rejoin_room))
 321            .add_request_handler(user_handler(leave_room))
 322            .add_request_handler(user_handler(set_room_participant_role))
 323            .add_request_handler(user_handler(call))
 324            .add_request_handler(user_handler(cancel_call))
 325            .add_message_handler(user_message_handler(decline_call))
 326            .add_request_handler(user_handler(update_participant_location))
 327            .add_request_handler(share_project)
 328            .add_message_handler(unshare_project)
 329            .add_request_handler(user_handler(join_project))
 330            .add_request_handler(user_handler(join_hosted_project))
 331            .add_message_handler(user_message_handler(leave_project))
 332            .add_request_handler(update_project)
 333            .add_request_handler(update_worktree)
 334            .add_message_handler(start_language_server)
 335            .add_message_handler(update_language_server)
 336            .add_message_handler(update_diagnostic_summary)
 337            .add_message_handler(update_worktree_settings)
 338            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 342            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 346            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 347            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 348            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 349            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 350            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 353            )
 354            .add_request_handler(
 355                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 356            )
 357            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 358            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 359            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 360            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 361            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 362            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 369            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 370            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 371            .add_message_handler(create_buffer_for_peer)
 372            .add_request_handler(update_buffer)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 374            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 375            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 376            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 377            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 378            .add_request_handler(get_users)
 379            .add_request_handler(user_handler(fuzzy_search_users))
 380            .add_request_handler(user_handler(request_contact))
 381            .add_request_handler(user_handler(remove_contact))
 382            .add_request_handler(user_handler(respond_to_contact_request))
 383            .add_request_handler(user_handler(create_channel))
 384            .add_request_handler(user_handler(delete_channel))
 385            .add_request_handler(user_handler(invite_channel_member))
 386            .add_request_handler(user_handler(remove_channel_member))
 387            .add_request_handler(user_handler(set_channel_member_role))
 388            .add_request_handler(user_handler(set_channel_visibility))
 389            .add_request_handler(user_handler(rename_channel))
 390            .add_request_handler(user_handler(join_channel_buffer))
 391            .add_request_handler(user_handler(leave_channel_buffer))
 392            .add_message_handler(user_message_handler(update_channel_buffer))
 393            .add_request_handler(user_handler(rejoin_channel_buffers))
 394            .add_request_handler(user_handler(get_channel_members))
 395            .add_request_handler(user_handler(respond_to_channel_invite))
 396            .add_request_handler(user_handler(join_channel))
 397            .add_request_handler(user_handler(join_channel_chat))
 398            .add_message_handler(user_message_handler(leave_channel_chat))
 399            .add_request_handler(user_handler(send_channel_message))
 400            .add_request_handler(user_handler(remove_channel_message))
 401            .add_request_handler(user_handler(update_channel_message))
 402            .add_request_handler(user_handler(get_channel_messages))
 403            .add_request_handler(user_handler(get_channel_messages_by_id))
 404            .add_request_handler(user_handler(get_notifications))
 405            .add_request_handler(user_handler(mark_notification_as_read))
 406            .add_request_handler(user_handler(move_channel))
 407            .add_request_handler(user_handler(follow))
 408            .add_message_handler(user_message_handler(unfollow))
 409            .add_message_handler(user_message_handler(update_followers))
 410            .add_request_handler(user_handler(get_private_user_info))
 411            .add_message_handler(user_message_handler(acknowledge_channel_message))
 412            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 413            .add_streaming_request_handler({
 414                let app_state = app_state.clone();
 415                move |request, response, session| {
 416                    complete_with_language_model(
 417                        request,
 418                        response,
 419                        session,
 420                        app_state.config.openai_api_key.clone(),
 421                        app_state.config.google_ai_api_key.clone(),
 422                    )
 423                }
 424            })
 425            .add_request_handler({
 426                let app_state = app_state.clone();
 427                user_handler(move |request, response, session| {
 428                    count_tokens_with_language_model(
 429                        request,
 430                        response,
 431                        session,
 432                        app_state.config.google_ai_api_key.clone(),
 433                    )
 434                })
 435            });
 436
 437        Arc::new(server)
 438    }
 439
 440    pub async fn start(&self) -> Result<()> {
 441        let server_id = *self.id.lock();
 442        let app_state = self.app_state.clone();
 443        let peer = self.peer.clone();
 444        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 445        let pool = self.connection_pool.clone();
 446        let live_kit_client = self.app_state.live_kit_client.clone();
 447
 448        let span = info_span!("start server");
 449        self.app_state.executor.spawn_detached(
 450            async move {
 451                tracing::info!("waiting for cleanup timeout");
 452                timeout.await;
 453                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 454                if let Some((room_ids, channel_ids)) = app_state
 455                    .db
 456                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 457                    .await
 458                    .trace_err()
 459                {
 460                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 461                    tracing::info!(
 462                        stale_channel_buffer_count = channel_ids.len(),
 463                        "retrieved stale channel buffers"
 464                    );
 465
 466                    for channel_id in channel_ids {
 467                        if let Some(refreshed_channel_buffer) = app_state
 468                            .db
 469                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 470                            .await
 471                            .trace_err()
 472                        {
 473                            for connection_id in refreshed_channel_buffer.connection_ids {
 474                                peer.send(
 475                                    connection_id,
 476                                    proto::UpdateChannelBufferCollaborators {
 477                                        channel_id: channel_id.to_proto(),
 478                                        collaborators: refreshed_channel_buffer
 479                                            .collaborators
 480                                            .clone(),
 481                                    },
 482                                )
 483                                .trace_err();
 484                            }
 485                        }
 486                    }
 487
 488                    for room_id in room_ids {
 489                        let mut contacts_to_update = HashSet::default();
 490                        let mut canceled_calls_to_user_ids = Vec::new();
 491                        let mut live_kit_room = String::new();
 492                        let mut delete_live_kit_room = false;
 493
 494                        if let Some(mut refreshed_room) = app_state
 495                            .db
 496                            .clear_stale_room_participants(room_id, server_id)
 497                            .await
 498                            .trace_err()
 499                        {
 500                            tracing::info!(
 501                                room_id = room_id.0,
 502                                new_participant_count = refreshed_room.room.participants.len(),
 503                                "refreshed room"
 504                            );
 505                            room_updated(&refreshed_room.room, &peer);
 506                            if let Some(channel) = refreshed_room.channel.as_ref() {
 507                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 508                            }
 509                            contacts_to_update
 510                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 511                            contacts_to_update
 512                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 513                            canceled_calls_to_user_ids =
 514                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 515                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 516                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 517                        }
 518
 519                        {
 520                            let pool = pool.lock();
 521                            for canceled_user_id in canceled_calls_to_user_ids {
 522                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 523                                    peer.send(
 524                                        connection_id,
 525                                        proto::CallCanceled {
 526                                            room_id: room_id.to_proto(),
 527                                        },
 528                                    )
 529                                    .trace_err();
 530                                }
 531                            }
 532                        }
 533
 534                        for user_id in contacts_to_update {
 535                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 536                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 537                            if let Some((busy, contacts)) = busy.zip(contacts) {
 538                                let pool = pool.lock();
 539                                let updated_contact = contact_for_user(user_id, busy, &pool);
 540                                for contact in contacts {
 541                                    if let db::Contact::Accepted {
 542                                        user_id: contact_user_id,
 543                                        ..
 544                                    } = contact
 545                                    {
 546                                        for contact_conn_id in
 547                                            pool.user_connection_ids(contact_user_id)
 548                                        {
 549                                            peer.send(
 550                                                contact_conn_id,
 551                                                proto::UpdateContacts {
 552                                                    contacts: vec![updated_contact.clone()],
 553                                                    remove_contacts: Default::default(),
 554                                                    incoming_requests: Default::default(),
 555                                                    remove_incoming_requests: Default::default(),
 556                                                    outgoing_requests: Default::default(),
 557                                                    remove_outgoing_requests: Default::default(),
 558                                                },
 559                                            )
 560                                            .trace_err();
 561                                        }
 562                                    }
 563                                }
 564                            }
 565                        }
 566
 567                        if let Some(live_kit) = live_kit_client.as_ref() {
 568                            if delete_live_kit_room {
 569                                live_kit.delete_room(live_kit_room).await.trace_err();
 570                            }
 571                        }
 572                    }
 573                }
 574
 575                app_state
 576                    .db
 577                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 578                    .await
 579                    .trace_err();
 580            }
 581            .instrument(span),
 582        );
 583        Ok(())
 584    }
 585
 586    pub fn teardown(&self) {
 587        self.peer.teardown();
 588        self.connection_pool.lock().reset();
 589        let _ = self.teardown.send(true);
 590    }
 591
 592    #[cfg(test)]
 593    pub fn reset(&self, id: ServerId) {
 594        self.teardown();
 595        *self.id.lock() = id;
 596        self.peer.reset(id.0 as u32);
 597        let _ = self.teardown.send(false);
 598    }
 599
 600    #[cfg(test)]
 601    pub fn id(&self) -> ServerId {
 602        *self.id.lock()
 603    }
 604
 605    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 606    where
 607        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 608        Fut: 'static + Send + Future<Output = Result<()>>,
 609        M: EnvelopedMessage,
 610    {
 611        let prev_handler = self.handlers.insert(
 612            TypeId::of::<M>(),
 613            Box::new(move |envelope, session| {
 614                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 615                let received_at = envelope.received_at;
 616                    tracing::info!(
 617                        "message received"
 618                    );
 619                let start_time = Instant::now();
 620                let future = (handler)(*envelope, session);
 621                async move {
 622                    let result = future.await;
 623                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 624                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 625                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 626                    match result {
 627                        Err(error) => {
 628                            tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
 629                        }
 630                        Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
 631                    }
 632                }
 633                .boxed()
 634            }),
 635        );
 636        if prev_handler.is_some() {
 637            panic!("registered a handler for the same message twice");
 638        }
 639        self
 640    }
 641
 642    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 643    where
 644        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 645        Fut: 'static + Send + Future<Output = Result<()>>,
 646        M: EnvelopedMessage,
 647    {
 648        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 649        self
 650    }
 651
 652    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 653    where
 654        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 655        Fut: Send + Future<Output = Result<()>>,
 656        M: RequestMessage,
 657    {
 658        let handler = Arc::new(handler);
 659        self.add_handler(move |envelope, session| {
 660            let receipt = envelope.receipt();
 661            let handler = handler.clone();
 662            async move {
 663                let peer = session.peer.clone();
 664                let responded = Arc::new(AtomicBool::default());
 665                let response = Response {
 666                    peer: peer.clone(),
 667                    responded: responded.clone(),
 668                    receipt,
 669                };
 670                match (handler)(envelope.payload, response, session).await {
 671                    Ok(()) => {
 672                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 673                            Ok(())
 674                        } else {
 675                            Err(anyhow!("handler did not send a response"))?
 676                        }
 677                    }
 678                    Err(error) => {
 679                        let proto_err = match &error {
 680                            Error::Internal(err) => err.to_proto(),
 681                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 682                        };
 683                        peer.respond_with_error(receipt, proto_err)?;
 684                        Err(error)
 685                    }
 686                }
 687            }
 688        })
 689    }
 690
 691    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 692    where
 693        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 694        Fut: Send + Future<Output = Result<()>>,
 695        M: RequestMessage,
 696    {
 697        let handler = Arc::new(handler);
 698        self.add_handler(move |envelope, session| {
 699            let receipt = envelope.receipt();
 700            let handler = handler.clone();
 701            async move {
 702                let peer = session.peer.clone();
 703                let response = StreamingResponse {
 704                    peer: peer.clone(),
 705                    receipt,
 706                };
 707                match (handler)(envelope.payload, response, session).await {
 708                    Ok(()) => {
 709                        peer.end_stream(receipt)?;
 710                        Ok(())
 711                    }
 712                    Err(error) => {
 713                        let proto_err = match &error {
 714                            Error::Internal(err) => err.to_proto(),
 715                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 716                        };
 717                        peer.respond_with_error(receipt, proto_err)?;
 718                        Err(error)
 719                    }
 720                }
 721            }
 722        })
 723    }
 724
 725    #[allow(clippy::too_many_arguments)]
 726    pub fn handle_connection(
 727        self: &Arc<Self>,
 728        connection: Connection,
 729        address: String,
 730        principal: Principal,
 731        zed_version: ZedVersion,
 732        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 733        executor: Executor,
 734    ) -> impl Future<Output = ()> {
 735        let this = self.clone();
 736        let span = info_span!("handle connection", %address, impersonator = field::Empty, connection_id = field::Empty);
 737        principal.update_span(&span);
 738
 739        let mut teardown = self.teardown.subscribe();
 740        async move {
 741            if *teardown.borrow() {
 742                tracing::error!("server is tearing down");
 743                return
 744            }
 745            let (connection_id, handle_io, mut incoming_rx) = this
 746                .peer
 747                .add_connection(connection, {
 748                    let executor = executor.clone();
 749                    move |duration| executor.sleep(duration)
 750                });
 751            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 752            tracing::info!("connection opened");
 753
 754            let http_client = match IsahcHttpClient::new() {
 755                Ok(http_client) => http_client,
 756                Err(error) => {
 757                    tracing::error!(?error, "failed to create HTTP client");
 758                    return;
 759                }
 760            };
 761
 762            let session = Session {
 763                principal: principal.clone(),
 764                connection_id,
 765                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 766                peer: this.peer.clone(),
 767                connection_pool: this.connection_pool.clone(),
 768                live_kit_client: this.app_state.live_kit_client.clone(),
 769                http_client,
 770                rate_limiter: this.app_state.rate_limiter.clone(),
 771                _executor: executor.clone(),
 772            };
 773
 774            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 775                tracing::error!(?error, "failed to send initial client update");
 776                return;
 777            }
 778
 779            let handle_io = handle_io.fuse();
 780            futures::pin_mut!(handle_io);
 781
 782            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 783            // This prevents deadlocks when e.g., client A performs a request to client B and
 784            // client B performs a request to client A. If both clients stop processing further
 785            // messages until their respective request completes, they won't have a chance to
 786            // respond to the other client's request and cause a deadlock.
 787            //
 788            // This arrangement ensures we will attempt to process earlier messages first, but fall
 789            // back to processing messages arrived later in the spirit of making progress.
 790            let mut foreground_message_handlers = FuturesUnordered::new();
 791            let concurrent_handlers = Arc::new(Semaphore::new(256));
 792            loop {
 793                let next_message = async {
 794                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 795                    let message = incoming_rx.next().await;
 796                    (permit, message)
 797                }.fuse();
 798                futures::pin_mut!(next_message);
 799                futures::select_biased! {
 800                    _ = teardown.changed().fuse() => return,
 801                    result = handle_io => {
 802                        if let Err(error) = result {
 803                            tracing::error!(?error, "error handling I/O");
 804                        }
 805                        break;
 806                    }
 807                    _ = foreground_message_handlers.next() => {}
 808                    next_message = next_message => {
 809                        let (permit, message) = next_message;
 810                        if let Some(message) = message {
 811                            let type_name = message.payload_type_name();
 812                            // note: we copy all the fields from the parent span so we can query them in the logs.
 813                            // (https://github.com/tokio-rs/tracing/issues/2670).
 814                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name);
 815                            principal.update_span(&span);
 816                            let span_enter = span.enter();
 817                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 818                                let is_background = message.is_background();
 819                                let handle_message = (handler)(message, session.clone());
 820                                drop(span_enter);
 821
 822                                let handle_message = async move {
 823                                    handle_message.await;
 824                                    drop(permit);
 825                                }.instrument(span);
 826                                if is_background {
 827                                    executor.spawn_detached(handle_message);
 828                                } else {
 829                                    foreground_message_handlers.push(handle_message);
 830                                }
 831                            } else {
 832                                tracing::error!("no message handler");
 833                            }
 834                        } else {
 835                            tracing::info!("connection closed");
 836                            break;
 837                        }
 838                    }
 839                }
 840            }
 841
 842            drop(foreground_message_handlers);
 843            tracing::info!("signing out");
 844            if let Err(error) = connection_lost(session, teardown, executor).await {
 845                tracing::error!(?error, "error signing out");
 846            }
 847
 848        }.instrument(span)
 849    }
 850
 851    async fn send_initial_client_update(
 852        &self,
 853        connection_id: ConnectionId,
 854        principal: &Principal,
 855        zed_version: ZedVersion,
 856        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 857        session: &Session,
 858    ) -> Result<()> {
 859        self.peer.send(
 860            connection_id,
 861            proto::Hello {
 862                peer_id: Some(connection_id.into()),
 863            },
 864        )?;
 865        tracing::info!("sent hello message");
 866
 867        let Principal::User(user) = principal else {
 868            return Ok(());
 869        };
 870
 871        if let Some(send_connection_id) = send_connection_id.take() {
 872            let _ = send_connection_id.send(connection_id);
 873        }
 874
 875        if !user.connected_once {
 876            self.peer.send(connection_id, proto::ShowContacts {})?;
 877            self.app_state
 878                .db
 879                .set_user_connected_once(user.id, true)
 880                .await?;
 881        }
 882
 883        let (contacts, channels_for_user, channel_invites) = future::try_join3(
 884            self.app_state.db.get_contacts(user.id),
 885            self.app_state.db.get_channels_for_user(user.id),
 886            self.app_state.db.get_channel_invites_for_user(user.id),
 887        )
 888        .await?;
 889
 890        {
 891            let mut pool = self.connection_pool.lock();
 892            pool.add_connection(connection_id, user.id, user.admin, zed_version);
 893            for membership in &channels_for_user.channel_memberships {
 894                pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
 895            }
 896            self.peer.send(
 897                connection_id,
 898                build_initial_contacts_update(contacts, &pool),
 899            )?;
 900            self.peer.send(
 901                connection_id,
 902                build_update_user_channels(&channels_for_user),
 903            )?;
 904            self.peer.send(
 905                connection_id,
 906                build_channels_update(channels_for_user, channel_invites),
 907            )?;
 908        }
 909
 910        if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
 911            self.peer.send(connection_id, incoming_call)?;
 912        }
 913
 914        update_user_contacts(user.id, &session).await?;
 915        Ok(())
 916    }
 917
 918    pub async fn invite_code_redeemed(
 919        self: &Arc<Self>,
 920        inviter_id: UserId,
 921        invitee_id: UserId,
 922    ) -> Result<()> {
 923        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 924            if let Some(code) = &user.invite_code {
 925                let pool = self.connection_pool.lock();
 926                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 927                for connection_id in pool.user_connection_ids(inviter_id) {
 928                    self.peer.send(
 929                        connection_id,
 930                        proto::UpdateContacts {
 931                            contacts: vec![invitee_contact.clone()],
 932                            ..Default::default()
 933                        },
 934                    )?;
 935                    self.peer.send(
 936                        connection_id,
 937                        proto::UpdateInviteInfo {
 938                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 939                            count: user.invite_count as u32,
 940                        },
 941                    )?;
 942                }
 943            }
 944        }
 945        Ok(())
 946    }
 947
 948    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 949        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 950            if let Some(invite_code) = &user.invite_code {
 951                let pool = self.connection_pool.lock();
 952                for connection_id in pool.user_connection_ids(user_id) {
 953                    self.peer.send(
 954                        connection_id,
 955                        proto::UpdateInviteInfo {
 956                            url: format!(
 957                                "{}{}",
 958                                self.app_state.config.invite_link_prefix, invite_code
 959                            ),
 960                            count: user.invite_count as u32,
 961                        },
 962                    )?;
 963                }
 964            }
 965        }
 966        Ok(())
 967    }
 968
 969    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 970        ServerSnapshot {
 971            connection_pool: ConnectionPoolGuard {
 972                guard: self.connection_pool.lock(),
 973                _not_send: PhantomData,
 974            },
 975            peer: &self.peer,
 976        }
 977    }
 978}
 979
 980impl<'a> Deref for ConnectionPoolGuard<'a> {
 981    type Target = ConnectionPool;
 982
 983    fn deref(&self) -> &Self::Target {
 984        &self.guard
 985    }
 986}
 987
 988impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 989    fn deref_mut(&mut self) -> &mut Self::Target {
 990        &mut self.guard
 991    }
 992}
 993
 994impl<'a> Drop for ConnectionPoolGuard<'a> {
 995    fn drop(&mut self) {
 996        #[cfg(test)]
 997        self.check_invariants();
 998    }
 999}
1000
1001fn broadcast<F>(
1002    sender_id: Option<ConnectionId>,
1003    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1004    mut f: F,
1005) where
1006    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1007{
1008    for receiver_id in receiver_ids {
1009        if Some(receiver_id) != sender_id {
1010            if let Err(error) = f(receiver_id) {
1011                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1012            }
1013        }
1014    }
1015}
1016
1017pub struct ProtocolVersion(u32);
1018
1019impl Header for ProtocolVersion {
1020    fn name() -> &'static HeaderName {
1021        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1022        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1023    }
1024
1025    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1026    where
1027        Self: Sized,
1028        I: Iterator<Item = &'i axum::http::HeaderValue>,
1029    {
1030        let version = values
1031            .next()
1032            .ok_or_else(axum::headers::Error::invalid)?
1033            .to_str()
1034            .map_err(|_| axum::headers::Error::invalid())?
1035            .parse()
1036            .map_err(|_| axum::headers::Error::invalid())?;
1037        Ok(Self(version))
1038    }
1039
1040    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1041        values.extend([self.0.to_string().parse().unwrap()]);
1042    }
1043}
1044
1045pub struct AppVersionHeader(SemanticVersion);
1046impl Header for AppVersionHeader {
1047    fn name() -> &'static HeaderName {
1048        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1049        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1050    }
1051
1052    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1053    where
1054        Self: Sized,
1055        I: Iterator<Item = &'i axum::http::HeaderValue>,
1056    {
1057        let version = values
1058            .next()
1059            .ok_or_else(axum::headers::Error::invalid)?
1060            .to_str()
1061            .map_err(|_| axum::headers::Error::invalid())?
1062            .parse()
1063            .map_err(|_| axum::headers::Error::invalid())?;
1064        Ok(Self(version))
1065    }
1066
1067    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1068        values.extend([self.0.to_string().parse().unwrap()]);
1069    }
1070}
1071
1072pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1073    Router::new()
1074        .route("/rpc", get(handle_websocket_request))
1075        .layer(
1076            ServiceBuilder::new()
1077                .layer(Extension(server.app_state.clone()))
1078                .layer(middleware::from_fn(auth::validate_header)),
1079        )
1080        .route("/metrics", get(handle_metrics))
1081        .layer(Extension(server))
1082}
1083
1084pub async fn handle_websocket_request(
1085    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1086    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1087    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1088    Extension(server): Extension<Arc<Server>>,
1089    Extension(principal): Extension<Principal>,
1090    ws: WebSocketUpgrade,
1091) -> axum::response::Response {
1092    if protocol_version != rpc::PROTOCOL_VERSION {
1093        return (
1094            StatusCode::UPGRADE_REQUIRED,
1095            "client must be upgraded".to_string(),
1096        )
1097            .into_response();
1098    }
1099
1100    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1101        return (
1102            StatusCode::UPGRADE_REQUIRED,
1103            "no version header found".to_string(),
1104        )
1105            .into_response();
1106    };
1107
1108    if !version.can_collaborate() {
1109        return (
1110            StatusCode::UPGRADE_REQUIRED,
1111            "client must be upgraded".to_string(),
1112        )
1113            .into_response();
1114    }
1115
1116    let socket_address = socket_address.to_string();
1117    ws.on_upgrade(move |socket| {
1118        let socket = socket
1119            .map_ok(to_tungstenite_message)
1120            .err_into()
1121            .with(|message| async move { Ok(to_axum_message(message)) });
1122        let connection = Connection::new(Box::pin(socket));
1123        async move {
1124            server
1125                .handle_connection(
1126                    connection,
1127                    socket_address,
1128                    principal,
1129                    version,
1130                    None,
1131                    Executor::Production,
1132                )
1133                .await;
1134        }
1135    })
1136}
1137
1138pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1139    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1140    let connections_metric = CONNECTIONS_METRIC
1141        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1142
1143    let connections = server
1144        .connection_pool
1145        .lock()
1146        .connections()
1147        .filter(|connection| !connection.admin)
1148        .count();
1149    connections_metric.set(connections as _);
1150
1151    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1152    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1153        register_int_gauge!(
1154            "shared_projects",
1155            "number of open projects with one or more guests"
1156        )
1157        .unwrap()
1158    });
1159
1160    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1161    shared_projects_metric.set(shared_projects as _);
1162
1163    let encoder = prometheus::TextEncoder::new();
1164    let metric_families = prometheus::gather();
1165    let encoded_metrics = encoder
1166        .encode_to_string(&metric_families)
1167        .map_err(|err| anyhow!("{}", err))?;
1168    Ok(encoded_metrics)
1169}
1170
1171#[instrument(err, skip(executor))]
1172async fn connection_lost(
1173    session: Session,
1174    mut teardown: watch::Receiver<bool>,
1175    executor: Executor,
1176) -> Result<()> {
1177    session.peer.disconnect(session.connection_id);
1178    session
1179        .connection_pool()
1180        .await
1181        .remove_connection(session.connection_id)?;
1182
1183    session
1184        .db()
1185        .await
1186        .connection_lost(session.connection_id)
1187        .await
1188        .trace_err();
1189
1190    futures::select_biased! {
1191        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1192            if let Some(session) = session.for_user() {
1193                log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1194                leave_room_for_session(&session).await.trace_err();
1195                leave_channel_buffers_for_session(&session)
1196                    .await
1197                    .trace_err();
1198
1199                if !session
1200                    .connection_pool()
1201                    .await
1202                    .is_user_online(session.user_id())
1203                {
1204                    let db = session.db().await;
1205                    if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1206                        room_updated(&room, &session.peer);
1207                    }
1208                }
1209
1210                update_user_contacts(session.user_id(), &session).await?;
1211            }
1212        }
1213        _ = teardown.changed().fuse() => {}
1214    }
1215
1216    Ok(())
1217}
1218
1219/// Acknowledges a ping from a client, used to keep the connection alive.
1220async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1221    response.send(proto::Ack {})?;
1222    Ok(())
1223}
1224
1225/// Creates a new room for calling (outside of channels)
1226async fn create_room(
1227    _request: proto::CreateRoom,
1228    response: Response<proto::CreateRoom>,
1229    session: UserSession,
1230) -> Result<()> {
1231    let live_kit_room = nanoid::nanoid!(30);
1232
1233    let live_kit_connection_info = util::maybe!(async {
1234        let live_kit = session.live_kit_client.as_ref();
1235        let live_kit = live_kit?;
1236        let user_id = session.user_id().to_string();
1237
1238        let token = live_kit
1239            .room_token(&live_kit_room, &user_id.to_string())
1240            .trace_err()?;
1241
1242        Some(proto::LiveKitConnectionInfo {
1243            server_url: live_kit.url().into(),
1244            token,
1245            can_publish: true,
1246        })
1247    })
1248    .await;
1249
1250    let room = session
1251        .db()
1252        .await
1253        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1254        .await?;
1255
1256    response.send(proto::CreateRoomResponse {
1257        room: Some(room.clone()),
1258        live_kit_connection_info,
1259    })?;
1260
1261    update_user_contacts(session.user_id(), &session).await?;
1262    Ok(())
1263}
1264
1265/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1266async fn join_room(
1267    request: proto::JoinRoom,
1268    response: Response<proto::JoinRoom>,
1269    session: UserSession,
1270) -> Result<()> {
1271    let room_id = RoomId::from_proto(request.id);
1272
1273    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1274
1275    if let Some(channel_id) = channel_id {
1276        return join_channel_internal(channel_id, Box::new(response), session).await;
1277    }
1278
1279    let joined_room = {
1280        let room = session
1281            .db()
1282            .await
1283            .join_room(room_id, session.user_id(), session.connection_id)
1284            .await?;
1285        room_updated(&room.room, &session.peer);
1286        room.into_inner()
1287    };
1288
1289    for connection_id in session
1290        .connection_pool()
1291        .await
1292        .user_connection_ids(session.user_id())
1293    {
1294        session
1295            .peer
1296            .send(
1297                connection_id,
1298                proto::CallCanceled {
1299                    room_id: room_id.to_proto(),
1300                },
1301            )
1302            .trace_err();
1303    }
1304
1305    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1306        if let Some(token) = live_kit
1307            .room_token(
1308                &joined_room.room.live_kit_room,
1309                &session.user_id().to_string(),
1310            )
1311            .trace_err()
1312        {
1313            Some(proto::LiveKitConnectionInfo {
1314                server_url: live_kit.url().into(),
1315                token,
1316                can_publish: true,
1317            })
1318        } else {
1319            None
1320        }
1321    } else {
1322        None
1323    };
1324
1325    response.send(proto::JoinRoomResponse {
1326        room: Some(joined_room.room),
1327        channel_id: None,
1328        live_kit_connection_info,
1329    })?;
1330
1331    update_user_contacts(session.user_id(), &session).await?;
1332    Ok(())
1333}
1334
1335/// Rejoin room is used to reconnect to a room after connection errors.
1336async fn rejoin_room(
1337    request: proto::RejoinRoom,
1338    response: Response<proto::RejoinRoom>,
1339    session: UserSession,
1340) -> Result<()> {
1341    let room;
1342    let channel;
1343    {
1344        let mut rejoined_room = session
1345            .db()
1346            .await
1347            .rejoin_room(request, session.user_id(), session.connection_id)
1348            .await?;
1349
1350        response.send(proto::RejoinRoomResponse {
1351            room: Some(rejoined_room.room.clone()),
1352            reshared_projects: rejoined_room
1353                .reshared_projects
1354                .iter()
1355                .map(|project| proto::ResharedProject {
1356                    id: project.id.to_proto(),
1357                    collaborators: project
1358                        .collaborators
1359                        .iter()
1360                        .map(|collaborator| collaborator.to_proto())
1361                        .collect(),
1362                })
1363                .collect(),
1364            rejoined_projects: rejoined_room
1365                .rejoined_projects
1366                .iter()
1367                .map(|rejoined_project| proto::RejoinedProject {
1368                    id: rejoined_project.id.to_proto(),
1369                    worktrees: rejoined_project
1370                        .worktrees
1371                        .iter()
1372                        .map(|worktree| proto::WorktreeMetadata {
1373                            id: worktree.id,
1374                            root_name: worktree.root_name.clone(),
1375                            visible: worktree.visible,
1376                            abs_path: worktree.abs_path.clone(),
1377                        })
1378                        .collect(),
1379                    collaborators: rejoined_project
1380                        .collaborators
1381                        .iter()
1382                        .map(|collaborator| collaborator.to_proto())
1383                        .collect(),
1384                    language_servers: rejoined_project.language_servers.clone(),
1385                })
1386                .collect(),
1387        })?;
1388        room_updated(&rejoined_room.room, &session.peer);
1389
1390        for project in &rejoined_room.reshared_projects {
1391            for collaborator in &project.collaborators {
1392                session
1393                    .peer
1394                    .send(
1395                        collaborator.connection_id,
1396                        proto::UpdateProjectCollaborator {
1397                            project_id: project.id.to_proto(),
1398                            old_peer_id: Some(project.old_connection_id.into()),
1399                            new_peer_id: Some(session.connection_id.into()),
1400                        },
1401                    )
1402                    .trace_err();
1403            }
1404
1405            broadcast(
1406                Some(session.connection_id),
1407                project
1408                    .collaborators
1409                    .iter()
1410                    .map(|collaborator| collaborator.connection_id),
1411                |connection_id| {
1412                    session.peer.forward_send(
1413                        session.connection_id,
1414                        connection_id,
1415                        proto::UpdateProject {
1416                            project_id: project.id.to_proto(),
1417                            worktrees: project.worktrees.clone(),
1418                        },
1419                    )
1420                },
1421            );
1422        }
1423
1424        for project in &rejoined_room.rejoined_projects {
1425            for collaborator in &project.collaborators {
1426                session
1427                    .peer
1428                    .send(
1429                        collaborator.connection_id,
1430                        proto::UpdateProjectCollaborator {
1431                            project_id: project.id.to_proto(),
1432                            old_peer_id: Some(project.old_connection_id.into()),
1433                            new_peer_id: Some(session.connection_id.into()),
1434                        },
1435                    )
1436                    .trace_err();
1437            }
1438        }
1439
1440        for project in &mut rejoined_room.rejoined_projects {
1441            for worktree in mem::take(&mut project.worktrees) {
1442                #[cfg(any(test, feature = "test-support"))]
1443                const MAX_CHUNK_SIZE: usize = 2;
1444                #[cfg(not(any(test, feature = "test-support")))]
1445                const MAX_CHUNK_SIZE: usize = 256;
1446
1447                // Stream this worktree's entries.
1448                let message = proto::UpdateWorktree {
1449                    project_id: project.id.to_proto(),
1450                    worktree_id: worktree.id,
1451                    abs_path: worktree.abs_path.clone(),
1452                    root_name: worktree.root_name,
1453                    updated_entries: worktree.updated_entries,
1454                    removed_entries: worktree.removed_entries,
1455                    scan_id: worktree.scan_id,
1456                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1457                    updated_repositories: worktree.updated_repositories,
1458                    removed_repositories: worktree.removed_repositories,
1459                };
1460                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1461                    session.peer.send(session.connection_id, update.clone())?;
1462                }
1463
1464                // Stream this worktree's diagnostics.
1465                for summary in worktree.diagnostic_summaries {
1466                    session.peer.send(
1467                        session.connection_id,
1468                        proto::UpdateDiagnosticSummary {
1469                            project_id: project.id.to_proto(),
1470                            worktree_id: worktree.id,
1471                            summary: Some(summary),
1472                        },
1473                    )?;
1474                }
1475
1476                for settings_file in worktree.settings_files {
1477                    session.peer.send(
1478                        session.connection_id,
1479                        proto::UpdateWorktreeSettings {
1480                            project_id: project.id.to_proto(),
1481                            worktree_id: worktree.id,
1482                            path: settings_file.path,
1483                            content: Some(settings_file.content),
1484                        },
1485                    )?;
1486                }
1487            }
1488
1489            for language_server in &project.language_servers {
1490                session.peer.send(
1491                    session.connection_id,
1492                    proto::UpdateLanguageServer {
1493                        project_id: project.id.to_proto(),
1494                        language_server_id: language_server.id,
1495                        variant: Some(
1496                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1497                                proto::LspDiskBasedDiagnosticsUpdated {},
1498                            ),
1499                        ),
1500                    },
1501                )?;
1502            }
1503        }
1504
1505        let rejoined_room = rejoined_room.into_inner();
1506
1507        room = rejoined_room.room;
1508        channel = rejoined_room.channel;
1509    }
1510
1511    if let Some(channel) = channel {
1512        channel_updated(
1513            &channel,
1514            &room,
1515            &session.peer,
1516            &*session.connection_pool().await,
1517        );
1518    }
1519
1520    update_user_contacts(session.user_id(), &session).await?;
1521    Ok(())
1522}
1523
1524/// leave room disconnects from the room.
1525async fn leave_room(
1526    _: proto::LeaveRoom,
1527    response: Response<proto::LeaveRoom>,
1528    session: UserSession,
1529) -> Result<()> {
1530    leave_room_for_session(&session).await?;
1531    response.send(proto::Ack {})?;
1532    Ok(())
1533}
1534
1535/// Updates the permissions of someone else in the room.
1536async fn set_room_participant_role(
1537    request: proto::SetRoomParticipantRole,
1538    response: Response<proto::SetRoomParticipantRole>,
1539    session: UserSession,
1540) -> Result<()> {
1541    let user_id = UserId::from_proto(request.user_id);
1542    let role = ChannelRole::from(request.role());
1543
1544    let (live_kit_room, can_publish) = {
1545        let room = session
1546            .db()
1547            .await
1548            .set_room_participant_role(
1549                session.user_id(),
1550                RoomId::from_proto(request.room_id),
1551                user_id,
1552                role,
1553            )
1554            .await?;
1555
1556        let live_kit_room = room.live_kit_room.clone();
1557        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1558        room_updated(&room, &session.peer);
1559        (live_kit_room, can_publish)
1560    };
1561
1562    if let Some(live_kit) = session.live_kit_client.as_ref() {
1563        live_kit
1564            .update_participant(
1565                live_kit_room.clone(),
1566                request.user_id.to_string(),
1567                live_kit_server::proto::ParticipantPermission {
1568                    can_subscribe: true,
1569                    can_publish,
1570                    can_publish_data: can_publish,
1571                    hidden: false,
1572                    recorder: false,
1573                },
1574            )
1575            .await
1576            .trace_err();
1577    }
1578
1579    response.send(proto::Ack {})?;
1580    Ok(())
1581}
1582
1583/// Call someone else into the current room
1584async fn call(
1585    request: proto::Call,
1586    response: Response<proto::Call>,
1587    session: UserSession,
1588) -> Result<()> {
1589    let room_id = RoomId::from_proto(request.room_id);
1590    let calling_user_id = session.user_id();
1591    let calling_connection_id = session.connection_id;
1592    let called_user_id = UserId::from_proto(request.called_user_id);
1593    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1594    if !session
1595        .db()
1596        .await
1597        .has_contact(calling_user_id, called_user_id)
1598        .await?
1599    {
1600        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1601    }
1602
1603    let incoming_call = {
1604        let (room, incoming_call) = &mut *session
1605            .db()
1606            .await
1607            .call(
1608                room_id,
1609                calling_user_id,
1610                calling_connection_id,
1611                called_user_id,
1612                initial_project_id,
1613            )
1614            .await?;
1615        room_updated(&room, &session.peer);
1616        mem::take(incoming_call)
1617    };
1618    update_user_contacts(called_user_id, &session).await?;
1619
1620    let mut calls = session
1621        .connection_pool()
1622        .await
1623        .user_connection_ids(called_user_id)
1624        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1625        .collect::<FuturesUnordered<_>>();
1626
1627    while let Some(call_response) = calls.next().await {
1628        match call_response.as_ref() {
1629            Ok(_) => {
1630                response.send(proto::Ack {})?;
1631                return Ok(());
1632            }
1633            Err(_) => {
1634                call_response.trace_err();
1635            }
1636        }
1637    }
1638
1639    {
1640        let room = session
1641            .db()
1642            .await
1643            .call_failed(room_id, called_user_id)
1644            .await?;
1645        room_updated(&room, &session.peer);
1646    }
1647    update_user_contacts(called_user_id, &session).await?;
1648
1649    Err(anyhow!("failed to ring user"))?
1650}
1651
1652/// Cancel an outgoing call.
1653async fn cancel_call(
1654    request: proto::CancelCall,
1655    response: Response<proto::CancelCall>,
1656    session: UserSession,
1657) -> Result<()> {
1658    let called_user_id = UserId::from_proto(request.called_user_id);
1659    let room_id = RoomId::from_proto(request.room_id);
1660    {
1661        let room = session
1662            .db()
1663            .await
1664            .cancel_call(room_id, session.connection_id, called_user_id)
1665            .await?;
1666        room_updated(&room, &session.peer);
1667    }
1668
1669    for connection_id in session
1670        .connection_pool()
1671        .await
1672        .user_connection_ids(called_user_id)
1673    {
1674        session
1675            .peer
1676            .send(
1677                connection_id,
1678                proto::CallCanceled {
1679                    room_id: room_id.to_proto(),
1680                },
1681            )
1682            .trace_err();
1683    }
1684    response.send(proto::Ack {})?;
1685
1686    update_user_contacts(called_user_id, &session).await?;
1687    Ok(())
1688}
1689
1690/// Decline an incoming call.
1691async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1692    let room_id = RoomId::from_proto(message.room_id);
1693    {
1694        let room = session
1695            .db()
1696            .await
1697            .decline_call(Some(room_id), session.user_id())
1698            .await?
1699            .ok_or_else(|| anyhow!("failed to decline call"))?;
1700        room_updated(&room, &session.peer);
1701    }
1702
1703    for connection_id in session
1704        .connection_pool()
1705        .await
1706        .user_connection_ids(session.user_id())
1707    {
1708        session
1709            .peer
1710            .send(
1711                connection_id,
1712                proto::CallCanceled {
1713                    room_id: room_id.to_proto(),
1714                },
1715            )
1716            .trace_err();
1717    }
1718    update_user_contacts(session.user_id(), &session).await?;
1719    Ok(())
1720}
1721
1722/// Updates other participants in the room with your current location.
1723async fn update_participant_location(
1724    request: proto::UpdateParticipantLocation,
1725    response: Response<proto::UpdateParticipantLocation>,
1726    session: UserSession,
1727) -> Result<()> {
1728    let room_id = RoomId::from_proto(request.room_id);
1729    let location = request
1730        .location
1731        .ok_or_else(|| anyhow!("invalid location"))?;
1732
1733    let db = session.db().await;
1734    let room = db
1735        .update_room_participant_location(room_id, session.connection_id, location)
1736        .await?;
1737
1738    room_updated(&room, &session.peer);
1739    response.send(proto::Ack {})?;
1740    Ok(())
1741}
1742
1743/// Share a project into the room.
1744async fn share_project(
1745    request: proto::ShareProject,
1746    response: Response<proto::ShareProject>,
1747    session: Session,
1748) -> Result<()> {
1749    let (project_id, room) = &*session
1750        .db()
1751        .await
1752        .share_project(
1753            RoomId::from_proto(request.room_id),
1754            session.connection_id,
1755            &request.worktrees,
1756        )
1757        .await?;
1758    response.send(proto::ShareProjectResponse {
1759        project_id: project_id.to_proto(),
1760    })?;
1761    room_updated(&room, &session.peer);
1762
1763    Ok(())
1764}
1765
1766/// Unshare a project from the room.
1767async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1768    let project_id = ProjectId::from_proto(message.project_id);
1769
1770    let (room, guest_connection_ids) = &*session
1771        .db()
1772        .await
1773        .unshare_project(project_id, session.connection_id)
1774        .await?;
1775
1776    broadcast(
1777        Some(session.connection_id),
1778        guest_connection_ids.iter().copied(),
1779        |conn_id| session.peer.send(conn_id, message.clone()),
1780    );
1781    room_updated(&room, &session.peer);
1782
1783    Ok(())
1784}
1785
1786/// Join someone elses shared project.
1787async fn join_project(
1788    request: proto::JoinProject,
1789    response: Response<proto::JoinProject>,
1790    session: UserSession,
1791) -> Result<()> {
1792    let project_id = ProjectId::from_proto(request.project_id);
1793
1794    tracing::info!(%project_id, "join project");
1795
1796    let (project, replica_id) = &mut *session
1797        .db()
1798        .await
1799        .join_project_in_room(project_id, session.connection_id)
1800        .await?;
1801
1802    join_project_internal(response, session, project, replica_id)
1803}
1804
1805trait JoinProjectInternalResponse {
1806    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1807}
1808impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1809    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1810        Response::<proto::JoinProject>::send(self, result)
1811    }
1812}
1813impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1814    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1815        Response::<proto::JoinHostedProject>::send(self, result)
1816    }
1817}
1818
1819fn join_project_internal(
1820    response: impl JoinProjectInternalResponse,
1821    session: UserSession,
1822    project: &mut Project,
1823    replica_id: &ReplicaId,
1824) -> Result<()> {
1825    let collaborators = project
1826        .collaborators
1827        .iter()
1828        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1829        .map(|collaborator| collaborator.to_proto())
1830        .collect::<Vec<_>>();
1831    let project_id = project.id;
1832    let guest_user_id = session.user_id();
1833
1834    let worktrees = project
1835        .worktrees
1836        .iter()
1837        .map(|(id, worktree)| proto::WorktreeMetadata {
1838            id: *id,
1839            root_name: worktree.root_name.clone(),
1840            visible: worktree.visible,
1841            abs_path: worktree.abs_path.clone(),
1842        })
1843        .collect::<Vec<_>>();
1844
1845    for collaborator in &collaborators {
1846        session
1847            .peer
1848            .send(
1849                collaborator.peer_id.unwrap().into(),
1850                proto::AddProjectCollaborator {
1851                    project_id: project_id.to_proto(),
1852                    collaborator: Some(proto::Collaborator {
1853                        peer_id: Some(session.connection_id.into()),
1854                        replica_id: replica_id.0 as u32,
1855                        user_id: guest_user_id.to_proto(),
1856                    }),
1857                },
1858            )
1859            .trace_err();
1860    }
1861
1862    // First, we send the metadata associated with each worktree.
1863    response.send(proto::JoinProjectResponse {
1864        project_id: project.id.0 as u64,
1865        worktrees: worktrees.clone(),
1866        replica_id: replica_id.0 as u32,
1867        collaborators: collaborators.clone(),
1868        language_servers: project.language_servers.clone(),
1869        role: project.role.into(), // todo
1870    })?;
1871
1872    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1873        #[cfg(any(test, feature = "test-support"))]
1874        const MAX_CHUNK_SIZE: usize = 2;
1875        #[cfg(not(any(test, feature = "test-support")))]
1876        const MAX_CHUNK_SIZE: usize = 256;
1877
1878        // Stream this worktree's entries.
1879        let message = proto::UpdateWorktree {
1880            project_id: project_id.to_proto(),
1881            worktree_id,
1882            abs_path: worktree.abs_path.clone(),
1883            root_name: worktree.root_name,
1884            updated_entries: worktree.entries,
1885            removed_entries: Default::default(),
1886            scan_id: worktree.scan_id,
1887            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1888            updated_repositories: worktree.repository_entries.into_values().collect(),
1889            removed_repositories: Default::default(),
1890        };
1891        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1892            session.peer.send(session.connection_id, update.clone())?;
1893        }
1894
1895        // Stream this worktree's diagnostics.
1896        for summary in worktree.diagnostic_summaries {
1897            session.peer.send(
1898                session.connection_id,
1899                proto::UpdateDiagnosticSummary {
1900                    project_id: project_id.to_proto(),
1901                    worktree_id: worktree.id,
1902                    summary: Some(summary),
1903                },
1904            )?;
1905        }
1906
1907        for settings_file in worktree.settings_files {
1908            session.peer.send(
1909                session.connection_id,
1910                proto::UpdateWorktreeSettings {
1911                    project_id: project_id.to_proto(),
1912                    worktree_id: worktree.id,
1913                    path: settings_file.path,
1914                    content: Some(settings_file.content),
1915                },
1916            )?;
1917        }
1918    }
1919
1920    for language_server in &project.language_servers {
1921        session.peer.send(
1922            session.connection_id,
1923            proto::UpdateLanguageServer {
1924                project_id: project_id.to_proto(),
1925                language_server_id: language_server.id,
1926                variant: Some(
1927                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1928                        proto::LspDiskBasedDiagnosticsUpdated {},
1929                    ),
1930                ),
1931            },
1932        )?;
1933    }
1934
1935    Ok(())
1936}
1937
1938/// Leave someone elses shared project.
1939async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1940    let sender_id = session.connection_id;
1941    let project_id = ProjectId::from_proto(request.project_id);
1942    let db = session.db().await;
1943    if db.is_hosted_project(project_id).await? {
1944        let project = db.leave_hosted_project(project_id, sender_id).await?;
1945        project_left(&project, &session);
1946        return Ok(());
1947    }
1948
1949    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1950    tracing::info!(
1951        %project_id,
1952        host_user_id = ?project.host_user_id,
1953        host_connection_id = ?project.host_connection_id,
1954        "leave project"
1955    );
1956
1957    project_left(&project, &session);
1958    room_updated(&room, &session.peer);
1959
1960    Ok(())
1961}
1962
1963async fn join_hosted_project(
1964    request: proto::JoinHostedProject,
1965    response: Response<proto::JoinHostedProject>,
1966    session: UserSession,
1967) -> Result<()> {
1968    let (mut project, replica_id) = session
1969        .db()
1970        .await
1971        .join_hosted_project(
1972            ProjectId(request.project_id as i32),
1973            session.user_id(),
1974            session.connection_id,
1975        )
1976        .await?;
1977
1978    join_project_internal(response, session, &mut project, &replica_id)
1979}
1980
1981/// Updates other participants with changes to the project
1982async fn update_project(
1983    request: proto::UpdateProject,
1984    response: Response<proto::UpdateProject>,
1985    session: Session,
1986) -> Result<()> {
1987    let project_id = ProjectId::from_proto(request.project_id);
1988    let (room, guest_connection_ids) = &*session
1989        .db()
1990        .await
1991        .update_project(project_id, session.connection_id, &request.worktrees)
1992        .await?;
1993    broadcast(
1994        Some(session.connection_id),
1995        guest_connection_ids.iter().copied(),
1996        |connection_id| {
1997            session
1998                .peer
1999                .forward_send(session.connection_id, connection_id, request.clone())
2000        },
2001    );
2002    room_updated(&room, &session.peer);
2003    response.send(proto::Ack {})?;
2004
2005    Ok(())
2006}
2007
2008/// Updates other participants with changes to the worktree
2009async fn update_worktree(
2010    request: proto::UpdateWorktree,
2011    response: Response<proto::UpdateWorktree>,
2012    session: Session,
2013) -> Result<()> {
2014    let guest_connection_ids = session
2015        .db()
2016        .await
2017        .update_worktree(&request, session.connection_id)
2018        .await?;
2019
2020    broadcast(
2021        Some(session.connection_id),
2022        guest_connection_ids.iter().copied(),
2023        |connection_id| {
2024            session
2025                .peer
2026                .forward_send(session.connection_id, connection_id, request.clone())
2027        },
2028    );
2029    response.send(proto::Ack {})?;
2030    Ok(())
2031}
2032
2033/// Updates other participants with changes to the diagnostics
2034async fn update_diagnostic_summary(
2035    message: proto::UpdateDiagnosticSummary,
2036    session: Session,
2037) -> Result<()> {
2038    let guest_connection_ids = session
2039        .db()
2040        .await
2041        .update_diagnostic_summary(&message, session.connection_id)
2042        .await?;
2043
2044    broadcast(
2045        Some(session.connection_id),
2046        guest_connection_ids.iter().copied(),
2047        |connection_id| {
2048            session
2049                .peer
2050                .forward_send(session.connection_id, connection_id, message.clone())
2051        },
2052    );
2053
2054    Ok(())
2055}
2056
2057/// Updates other participants with changes to the worktree settings
2058async fn update_worktree_settings(
2059    message: proto::UpdateWorktreeSettings,
2060    session: Session,
2061) -> Result<()> {
2062    let guest_connection_ids = session
2063        .db()
2064        .await
2065        .update_worktree_settings(&message, session.connection_id)
2066        .await?;
2067
2068    broadcast(
2069        Some(session.connection_id),
2070        guest_connection_ids.iter().copied(),
2071        |connection_id| {
2072            session
2073                .peer
2074                .forward_send(session.connection_id, connection_id, message.clone())
2075        },
2076    );
2077
2078    Ok(())
2079}
2080
2081/// Notify other participants that a  language server has started.
2082async fn start_language_server(
2083    request: proto::StartLanguageServer,
2084    session: Session,
2085) -> Result<()> {
2086    let guest_connection_ids = session
2087        .db()
2088        .await
2089        .start_language_server(&request, session.connection_id)
2090        .await?;
2091
2092    broadcast(
2093        Some(session.connection_id),
2094        guest_connection_ids.iter().copied(),
2095        |connection_id| {
2096            session
2097                .peer
2098                .forward_send(session.connection_id, connection_id, request.clone())
2099        },
2100    );
2101    Ok(())
2102}
2103
2104/// Notify other participants that a language server has changed.
2105async fn update_language_server(
2106    request: proto::UpdateLanguageServer,
2107    session: Session,
2108) -> Result<()> {
2109    let project_id = ProjectId::from_proto(request.project_id);
2110    let project_connection_ids = session
2111        .db()
2112        .await
2113        .project_connection_ids(project_id, session.connection_id)
2114        .await?;
2115    broadcast(
2116        Some(session.connection_id),
2117        project_connection_ids.iter().copied(),
2118        |connection_id| {
2119            session
2120                .peer
2121                .forward_send(session.connection_id, connection_id, request.clone())
2122        },
2123    );
2124    Ok(())
2125}
2126
2127/// forward a project request to the host. These requests should be read only
2128/// as guests are allowed to send them.
2129async fn forward_read_only_project_request<T>(
2130    request: T,
2131    response: Response<T>,
2132    session: Session,
2133) -> Result<()>
2134where
2135    T: EntityMessage + RequestMessage,
2136{
2137    let project_id = ProjectId::from_proto(request.remote_entity_id());
2138    let host_connection_id = session
2139        .db()
2140        .await
2141        .host_for_read_only_project_request(project_id, session.connection_id)
2142        .await?;
2143    let payload = session
2144        .peer
2145        .forward_request(session.connection_id, host_connection_id, request)
2146        .await?;
2147    response.send(payload)?;
2148    Ok(())
2149}
2150
2151/// forward a project request to the host. These requests are disallowed
2152/// for guests.
2153async fn forward_mutating_project_request<T>(
2154    request: T,
2155    response: Response<T>,
2156    session: Session,
2157) -> Result<()>
2158where
2159    T: EntityMessage + RequestMessage,
2160{
2161    let project_id = ProjectId::from_proto(request.remote_entity_id());
2162    let host_connection_id = session
2163        .db()
2164        .await
2165        .host_for_mutating_project_request(project_id, session.connection_id)
2166        .await?;
2167    let payload = session
2168        .peer
2169        .forward_request(session.connection_id, host_connection_id, request)
2170        .await?;
2171    response.send(payload)?;
2172    Ok(())
2173}
2174
2175/// Notify other participants that a new buffer has been created
2176async fn create_buffer_for_peer(
2177    request: proto::CreateBufferForPeer,
2178    session: Session,
2179) -> Result<()> {
2180    session
2181        .db()
2182        .await
2183        .check_user_is_project_host(
2184            ProjectId::from_proto(request.project_id),
2185            session.connection_id,
2186        )
2187        .await?;
2188    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2189    session
2190        .peer
2191        .forward_send(session.connection_id, peer_id.into(), request)?;
2192    Ok(())
2193}
2194
2195/// Notify other participants that a buffer has been updated. This is
2196/// allowed for guests as long as the update is limited to selections.
2197async fn update_buffer(
2198    request: proto::UpdateBuffer,
2199    response: Response<proto::UpdateBuffer>,
2200    session: Session,
2201) -> Result<()> {
2202    let project_id = ProjectId::from_proto(request.project_id);
2203    let mut guest_connection_ids;
2204    let mut host_connection_id = None;
2205
2206    let mut requires_write_permission = false;
2207
2208    for op in request.operations.iter() {
2209        match op.variant {
2210            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2211            Some(_) => requires_write_permission = true,
2212        }
2213    }
2214
2215    {
2216        let collaborators = session
2217            .db()
2218            .await
2219            .project_collaborators_for_buffer_update(
2220                project_id,
2221                session.connection_id,
2222                requires_write_permission,
2223            )
2224            .await?;
2225        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2226        for collaborator in collaborators.iter() {
2227            if collaborator.is_host {
2228                host_connection_id = Some(collaborator.connection_id);
2229            } else {
2230                guest_connection_ids.push(collaborator.connection_id);
2231            }
2232        }
2233    }
2234    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2235
2236    broadcast(
2237        Some(session.connection_id),
2238        guest_connection_ids,
2239        |connection_id| {
2240            session
2241                .peer
2242                .forward_send(session.connection_id, connection_id, request.clone())
2243        },
2244    );
2245    if host_connection_id != session.connection_id {
2246        session
2247            .peer
2248            .forward_request(session.connection_id, host_connection_id, request.clone())
2249            .await?;
2250    }
2251
2252    response.send(proto::Ack {})?;
2253    Ok(())
2254}
2255
2256/// Notify other participants that a project has been updated.
2257async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2258    request: T,
2259    session: Session,
2260) -> Result<()> {
2261    let project_id = ProjectId::from_proto(request.remote_entity_id());
2262    let project_connection_ids = session
2263        .db()
2264        .await
2265        .project_connection_ids(project_id, session.connection_id)
2266        .await?;
2267
2268    broadcast(
2269        Some(session.connection_id),
2270        project_connection_ids.iter().copied(),
2271        |connection_id| {
2272            session
2273                .peer
2274                .forward_send(session.connection_id, connection_id, request.clone())
2275        },
2276    );
2277    Ok(())
2278}
2279
2280/// Start following another user in a call.
2281async fn follow(
2282    request: proto::Follow,
2283    response: Response<proto::Follow>,
2284    session: UserSession,
2285) -> Result<()> {
2286    let room_id = RoomId::from_proto(request.room_id);
2287    let project_id = request.project_id.map(ProjectId::from_proto);
2288    let leader_id = request
2289        .leader_id
2290        .ok_or_else(|| anyhow!("invalid leader id"))?
2291        .into();
2292    let follower_id = session.connection_id;
2293
2294    session
2295        .db()
2296        .await
2297        .check_room_participants(room_id, leader_id, session.connection_id)
2298        .await?;
2299
2300    let response_payload = session
2301        .peer
2302        .forward_request(session.connection_id, leader_id, request)
2303        .await?;
2304    response.send(response_payload)?;
2305
2306    if let Some(project_id) = project_id {
2307        let room = session
2308            .db()
2309            .await
2310            .follow(room_id, project_id, leader_id, follower_id)
2311            .await?;
2312        room_updated(&room, &session.peer);
2313    }
2314
2315    Ok(())
2316}
2317
2318/// Stop following another user in a call.
2319async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2320    let room_id = RoomId::from_proto(request.room_id);
2321    let project_id = request.project_id.map(ProjectId::from_proto);
2322    let leader_id = request
2323        .leader_id
2324        .ok_or_else(|| anyhow!("invalid leader id"))?
2325        .into();
2326    let follower_id = session.connection_id;
2327
2328    session
2329        .db()
2330        .await
2331        .check_room_participants(room_id, leader_id, session.connection_id)
2332        .await?;
2333
2334    session
2335        .peer
2336        .forward_send(session.connection_id, leader_id, request)?;
2337
2338    if let Some(project_id) = project_id {
2339        let room = session
2340            .db()
2341            .await
2342            .unfollow(room_id, project_id, leader_id, follower_id)
2343            .await?;
2344        room_updated(&room, &session.peer);
2345    }
2346
2347    Ok(())
2348}
2349
2350/// Notify everyone following you of your current location.
2351async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2352    let room_id = RoomId::from_proto(request.room_id);
2353    let database = session.db.lock().await;
2354
2355    let connection_ids = if let Some(project_id) = request.project_id {
2356        let project_id = ProjectId::from_proto(project_id);
2357        database
2358            .project_connection_ids(project_id, session.connection_id)
2359            .await?
2360    } else {
2361        database
2362            .room_connection_ids(room_id, session.connection_id)
2363            .await?
2364    };
2365
2366    // For now, don't send view update messages back to that view's current leader.
2367    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2368        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2369        _ => None,
2370    });
2371
2372    for connection_id in connection_ids.iter().cloned() {
2373        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2374            session
2375                .peer
2376                .forward_send(session.connection_id, connection_id, request.clone())?;
2377        }
2378    }
2379    Ok(())
2380}
2381
2382/// Get public data about users.
2383async fn get_users(
2384    request: proto::GetUsers,
2385    response: Response<proto::GetUsers>,
2386    session: Session,
2387) -> Result<()> {
2388    let user_ids = request
2389        .user_ids
2390        .into_iter()
2391        .map(UserId::from_proto)
2392        .collect();
2393    let users = session
2394        .db()
2395        .await
2396        .get_users_by_ids(user_ids)
2397        .await?
2398        .into_iter()
2399        .map(|user| proto::User {
2400            id: user.id.to_proto(),
2401            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2402            github_login: user.github_login,
2403        })
2404        .collect();
2405    response.send(proto::UsersResponse { users })?;
2406    Ok(())
2407}
2408
2409/// Search for users (to invite) buy Github login
2410async fn fuzzy_search_users(
2411    request: proto::FuzzySearchUsers,
2412    response: Response<proto::FuzzySearchUsers>,
2413    session: UserSession,
2414) -> Result<()> {
2415    let query = request.query;
2416    let users = match query.len() {
2417        0 => vec![],
2418        1 | 2 => session
2419            .db()
2420            .await
2421            .get_user_by_github_login(&query)
2422            .await?
2423            .into_iter()
2424            .collect(),
2425        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2426    };
2427    let users = users
2428        .into_iter()
2429        .filter(|user| user.id != session.user_id())
2430        .map(|user| proto::User {
2431            id: user.id.to_proto(),
2432            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2433            github_login: user.github_login,
2434        })
2435        .collect();
2436    response.send(proto::UsersResponse { users })?;
2437    Ok(())
2438}
2439
2440/// Send a contact request to another user.
2441async fn request_contact(
2442    request: proto::RequestContact,
2443    response: Response<proto::RequestContact>,
2444    session: UserSession,
2445) -> Result<()> {
2446    let requester_id = session.user_id();
2447    let responder_id = UserId::from_proto(request.responder_id);
2448    if requester_id == responder_id {
2449        return Err(anyhow!("cannot add yourself as a contact"))?;
2450    }
2451
2452    let notifications = session
2453        .db()
2454        .await
2455        .send_contact_request(requester_id, responder_id)
2456        .await?;
2457
2458    // Update outgoing contact requests of requester
2459    let mut update = proto::UpdateContacts::default();
2460    update.outgoing_requests.push(responder_id.to_proto());
2461    for connection_id in session
2462        .connection_pool()
2463        .await
2464        .user_connection_ids(requester_id)
2465    {
2466        session.peer.send(connection_id, update.clone())?;
2467    }
2468
2469    // Update incoming contact requests of responder
2470    let mut update = proto::UpdateContacts::default();
2471    update
2472        .incoming_requests
2473        .push(proto::IncomingContactRequest {
2474            requester_id: requester_id.to_proto(),
2475        });
2476    let connection_pool = session.connection_pool().await;
2477    for connection_id in connection_pool.user_connection_ids(responder_id) {
2478        session.peer.send(connection_id, update.clone())?;
2479    }
2480
2481    send_notifications(&connection_pool, &session.peer, notifications);
2482
2483    response.send(proto::Ack {})?;
2484    Ok(())
2485}
2486
2487/// Accept or decline a contact request
2488async fn respond_to_contact_request(
2489    request: proto::RespondToContactRequest,
2490    response: Response<proto::RespondToContactRequest>,
2491    session: UserSession,
2492) -> Result<()> {
2493    let responder_id = session.user_id();
2494    let requester_id = UserId::from_proto(request.requester_id);
2495    let db = session.db().await;
2496    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2497        db.dismiss_contact_notification(responder_id, requester_id)
2498            .await?;
2499    } else {
2500        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2501
2502        let notifications = db
2503            .respond_to_contact_request(responder_id, requester_id, accept)
2504            .await?;
2505        let requester_busy = db.is_user_busy(requester_id).await?;
2506        let responder_busy = db.is_user_busy(responder_id).await?;
2507
2508        let pool = session.connection_pool().await;
2509        // Update responder with new contact
2510        let mut update = proto::UpdateContacts::default();
2511        if accept {
2512            update
2513                .contacts
2514                .push(contact_for_user(requester_id, requester_busy, &pool));
2515        }
2516        update
2517            .remove_incoming_requests
2518            .push(requester_id.to_proto());
2519        for connection_id in pool.user_connection_ids(responder_id) {
2520            session.peer.send(connection_id, update.clone())?;
2521        }
2522
2523        // Update requester with new contact
2524        let mut update = proto::UpdateContacts::default();
2525        if accept {
2526            update
2527                .contacts
2528                .push(contact_for_user(responder_id, responder_busy, &pool));
2529        }
2530        update
2531            .remove_outgoing_requests
2532            .push(responder_id.to_proto());
2533
2534        for connection_id in pool.user_connection_ids(requester_id) {
2535            session.peer.send(connection_id, update.clone())?;
2536        }
2537
2538        send_notifications(&pool, &session.peer, notifications);
2539    }
2540
2541    response.send(proto::Ack {})?;
2542    Ok(())
2543}
2544
2545/// Remove a contact.
2546async fn remove_contact(
2547    request: proto::RemoveContact,
2548    response: Response<proto::RemoveContact>,
2549    session: UserSession,
2550) -> Result<()> {
2551    let requester_id = session.user_id();
2552    let responder_id = UserId::from_proto(request.user_id);
2553    let db = session.db().await;
2554    let (contact_accepted, deleted_notification_id) =
2555        db.remove_contact(requester_id, responder_id).await?;
2556
2557    let pool = session.connection_pool().await;
2558    // Update outgoing contact requests of requester
2559    let mut update = proto::UpdateContacts::default();
2560    if contact_accepted {
2561        update.remove_contacts.push(responder_id.to_proto());
2562    } else {
2563        update
2564            .remove_outgoing_requests
2565            .push(responder_id.to_proto());
2566    }
2567    for connection_id in pool.user_connection_ids(requester_id) {
2568        session.peer.send(connection_id, update.clone())?;
2569    }
2570
2571    // Update incoming contact requests of responder
2572    let mut update = proto::UpdateContacts::default();
2573    if contact_accepted {
2574        update.remove_contacts.push(requester_id.to_proto());
2575    } else {
2576        update
2577            .remove_incoming_requests
2578            .push(requester_id.to_proto());
2579    }
2580    for connection_id in pool.user_connection_ids(responder_id) {
2581        session.peer.send(connection_id, update.clone())?;
2582        if let Some(notification_id) = deleted_notification_id {
2583            session.peer.send(
2584                connection_id,
2585                proto::DeleteNotification {
2586                    notification_id: notification_id.to_proto(),
2587                },
2588            )?;
2589        }
2590    }
2591
2592    response.send(proto::Ack {})?;
2593    Ok(())
2594}
2595
2596/// Creates a new channel.
2597async fn create_channel(
2598    request: proto::CreateChannel,
2599    response: Response<proto::CreateChannel>,
2600    session: UserSession,
2601) -> Result<()> {
2602    let db = session.db().await;
2603
2604    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2605    let (channel, membership) = db
2606        .create_channel(&request.name, parent_id, session.user_id())
2607        .await?;
2608
2609    let root_id = channel.root_id();
2610    let channel = Channel::from_model(channel);
2611
2612    response.send(proto::CreateChannelResponse {
2613        channel: Some(channel.to_proto()),
2614        parent_id: request.parent_id,
2615    })?;
2616
2617    let mut connection_pool = session.connection_pool().await;
2618    if let Some(membership) = membership {
2619        connection_pool.subscribe_to_channel(
2620            membership.user_id,
2621            membership.channel_id,
2622            membership.role,
2623        );
2624        let update = proto::UpdateUserChannels {
2625            channel_memberships: vec![proto::ChannelMembership {
2626                channel_id: membership.channel_id.to_proto(),
2627                role: membership.role.into(),
2628            }],
2629            ..Default::default()
2630        };
2631        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2632            session.peer.send(connection_id, update.clone())?;
2633        }
2634    }
2635
2636    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2637        if !role.can_see_channel(channel.visibility) {
2638            continue;
2639        }
2640
2641        let update = proto::UpdateChannels {
2642            channels: vec![channel.to_proto()],
2643            ..Default::default()
2644        };
2645        session.peer.send(connection_id, update.clone())?;
2646    }
2647
2648    Ok(())
2649}
2650
2651/// Delete a channel
2652async fn delete_channel(
2653    request: proto::DeleteChannel,
2654    response: Response<proto::DeleteChannel>,
2655    session: UserSession,
2656) -> Result<()> {
2657    let db = session.db().await;
2658
2659    let channel_id = request.channel_id;
2660    let (root_channel, removed_channels) = db
2661        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2662        .await?;
2663    response.send(proto::Ack {})?;
2664
2665    // Notify members of removed channels
2666    let mut update = proto::UpdateChannels::default();
2667    update
2668        .delete_channels
2669        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2670
2671    let connection_pool = session.connection_pool().await;
2672    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2673        session.peer.send(connection_id, update.clone())?;
2674    }
2675
2676    Ok(())
2677}
2678
2679/// Invite someone to join a channel.
2680async fn invite_channel_member(
2681    request: proto::InviteChannelMember,
2682    response: Response<proto::InviteChannelMember>,
2683    session: UserSession,
2684) -> Result<()> {
2685    let db = session.db().await;
2686    let channel_id = ChannelId::from_proto(request.channel_id);
2687    let invitee_id = UserId::from_proto(request.user_id);
2688    let InviteMemberResult {
2689        channel,
2690        notifications,
2691    } = db
2692        .invite_channel_member(
2693            channel_id,
2694            invitee_id,
2695            session.user_id(),
2696            request.role().into(),
2697        )
2698        .await?;
2699
2700    let update = proto::UpdateChannels {
2701        channel_invitations: vec![channel.to_proto()],
2702        ..Default::default()
2703    };
2704
2705    let connection_pool = session.connection_pool().await;
2706    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2707        session.peer.send(connection_id, update.clone())?;
2708    }
2709
2710    send_notifications(&connection_pool, &session.peer, notifications);
2711
2712    response.send(proto::Ack {})?;
2713    Ok(())
2714}
2715
2716/// remove someone from a channel
2717async fn remove_channel_member(
2718    request: proto::RemoveChannelMember,
2719    response: Response<proto::RemoveChannelMember>,
2720    session: UserSession,
2721) -> Result<()> {
2722    let db = session.db().await;
2723    let channel_id = ChannelId::from_proto(request.channel_id);
2724    let member_id = UserId::from_proto(request.user_id);
2725
2726    let RemoveChannelMemberResult {
2727        membership_update,
2728        notification_id,
2729    } = db
2730        .remove_channel_member(channel_id, member_id, session.user_id())
2731        .await?;
2732
2733    let mut connection_pool = session.connection_pool().await;
2734    notify_membership_updated(
2735        &mut connection_pool,
2736        membership_update,
2737        member_id,
2738        &session.peer,
2739    );
2740    for connection_id in connection_pool.user_connection_ids(member_id) {
2741        if let Some(notification_id) = notification_id {
2742            session
2743                .peer
2744                .send(
2745                    connection_id,
2746                    proto::DeleteNotification {
2747                        notification_id: notification_id.to_proto(),
2748                    },
2749                )
2750                .trace_err();
2751        }
2752    }
2753
2754    response.send(proto::Ack {})?;
2755    Ok(())
2756}
2757
2758/// Toggle the channel between public and private.
2759/// Care is taken to maintain the invariant that public channels only descend from public channels,
2760/// (though members-only channels can appear at any point in the hierarchy).
2761async fn set_channel_visibility(
2762    request: proto::SetChannelVisibility,
2763    response: Response<proto::SetChannelVisibility>,
2764    session: UserSession,
2765) -> Result<()> {
2766    let db = session.db().await;
2767    let channel_id = ChannelId::from_proto(request.channel_id);
2768    let visibility = request.visibility().into();
2769
2770    let channel_model = db
2771        .set_channel_visibility(channel_id, visibility, session.user_id())
2772        .await?;
2773    let root_id = channel_model.root_id();
2774    let channel = Channel::from_model(channel_model);
2775
2776    let mut connection_pool = session.connection_pool().await;
2777    for (user_id, role) in connection_pool
2778        .channel_user_ids(root_id)
2779        .collect::<Vec<_>>()
2780        .into_iter()
2781    {
2782        let update = if role.can_see_channel(channel.visibility) {
2783            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2784            proto::UpdateChannels {
2785                channels: vec![channel.to_proto()],
2786                ..Default::default()
2787            }
2788        } else {
2789            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2790            proto::UpdateChannels {
2791                delete_channels: vec![channel.id.to_proto()],
2792                ..Default::default()
2793            }
2794        };
2795
2796        for connection_id in connection_pool.user_connection_ids(user_id) {
2797            session.peer.send(connection_id, update.clone())?;
2798        }
2799    }
2800
2801    response.send(proto::Ack {})?;
2802    Ok(())
2803}
2804
2805/// Alter the role for a user in the channel.
2806async fn set_channel_member_role(
2807    request: proto::SetChannelMemberRole,
2808    response: Response<proto::SetChannelMemberRole>,
2809    session: UserSession,
2810) -> Result<()> {
2811    let db = session.db().await;
2812    let channel_id = ChannelId::from_proto(request.channel_id);
2813    let member_id = UserId::from_proto(request.user_id);
2814    let result = db
2815        .set_channel_member_role(
2816            channel_id,
2817            session.user_id(),
2818            member_id,
2819            request.role().into(),
2820        )
2821        .await?;
2822
2823    match result {
2824        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2825            let mut connection_pool = session.connection_pool().await;
2826            notify_membership_updated(
2827                &mut connection_pool,
2828                membership_update,
2829                member_id,
2830                &session.peer,
2831            )
2832        }
2833        db::SetMemberRoleResult::InviteUpdated(channel) => {
2834            let update = proto::UpdateChannels {
2835                channel_invitations: vec![channel.to_proto()],
2836                ..Default::default()
2837            };
2838
2839            for connection_id in session
2840                .connection_pool()
2841                .await
2842                .user_connection_ids(member_id)
2843            {
2844                session.peer.send(connection_id, update.clone())?;
2845            }
2846        }
2847    }
2848
2849    response.send(proto::Ack {})?;
2850    Ok(())
2851}
2852
2853/// Change the name of a channel
2854async fn rename_channel(
2855    request: proto::RenameChannel,
2856    response: Response<proto::RenameChannel>,
2857    session: UserSession,
2858) -> Result<()> {
2859    let db = session.db().await;
2860    let channel_id = ChannelId::from_proto(request.channel_id);
2861    let channel_model = db
2862        .rename_channel(channel_id, session.user_id(), &request.name)
2863        .await?;
2864    let root_id = channel_model.root_id();
2865    let channel = Channel::from_model(channel_model);
2866
2867    response.send(proto::RenameChannelResponse {
2868        channel: Some(channel.to_proto()),
2869    })?;
2870
2871    let connection_pool = session.connection_pool().await;
2872    let update = proto::UpdateChannels {
2873        channels: vec![channel.to_proto()],
2874        ..Default::default()
2875    };
2876    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2877        if role.can_see_channel(channel.visibility) {
2878            session.peer.send(connection_id, update.clone())?;
2879        }
2880    }
2881
2882    Ok(())
2883}
2884
2885/// Move a channel to a new parent.
2886async fn move_channel(
2887    request: proto::MoveChannel,
2888    response: Response<proto::MoveChannel>,
2889    session: UserSession,
2890) -> Result<()> {
2891    let channel_id = ChannelId::from_proto(request.channel_id);
2892    let to = ChannelId::from_proto(request.to);
2893
2894    let (root_id, channels) = session
2895        .db()
2896        .await
2897        .move_channel(channel_id, to, session.user_id())
2898        .await?;
2899
2900    let connection_pool = session.connection_pool().await;
2901    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2902        let channels = channels
2903            .iter()
2904            .filter_map(|channel| {
2905                if role.can_see_channel(channel.visibility) {
2906                    Some(channel.to_proto())
2907                } else {
2908                    None
2909                }
2910            })
2911            .collect::<Vec<_>>();
2912        if channels.is_empty() {
2913            continue;
2914        }
2915
2916        let update = proto::UpdateChannels {
2917            channels,
2918            ..Default::default()
2919        };
2920
2921        session.peer.send(connection_id, update.clone())?;
2922    }
2923
2924    response.send(Ack {})?;
2925    Ok(())
2926}
2927
2928/// Get the list of channel members
2929async fn get_channel_members(
2930    request: proto::GetChannelMembers,
2931    response: Response<proto::GetChannelMembers>,
2932    session: UserSession,
2933) -> Result<()> {
2934    let db = session.db().await;
2935    let channel_id = ChannelId::from_proto(request.channel_id);
2936    let members = db
2937        .get_channel_participant_details(channel_id, session.user_id())
2938        .await?;
2939    response.send(proto::GetChannelMembersResponse { members })?;
2940    Ok(())
2941}
2942
2943/// Accept or decline a channel invitation.
2944async fn respond_to_channel_invite(
2945    request: proto::RespondToChannelInvite,
2946    response: Response<proto::RespondToChannelInvite>,
2947    session: UserSession,
2948) -> Result<()> {
2949    let db = session.db().await;
2950    let channel_id = ChannelId::from_proto(request.channel_id);
2951    let RespondToChannelInvite {
2952        membership_update,
2953        notifications,
2954    } = db
2955        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2956        .await?;
2957
2958    let mut connection_pool = session.connection_pool().await;
2959    if let Some(membership_update) = membership_update {
2960        notify_membership_updated(
2961            &mut connection_pool,
2962            membership_update,
2963            session.user_id(),
2964            &session.peer,
2965        );
2966    } else {
2967        let update = proto::UpdateChannels {
2968            remove_channel_invitations: vec![channel_id.to_proto()],
2969            ..Default::default()
2970        };
2971
2972        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2973            session.peer.send(connection_id, update.clone())?;
2974        }
2975    };
2976
2977    send_notifications(&connection_pool, &session.peer, notifications);
2978
2979    response.send(proto::Ack {})?;
2980
2981    Ok(())
2982}
2983
2984/// Join the channels' room
2985async fn join_channel(
2986    request: proto::JoinChannel,
2987    response: Response<proto::JoinChannel>,
2988    session: UserSession,
2989) -> Result<()> {
2990    let channel_id = ChannelId::from_proto(request.channel_id);
2991    join_channel_internal(channel_id, Box::new(response), session).await
2992}
2993
2994trait JoinChannelInternalResponse {
2995    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2996}
2997impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2998    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2999        Response::<proto::JoinChannel>::send(self, result)
3000    }
3001}
3002impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3003    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3004        Response::<proto::JoinRoom>::send(self, result)
3005    }
3006}
3007
3008async fn join_channel_internal(
3009    channel_id: ChannelId,
3010    response: Box<impl JoinChannelInternalResponse>,
3011    session: UserSession,
3012) -> Result<()> {
3013    let joined_room = {
3014        leave_room_for_session(&session).await?;
3015        let db = session.db().await;
3016
3017        let (joined_room, membership_updated, role) = db
3018            .join_channel(channel_id, session.user_id(), session.connection_id)
3019            .await?;
3020
3021        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3022            let (can_publish, token) = if role == ChannelRole::Guest {
3023                (
3024                    false,
3025                    live_kit
3026                        .guest_token(
3027                            &joined_room.room.live_kit_room,
3028                            &session.user_id().to_string(),
3029                        )
3030                        .trace_err()?,
3031                )
3032            } else {
3033                (
3034                    true,
3035                    live_kit
3036                        .room_token(
3037                            &joined_room.room.live_kit_room,
3038                            &session.user_id().to_string(),
3039                        )
3040                        .trace_err()?,
3041                )
3042            };
3043
3044            Some(LiveKitConnectionInfo {
3045                server_url: live_kit.url().into(),
3046                token,
3047                can_publish,
3048            })
3049        });
3050
3051        response.send(proto::JoinRoomResponse {
3052            room: Some(joined_room.room.clone()),
3053            channel_id: joined_room
3054                .channel
3055                .as_ref()
3056                .map(|channel| channel.id.to_proto()),
3057            live_kit_connection_info,
3058        })?;
3059
3060        let mut connection_pool = session.connection_pool().await;
3061        if let Some(membership_updated) = membership_updated {
3062            notify_membership_updated(
3063                &mut connection_pool,
3064                membership_updated,
3065                session.user_id(),
3066                &session.peer,
3067            );
3068        }
3069
3070        room_updated(&joined_room.room, &session.peer);
3071
3072        joined_room
3073    };
3074
3075    channel_updated(
3076        &joined_room
3077            .channel
3078            .ok_or_else(|| anyhow!("channel not returned"))?,
3079        &joined_room.room,
3080        &session.peer,
3081        &*session.connection_pool().await,
3082    );
3083
3084    update_user_contacts(session.user_id(), &session).await?;
3085    Ok(())
3086}
3087
3088/// Start editing the channel notes
3089async fn join_channel_buffer(
3090    request: proto::JoinChannelBuffer,
3091    response: Response<proto::JoinChannelBuffer>,
3092    session: UserSession,
3093) -> Result<()> {
3094    let db = session.db().await;
3095    let channel_id = ChannelId::from_proto(request.channel_id);
3096
3097    let open_response = db
3098        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3099        .await?;
3100
3101    let collaborators = open_response.collaborators.clone();
3102    response.send(open_response)?;
3103
3104    let update = UpdateChannelBufferCollaborators {
3105        channel_id: channel_id.to_proto(),
3106        collaborators: collaborators.clone(),
3107    };
3108    channel_buffer_updated(
3109        session.connection_id,
3110        collaborators
3111            .iter()
3112            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3113        &update,
3114        &session.peer,
3115    );
3116
3117    Ok(())
3118}
3119
3120/// Edit the channel notes
3121async fn update_channel_buffer(
3122    request: proto::UpdateChannelBuffer,
3123    session: UserSession,
3124) -> Result<()> {
3125    let db = session.db().await;
3126    let channel_id = ChannelId::from_proto(request.channel_id);
3127
3128    let (collaborators, non_collaborators, epoch, version) = db
3129        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3130        .await?;
3131
3132    channel_buffer_updated(
3133        session.connection_id,
3134        collaborators,
3135        &proto::UpdateChannelBuffer {
3136            channel_id: channel_id.to_proto(),
3137            operations: request.operations,
3138        },
3139        &session.peer,
3140    );
3141
3142    let pool = &*session.connection_pool().await;
3143
3144    broadcast(
3145        None,
3146        non_collaborators
3147            .iter()
3148            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3149        |peer_id| {
3150            session.peer.send(
3151                peer_id,
3152                proto::UpdateChannels {
3153                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3154                        channel_id: channel_id.to_proto(),
3155                        epoch: epoch as u64,
3156                        version: version.clone(),
3157                    }],
3158                    ..Default::default()
3159                },
3160            )
3161        },
3162    );
3163
3164    Ok(())
3165}
3166
3167/// Rejoin the channel notes after a connection blip
3168async fn rejoin_channel_buffers(
3169    request: proto::RejoinChannelBuffers,
3170    response: Response<proto::RejoinChannelBuffers>,
3171    session: UserSession,
3172) -> Result<()> {
3173    let db = session.db().await;
3174    let buffers = db
3175        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3176        .await?;
3177
3178    for rejoined_buffer in &buffers {
3179        let collaborators_to_notify = rejoined_buffer
3180            .buffer
3181            .collaborators
3182            .iter()
3183            .filter_map(|c| Some(c.peer_id?.into()));
3184        channel_buffer_updated(
3185            session.connection_id,
3186            collaborators_to_notify,
3187            &proto::UpdateChannelBufferCollaborators {
3188                channel_id: rejoined_buffer.buffer.channel_id,
3189                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3190            },
3191            &session.peer,
3192        );
3193    }
3194
3195    response.send(proto::RejoinChannelBuffersResponse {
3196        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3197    })?;
3198
3199    Ok(())
3200}
3201
3202/// Stop editing the channel notes
3203async fn leave_channel_buffer(
3204    request: proto::LeaveChannelBuffer,
3205    response: Response<proto::LeaveChannelBuffer>,
3206    session: UserSession,
3207) -> Result<()> {
3208    let db = session.db().await;
3209    let channel_id = ChannelId::from_proto(request.channel_id);
3210
3211    let left_buffer = db
3212        .leave_channel_buffer(channel_id, session.connection_id)
3213        .await?;
3214
3215    response.send(Ack {})?;
3216
3217    channel_buffer_updated(
3218        session.connection_id,
3219        left_buffer.connections,
3220        &proto::UpdateChannelBufferCollaborators {
3221            channel_id: channel_id.to_proto(),
3222            collaborators: left_buffer.collaborators,
3223        },
3224        &session.peer,
3225    );
3226
3227    Ok(())
3228}
3229
3230fn channel_buffer_updated<T: EnvelopedMessage>(
3231    sender_id: ConnectionId,
3232    collaborators: impl IntoIterator<Item = ConnectionId>,
3233    message: &T,
3234    peer: &Peer,
3235) {
3236    broadcast(Some(sender_id), collaborators, |peer_id| {
3237        peer.send(peer_id, message.clone())
3238    });
3239}
3240
3241fn send_notifications(
3242    connection_pool: &ConnectionPool,
3243    peer: &Peer,
3244    notifications: db::NotificationBatch,
3245) {
3246    for (user_id, notification) in notifications {
3247        for connection_id in connection_pool.user_connection_ids(user_id) {
3248            if let Err(error) = peer.send(
3249                connection_id,
3250                proto::AddNotification {
3251                    notification: Some(notification.clone()),
3252                },
3253            ) {
3254                tracing::error!(
3255                    "failed to send notification to {:?} {}",
3256                    connection_id,
3257                    error
3258                );
3259            }
3260        }
3261    }
3262}
3263
3264/// Send a message to the channel
3265async fn send_channel_message(
3266    request: proto::SendChannelMessage,
3267    response: Response<proto::SendChannelMessage>,
3268    session: UserSession,
3269) -> Result<()> {
3270    // Validate the message body.
3271    let body = request.body.trim().to_string();
3272    if body.len() > MAX_MESSAGE_LEN {
3273        return Err(anyhow!("message is too long"))?;
3274    }
3275    if body.is_empty() {
3276        return Err(anyhow!("message can't be blank"))?;
3277    }
3278
3279    // TODO: adjust mentions if body is trimmed
3280
3281    let timestamp = OffsetDateTime::now_utc();
3282    let nonce = request
3283        .nonce
3284        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3285
3286    let channel_id = ChannelId::from_proto(request.channel_id);
3287    let CreatedChannelMessage {
3288        message_id,
3289        participant_connection_ids,
3290        channel_members,
3291        notifications,
3292    } = session
3293        .db()
3294        .await
3295        .create_channel_message(
3296            channel_id,
3297            session.user_id(),
3298            &body,
3299            &request.mentions,
3300            timestamp,
3301            nonce.clone().into(),
3302            match request.reply_to_message_id {
3303                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3304                None => None,
3305            },
3306        )
3307        .await?;
3308
3309    let message = proto::ChannelMessage {
3310        sender_id: session.user_id().to_proto(),
3311        id: message_id.to_proto(),
3312        body,
3313        mentions: request.mentions,
3314        timestamp: timestamp.unix_timestamp() as u64,
3315        nonce: Some(nonce),
3316        reply_to_message_id: request.reply_to_message_id,
3317        edited_at: None,
3318    };
3319    broadcast(
3320        Some(session.connection_id),
3321        participant_connection_ids,
3322        |connection| {
3323            session.peer.send(
3324                connection,
3325                proto::ChannelMessageSent {
3326                    channel_id: channel_id.to_proto(),
3327                    message: Some(message.clone()),
3328                },
3329            )
3330        },
3331    );
3332    response.send(proto::SendChannelMessageResponse {
3333        message: Some(message),
3334    })?;
3335
3336    let pool = &*session.connection_pool().await;
3337    broadcast(
3338        None,
3339        channel_members
3340            .iter()
3341            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3342        |peer_id| {
3343            session.peer.send(
3344                peer_id,
3345                proto::UpdateChannels {
3346                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3347                        channel_id: channel_id.to_proto(),
3348                        message_id: message_id.to_proto(),
3349                    }],
3350                    ..Default::default()
3351                },
3352            )
3353        },
3354    );
3355    send_notifications(pool, &session.peer, notifications);
3356
3357    Ok(())
3358}
3359
3360/// Delete a channel message
3361async fn remove_channel_message(
3362    request: proto::RemoveChannelMessage,
3363    response: Response<proto::RemoveChannelMessage>,
3364    session: UserSession,
3365) -> Result<()> {
3366    let channel_id = ChannelId::from_proto(request.channel_id);
3367    let message_id = MessageId::from_proto(request.message_id);
3368    let connection_ids = session
3369        .db()
3370        .await
3371        .remove_channel_message(channel_id, message_id, session.user_id())
3372        .await?;
3373    broadcast(Some(session.connection_id), connection_ids, |connection| {
3374        session.peer.send(connection, request.clone())
3375    });
3376    response.send(proto::Ack {})?;
3377    Ok(())
3378}
3379
3380async fn update_channel_message(
3381    request: proto::UpdateChannelMessage,
3382    response: Response<proto::UpdateChannelMessage>,
3383    session: UserSession,
3384) -> Result<()> {
3385    let channel_id = ChannelId::from_proto(request.channel_id);
3386    let message_id = MessageId::from_proto(request.message_id);
3387    let updated_at = OffsetDateTime::now_utc();
3388    let UpdatedChannelMessage {
3389        message_id,
3390        participant_connection_ids,
3391        notifications,
3392        reply_to_message_id,
3393        timestamp,
3394    } = session
3395        .db()
3396        .await
3397        .update_channel_message(
3398            channel_id,
3399            message_id,
3400            session.user_id(),
3401            request.body.as_str(),
3402            &request.mentions,
3403            updated_at,
3404        )
3405        .await?;
3406
3407    let nonce = request
3408        .nonce
3409        .clone()
3410        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3411
3412    let message = proto::ChannelMessage {
3413        sender_id: session.user_id().to_proto(),
3414        id: message_id.to_proto(),
3415        body: request.body.clone(),
3416        mentions: request.mentions.clone(),
3417        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3418        nonce: Some(nonce),
3419        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3420        edited_at: Some(updated_at.unix_timestamp() as u64),
3421    };
3422
3423    response.send(proto::Ack {})?;
3424
3425    let pool = &*session.connection_pool().await;
3426    broadcast(
3427        Some(session.connection_id),
3428        participant_connection_ids,
3429        |connection| {
3430            session.peer.send(
3431                connection,
3432                proto::ChannelMessageUpdate {
3433                    channel_id: channel_id.to_proto(),
3434                    message: Some(message.clone()),
3435                },
3436            )
3437        },
3438    );
3439
3440    send_notifications(pool, &session.peer, notifications);
3441
3442    Ok(())
3443}
3444
3445/// Mark a channel message as read
3446async fn acknowledge_channel_message(
3447    request: proto::AckChannelMessage,
3448    session: UserSession,
3449) -> Result<()> {
3450    let channel_id = ChannelId::from_proto(request.channel_id);
3451    let message_id = MessageId::from_proto(request.message_id);
3452    let notifications = session
3453        .db()
3454        .await
3455        .observe_channel_message(channel_id, session.user_id(), message_id)
3456        .await?;
3457    send_notifications(
3458        &*session.connection_pool().await,
3459        &session.peer,
3460        notifications,
3461    );
3462    Ok(())
3463}
3464
3465/// Mark a buffer version as synced
3466async fn acknowledge_buffer_version(
3467    request: proto::AckBufferOperation,
3468    session: UserSession,
3469) -> Result<()> {
3470    let buffer_id = BufferId::from_proto(request.buffer_id);
3471    session
3472        .db()
3473        .await
3474        .observe_buffer_version(
3475            buffer_id,
3476            session.user_id(),
3477            request.epoch as i32,
3478            &request.version,
3479        )
3480        .await?;
3481    Ok(())
3482}
3483
3484struct CompleteWithLanguageModelRateLimit;
3485
3486impl RateLimit for CompleteWithLanguageModelRateLimit {
3487    fn capacity() -> usize {
3488        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3489            .ok()
3490            .and_then(|v| v.parse().ok())
3491            .unwrap_or(120) // Picked arbitrarily
3492    }
3493
3494    fn refill_duration() -> chrono::Duration {
3495        chrono::Duration::hours(1)
3496    }
3497
3498    fn db_name() -> &'static str {
3499        "complete-with-language-model"
3500    }
3501}
3502
3503async fn complete_with_language_model(
3504    request: proto::CompleteWithLanguageModel,
3505    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3506    session: Session,
3507    open_ai_api_key: Option<Arc<str>>,
3508    google_ai_api_key: Option<Arc<str>>,
3509) -> Result<()> {
3510    let Some(session) = session.for_user() else {
3511        return Err(anyhow!("user not found"))?;
3512    };
3513    authorize_access_to_language_models(&session).await?;
3514    session
3515        .rate_limiter
3516        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3517        .await?;
3518
3519    if request.model.starts_with("gpt") {
3520        let api_key =
3521            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3522        complete_with_open_ai(request, response, session, api_key).await?;
3523    } else if request.model.starts_with("gemini") {
3524        let api_key = google_ai_api_key
3525            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3526        complete_with_google_ai(request, response, session, api_key).await?;
3527    }
3528
3529    Ok(())
3530}
3531
3532async fn complete_with_open_ai(
3533    request: proto::CompleteWithLanguageModel,
3534    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3535    session: UserSession,
3536    api_key: Arc<str>,
3537) -> Result<()> {
3538    const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3539
3540    let mut completion_stream = open_ai::stream_completion(
3541        &session.http_client,
3542        OPEN_AI_API_URL,
3543        &api_key,
3544        crate::ai::language_model_request_to_open_ai(request)?,
3545    )
3546    .await
3547    .context("open_ai::stream_completion request failed")?;
3548
3549    while let Some(event) = completion_stream.next().await {
3550        let event = event?;
3551        response.send(proto::LanguageModelResponse {
3552            choices: event
3553                .choices
3554                .into_iter()
3555                .map(|choice| proto::LanguageModelChoiceDelta {
3556                    index: choice.index,
3557                    delta: Some(proto::LanguageModelResponseMessage {
3558                        role: choice.delta.role.map(|role| match role {
3559                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3560                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3561                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3562                        } as i32),
3563                        content: choice.delta.content,
3564                    }),
3565                    finish_reason: choice.finish_reason,
3566                })
3567                .collect(),
3568        })?;
3569    }
3570
3571    Ok(())
3572}
3573
3574async fn complete_with_google_ai(
3575    request: proto::CompleteWithLanguageModel,
3576    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3577    session: UserSession,
3578    api_key: Arc<str>,
3579) -> Result<()> {
3580    let mut stream = google_ai::stream_generate_content(
3581        &session.http_client,
3582        google_ai::API_URL,
3583        api_key.as_ref(),
3584        crate::ai::language_model_request_to_google_ai(request)?,
3585    )
3586    .await
3587    .context("google_ai::stream_generate_content request failed")?;
3588
3589    while let Some(event) = stream.next().await {
3590        let event = event?;
3591        response.send(proto::LanguageModelResponse {
3592            choices: event
3593                .candidates
3594                .unwrap_or_default()
3595                .into_iter()
3596                .map(|candidate| proto::LanguageModelChoiceDelta {
3597                    index: candidate.index as u32,
3598                    delta: Some(proto::LanguageModelResponseMessage {
3599                        role: Some(match candidate.content.role {
3600                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3601                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3602                        } as i32),
3603                        content: Some(
3604                            candidate
3605                                .content
3606                                .parts
3607                                .into_iter()
3608                                .filter_map(|part| match part {
3609                                    google_ai::Part::TextPart(part) => Some(part.text),
3610                                    google_ai::Part::InlineDataPart(_) => None,
3611                                })
3612                                .collect(),
3613                        ),
3614                    }),
3615                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3616                })
3617                .collect(),
3618        })?;
3619    }
3620
3621    Ok(())
3622}
3623
3624struct CountTokensWithLanguageModelRateLimit;
3625
3626impl RateLimit for CountTokensWithLanguageModelRateLimit {
3627    fn capacity() -> usize {
3628        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3629            .ok()
3630            .and_then(|v| v.parse().ok())
3631            .unwrap_or(600) // Picked arbitrarily
3632    }
3633
3634    fn refill_duration() -> chrono::Duration {
3635        chrono::Duration::hours(1)
3636    }
3637
3638    fn db_name() -> &'static str {
3639        "count-tokens-with-language-model"
3640    }
3641}
3642
3643async fn count_tokens_with_language_model(
3644    request: proto::CountTokensWithLanguageModel,
3645    response: Response<proto::CountTokensWithLanguageModel>,
3646    session: UserSession,
3647    google_ai_api_key: Option<Arc<str>>,
3648) -> Result<()> {
3649    authorize_access_to_language_models(&session).await?;
3650
3651    if !request.model.starts_with("gemini") {
3652        return Err(anyhow!(
3653            "counting tokens for model: {:?} is not supported",
3654            request.model
3655        ))?;
3656    }
3657
3658    session
3659        .rate_limiter
3660        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3661        .await?;
3662
3663    let api_key = google_ai_api_key
3664        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3665    let tokens_response = google_ai::count_tokens(
3666        &session.http_client,
3667        google_ai::API_URL,
3668        &api_key,
3669        crate::ai::count_tokens_request_to_google_ai(request)?,
3670    )
3671    .await?;
3672    response.send(proto::CountTokensResponse {
3673        token_count: tokens_response.total_tokens as u32,
3674    })?;
3675    Ok(())
3676}
3677
3678async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3679    let db = session.db().await;
3680    let flags = db.get_user_flags(session.user_id()).await?;
3681    if flags.iter().any(|flag| flag == "language-models") {
3682        Ok(())
3683    } else {
3684        Err(anyhow!("permission denied"))?
3685    }
3686}
3687
3688/// Start receiving chat updates for a channel
3689async fn join_channel_chat(
3690    request: proto::JoinChannelChat,
3691    response: Response<proto::JoinChannelChat>,
3692    session: UserSession,
3693) -> Result<()> {
3694    let channel_id = ChannelId::from_proto(request.channel_id);
3695
3696    let db = session.db().await;
3697    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3698        .await?;
3699    let messages = db
3700        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3701        .await?;
3702    response.send(proto::JoinChannelChatResponse {
3703        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3704        messages,
3705    })?;
3706    Ok(())
3707}
3708
3709/// Stop receiving chat updates for a channel
3710async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3711    let channel_id = ChannelId::from_proto(request.channel_id);
3712    session
3713        .db()
3714        .await
3715        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3716        .await?;
3717    Ok(())
3718}
3719
3720/// Retrieve the chat history for a channel
3721async fn get_channel_messages(
3722    request: proto::GetChannelMessages,
3723    response: Response<proto::GetChannelMessages>,
3724    session: UserSession,
3725) -> Result<()> {
3726    let channel_id = ChannelId::from_proto(request.channel_id);
3727    let messages = session
3728        .db()
3729        .await
3730        .get_channel_messages(
3731            channel_id,
3732            session.user_id(),
3733            MESSAGE_COUNT_PER_PAGE,
3734            Some(MessageId::from_proto(request.before_message_id)),
3735        )
3736        .await?;
3737    response.send(proto::GetChannelMessagesResponse {
3738        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3739        messages,
3740    })?;
3741    Ok(())
3742}
3743
3744/// Retrieve specific chat messages
3745async fn get_channel_messages_by_id(
3746    request: proto::GetChannelMessagesById,
3747    response: Response<proto::GetChannelMessagesById>,
3748    session: UserSession,
3749) -> Result<()> {
3750    let message_ids = request
3751        .message_ids
3752        .iter()
3753        .map(|id| MessageId::from_proto(*id))
3754        .collect::<Vec<_>>();
3755    let messages = session
3756        .db()
3757        .await
3758        .get_channel_messages_by_id(session.user_id(), &message_ids)
3759        .await?;
3760    response.send(proto::GetChannelMessagesResponse {
3761        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3762        messages,
3763    })?;
3764    Ok(())
3765}
3766
3767/// Retrieve the current users notifications
3768async fn get_notifications(
3769    request: proto::GetNotifications,
3770    response: Response<proto::GetNotifications>,
3771    session: UserSession,
3772) -> Result<()> {
3773    let notifications = session
3774        .db()
3775        .await
3776        .get_notifications(
3777            session.user_id(),
3778            NOTIFICATION_COUNT_PER_PAGE,
3779            request
3780                .before_id
3781                .map(|id| db::NotificationId::from_proto(id)),
3782        )
3783        .await?;
3784    response.send(proto::GetNotificationsResponse {
3785        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3786        notifications,
3787    })?;
3788    Ok(())
3789}
3790
3791/// Mark notifications as read
3792async fn mark_notification_as_read(
3793    request: proto::MarkNotificationRead,
3794    response: Response<proto::MarkNotificationRead>,
3795    session: UserSession,
3796) -> Result<()> {
3797    let database = &session.db().await;
3798    let notifications = database
3799        .mark_notification_as_read_by_id(
3800            session.user_id(),
3801            NotificationId::from_proto(request.notification_id),
3802        )
3803        .await?;
3804    send_notifications(
3805        &*session.connection_pool().await,
3806        &session.peer,
3807        notifications,
3808    );
3809    response.send(proto::Ack {})?;
3810    Ok(())
3811}
3812
3813/// Get the current users information
3814async fn get_private_user_info(
3815    _request: proto::GetPrivateUserInfo,
3816    response: Response<proto::GetPrivateUserInfo>,
3817    session: UserSession,
3818) -> Result<()> {
3819    let db = session.db().await;
3820
3821    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3822    let user = db
3823        .get_user_by_id(session.user_id())
3824        .await?
3825        .ok_or_else(|| anyhow!("user not found"))?;
3826    let flags = db.get_user_flags(session.user_id()).await?;
3827
3828    response.send(proto::GetPrivateUserInfoResponse {
3829        metrics_id,
3830        staff: user.admin,
3831        flags,
3832    })?;
3833    Ok(())
3834}
3835
3836fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3837    match message {
3838        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3839        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3840        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3841        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3842        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3843            code: frame.code.into(),
3844            reason: frame.reason,
3845        })),
3846    }
3847}
3848
3849fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3850    match message {
3851        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3852        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3853        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3854        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3855        AxumMessage::Close(frame) => {
3856            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3857                code: frame.code.into(),
3858                reason: frame.reason,
3859            }))
3860        }
3861    }
3862}
3863
3864fn notify_membership_updated(
3865    connection_pool: &mut ConnectionPool,
3866    result: MembershipUpdated,
3867    user_id: UserId,
3868    peer: &Peer,
3869) {
3870    for membership in &result.new_channels.channel_memberships {
3871        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3872    }
3873    for channel_id in &result.removed_channels {
3874        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3875    }
3876
3877    let user_channels_update = proto::UpdateUserChannels {
3878        channel_memberships: result
3879            .new_channels
3880            .channel_memberships
3881            .iter()
3882            .map(|cm| proto::ChannelMembership {
3883                channel_id: cm.channel_id.to_proto(),
3884                role: cm.role.into(),
3885            })
3886            .collect(),
3887        ..Default::default()
3888    };
3889
3890    let mut update = build_channels_update(result.new_channels, vec![]);
3891    update.delete_channels = result
3892        .removed_channels
3893        .into_iter()
3894        .map(|id| id.to_proto())
3895        .collect();
3896    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3897
3898    for connection_id in connection_pool.user_connection_ids(user_id) {
3899        peer.send(connection_id, user_channels_update.clone())
3900            .trace_err();
3901        peer.send(connection_id, update.clone()).trace_err();
3902    }
3903}
3904
3905fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3906    proto::UpdateUserChannels {
3907        channel_memberships: channels
3908            .channel_memberships
3909            .iter()
3910            .map(|m| proto::ChannelMembership {
3911                channel_id: m.channel_id.to_proto(),
3912                role: m.role.into(),
3913            })
3914            .collect(),
3915        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3916        observed_channel_message_id: channels.observed_channel_messages.clone(),
3917    }
3918}
3919
3920fn build_channels_update(
3921    channels: ChannelsForUser,
3922    channel_invites: Vec<db::Channel>,
3923) -> proto::UpdateChannels {
3924    let mut update = proto::UpdateChannels::default();
3925
3926    for channel in channels.channels {
3927        update.channels.push(channel.to_proto());
3928    }
3929
3930    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3931    update.latest_channel_message_ids = channels.latest_channel_messages;
3932
3933    for (channel_id, participants) in channels.channel_participants {
3934        update
3935            .channel_participants
3936            .push(proto::ChannelParticipants {
3937                channel_id: channel_id.to_proto(),
3938                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3939            });
3940    }
3941
3942    for channel in channel_invites {
3943        update.channel_invitations.push(channel.to_proto());
3944    }
3945    for project in channels.hosted_projects {
3946        update.hosted_projects.push(project);
3947    }
3948
3949    update
3950}
3951
3952fn build_initial_contacts_update(
3953    contacts: Vec<db::Contact>,
3954    pool: &ConnectionPool,
3955) -> proto::UpdateContacts {
3956    let mut update = proto::UpdateContacts::default();
3957
3958    for contact in contacts {
3959        match contact {
3960            db::Contact::Accepted { user_id, busy } => {
3961                update.contacts.push(contact_for_user(user_id, busy, &pool));
3962            }
3963            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3964            db::Contact::Incoming { user_id } => {
3965                update
3966                    .incoming_requests
3967                    .push(proto::IncomingContactRequest {
3968                        requester_id: user_id.to_proto(),
3969                    })
3970            }
3971        }
3972    }
3973
3974    update
3975}
3976
3977fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3978    proto::Contact {
3979        user_id: user_id.to_proto(),
3980        online: pool.is_user_online(user_id),
3981        busy,
3982    }
3983}
3984
3985fn room_updated(room: &proto::Room, peer: &Peer) {
3986    broadcast(
3987        None,
3988        room.participants
3989            .iter()
3990            .filter_map(|participant| Some(participant.peer_id?.into())),
3991        |peer_id| {
3992            peer.send(
3993                peer_id,
3994                proto::RoomUpdated {
3995                    room: Some(room.clone()),
3996                },
3997            )
3998        },
3999    );
4000}
4001
4002fn channel_updated(
4003    channel: &db::channel::Model,
4004    room: &proto::Room,
4005    peer: &Peer,
4006    pool: &ConnectionPool,
4007) {
4008    let participants = room
4009        .participants
4010        .iter()
4011        .map(|p| p.user_id)
4012        .collect::<Vec<_>>();
4013
4014    broadcast(
4015        None,
4016        pool.channel_connection_ids(channel.root_id())
4017            .filter_map(|(channel_id, role)| {
4018                role.can_see_channel(channel.visibility).then(|| channel_id)
4019            }),
4020        |peer_id| {
4021            peer.send(
4022                peer_id,
4023                proto::UpdateChannels {
4024                    channel_participants: vec![proto::ChannelParticipants {
4025                        channel_id: channel.id.to_proto(),
4026                        participant_user_ids: participants.clone(),
4027                    }],
4028                    ..Default::default()
4029                },
4030            )
4031        },
4032    );
4033}
4034
4035async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4036    let db = session.db().await;
4037
4038    let contacts = db.get_contacts(user_id).await?;
4039    let busy = db.is_user_busy(user_id).await?;
4040
4041    let pool = session.connection_pool().await;
4042    let updated_contact = contact_for_user(user_id, busy, &pool);
4043    for contact in contacts {
4044        if let db::Contact::Accepted {
4045            user_id: contact_user_id,
4046            ..
4047        } = contact
4048        {
4049            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4050                session
4051                    .peer
4052                    .send(
4053                        contact_conn_id,
4054                        proto::UpdateContacts {
4055                            contacts: vec![updated_contact.clone()],
4056                            remove_contacts: Default::default(),
4057                            incoming_requests: Default::default(),
4058                            remove_incoming_requests: Default::default(),
4059                            outgoing_requests: Default::default(),
4060                            remove_outgoing_requests: Default::default(),
4061                        },
4062                    )
4063                    .trace_err();
4064            }
4065        }
4066    }
4067    Ok(())
4068}
4069
4070async fn leave_room_for_session(session: &UserSession) -> Result<()> {
4071    let mut contacts_to_update = HashSet::default();
4072
4073    let room_id;
4074    let canceled_calls_to_user_ids;
4075    let live_kit_room;
4076    let delete_live_kit_room;
4077    let room;
4078    let channel;
4079
4080    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
4081        contacts_to_update.insert(session.user_id());
4082
4083        for project in left_room.left_projects.values() {
4084            project_left(project, session);
4085        }
4086
4087        room_id = RoomId::from_proto(left_room.room.id);
4088        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4089        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4090        delete_live_kit_room = left_room.deleted;
4091        room = mem::take(&mut left_room.room);
4092        channel = mem::take(&mut left_room.channel);
4093
4094        room_updated(&room, &session.peer);
4095    } else {
4096        return Ok(());
4097    }
4098
4099    if let Some(channel) = channel {
4100        channel_updated(
4101            &channel,
4102            &room,
4103            &session.peer,
4104            &*session.connection_pool().await,
4105        );
4106    }
4107
4108    {
4109        let pool = session.connection_pool().await;
4110        for canceled_user_id in canceled_calls_to_user_ids {
4111            for connection_id in pool.user_connection_ids(canceled_user_id) {
4112                session
4113                    .peer
4114                    .send(
4115                        connection_id,
4116                        proto::CallCanceled {
4117                            room_id: room_id.to_proto(),
4118                        },
4119                    )
4120                    .trace_err();
4121            }
4122            contacts_to_update.insert(canceled_user_id);
4123        }
4124    }
4125
4126    for contact_user_id in contacts_to_update {
4127        update_user_contacts(contact_user_id, &session).await?;
4128    }
4129
4130    if let Some(live_kit) = session.live_kit_client.as_ref() {
4131        live_kit
4132            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4133            .await
4134            .trace_err();
4135
4136        if delete_live_kit_room {
4137            live_kit.delete_room(live_kit_room).await.trace_err();
4138        }
4139    }
4140
4141    Ok(())
4142}
4143
4144async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4145    let left_channel_buffers = session
4146        .db()
4147        .await
4148        .leave_channel_buffers(session.connection_id)
4149        .await?;
4150
4151    for left_buffer in left_channel_buffers {
4152        channel_buffer_updated(
4153            session.connection_id,
4154            left_buffer.connections,
4155            &proto::UpdateChannelBufferCollaborators {
4156                channel_id: left_buffer.channel_id.to_proto(),
4157                collaborators: left_buffer.collaborators,
4158            },
4159            &session.peer,
4160        );
4161    }
4162
4163    Ok(())
4164}
4165
4166fn project_left(project: &db::LeftProject, session: &UserSession) {
4167    for connection_id in &project.connection_ids {
4168        if project.host_user_id == Some(session.user_id()) {
4169            session
4170                .peer
4171                .send(
4172                    *connection_id,
4173                    proto::UnshareProject {
4174                        project_id: project.id.to_proto(),
4175                    },
4176                )
4177                .trace_err();
4178        } else {
4179            session
4180                .peer
4181                .send(
4182                    *connection_id,
4183                    proto::RemoveProjectCollaborator {
4184                        project_id: project.id.to_proto(),
4185                        peer_id: Some(session.connection_id.into()),
4186                    },
4187                )
4188                .trace_err();
4189        }
4190    }
4191}
4192
4193pub trait ResultExt {
4194    type Ok;
4195
4196    fn trace_err(self) -> Option<Self::Ok>;
4197}
4198
4199impl<T, E> ResultExt for Result<T, E>
4200where
4201    E: std::fmt::Debug,
4202{
4203    type Ok = T;
4204
4205    fn trace_err(self) -> Option<T> {
4206        match self {
4207            Ok(value) => Some(value),
4208            Err(error) => {
4209                tracing::error!("{:?}", error);
4210                None
4211            }
4212        }
4213    }
4214}