rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self},
   5    db::{
   6        self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
   9        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, RateLimit, RateLimiter, Result,
  13};
  14use anyhow::{anyhow, Context as _};
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34
  35use futures::{
  36    channel::oneshot,
  37    future::{self, BoxFuture},
  38    stream::FuturesUnordered,
  39    FutureExt, SinkExt, StreamExt, TryStreamExt,
  40};
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  45        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    future::Future,
  53    marker::PhantomData,
  54    mem,
  55    net::SocketAddr,
  56    ops::{Deref, DerefMut},
  57    rc::Rc,
  58    sync::{
  59        atomic::{AtomicBool, Ordering::SeqCst},
  60        Arc, OnceLock,
  61    },
  62    time::{Duration, Instant},
  63};
  64use time::OffsetDateTime;
  65use tokio::sync::{watch, Semaphore};
  66use tower::ServiceBuilder;
  67use tracing::{
  68    field::{self},
  69    info_span, instrument, Instrument,
  70};
  71use util::{http::IsahcHttpClient, SemanticVersion};
  72
  73pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  74
  75// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  76pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  77
  78const MESSAGE_COUNT_PER_PAGE: usize = 100;
  79const MAX_MESSAGE_LEN: usize = 1024;
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81
  82type MessageHandler =
  83    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  84
  85struct Response<R> {
  86    peer: Arc<Peer>,
  87    receipt: Receipt<R>,
  88    responded: Arc<AtomicBool>,
  89}
  90
  91impl<R: RequestMessage> Response<R> {
  92    fn send(self, payload: R::Response) -> Result<()> {
  93        self.responded.store(true, SeqCst);
  94        self.peer.respond(self.receipt, payload)?;
  95        Ok(())
  96    }
  97}
  98
  99struct StreamingResponse<R: RequestMessage> {
 100    peer: Arc<Peer>,
 101    receipt: Receipt<R>,
 102}
 103
 104impl<R: RequestMessage> StreamingResponse<R> {
 105    fn send(&self, payload: R::Response) -> Result<()> {
 106        self.peer.respond(self.receipt, payload)?;
 107        Ok(())
 108    }
 109}
 110
 111#[derive(Clone, Debug)]
 112pub enum Principal {
 113    User(User),
 114    Impersonated { user: User, admin: User },
 115    DevServer(dev_server::Model),
 116}
 117
 118impl Principal {
 119    fn update_span(&self, span: &tracing::Span) {
 120        match &self {
 121            Principal::User(user) => {
 122                span.record("user_id", &user.id.0);
 123                span.record("login", &user.github_login);
 124            }
 125            Principal::Impersonated { user, admin } => {
 126                span.record("user_id", &user.id.0);
 127                span.record("login", &user.github_login);
 128                span.record("impersonator", &admin.github_login);
 129            }
 130            Principal::DevServer(dev_server) => {
 131                span.record("dev_server_id", &dev_server.id.0);
 132            }
 133        }
 134    }
 135}
 136
 137#[derive(Clone)]
 138struct Session {
 139    principal: Principal,
 140    connection_id: ConnectionId,
 141    db: Arc<tokio::sync::Mutex<DbHandle>>,
 142    peer: Arc<Peer>,
 143    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 144    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 145    http_client: IsahcHttpClient,
 146    rate_limiter: Arc<RateLimiter>,
 147    _executor: Executor,
 148}
 149
 150impl Session {
 151    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.db.lock().await;
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        guard
 158    }
 159
 160    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        let guard = self.connection_pool.lock();
 164        ConnectionPoolGuard {
 165            guard,
 166            _not_send: PhantomData,
 167        }
 168    }
 169
 170    fn for_user(self) -> Option<UserSession> {
 171        UserSession::new(self)
 172    }
 173
 174    fn user_id(&self) -> Option<UserId> {
 175        match &self.principal {
 176            Principal::User(user) => Some(user.id),
 177            Principal::Impersonated { user, .. } => Some(user.id),
 178            Principal::DevServer(_) => None,
 179        }
 180    }
 181}
 182
 183impl Debug for Session {
 184    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 185        let mut result = f.debug_struct("Session");
 186        match &self.principal {
 187            Principal::User(user) => {
 188                result.field("user", &user.github_login);
 189            }
 190            Principal::Impersonated { user, admin } => {
 191                result.field("user", &user.github_login);
 192                result.field("impersonator", &admin.github_login);
 193            }
 194            Principal::DevServer(dev_server) => {
 195                result.field("dev_server", &dev_server.id);
 196            }
 197        }
 198        result.field("connection_id", &self.connection_id).finish()
 199    }
 200}
 201
 202struct UserSession(Session);
 203
 204impl UserSession {
 205    pub fn new(s: Session) -> Option<Self> {
 206        s.user_id().map(|_| UserSession(s))
 207    }
 208    pub fn user_id(&self) -> UserId {
 209        self.0.user_id().unwrap()
 210    }
 211}
 212
 213impl Deref for UserSession {
 214    type Target = Session;
 215
 216    fn deref(&self) -> &Self::Target {
 217        &self.0
 218    }
 219}
 220impl DerefMut for UserSession {
 221    fn deref_mut(&mut self) -> &mut Self::Target {
 222        &mut self.0
 223    }
 224}
 225
 226fn user_handler<M: RequestMessage, Fut>(
 227    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 228) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 229where
 230    Fut: Send + Future<Output = Result<()>>,
 231{
 232    let handler = Arc::new(handler);
 233    move |message, response, session| {
 234        let handler = handler.clone();
 235        Box::pin(async move {
 236            if let Some(user_session) = session.for_user() {
 237                Ok(handler(message, response, user_session).await?)
 238            } else {
 239                Err(Error::Internal(anyhow!("must be a user")))
 240            }
 241        })
 242    }
 243}
 244
 245fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 246    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 247) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 248where
 249    InnertRetFut: Send + Future<Output = Result<()>>,
 250{
 251    let handler = Arc::new(handler);
 252    move |message, session| {
 253        let handler = handler.clone();
 254        Box::pin(async move {
 255            if let Some(user_session) = session.for_user() {
 256                Ok(handler(message, user_session).await?)
 257            } else {
 258                Err(Error::Internal(anyhow!("must be a user")))
 259            }
 260        })
 261    }
 262}
 263
 264struct DbHandle(Arc<Database>);
 265
 266impl Deref for DbHandle {
 267    type Target = Database;
 268
 269    fn deref(&self) -> &Self::Target {
 270        self.0.as_ref()
 271    }
 272}
 273
 274pub struct Server {
 275    id: parking_lot::Mutex<ServerId>,
 276    peer: Arc<Peer>,
 277    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 278    app_state: Arc<AppState>,
 279    handlers: HashMap<TypeId, MessageHandler>,
 280    teardown: watch::Sender<bool>,
 281}
 282
 283pub(crate) struct ConnectionPoolGuard<'a> {
 284    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 285    _not_send: PhantomData<Rc<()>>,
 286}
 287
 288#[derive(Serialize)]
 289pub struct ServerSnapshot<'a> {
 290    peer: &'a Peer,
 291    #[serde(serialize_with = "serialize_deref")]
 292    connection_pool: ConnectionPoolGuard<'a>,
 293}
 294
 295pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 296where
 297    S: Serializer,
 298    T: Deref<Target = U>,
 299    U: Serialize,
 300{
 301    Serialize::serialize(value.deref(), serializer)
 302}
 303
 304impl Server {
 305    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 306        let mut server = Self {
 307            id: parking_lot::Mutex::new(id),
 308            peer: Peer::new(id.0 as u32),
 309            app_state: app_state.clone(),
 310            connection_pool: Default::default(),
 311            handlers: Default::default(),
 312            teardown: watch::channel(false).0,
 313        };
 314
 315        server
 316            .add_request_handler(ping)
 317            .add_request_handler(user_handler(create_room))
 318            .add_request_handler(user_handler(join_room))
 319            .add_request_handler(user_handler(rejoin_room))
 320            .add_request_handler(user_handler(leave_room))
 321            .add_request_handler(user_handler(set_room_participant_role))
 322            .add_request_handler(user_handler(call))
 323            .add_request_handler(user_handler(cancel_call))
 324            .add_message_handler(user_message_handler(decline_call))
 325            .add_request_handler(user_handler(update_participant_location))
 326            .add_request_handler(share_project)
 327            .add_message_handler(unshare_project)
 328            .add_request_handler(user_handler(join_project))
 329            .add_request_handler(user_handler(join_hosted_project))
 330            .add_message_handler(user_message_handler(leave_project))
 331            .add_request_handler(update_project)
 332            .add_request_handler(update_worktree)
 333            .add_message_handler(start_language_server)
 334            .add_message_handler(update_language_server)
 335            .add_message_handler(update_diagnostic_summary)
 336            .add_message_handler(update_worktree_settings)
 337            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 338            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 341            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 346            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 347            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 348            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 349            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 350            .add_request_handler(
 351                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 352            )
 353            .add_request_handler(
 354                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 355            )
 356            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 357            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 358            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 359            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 360            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 362            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 363            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 368            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 369            .add_message_handler(create_buffer_for_peer)
 370            .add_request_handler(update_buffer)
 371            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 374            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 375            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 376            .add_request_handler(get_users)
 377            .add_request_handler(user_handler(fuzzy_search_users))
 378            .add_request_handler(user_handler(request_contact))
 379            .add_request_handler(user_handler(remove_contact))
 380            .add_request_handler(user_handler(respond_to_contact_request))
 381            .add_request_handler(user_handler(create_channel))
 382            .add_request_handler(user_handler(delete_channel))
 383            .add_request_handler(user_handler(invite_channel_member))
 384            .add_request_handler(user_handler(remove_channel_member))
 385            .add_request_handler(user_handler(set_channel_member_role))
 386            .add_request_handler(user_handler(set_channel_visibility))
 387            .add_request_handler(user_handler(rename_channel))
 388            .add_request_handler(user_handler(join_channel_buffer))
 389            .add_request_handler(user_handler(leave_channel_buffer))
 390            .add_message_handler(user_message_handler(update_channel_buffer))
 391            .add_request_handler(user_handler(rejoin_channel_buffers))
 392            .add_request_handler(user_handler(get_channel_members))
 393            .add_request_handler(user_handler(respond_to_channel_invite))
 394            .add_request_handler(user_handler(join_channel))
 395            .add_request_handler(user_handler(join_channel_chat))
 396            .add_message_handler(user_message_handler(leave_channel_chat))
 397            .add_request_handler(user_handler(send_channel_message))
 398            .add_request_handler(user_handler(remove_channel_message))
 399            .add_request_handler(user_handler(update_channel_message))
 400            .add_request_handler(user_handler(get_channel_messages))
 401            .add_request_handler(user_handler(get_channel_messages_by_id))
 402            .add_request_handler(user_handler(get_notifications))
 403            .add_request_handler(user_handler(mark_notification_as_read))
 404            .add_request_handler(user_handler(move_channel))
 405            .add_request_handler(user_handler(follow))
 406            .add_message_handler(user_message_handler(unfollow))
 407            .add_message_handler(user_message_handler(update_followers))
 408            .add_request_handler(user_handler(get_private_user_info))
 409            .add_message_handler(user_message_handler(acknowledge_channel_message))
 410            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 411            .add_streaming_request_handler({
 412                let app_state = app_state.clone();
 413                move |request, response, session| {
 414                    complete_with_language_model(
 415                        request,
 416                        response,
 417                        session,
 418                        app_state.config.openai_api_key.clone(),
 419                        app_state.config.google_ai_api_key.clone(),
 420                    )
 421                }
 422            })
 423            .add_request_handler({
 424                let app_state = app_state.clone();
 425                user_handler(move |request, response, session| {
 426                    count_tokens_with_language_model(
 427                        request,
 428                        response,
 429                        session,
 430                        app_state.config.google_ai_api_key.clone(),
 431                    )
 432                })
 433            });
 434
 435        Arc::new(server)
 436    }
 437
 438    pub async fn start(&self) -> Result<()> {
 439        let server_id = *self.id.lock();
 440        let app_state = self.app_state.clone();
 441        let peer = self.peer.clone();
 442        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 443        let pool = self.connection_pool.clone();
 444        let live_kit_client = self.app_state.live_kit_client.clone();
 445
 446        let span = info_span!("start server");
 447        self.app_state.executor.spawn_detached(
 448            async move {
 449                tracing::info!("waiting for cleanup timeout");
 450                timeout.await;
 451                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 452                if let Some((room_ids, channel_ids)) = app_state
 453                    .db
 454                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 455                    .await
 456                    .trace_err()
 457                {
 458                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 459                    tracing::info!(
 460                        stale_channel_buffer_count = channel_ids.len(),
 461                        "retrieved stale channel buffers"
 462                    );
 463
 464                    for channel_id in channel_ids {
 465                        if let Some(refreshed_channel_buffer) = app_state
 466                            .db
 467                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 468                            .await
 469                            .trace_err()
 470                        {
 471                            for connection_id in refreshed_channel_buffer.connection_ids {
 472                                peer.send(
 473                                    connection_id,
 474                                    proto::UpdateChannelBufferCollaborators {
 475                                        channel_id: channel_id.to_proto(),
 476                                        collaborators: refreshed_channel_buffer
 477                                            .collaborators
 478                                            .clone(),
 479                                    },
 480                                )
 481                                .trace_err();
 482                            }
 483                        }
 484                    }
 485
 486                    for room_id in room_ids {
 487                        let mut contacts_to_update = HashSet::default();
 488                        let mut canceled_calls_to_user_ids = Vec::new();
 489                        let mut live_kit_room = String::new();
 490                        let mut delete_live_kit_room = false;
 491
 492                        if let Some(mut refreshed_room) = app_state
 493                            .db
 494                            .clear_stale_room_participants(room_id, server_id)
 495                            .await
 496                            .trace_err()
 497                        {
 498                            tracing::info!(
 499                                room_id = room_id.0,
 500                                new_participant_count = refreshed_room.room.participants.len(),
 501                                "refreshed room"
 502                            );
 503                            room_updated(&refreshed_room.room, &peer);
 504                            if let Some(channel) = refreshed_room.channel.as_ref() {
 505                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 506                            }
 507                            contacts_to_update
 508                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 509                            contacts_to_update
 510                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 511                            canceled_calls_to_user_ids =
 512                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 513                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 514                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 515                        }
 516
 517                        {
 518                            let pool = pool.lock();
 519                            for canceled_user_id in canceled_calls_to_user_ids {
 520                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 521                                    peer.send(
 522                                        connection_id,
 523                                        proto::CallCanceled {
 524                                            room_id: room_id.to_proto(),
 525                                        },
 526                                    )
 527                                    .trace_err();
 528                                }
 529                            }
 530                        }
 531
 532                        for user_id in contacts_to_update {
 533                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 534                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 535                            if let Some((busy, contacts)) = busy.zip(contacts) {
 536                                let pool = pool.lock();
 537                                let updated_contact = contact_for_user(user_id, busy, &pool);
 538                                for contact in contacts {
 539                                    if let db::Contact::Accepted {
 540                                        user_id: contact_user_id,
 541                                        ..
 542                                    } = contact
 543                                    {
 544                                        for contact_conn_id in
 545                                            pool.user_connection_ids(contact_user_id)
 546                                        {
 547                                            peer.send(
 548                                                contact_conn_id,
 549                                                proto::UpdateContacts {
 550                                                    contacts: vec![updated_contact.clone()],
 551                                                    remove_contacts: Default::default(),
 552                                                    incoming_requests: Default::default(),
 553                                                    remove_incoming_requests: Default::default(),
 554                                                    outgoing_requests: Default::default(),
 555                                                    remove_outgoing_requests: Default::default(),
 556                                                },
 557                                            )
 558                                            .trace_err();
 559                                        }
 560                                    }
 561                                }
 562                            }
 563                        }
 564
 565                        if let Some(live_kit) = live_kit_client.as_ref() {
 566                            if delete_live_kit_room {
 567                                live_kit.delete_room(live_kit_room).await.trace_err();
 568                            }
 569                        }
 570                    }
 571                }
 572
 573                app_state
 574                    .db
 575                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 576                    .await
 577                    .trace_err();
 578            }
 579            .instrument(span),
 580        );
 581        Ok(())
 582    }
 583
 584    pub fn teardown(&self) {
 585        self.peer.teardown();
 586        self.connection_pool.lock().reset();
 587        let _ = self.teardown.send(true);
 588    }
 589
 590    #[cfg(test)]
 591    pub fn reset(&self, id: ServerId) {
 592        self.teardown();
 593        *self.id.lock() = id;
 594        self.peer.reset(id.0 as u32);
 595        let _ = self.teardown.send(false);
 596    }
 597
 598    #[cfg(test)]
 599    pub fn id(&self) -> ServerId {
 600        *self.id.lock()
 601    }
 602
 603    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 604    where
 605        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 606        Fut: 'static + Send + Future<Output = Result<()>>,
 607        M: EnvelopedMessage,
 608    {
 609        let prev_handler = self.handlers.insert(
 610            TypeId::of::<M>(),
 611            Box::new(move |envelope, session| {
 612                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 613                let received_at = envelope.received_at;
 614                    tracing::info!(
 615                        "message received"
 616                    );
 617                let start_time = Instant::now();
 618                let future = (handler)(*envelope, session);
 619                async move {
 620                    let result = future.await;
 621                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 622                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 623                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 624                    match result {
 625                        Err(error) => {
 626                            tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
 627                        }
 628                        Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
 629                    }
 630                }
 631                .boxed()
 632            }),
 633        );
 634        if prev_handler.is_some() {
 635            panic!("registered a handler for the same message twice");
 636        }
 637        self
 638    }
 639
 640    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 641    where
 642        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 643        Fut: 'static + Send + Future<Output = Result<()>>,
 644        M: EnvelopedMessage,
 645    {
 646        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 647        self
 648    }
 649
 650    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 651    where
 652        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 653        Fut: Send + Future<Output = Result<()>>,
 654        M: RequestMessage,
 655    {
 656        let handler = Arc::new(handler);
 657        self.add_handler(move |envelope, session| {
 658            let receipt = envelope.receipt();
 659            let handler = handler.clone();
 660            async move {
 661                let peer = session.peer.clone();
 662                let responded = Arc::new(AtomicBool::default());
 663                let response = Response {
 664                    peer: peer.clone(),
 665                    responded: responded.clone(),
 666                    receipt,
 667                };
 668                match (handler)(envelope.payload, response, session).await {
 669                    Ok(()) => {
 670                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 671                            Ok(())
 672                        } else {
 673                            Err(anyhow!("handler did not send a response"))?
 674                        }
 675                    }
 676                    Err(error) => {
 677                        let proto_err = match &error {
 678                            Error::Internal(err) => err.to_proto(),
 679                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 680                        };
 681                        peer.respond_with_error(receipt, proto_err)?;
 682                        Err(error)
 683                    }
 684                }
 685            }
 686        })
 687    }
 688
 689    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 690    where
 691        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 692        Fut: Send + Future<Output = Result<()>>,
 693        M: RequestMessage,
 694    {
 695        let handler = Arc::new(handler);
 696        self.add_handler(move |envelope, session| {
 697            let receipt = envelope.receipt();
 698            let handler = handler.clone();
 699            async move {
 700                let peer = session.peer.clone();
 701                let response = StreamingResponse {
 702                    peer: peer.clone(),
 703                    receipt,
 704                };
 705                match (handler)(envelope.payload, response, session).await {
 706                    Ok(()) => {
 707                        peer.end_stream(receipt)?;
 708                        Ok(())
 709                    }
 710                    Err(error) => {
 711                        let proto_err = match &error {
 712                            Error::Internal(err) => err.to_proto(),
 713                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 714                        };
 715                        peer.respond_with_error(receipt, proto_err)?;
 716                        Err(error)
 717                    }
 718                }
 719            }
 720        })
 721    }
 722
 723    #[allow(clippy::too_many_arguments)]
 724    pub fn handle_connection(
 725        self: &Arc<Self>,
 726        connection: Connection,
 727        address: String,
 728        principal: Principal,
 729        zed_version: ZedVersion,
 730        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 731        executor: Executor,
 732    ) -> impl Future<Output = ()> {
 733        let this = self.clone();
 734        let span = info_span!("handle connection", %address, impersonator = field::Empty, connection_id = field::Empty);
 735        principal.update_span(&span);
 736
 737        let mut teardown = self.teardown.subscribe();
 738        async move {
 739            if *teardown.borrow() {
 740                tracing::error!("server is tearing down");
 741                return
 742            }
 743            let (connection_id, handle_io, mut incoming_rx) = this
 744                .peer
 745                .add_connection(connection, {
 746                    let executor = executor.clone();
 747                    move |duration| executor.sleep(duration)
 748                });
 749            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 750            tracing::info!("connection opened");
 751
 752            let http_client = match IsahcHttpClient::new() {
 753                Ok(http_client) => http_client,
 754                Err(error) => {
 755                    tracing::error!(?error, "failed to create HTTP client");
 756                    return;
 757                }
 758            };
 759
 760            let session = Session {
 761                principal: principal.clone(),
 762                connection_id,
 763                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 764                peer: this.peer.clone(),
 765                connection_pool: this.connection_pool.clone(),
 766                live_kit_client: this.app_state.live_kit_client.clone(),
 767                http_client,
 768                rate_limiter: this.app_state.rate_limiter.clone(),
 769                _executor: executor.clone(),
 770            };
 771
 772            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 773                tracing::error!(?error, "failed to send initial client update");
 774                return;
 775            }
 776
 777            let handle_io = handle_io.fuse();
 778            futures::pin_mut!(handle_io);
 779
 780            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 781            // This prevents deadlocks when e.g., client A performs a request to client B and
 782            // client B performs a request to client A. If both clients stop processing further
 783            // messages until their respective request completes, they won't have a chance to
 784            // respond to the other client's request and cause a deadlock.
 785            //
 786            // This arrangement ensures we will attempt to process earlier messages first, but fall
 787            // back to processing messages arrived later in the spirit of making progress.
 788            let mut foreground_message_handlers = FuturesUnordered::new();
 789            let concurrent_handlers = Arc::new(Semaphore::new(256));
 790            loop {
 791                let next_message = async {
 792                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 793                    let message = incoming_rx.next().await;
 794                    (permit, message)
 795                }.fuse();
 796                futures::pin_mut!(next_message);
 797                futures::select_biased! {
 798                    _ = teardown.changed().fuse() => return,
 799                    result = handle_io => {
 800                        if let Err(error) = result {
 801                            tracing::error!(?error, "error handling I/O");
 802                        }
 803                        break;
 804                    }
 805                    _ = foreground_message_handlers.next() => {}
 806                    next_message = next_message => {
 807                        let (permit, message) = next_message;
 808                        if let Some(message) = message {
 809                            let type_name = message.payload_type_name();
 810                            // note: we copy all the fields from the parent span so we can query them in the logs.
 811                            // (https://github.com/tokio-rs/tracing/issues/2670).
 812                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name);
 813                            principal.update_span(&span);
 814                            let span_enter = span.enter();
 815                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 816                                let is_background = message.is_background();
 817                                let handle_message = (handler)(message, session.clone());
 818                                drop(span_enter);
 819
 820                                let handle_message = async move {
 821                                    handle_message.await;
 822                                    drop(permit);
 823                                }.instrument(span);
 824                                if is_background {
 825                                    executor.spawn_detached(handle_message);
 826                                } else {
 827                                    foreground_message_handlers.push(handle_message);
 828                                }
 829                            } else {
 830                                tracing::error!("no message handler");
 831                            }
 832                        } else {
 833                            tracing::info!("connection closed");
 834                            break;
 835                        }
 836                    }
 837                }
 838            }
 839
 840            drop(foreground_message_handlers);
 841            tracing::info!("signing out");
 842            if let Err(error) = connection_lost(session, teardown, executor).await {
 843                tracing::error!(?error, "error signing out");
 844            }
 845
 846        }.instrument(span)
 847    }
 848
 849    async fn send_initial_client_update(
 850        &self,
 851        connection_id: ConnectionId,
 852        principal: &Principal,
 853        zed_version: ZedVersion,
 854        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 855        session: &Session,
 856    ) -> Result<()> {
 857        self.peer.send(
 858            connection_id,
 859            proto::Hello {
 860                peer_id: Some(connection_id.into()),
 861            },
 862        )?;
 863        tracing::info!("sent hello message");
 864
 865        let Principal::User(user) = principal else {
 866            return Ok(());
 867        };
 868
 869        if let Some(send_connection_id) = send_connection_id.take() {
 870            let _ = send_connection_id.send(connection_id);
 871        }
 872
 873        if !user.connected_once {
 874            self.peer.send(connection_id, proto::ShowContacts {})?;
 875            self.app_state
 876                .db
 877                .set_user_connected_once(user.id, true)
 878                .await?;
 879        }
 880
 881        let (contacts, channels_for_user, channel_invites) = future::try_join3(
 882            self.app_state.db.get_contacts(user.id),
 883            self.app_state.db.get_channels_for_user(user.id),
 884            self.app_state.db.get_channel_invites_for_user(user.id),
 885        )
 886        .await?;
 887
 888        {
 889            let mut pool = self.connection_pool.lock();
 890            pool.add_connection(connection_id, user.id, user.admin, zed_version);
 891            for membership in &channels_for_user.channel_memberships {
 892                pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
 893            }
 894            self.peer.send(
 895                connection_id,
 896                build_initial_contacts_update(contacts, &pool),
 897            )?;
 898            self.peer.send(
 899                connection_id,
 900                build_update_user_channels(&channels_for_user),
 901            )?;
 902            self.peer.send(
 903                connection_id,
 904                build_channels_update(channels_for_user, channel_invites),
 905            )?;
 906        }
 907
 908        if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
 909            self.peer.send(connection_id, incoming_call)?;
 910        }
 911
 912        update_user_contacts(user.id, &session).await?;
 913        Ok(())
 914    }
 915
 916    pub async fn invite_code_redeemed(
 917        self: &Arc<Self>,
 918        inviter_id: UserId,
 919        invitee_id: UserId,
 920    ) -> Result<()> {
 921        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 922            if let Some(code) = &user.invite_code {
 923                let pool = self.connection_pool.lock();
 924                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 925                for connection_id in pool.user_connection_ids(inviter_id) {
 926                    self.peer.send(
 927                        connection_id,
 928                        proto::UpdateContacts {
 929                            contacts: vec![invitee_contact.clone()],
 930                            ..Default::default()
 931                        },
 932                    )?;
 933                    self.peer.send(
 934                        connection_id,
 935                        proto::UpdateInviteInfo {
 936                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 937                            count: user.invite_count as u32,
 938                        },
 939                    )?;
 940                }
 941            }
 942        }
 943        Ok(())
 944    }
 945
 946    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 947        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 948            if let Some(invite_code) = &user.invite_code {
 949                let pool = self.connection_pool.lock();
 950                for connection_id in pool.user_connection_ids(user_id) {
 951                    self.peer.send(
 952                        connection_id,
 953                        proto::UpdateInviteInfo {
 954                            url: format!(
 955                                "{}{}",
 956                                self.app_state.config.invite_link_prefix, invite_code
 957                            ),
 958                            count: user.invite_count as u32,
 959                        },
 960                    )?;
 961                }
 962            }
 963        }
 964        Ok(())
 965    }
 966
 967    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 968        ServerSnapshot {
 969            connection_pool: ConnectionPoolGuard {
 970                guard: self.connection_pool.lock(),
 971                _not_send: PhantomData,
 972            },
 973            peer: &self.peer,
 974        }
 975    }
 976}
 977
 978impl<'a> Deref for ConnectionPoolGuard<'a> {
 979    type Target = ConnectionPool;
 980
 981    fn deref(&self) -> &Self::Target {
 982        &self.guard
 983    }
 984}
 985
 986impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 987    fn deref_mut(&mut self) -> &mut Self::Target {
 988        &mut self.guard
 989    }
 990}
 991
 992impl<'a> Drop for ConnectionPoolGuard<'a> {
 993    fn drop(&mut self) {
 994        #[cfg(test)]
 995        self.check_invariants();
 996    }
 997}
 998
 999fn broadcast<F>(
1000    sender_id: Option<ConnectionId>,
1001    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1002    mut f: F,
1003) where
1004    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1005{
1006    for receiver_id in receiver_ids {
1007        if Some(receiver_id) != sender_id {
1008            if let Err(error) = f(receiver_id) {
1009                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1010            }
1011        }
1012    }
1013}
1014
1015pub struct ProtocolVersion(u32);
1016
1017impl Header for ProtocolVersion {
1018    fn name() -> &'static HeaderName {
1019        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1020        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1021    }
1022
1023    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1024    where
1025        Self: Sized,
1026        I: Iterator<Item = &'i axum::http::HeaderValue>,
1027    {
1028        let version = values
1029            .next()
1030            .ok_or_else(axum::headers::Error::invalid)?
1031            .to_str()
1032            .map_err(|_| axum::headers::Error::invalid())?
1033            .parse()
1034            .map_err(|_| axum::headers::Error::invalid())?;
1035        Ok(Self(version))
1036    }
1037
1038    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1039        values.extend([self.0.to_string().parse().unwrap()]);
1040    }
1041}
1042
1043pub struct AppVersionHeader(SemanticVersion);
1044impl Header for AppVersionHeader {
1045    fn name() -> &'static HeaderName {
1046        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1047        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1048    }
1049
1050    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1051    where
1052        Self: Sized,
1053        I: Iterator<Item = &'i axum::http::HeaderValue>,
1054    {
1055        let version = values
1056            .next()
1057            .ok_or_else(axum::headers::Error::invalid)?
1058            .to_str()
1059            .map_err(|_| axum::headers::Error::invalid())?
1060            .parse()
1061            .map_err(|_| axum::headers::Error::invalid())?;
1062        Ok(Self(version))
1063    }
1064
1065    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1066        values.extend([self.0.to_string().parse().unwrap()]);
1067    }
1068}
1069
1070pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1071    Router::new()
1072        .route("/rpc", get(handle_websocket_request))
1073        .layer(
1074            ServiceBuilder::new()
1075                .layer(Extension(server.app_state.clone()))
1076                .layer(middleware::from_fn(auth::validate_header)),
1077        )
1078        .route("/metrics", get(handle_metrics))
1079        .layer(Extension(server))
1080}
1081
1082pub async fn handle_websocket_request(
1083    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1084    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1085    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1086    Extension(server): Extension<Arc<Server>>,
1087    Extension(principal): Extension<Principal>,
1088    ws: WebSocketUpgrade,
1089) -> axum::response::Response {
1090    if protocol_version != rpc::PROTOCOL_VERSION {
1091        return (
1092            StatusCode::UPGRADE_REQUIRED,
1093            "client must be upgraded".to_string(),
1094        )
1095            .into_response();
1096    }
1097
1098    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1099        return (
1100            StatusCode::UPGRADE_REQUIRED,
1101            "no version header found".to_string(),
1102        )
1103            .into_response();
1104    };
1105
1106    if !version.can_collaborate() {
1107        return (
1108            StatusCode::UPGRADE_REQUIRED,
1109            "client must be upgraded".to_string(),
1110        )
1111            .into_response();
1112    }
1113
1114    let socket_address = socket_address.to_string();
1115    ws.on_upgrade(move |socket| {
1116        let socket = socket
1117            .map_ok(to_tungstenite_message)
1118            .err_into()
1119            .with(|message| async move { Ok(to_axum_message(message)) });
1120        let connection = Connection::new(Box::pin(socket));
1121        async move {
1122            server
1123                .handle_connection(
1124                    connection,
1125                    socket_address,
1126                    principal,
1127                    version,
1128                    None,
1129                    Executor::Production,
1130                )
1131                .await;
1132        }
1133    })
1134}
1135
1136pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1137    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1138    let connections_metric = CONNECTIONS_METRIC
1139        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1140
1141    let connections = server
1142        .connection_pool
1143        .lock()
1144        .connections()
1145        .filter(|connection| !connection.admin)
1146        .count();
1147    connections_metric.set(connections as _);
1148
1149    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1150    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1151        register_int_gauge!(
1152            "shared_projects",
1153            "number of open projects with one or more guests"
1154        )
1155        .unwrap()
1156    });
1157
1158    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1159    shared_projects_metric.set(shared_projects as _);
1160
1161    let encoder = prometheus::TextEncoder::new();
1162    let metric_families = prometheus::gather();
1163    let encoded_metrics = encoder
1164        .encode_to_string(&metric_families)
1165        .map_err(|err| anyhow!("{}", err))?;
1166    Ok(encoded_metrics)
1167}
1168
1169#[instrument(err, skip(executor))]
1170async fn connection_lost(
1171    session: Session,
1172    mut teardown: watch::Receiver<bool>,
1173    executor: Executor,
1174) -> Result<()> {
1175    session.peer.disconnect(session.connection_id);
1176    session
1177        .connection_pool()
1178        .await
1179        .remove_connection(session.connection_id)?;
1180
1181    session
1182        .db()
1183        .await
1184        .connection_lost(session.connection_id)
1185        .await
1186        .trace_err();
1187
1188    futures::select_biased! {
1189        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1190            if let Some(session) = session.for_user() {
1191                log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1192                leave_room_for_session(&session).await.trace_err();
1193                leave_channel_buffers_for_session(&session)
1194                    .await
1195                    .trace_err();
1196
1197                if !session
1198                    .connection_pool()
1199                    .await
1200                    .is_user_online(session.user_id())
1201                {
1202                    let db = session.db().await;
1203                    if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1204                        room_updated(&room, &session.peer);
1205                    }
1206                }
1207
1208                update_user_contacts(session.user_id(), &session).await?;
1209            }
1210        }
1211        _ = teardown.changed().fuse() => {}
1212    }
1213
1214    Ok(())
1215}
1216
1217/// Acknowledges a ping from a client, used to keep the connection alive.
1218async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1219    response.send(proto::Ack {})?;
1220    Ok(())
1221}
1222
1223/// Creates a new room for calling (outside of channels)
1224async fn create_room(
1225    _request: proto::CreateRoom,
1226    response: Response<proto::CreateRoom>,
1227    session: UserSession,
1228) -> Result<()> {
1229    let live_kit_room = nanoid::nanoid!(30);
1230
1231    let live_kit_connection_info = util::maybe!(async {
1232        let live_kit = session.live_kit_client.as_ref();
1233        let live_kit = live_kit?;
1234        let user_id = session.user_id().to_string();
1235
1236        let token = live_kit
1237            .room_token(&live_kit_room, &user_id.to_string())
1238            .trace_err()?;
1239
1240        Some(proto::LiveKitConnectionInfo {
1241            server_url: live_kit.url().into(),
1242            token,
1243            can_publish: true,
1244        })
1245    })
1246    .await;
1247
1248    let room = session
1249        .db()
1250        .await
1251        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1252        .await?;
1253
1254    response.send(proto::CreateRoomResponse {
1255        room: Some(room.clone()),
1256        live_kit_connection_info,
1257    })?;
1258
1259    update_user_contacts(session.user_id(), &session).await?;
1260    Ok(())
1261}
1262
1263/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1264async fn join_room(
1265    request: proto::JoinRoom,
1266    response: Response<proto::JoinRoom>,
1267    session: UserSession,
1268) -> Result<()> {
1269    let room_id = RoomId::from_proto(request.id);
1270
1271    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1272
1273    if let Some(channel_id) = channel_id {
1274        return join_channel_internal(channel_id, Box::new(response), session).await;
1275    }
1276
1277    let joined_room = {
1278        let room = session
1279            .db()
1280            .await
1281            .join_room(room_id, session.user_id(), session.connection_id)
1282            .await?;
1283        room_updated(&room.room, &session.peer);
1284        room.into_inner()
1285    };
1286
1287    for connection_id in session
1288        .connection_pool()
1289        .await
1290        .user_connection_ids(session.user_id())
1291    {
1292        session
1293            .peer
1294            .send(
1295                connection_id,
1296                proto::CallCanceled {
1297                    room_id: room_id.to_proto(),
1298                },
1299            )
1300            .trace_err();
1301    }
1302
1303    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1304        if let Some(token) = live_kit
1305            .room_token(
1306                &joined_room.room.live_kit_room,
1307                &session.user_id().to_string(),
1308            )
1309            .trace_err()
1310        {
1311            Some(proto::LiveKitConnectionInfo {
1312                server_url: live_kit.url().into(),
1313                token,
1314                can_publish: true,
1315            })
1316        } else {
1317            None
1318        }
1319    } else {
1320        None
1321    };
1322
1323    response.send(proto::JoinRoomResponse {
1324        room: Some(joined_room.room),
1325        channel_id: None,
1326        live_kit_connection_info,
1327    })?;
1328
1329    update_user_contacts(session.user_id(), &session).await?;
1330    Ok(())
1331}
1332
1333/// Rejoin room is used to reconnect to a room after connection errors.
1334async fn rejoin_room(
1335    request: proto::RejoinRoom,
1336    response: Response<proto::RejoinRoom>,
1337    session: UserSession,
1338) -> Result<()> {
1339    let room;
1340    let channel;
1341    {
1342        let mut rejoined_room = session
1343            .db()
1344            .await
1345            .rejoin_room(request, session.user_id(), session.connection_id)
1346            .await?;
1347
1348        response.send(proto::RejoinRoomResponse {
1349            room: Some(rejoined_room.room.clone()),
1350            reshared_projects: rejoined_room
1351                .reshared_projects
1352                .iter()
1353                .map(|project| proto::ResharedProject {
1354                    id: project.id.to_proto(),
1355                    collaborators: project
1356                        .collaborators
1357                        .iter()
1358                        .map(|collaborator| collaborator.to_proto())
1359                        .collect(),
1360                })
1361                .collect(),
1362            rejoined_projects: rejoined_room
1363                .rejoined_projects
1364                .iter()
1365                .map(|rejoined_project| proto::RejoinedProject {
1366                    id: rejoined_project.id.to_proto(),
1367                    worktrees: rejoined_project
1368                        .worktrees
1369                        .iter()
1370                        .map(|worktree| proto::WorktreeMetadata {
1371                            id: worktree.id,
1372                            root_name: worktree.root_name.clone(),
1373                            visible: worktree.visible,
1374                            abs_path: worktree.abs_path.clone(),
1375                        })
1376                        .collect(),
1377                    collaborators: rejoined_project
1378                        .collaborators
1379                        .iter()
1380                        .map(|collaborator| collaborator.to_proto())
1381                        .collect(),
1382                    language_servers: rejoined_project.language_servers.clone(),
1383                })
1384                .collect(),
1385        })?;
1386        room_updated(&rejoined_room.room, &session.peer);
1387
1388        for project in &rejoined_room.reshared_projects {
1389            for collaborator in &project.collaborators {
1390                session
1391                    .peer
1392                    .send(
1393                        collaborator.connection_id,
1394                        proto::UpdateProjectCollaborator {
1395                            project_id: project.id.to_proto(),
1396                            old_peer_id: Some(project.old_connection_id.into()),
1397                            new_peer_id: Some(session.connection_id.into()),
1398                        },
1399                    )
1400                    .trace_err();
1401            }
1402
1403            broadcast(
1404                Some(session.connection_id),
1405                project
1406                    .collaborators
1407                    .iter()
1408                    .map(|collaborator| collaborator.connection_id),
1409                |connection_id| {
1410                    session.peer.forward_send(
1411                        session.connection_id,
1412                        connection_id,
1413                        proto::UpdateProject {
1414                            project_id: project.id.to_proto(),
1415                            worktrees: project.worktrees.clone(),
1416                        },
1417                    )
1418                },
1419            );
1420        }
1421
1422        for project in &rejoined_room.rejoined_projects {
1423            for collaborator in &project.collaborators {
1424                session
1425                    .peer
1426                    .send(
1427                        collaborator.connection_id,
1428                        proto::UpdateProjectCollaborator {
1429                            project_id: project.id.to_proto(),
1430                            old_peer_id: Some(project.old_connection_id.into()),
1431                            new_peer_id: Some(session.connection_id.into()),
1432                        },
1433                    )
1434                    .trace_err();
1435            }
1436        }
1437
1438        for project in &mut rejoined_room.rejoined_projects {
1439            for worktree in mem::take(&mut project.worktrees) {
1440                #[cfg(any(test, feature = "test-support"))]
1441                const MAX_CHUNK_SIZE: usize = 2;
1442                #[cfg(not(any(test, feature = "test-support")))]
1443                const MAX_CHUNK_SIZE: usize = 256;
1444
1445                // Stream this worktree's entries.
1446                let message = proto::UpdateWorktree {
1447                    project_id: project.id.to_proto(),
1448                    worktree_id: worktree.id,
1449                    abs_path: worktree.abs_path.clone(),
1450                    root_name: worktree.root_name,
1451                    updated_entries: worktree.updated_entries,
1452                    removed_entries: worktree.removed_entries,
1453                    scan_id: worktree.scan_id,
1454                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1455                    updated_repositories: worktree.updated_repositories,
1456                    removed_repositories: worktree.removed_repositories,
1457                };
1458                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1459                    session.peer.send(session.connection_id, update.clone())?;
1460                }
1461
1462                // Stream this worktree's diagnostics.
1463                for summary in worktree.diagnostic_summaries {
1464                    session.peer.send(
1465                        session.connection_id,
1466                        proto::UpdateDiagnosticSummary {
1467                            project_id: project.id.to_proto(),
1468                            worktree_id: worktree.id,
1469                            summary: Some(summary),
1470                        },
1471                    )?;
1472                }
1473
1474                for settings_file in worktree.settings_files {
1475                    session.peer.send(
1476                        session.connection_id,
1477                        proto::UpdateWorktreeSettings {
1478                            project_id: project.id.to_proto(),
1479                            worktree_id: worktree.id,
1480                            path: settings_file.path,
1481                            content: Some(settings_file.content),
1482                        },
1483                    )?;
1484                }
1485            }
1486
1487            for language_server in &project.language_servers {
1488                session.peer.send(
1489                    session.connection_id,
1490                    proto::UpdateLanguageServer {
1491                        project_id: project.id.to_proto(),
1492                        language_server_id: language_server.id,
1493                        variant: Some(
1494                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1495                                proto::LspDiskBasedDiagnosticsUpdated {},
1496                            ),
1497                        ),
1498                    },
1499                )?;
1500            }
1501        }
1502
1503        let rejoined_room = rejoined_room.into_inner();
1504
1505        room = rejoined_room.room;
1506        channel = rejoined_room.channel;
1507    }
1508
1509    if let Some(channel) = channel {
1510        channel_updated(
1511            &channel,
1512            &room,
1513            &session.peer,
1514            &*session.connection_pool().await,
1515        );
1516    }
1517
1518    update_user_contacts(session.user_id(), &session).await?;
1519    Ok(())
1520}
1521
1522/// leave room disconnects from the room.
1523async fn leave_room(
1524    _: proto::LeaveRoom,
1525    response: Response<proto::LeaveRoom>,
1526    session: UserSession,
1527) -> Result<()> {
1528    leave_room_for_session(&session).await?;
1529    response.send(proto::Ack {})?;
1530    Ok(())
1531}
1532
1533/// Updates the permissions of someone else in the room.
1534async fn set_room_participant_role(
1535    request: proto::SetRoomParticipantRole,
1536    response: Response<proto::SetRoomParticipantRole>,
1537    session: UserSession,
1538) -> Result<()> {
1539    let user_id = UserId::from_proto(request.user_id);
1540    let role = ChannelRole::from(request.role());
1541
1542    let (live_kit_room, can_publish) = {
1543        let room = session
1544            .db()
1545            .await
1546            .set_room_participant_role(
1547                session.user_id(),
1548                RoomId::from_proto(request.room_id),
1549                user_id,
1550                role,
1551            )
1552            .await?;
1553
1554        let live_kit_room = room.live_kit_room.clone();
1555        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1556        room_updated(&room, &session.peer);
1557        (live_kit_room, can_publish)
1558    };
1559
1560    if let Some(live_kit) = session.live_kit_client.as_ref() {
1561        live_kit
1562            .update_participant(
1563                live_kit_room.clone(),
1564                request.user_id.to_string(),
1565                live_kit_server::proto::ParticipantPermission {
1566                    can_subscribe: true,
1567                    can_publish,
1568                    can_publish_data: can_publish,
1569                    hidden: false,
1570                    recorder: false,
1571                },
1572            )
1573            .await
1574            .trace_err();
1575    }
1576
1577    response.send(proto::Ack {})?;
1578    Ok(())
1579}
1580
1581/// Call someone else into the current room
1582async fn call(
1583    request: proto::Call,
1584    response: Response<proto::Call>,
1585    session: UserSession,
1586) -> Result<()> {
1587    let room_id = RoomId::from_proto(request.room_id);
1588    let calling_user_id = session.user_id();
1589    let calling_connection_id = session.connection_id;
1590    let called_user_id = UserId::from_proto(request.called_user_id);
1591    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1592    if !session
1593        .db()
1594        .await
1595        .has_contact(calling_user_id, called_user_id)
1596        .await?
1597    {
1598        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1599    }
1600
1601    let incoming_call = {
1602        let (room, incoming_call) = &mut *session
1603            .db()
1604            .await
1605            .call(
1606                room_id,
1607                calling_user_id,
1608                calling_connection_id,
1609                called_user_id,
1610                initial_project_id,
1611            )
1612            .await?;
1613        room_updated(&room, &session.peer);
1614        mem::take(incoming_call)
1615    };
1616    update_user_contacts(called_user_id, &session).await?;
1617
1618    let mut calls = session
1619        .connection_pool()
1620        .await
1621        .user_connection_ids(called_user_id)
1622        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1623        .collect::<FuturesUnordered<_>>();
1624
1625    while let Some(call_response) = calls.next().await {
1626        match call_response.as_ref() {
1627            Ok(_) => {
1628                response.send(proto::Ack {})?;
1629                return Ok(());
1630            }
1631            Err(_) => {
1632                call_response.trace_err();
1633            }
1634        }
1635    }
1636
1637    {
1638        let room = session
1639            .db()
1640            .await
1641            .call_failed(room_id, called_user_id)
1642            .await?;
1643        room_updated(&room, &session.peer);
1644    }
1645    update_user_contacts(called_user_id, &session).await?;
1646
1647    Err(anyhow!("failed to ring user"))?
1648}
1649
1650/// Cancel an outgoing call.
1651async fn cancel_call(
1652    request: proto::CancelCall,
1653    response: Response<proto::CancelCall>,
1654    session: UserSession,
1655) -> Result<()> {
1656    let called_user_id = UserId::from_proto(request.called_user_id);
1657    let room_id = RoomId::from_proto(request.room_id);
1658    {
1659        let room = session
1660            .db()
1661            .await
1662            .cancel_call(room_id, session.connection_id, called_user_id)
1663            .await?;
1664        room_updated(&room, &session.peer);
1665    }
1666
1667    for connection_id in session
1668        .connection_pool()
1669        .await
1670        .user_connection_ids(called_user_id)
1671    {
1672        session
1673            .peer
1674            .send(
1675                connection_id,
1676                proto::CallCanceled {
1677                    room_id: room_id.to_proto(),
1678                },
1679            )
1680            .trace_err();
1681    }
1682    response.send(proto::Ack {})?;
1683
1684    update_user_contacts(called_user_id, &session).await?;
1685    Ok(())
1686}
1687
1688/// Decline an incoming call.
1689async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1690    let room_id = RoomId::from_proto(message.room_id);
1691    {
1692        let room = session
1693            .db()
1694            .await
1695            .decline_call(Some(room_id), session.user_id())
1696            .await?
1697            .ok_or_else(|| anyhow!("failed to decline call"))?;
1698        room_updated(&room, &session.peer);
1699    }
1700
1701    for connection_id in session
1702        .connection_pool()
1703        .await
1704        .user_connection_ids(session.user_id())
1705    {
1706        session
1707            .peer
1708            .send(
1709                connection_id,
1710                proto::CallCanceled {
1711                    room_id: room_id.to_proto(),
1712                },
1713            )
1714            .trace_err();
1715    }
1716    update_user_contacts(session.user_id(), &session).await?;
1717    Ok(())
1718}
1719
1720/// Updates other participants in the room with your current location.
1721async fn update_participant_location(
1722    request: proto::UpdateParticipantLocation,
1723    response: Response<proto::UpdateParticipantLocation>,
1724    session: UserSession,
1725) -> Result<()> {
1726    let room_id = RoomId::from_proto(request.room_id);
1727    let location = request
1728        .location
1729        .ok_or_else(|| anyhow!("invalid location"))?;
1730
1731    let db = session.db().await;
1732    let room = db
1733        .update_room_participant_location(room_id, session.connection_id, location)
1734        .await?;
1735
1736    room_updated(&room, &session.peer);
1737    response.send(proto::Ack {})?;
1738    Ok(())
1739}
1740
1741/// Share a project into the room.
1742async fn share_project(
1743    request: proto::ShareProject,
1744    response: Response<proto::ShareProject>,
1745    session: Session,
1746) -> Result<()> {
1747    let (project_id, room) = &*session
1748        .db()
1749        .await
1750        .share_project(
1751            RoomId::from_proto(request.room_id),
1752            session.connection_id,
1753            &request.worktrees,
1754        )
1755        .await?;
1756    response.send(proto::ShareProjectResponse {
1757        project_id: project_id.to_proto(),
1758    })?;
1759    room_updated(&room, &session.peer);
1760
1761    Ok(())
1762}
1763
1764/// Unshare a project from the room.
1765async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1766    let project_id = ProjectId::from_proto(message.project_id);
1767
1768    let (room, guest_connection_ids) = &*session
1769        .db()
1770        .await
1771        .unshare_project(project_id, session.connection_id)
1772        .await?;
1773
1774    broadcast(
1775        Some(session.connection_id),
1776        guest_connection_ids.iter().copied(),
1777        |conn_id| session.peer.send(conn_id, message.clone()),
1778    );
1779    room_updated(&room, &session.peer);
1780
1781    Ok(())
1782}
1783
1784/// Join someone elses shared project.
1785async fn join_project(
1786    request: proto::JoinProject,
1787    response: Response<proto::JoinProject>,
1788    session: UserSession,
1789) -> Result<()> {
1790    let project_id = ProjectId::from_proto(request.project_id);
1791
1792    tracing::info!(%project_id, "join project");
1793
1794    let (project, replica_id) = &mut *session
1795        .db()
1796        .await
1797        .join_project_in_room(project_id, session.connection_id)
1798        .await?;
1799
1800    join_project_internal(response, session, project, replica_id)
1801}
1802
1803trait JoinProjectInternalResponse {
1804    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1805}
1806impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1807    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1808        Response::<proto::JoinProject>::send(self, result)
1809    }
1810}
1811impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1812    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1813        Response::<proto::JoinHostedProject>::send(self, result)
1814    }
1815}
1816
1817fn join_project_internal(
1818    response: impl JoinProjectInternalResponse,
1819    session: UserSession,
1820    project: &mut Project,
1821    replica_id: &ReplicaId,
1822) -> Result<()> {
1823    let collaborators = project
1824        .collaborators
1825        .iter()
1826        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1827        .map(|collaborator| collaborator.to_proto())
1828        .collect::<Vec<_>>();
1829    let project_id = project.id;
1830    let guest_user_id = session.user_id();
1831
1832    let worktrees = project
1833        .worktrees
1834        .iter()
1835        .map(|(id, worktree)| proto::WorktreeMetadata {
1836            id: *id,
1837            root_name: worktree.root_name.clone(),
1838            visible: worktree.visible,
1839            abs_path: worktree.abs_path.clone(),
1840        })
1841        .collect::<Vec<_>>();
1842
1843    for collaborator in &collaborators {
1844        session
1845            .peer
1846            .send(
1847                collaborator.peer_id.unwrap().into(),
1848                proto::AddProjectCollaborator {
1849                    project_id: project_id.to_proto(),
1850                    collaborator: Some(proto::Collaborator {
1851                        peer_id: Some(session.connection_id.into()),
1852                        replica_id: replica_id.0 as u32,
1853                        user_id: guest_user_id.to_proto(),
1854                    }),
1855                },
1856            )
1857            .trace_err();
1858    }
1859
1860    // First, we send the metadata associated with each worktree.
1861    response.send(proto::JoinProjectResponse {
1862        project_id: project.id.0 as u64,
1863        worktrees: worktrees.clone(),
1864        replica_id: replica_id.0 as u32,
1865        collaborators: collaborators.clone(),
1866        language_servers: project.language_servers.clone(),
1867        role: project.role.into(), // todo
1868    })?;
1869
1870    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1871        #[cfg(any(test, feature = "test-support"))]
1872        const MAX_CHUNK_SIZE: usize = 2;
1873        #[cfg(not(any(test, feature = "test-support")))]
1874        const MAX_CHUNK_SIZE: usize = 256;
1875
1876        // Stream this worktree's entries.
1877        let message = proto::UpdateWorktree {
1878            project_id: project_id.to_proto(),
1879            worktree_id,
1880            abs_path: worktree.abs_path.clone(),
1881            root_name: worktree.root_name,
1882            updated_entries: worktree.entries,
1883            removed_entries: Default::default(),
1884            scan_id: worktree.scan_id,
1885            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1886            updated_repositories: worktree.repository_entries.into_values().collect(),
1887            removed_repositories: Default::default(),
1888        };
1889        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1890            session.peer.send(session.connection_id, update.clone())?;
1891        }
1892
1893        // Stream this worktree's diagnostics.
1894        for summary in worktree.diagnostic_summaries {
1895            session.peer.send(
1896                session.connection_id,
1897                proto::UpdateDiagnosticSummary {
1898                    project_id: project_id.to_proto(),
1899                    worktree_id: worktree.id,
1900                    summary: Some(summary),
1901                },
1902            )?;
1903        }
1904
1905        for settings_file in worktree.settings_files {
1906            session.peer.send(
1907                session.connection_id,
1908                proto::UpdateWorktreeSettings {
1909                    project_id: project_id.to_proto(),
1910                    worktree_id: worktree.id,
1911                    path: settings_file.path,
1912                    content: Some(settings_file.content),
1913                },
1914            )?;
1915        }
1916    }
1917
1918    for language_server in &project.language_servers {
1919        session.peer.send(
1920            session.connection_id,
1921            proto::UpdateLanguageServer {
1922                project_id: project_id.to_proto(),
1923                language_server_id: language_server.id,
1924                variant: Some(
1925                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1926                        proto::LspDiskBasedDiagnosticsUpdated {},
1927                    ),
1928                ),
1929            },
1930        )?;
1931    }
1932
1933    Ok(())
1934}
1935
1936/// Leave someone elses shared project.
1937async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1938    let sender_id = session.connection_id;
1939    let project_id = ProjectId::from_proto(request.project_id);
1940    let db = session.db().await;
1941    if db.is_hosted_project(project_id).await? {
1942        let project = db.leave_hosted_project(project_id, sender_id).await?;
1943        project_left(&project, &session);
1944        return Ok(());
1945    }
1946
1947    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1948    tracing::info!(
1949        %project_id,
1950        host_user_id = ?project.host_user_id,
1951        host_connection_id = ?project.host_connection_id,
1952        "leave project"
1953    );
1954
1955    project_left(&project, &session);
1956    room_updated(&room, &session.peer);
1957
1958    Ok(())
1959}
1960
1961async fn join_hosted_project(
1962    request: proto::JoinHostedProject,
1963    response: Response<proto::JoinHostedProject>,
1964    session: UserSession,
1965) -> Result<()> {
1966    let (mut project, replica_id) = session
1967        .db()
1968        .await
1969        .join_hosted_project(
1970            ProjectId(request.project_id as i32),
1971            session.user_id(),
1972            session.connection_id,
1973        )
1974        .await?;
1975
1976    join_project_internal(response, session, &mut project, &replica_id)
1977}
1978
1979/// Updates other participants with changes to the project
1980async fn update_project(
1981    request: proto::UpdateProject,
1982    response: Response<proto::UpdateProject>,
1983    session: Session,
1984) -> Result<()> {
1985    let project_id = ProjectId::from_proto(request.project_id);
1986    let (room, guest_connection_ids) = &*session
1987        .db()
1988        .await
1989        .update_project(project_id, session.connection_id, &request.worktrees)
1990        .await?;
1991    broadcast(
1992        Some(session.connection_id),
1993        guest_connection_ids.iter().copied(),
1994        |connection_id| {
1995            session
1996                .peer
1997                .forward_send(session.connection_id, connection_id, request.clone())
1998        },
1999    );
2000    room_updated(&room, &session.peer);
2001    response.send(proto::Ack {})?;
2002
2003    Ok(())
2004}
2005
2006/// Updates other participants with changes to the worktree
2007async fn update_worktree(
2008    request: proto::UpdateWorktree,
2009    response: Response<proto::UpdateWorktree>,
2010    session: Session,
2011) -> Result<()> {
2012    let guest_connection_ids = session
2013        .db()
2014        .await
2015        .update_worktree(&request, session.connection_id)
2016        .await?;
2017
2018    broadcast(
2019        Some(session.connection_id),
2020        guest_connection_ids.iter().copied(),
2021        |connection_id| {
2022            session
2023                .peer
2024                .forward_send(session.connection_id, connection_id, request.clone())
2025        },
2026    );
2027    response.send(proto::Ack {})?;
2028    Ok(())
2029}
2030
2031/// Updates other participants with changes to the diagnostics
2032async fn update_diagnostic_summary(
2033    message: proto::UpdateDiagnosticSummary,
2034    session: Session,
2035) -> Result<()> {
2036    let guest_connection_ids = session
2037        .db()
2038        .await
2039        .update_diagnostic_summary(&message, session.connection_id)
2040        .await?;
2041
2042    broadcast(
2043        Some(session.connection_id),
2044        guest_connection_ids.iter().copied(),
2045        |connection_id| {
2046            session
2047                .peer
2048                .forward_send(session.connection_id, connection_id, message.clone())
2049        },
2050    );
2051
2052    Ok(())
2053}
2054
2055/// Updates other participants with changes to the worktree settings
2056async fn update_worktree_settings(
2057    message: proto::UpdateWorktreeSettings,
2058    session: Session,
2059) -> Result<()> {
2060    let guest_connection_ids = session
2061        .db()
2062        .await
2063        .update_worktree_settings(&message, session.connection_id)
2064        .await?;
2065
2066    broadcast(
2067        Some(session.connection_id),
2068        guest_connection_ids.iter().copied(),
2069        |connection_id| {
2070            session
2071                .peer
2072                .forward_send(session.connection_id, connection_id, message.clone())
2073        },
2074    );
2075
2076    Ok(())
2077}
2078
2079/// Notify other participants that a  language server has started.
2080async fn start_language_server(
2081    request: proto::StartLanguageServer,
2082    session: Session,
2083) -> Result<()> {
2084    let guest_connection_ids = session
2085        .db()
2086        .await
2087        .start_language_server(&request, session.connection_id)
2088        .await?;
2089
2090    broadcast(
2091        Some(session.connection_id),
2092        guest_connection_ids.iter().copied(),
2093        |connection_id| {
2094            session
2095                .peer
2096                .forward_send(session.connection_id, connection_id, request.clone())
2097        },
2098    );
2099    Ok(())
2100}
2101
2102/// Notify other participants that a language server has changed.
2103async fn update_language_server(
2104    request: proto::UpdateLanguageServer,
2105    session: Session,
2106) -> Result<()> {
2107    let project_id = ProjectId::from_proto(request.project_id);
2108    let project_connection_ids = session
2109        .db()
2110        .await
2111        .project_connection_ids(project_id, session.connection_id)
2112        .await?;
2113    broadcast(
2114        Some(session.connection_id),
2115        project_connection_ids.iter().copied(),
2116        |connection_id| {
2117            session
2118                .peer
2119                .forward_send(session.connection_id, connection_id, request.clone())
2120        },
2121    );
2122    Ok(())
2123}
2124
2125/// forward a project request to the host. These requests should be read only
2126/// as guests are allowed to send them.
2127async fn forward_read_only_project_request<T>(
2128    request: T,
2129    response: Response<T>,
2130    session: Session,
2131) -> Result<()>
2132where
2133    T: EntityMessage + RequestMessage,
2134{
2135    let project_id = ProjectId::from_proto(request.remote_entity_id());
2136    let host_connection_id = session
2137        .db()
2138        .await
2139        .host_for_read_only_project_request(project_id, session.connection_id)
2140        .await?;
2141    let payload = session
2142        .peer
2143        .forward_request(session.connection_id, host_connection_id, request)
2144        .await?;
2145    response.send(payload)?;
2146    Ok(())
2147}
2148
2149/// forward a project request to the host. These requests are disallowed
2150/// for guests.
2151async fn forward_mutating_project_request<T>(
2152    request: T,
2153    response: Response<T>,
2154    session: Session,
2155) -> Result<()>
2156where
2157    T: EntityMessage + RequestMessage,
2158{
2159    let project_id = ProjectId::from_proto(request.remote_entity_id());
2160    let host_connection_id = session
2161        .db()
2162        .await
2163        .host_for_mutating_project_request(project_id, session.connection_id)
2164        .await?;
2165    let payload = session
2166        .peer
2167        .forward_request(session.connection_id, host_connection_id, request)
2168        .await?;
2169    response.send(payload)?;
2170    Ok(())
2171}
2172
2173/// Notify other participants that a new buffer has been created
2174async fn create_buffer_for_peer(
2175    request: proto::CreateBufferForPeer,
2176    session: Session,
2177) -> Result<()> {
2178    session
2179        .db()
2180        .await
2181        .check_user_is_project_host(
2182            ProjectId::from_proto(request.project_id),
2183            session.connection_id,
2184        )
2185        .await?;
2186    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2187    session
2188        .peer
2189        .forward_send(session.connection_id, peer_id.into(), request)?;
2190    Ok(())
2191}
2192
2193/// Notify other participants that a buffer has been updated. This is
2194/// allowed for guests as long as the update is limited to selections.
2195async fn update_buffer(
2196    request: proto::UpdateBuffer,
2197    response: Response<proto::UpdateBuffer>,
2198    session: Session,
2199) -> Result<()> {
2200    let project_id = ProjectId::from_proto(request.project_id);
2201    let mut guest_connection_ids;
2202    let mut host_connection_id = None;
2203
2204    let mut requires_write_permission = false;
2205
2206    for op in request.operations.iter() {
2207        match op.variant {
2208            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2209            Some(_) => requires_write_permission = true,
2210        }
2211    }
2212
2213    {
2214        let collaborators = session
2215            .db()
2216            .await
2217            .project_collaborators_for_buffer_update(
2218                project_id,
2219                session.connection_id,
2220                requires_write_permission,
2221            )
2222            .await?;
2223        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2224        for collaborator in collaborators.iter() {
2225            if collaborator.is_host {
2226                host_connection_id = Some(collaborator.connection_id);
2227            } else {
2228                guest_connection_ids.push(collaborator.connection_id);
2229            }
2230        }
2231    }
2232    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2233
2234    broadcast(
2235        Some(session.connection_id),
2236        guest_connection_ids,
2237        |connection_id| {
2238            session
2239                .peer
2240                .forward_send(session.connection_id, connection_id, request.clone())
2241        },
2242    );
2243    if host_connection_id != session.connection_id {
2244        session
2245            .peer
2246            .forward_request(session.connection_id, host_connection_id, request.clone())
2247            .await?;
2248    }
2249
2250    response.send(proto::Ack {})?;
2251    Ok(())
2252}
2253
2254/// Notify other participants that a project has been updated.
2255async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2256    request: T,
2257    session: Session,
2258) -> Result<()> {
2259    let project_id = ProjectId::from_proto(request.remote_entity_id());
2260    let project_connection_ids = session
2261        .db()
2262        .await
2263        .project_connection_ids(project_id, session.connection_id)
2264        .await?;
2265
2266    broadcast(
2267        Some(session.connection_id),
2268        project_connection_ids.iter().copied(),
2269        |connection_id| {
2270            session
2271                .peer
2272                .forward_send(session.connection_id, connection_id, request.clone())
2273        },
2274    );
2275    Ok(())
2276}
2277
2278/// Start following another user in a call.
2279async fn follow(
2280    request: proto::Follow,
2281    response: Response<proto::Follow>,
2282    session: UserSession,
2283) -> Result<()> {
2284    let room_id = RoomId::from_proto(request.room_id);
2285    let project_id = request.project_id.map(ProjectId::from_proto);
2286    let leader_id = request
2287        .leader_id
2288        .ok_or_else(|| anyhow!("invalid leader id"))?
2289        .into();
2290    let follower_id = session.connection_id;
2291
2292    session
2293        .db()
2294        .await
2295        .check_room_participants(room_id, leader_id, session.connection_id)
2296        .await?;
2297
2298    let response_payload = session
2299        .peer
2300        .forward_request(session.connection_id, leader_id, request)
2301        .await?;
2302    response.send(response_payload)?;
2303
2304    if let Some(project_id) = project_id {
2305        let room = session
2306            .db()
2307            .await
2308            .follow(room_id, project_id, leader_id, follower_id)
2309            .await?;
2310        room_updated(&room, &session.peer);
2311    }
2312
2313    Ok(())
2314}
2315
2316/// Stop following another user in a call.
2317async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2318    let room_id = RoomId::from_proto(request.room_id);
2319    let project_id = request.project_id.map(ProjectId::from_proto);
2320    let leader_id = request
2321        .leader_id
2322        .ok_or_else(|| anyhow!("invalid leader id"))?
2323        .into();
2324    let follower_id = session.connection_id;
2325
2326    session
2327        .db()
2328        .await
2329        .check_room_participants(room_id, leader_id, session.connection_id)
2330        .await?;
2331
2332    session
2333        .peer
2334        .forward_send(session.connection_id, leader_id, request)?;
2335
2336    if let Some(project_id) = project_id {
2337        let room = session
2338            .db()
2339            .await
2340            .unfollow(room_id, project_id, leader_id, follower_id)
2341            .await?;
2342        room_updated(&room, &session.peer);
2343    }
2344
2345    Ok(())
2346}
2347
2348/// Notify everyone following you of your current location.
2349async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2350    let room_id = RoomId::from_proto(request.room_id);
2351    let database = session.db.lock().await;
2352
2353    let connection_ids = if let Some(project_id) = request.project_id {
2354        let project_id = ProjectId::from_proto(project_id);
2355        database
2356            .project_connection_ids(project_id, session.connection_id)
2357            .await?
2358    } else {
2359        database
2360            .room_connection_ids(room_id, session.connection_id)
2361            .await?
2362    };
2363
2364    // For now, don't send view update messages back to that view's current leader.
2365    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2366        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2367        _ => None,
2368    });
2369
2370    for connection_id in connection_ids.iter().cloned() {
2371        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2372            session
2373                .peer
2374                .forward_send(session.connection_id, connection_id, request.clone())?;
2375        }
2376    }
2377    Ok(())
2378}
2379
2380/// Get public data about users.
2381async fn get_users(
2382    request: proto::GetUsers,
2383    response: Response<proto::GetUsers>,
2384    session: Session,
2385) -> Result<()> {
2386    let user_ids = request
2387        .user_ids
2388        .into_iter()
2389        .map(UserId::from_proto)
2390        .collect();
2391    let users = session
2392        .db()
2393        .await
2394        .get_users_by_ids(user_ids)
2395        .await?
2396        .into_iter()
2397        .map(|user| proto::User {
2398            id: user.id.to_proto(),
2399            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2400            github_login: user.github_login,
2401        })
2402        .collect();
2403    response.send(proto::UsersResponse { users })?;
2404    Ok(())
2405}
2406
2407/// Search for users (to invite) buy Github login
2408async fn fuzzy_search_users(
2409    request: proto::FuzzySearchUsers,
2410    response: Response<proto::FuzzySearchUsers>,
2411    session: UserSession,
2412) -> Result<()> {
2413    let query = request.query;
2414    let users = match query.len() {
2415        0 => vec![],
2416        1 | 2 => session
2417            .db()
2418            .await
2419            .get_user_by_github_login(&query)
2420            .await?
2421            .into_iter()
2422            .collect(),
2423        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2424    };
2425    let users = users
2426        .into_iter()
2427        .filter(|user| user.id != session.user_id())
2428        .map(|user| proto::User {
2429            id: user.id.to_proto(),
2430            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2431            github_login: user.github_login,
2432        })
2433        .collect();
2434    response.send(proto::UsersResponse { users })?;
2435    Ok(())
2436}
2437
2438/// Send a contact request to another user.
2439async fn request_contact(
2440    request: proto::RequestContact,
2441    response: Response<proto::RequestContact>,
2442    session: UserSession,
2443) -> Result<()> {
2444    let requester_id = session.user_id();
2445    let responder_id = UserId::from_proto(request.responder_id);
2446    if requester_id == responder_id {
2447        return Err(anyhow!("cannot add yourself as a contact"))?;
2448    }
2449
2450    let notifications = session
2451        .db()
2452        .await
2453        .send_contact_request(requester_id, responder_id)
2454        .await?;
2455
2456    // Update outgoing contact requests of requester
2457    let mut update = proto::UpdateContacts::default();
2458    update.outgoing_requests.push(responder_id.to_proto());
2459    for connection_id in session
2460        .connection_pool()
2461        .await
2462        .user_connection_ids(requester_id)
2463    {
2464        session.peer.send(connection_id, update.clone())?;
2465    }
2466
2467    // Update incoming contact requests of responder
2468    let mut update = proto::UpdateContacts::default();
2469    update
2470        .incoming_requests
2471        .push(proto::IncomingContactRequest {
2472            requester_id: requester_id.to_proto(),
2473        });
2474    let connection_pool = session.connection_pool().await;
2475    for connection_id in connection_pool.user_connection_ids(responder_id) {
2476        session.peer.send(connection_id, update.clone())?;
2477    }
2478
2479    send_notifications(&connection_pool, &session.peer, notifications);
2480
2481    response.send(proto::Ack {})?;
2482    Ok(())
2483}
2484
2485/// Accept or decline a contact request
2486async fn respond_to_contact_request(
2487    request: proto::RespondToContactRequest,
2488    response: Response<proto::RespondToContactRequest>,
2489    session: UserSession,
2490) -> Result<()> {
2491    let responder_id = session.user_id();
2492    let requester_id = UserId::from_proto(request.requester_id);
2493    let db = session.db().await;
2494    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2495        db.dismiss_contact_notification(responder_id, requester_id)
2496            .await?;
2497    } else {
2498        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2499
2500        let notifications = db
2501            .respond_to_contact_request(responder_id, requester_id, accept)
2502            .await?;
2503        let requester_busy = db.is_user_busy(requester_id).await?;
2504        let responder_busy = db.is_user_busy(responder_id).await?;
2505
2506        let pool = session.connection_pool().await;
2507        // Update responder with new contact
2508        let mut update = proto::UpdateContacts::default();
2509        if accept {
2510            update
2511                .contacts
2512                .push(contact_for_user(requester_id, requester_busy, &pool));
2513        }
2514        update
2515            .remove_incoming_requests
2516            .push(requester_id.to_proto());
2517        for connection_id in pool.user_connection_ids(responder_id) {
2518            session.peer.send(connection_id, update.clone())?;
2519        }
2520
2521        // Update requester with new contact
2522        let mut update = proto::UpdateContacts::default();
2523        if accept {
2524            update
2525                .contacts
2526                .push(contact_for_user(responder_id, responder_busy, &pool));
2527        }
2528        update
2529            .remove_outgoing_requests
2530            .push(responder_id.to_proto());
2531
2532        for connection_id in pool.user_connection_ids(requester_id) {
2533            session.peer.send(connection_id, update.clone())?;
2534        }
2535
2536        send_notifications(&pool, &session.peer, notifications);
2537    }
2538
2539    response.send(proto::Ack {})?;
2540    Ok(())
2541}
2542
2543/// Remove a contact.
2544async fn remove_contact(
2545    request: proto::RemoveContact,
2546    response: Response<proto::RemoveContact>,
2547    session: UserSession,
2548) -> Result<()> {
2549    let requester_id = session.user_id();
2550    let responder_id = UserId::from_proto(request.user_id);
2551    let db = session.db().await;
2552    let (contact_accepted, deleted_notification_id) =
2553        db.remove_contact(requester_id, responder_id).await?;
2554
2555    let pool = session.connection_pool().await;
2556    // Update outgoing contact requests of requester
2557    let mut update = proto::UpdateContacts::default();
2558    if contact_accepted {
2559        update.remove_contacts.push(responder_id.to_proto());
2560    } else {
2561        update
2562            .remove_outgoing_requests
2563            .push(responder_id.to_proto());
2564    }
2565    for connection_id in pool.user_connection_ids(requester_id) {
2566        session.peer.send(connection_id, update.clone())?;
2567    }
2568
2569    // Update incoming contact requests of responder
2570    let mut update = proto::UpdateContacts::default();
2571    if contact_accepted {
2572        update.remove_contacts.push(requester_id.to_proto());
2573    } else {
2574        update
2575            .remove_incoming_requests
2576            .push(requester_id.to_proto());
2577    }
2578    for connection_id in pool.user_connection_ids(responder_id) {
2579        session.peer.send(connection_id, update.clone())?;
2580        if let Some(notification_id) = deleted_notification_id {
2581            session.peer.send(
2582                connection_id,
2583                proto::DeleteNotification {
2584                    notification_id: notification_id.to_proto(),
2585                },
2586            )?;
2587        }
2588    }
2589
2590    response.send(proto::Ack {})?;
2591    Ok(())
2592}
2593
2594/// Creates a new channel.
2595async fn create_channel(
2596    request: proto::CreateChannel,
2597    response: Response<proto::CreateChannel>,
2598    session: UserSession,
2599) -> Result<()> {
2600    let db = session.db().await;
2601
2602    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2603    let (channel, membership) = db
2604        .create_channel(&request.name, parent_id, session.user_id())
2605        .await?;
2606
2607    let root_id = channel.root_id();
2608    let channel = Channel::from_model(channel);
2609
2610    response.send(proto::CreateChannelResponse {
2611        channel: Some(channel.to_proto()),
2612        parent_id: request.parent_id,
2613    })?;
2614
2615    let mut connection_pool = session.connection_pool().await;
2616    if let Some(membership) = membership {
2617        connection_pool.subscribe_to_channel(
2618            membership.user_id,
2619            membership.channel_id,
2620            membership.role,
2621        );
2622        let update = proto::UpdateUserChannels {
2623            channel_memberships: vec![proto::ChannelMembership {
2624                channel_id: membership.channel_id.to_proto(),
2625                role: membership.role.into(),
2626            }],
2627            ..Default::default()
2628        };
2629        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2630            session.peer.send(connection_id, update.clone())?;
2631        }
2632    }
2633
2634    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2635        if !role.can_see_channel(channel.visibility) {
2636            continue;
2637        }
2638
2639        let update = proto::UpdateChannels {
2640            channels: vec![channel.to_proto()],
2641            ..Default::default()
2642        };
2643        session.peer.send(connection_id, update.clone())?;
2644    }
2645
2646    Ok(())
2647}
2648
2649/// Delete a channel
2650async fn delete_channel(
2651    request: proto::DeleteChannel,
2652    response: Response<proto::DeleteChannel>,
2653    session: UserSession,
2654) -> Result<()> {
2655    let db = session.db().await;
2656
2657    let channel_id = request.channel_id;
2658    let (root_channel, removed_channels) = db
2659        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2660        .await?;
2661    response.send(proto::Ack {})?;
2662
2663    // Notify members of removed channels
2664    let mut update = proto::UpdateChannels::default();
2665    update
2666        .delete_channels
2667        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2668
2669    let connection_pool = session.connection_pool().await;
2670    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2671        session.peer.send(connection_id, update.clone())?;
2672    }
2673
2674    Ok(())
2675}
2676
2677/// Invite someone to join a channel.
2678async fn invite_channel_member(
2679    request: proto::InviteChannelMember,
2680    response: Response<proto::InviteChannelMember>,
2681    session: UserSession,
2682) -> Result<()> {
2683    let db = session.db().await;
2684    let channel_id = ChannelId::from_proto(request.channel_id);
2685    let invitee_id = UserId::from_proto(request.user_id);
2686    let InviteMemberResult {
2687        channel,
2688        notifications,
2689    } = db
2690        .invite_channel_member(
2691            channel_id,
2692            invitee_id,
2693            session.user_id(),
2694            request.role().into(),
2695        )
2696        .await?;
2697
2698    let update = proto::UpdateChannels {
2699        channel_invitations: vec![channel.to_proto()],
2700        ..Default::default()
2701    };
2702
2703    let connection_pool = session.connection_pool().await;
2704    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2705        session.peer.send(connection_id, update.clone())?;
2706    }
2707
2708    send_notifications(&connection_pool, &session.peer, notifications);
2709
2710    response.send(proto::Ack {})?;
2711    Ok(())
2712}
2713
2714/// remove someone from a channel
2715async fn remove_channel_member(
2716    request: proto::RemoveChannelMember,
2717    response: Response<proto::RemoveChannelMember>,
2718    session: UserSession,
2719) -> Result<()> {
2720    let db = session.db().await;
2721    let channel_id = ChannelId::from_proto(request.channel_id);
2722    let member_id = UserId::from_proto(request.user_id);
2723
2724    let RemoveChannelMemberResult {
2725        membership_update,
2726        notification_id,
2727    } = db
2728        .remove_channel_member(channel_id, member_id, session.user_id())
2729        .await?;
2730
2731    let mut connection_pool = session.connection_pool().await;
2732    notify_membership_updated(
2733        &mut connection_pool,
2734        membership_update,
2735        member_id,
2736        &session.peer,
2737    );
2738    for connection_id in connection_pool.user_connection_ids(member_id) {
2739        if let Some(notification_id) = notification_id {
2740            session
2741                .peer
2742                .send(
2743                    connection_id,
2744                    proto::DeleteNotification {
2745                        notification_id: notification_id.to_proto(),
2746                    },
2747                )
2748                .trace_err();
2749        }
2750    }
2751
2752    response.send(proto::Ack {})?;
2753    Ok(())
2754}
2755
2756/// Toggle the channel between public and private.
2757/// Care is taken to maintain the invariant that public channels only descend from public channels,
2758/// (though members-only channels can appear at any point in the hierarchy).
2759async fn set_channel_visibility(
2760    request: proto::SetChannelVisibility,
2761    response: Response<proto::SetChannelVisibility>,
2762    session: UserSession,
2763) -> Result<()> {
2764    let db = session.db().await;
2765    let channel_id = ChannelId::from_proto(request.channel_id);
2766    let visibility = request.visibility().into();
2767
2768    let channel_model = db
2769        .set_channel_visibility(channel_id, visibility, session.user_id())
2770        .await?;
2771    let root_id = channel_model.root_id();
2772    let channel = Channel::from_model(channel_model);
2773
2774    let mut connection_pool = session.connection_pool().await;
2775    for (user_id, role) in connection_pool
2776        .channel_user_ids(root_id)
2777        .collect::<Vec<_>>()
2778        .into_iter()
2779    {
2780        let update = if role.can_see_channel(channel.visibility) {
2781            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2782            proto::UpdateChannels {
2783                channels: vec![channel.to_proto()],
2784                ..Default::default()
2785            }
2786        } else {
2787            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2788            proto::UpdateChannels {
2789                delete_channels: vec![channel.id.to_proto()],
2790                ..Default::default()
2791            }
2792        };
2793
2794        for connection_id in connection_pool.user_connection_ids(user_id) {
2795            session.peer.send(connection_id, update.clone())?;
2796        }
2797    }
2798
2799    response.send(proto::Ack {})?;
2800    Ok(())
2801}
2802
2803/// Alter the role for a user in the channel.
2804async fn set_channel_member_role(
2805    request: proto::SetChannelMemberRole,
2806    response: Response<proto::SetChannelMemberRole>,
2807    session: UserSession,
2808) -> Result<()> {
2809    let db = session.db().await;
2810    let channel_id = ChannelId::from_proto(request.channel_id);
2811    let member_id = UserId::from_proto(request.user_id);
2812    let result = db
2813        .set_channel_member_role(
2814            channel_id,
2815            session.user_id(),
2816            member_id,
2817            request.role().into(),
2818        )
2819        .await?;
2820
2821    match result {
2822        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2823            let mut connection_pool = session.connection_pool().await;
2824            notify_membership_updated(
2825                &mut connection_pool,
2826                membership_update,
2827                member_id,
2828                &session.peer,
2829            )
2830        }
2831        db::SetMemberRoleResult::InviteUpdated(channel) => {
2832            let update = proto::UpdateChannels {
2833                channel_invitations: vec![channel.to_proto()],
2834                ..Default::default()
2835            };
2836
2837            for connection_id in session
2838                .connection_pool()
2839                .await
2840                .user_connection_ids(member_id)
2841            {
2842                session.peer.send(connection_id, update.clone())?;
2843            }
2844        }
2845    }
2846
2847    response.send(proto::Ack {})?;
2848    Ok(())
2849}
2850
2851/// Change the name of a channel
2852async fn rename_channel(
2853    request: proto::RenameChannel,
2854    response: Response<proto::RenameChannel>,
2855    session: UserSession,
2856) -> Result<()> {
2857    let db = session.db().await;
2858    let channel_id = ChannelId::from_proto(request.channel_id);
2859    let channel_model = db
2860        .rename_channel(channel_id, session.user_id(), &request.name)
2861        .await?;
2862    let root_id = channel_model.root_id();
2863    let channel = Channel::from_model(channel_model);
2864
2865    response.send(proto::RenameChannelResponse {
2866        channel: Some(channel.to_proto()),
2867    })?;
2868
2869    let connection_pool = session.connection_pool().await;
2870    let update = proto::UpdateChannels {
2871        channels: vec![channel.to_proto()],
2872        ..Default::default()
2873    };
2874    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2875        if role.can_see_channel(channel.visibility) {
2876            session.peer.send(connection_id, update.clone())?;
2877        }
2878    }
2879
2880    Ok(())
2881}
2882
2883/// Move a channel to a new parent.
2884async fn move_channel(
2885    request: proto::MoveChannel,
2886    response: Response<proto::MoveChannel>,
2887    session: UserSession,
2888) -> Result<()> {
2889    let channel_id = ChannelId::from_proto(request.channel_id);
2890    let to = ChannelId::from_proto(request.to);
2891
2892    let (root_id, channels) = session
2893        .db()
2894        .await
2895        .move_channel(channel_id, to, session.user_id())
2896        .await?;
2897
2898    let connection_pool = session.connection_pool().await;
2899    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2900        let channels = channels
2901            .iter()
2902            .filter_map(|channel| {
2903                if role.can_see_channel(channel.visibility) {
2904                    Some(channel.to_proto())
2905                } else {
2906                    None
2907                }
2908            })
2909            .collect::<Vec<_>>();
2910        if channels.is_empty() {
2911            continue;
2912        }
2913
2914        let update = proto::UpdateChannels {
2915            channels,
2916            ..Default::default()
2917        };
2918
2919        session.peer.send(connection_id, update.clone())?;
2920    }
2921
2922    response.send(Ack {})?;
2923    Ok(())
2924}
2925
2926/// Get the list of channel members
2927async fn get_channel_members(
2928    request: proto::GetChannelMembers,
2929    response: Response<proto::GetChannelMembers>,
2930    session: UserSession,
2931) -> Result<()> {
2932    let db = session.db().await;
2933    let channel_id = ChannelId::from_proto(request.channel_id);
2934    let members = db
2935        .get_channel_participant_details(channel_id, session.user_id())
2936        .await?;
2937    response.send(proto::GetChannelMembersResponse { members })?;
2938    Ok(())
2939}
2940
2941/// Accept or decline a channel invitation.
2942async fn respond_to_channel_invite(
2943    request: proto::RespondToChannelInvite,
2944    response: Response<proto::RespondToChannelInvite>,
2945    session: UserSession,
2946) -> Result<()> {
2947    let db = session.db().await;
2948    let channel_id = ChannelId::from_proto(request.channel_id);
2949    let RespondToChannelInvite {
2950        membership_update,
2951        notifications,
2952    } = db
2953        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2954        .await?;
2955
2956    let mut connection_pool = session.connection_pool().await;
2957    if let Some(membership_update) = membership_update {
2958        notify_membership_updated(
2959            &mut connection_pool,
2960            membership_update,
2961            session.user_id(),
2962            &session.peer,
2963        );
2964    } else {
2965        let update = proto::UpdateChannels {
2966            remove_channel_invitations: vec![channel_id.to_proto()],
2967            ..Default::default()
2968        };
2969
2970        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2971            session.peer.send(connection_id, update.clone())?;
2972        }
2973    };
2974
2975    send_notifications(&connection_pool, &session.peer, notifications);
2976
2977    response.send(proto::Ack {})?;
2978
2979    Ok(())
2980}
2981
2982/// Join the channels' room
2983async fn join_channel(
2984    request: proto::JoinChannel,
2985    response: Response<proto::JoinChannel>,
2986    session: UserSession,
2987) -> Result<()> {
2988    let channel_id = ChannelId::from_proto(request.channel_id);
2989    join_channel_internal(channel_id, Box::new(response), session).await
2990}
2991
2992trait JoinChannelInternalResponse {
2993    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2994}
2995impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2996    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2997        Response::<proto::JoinChannel>::send(self, result)
2998    }
2999}
3000impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3001    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3002        Response::<proto::JoinRoom>::send(self, result)
3003    }
3004}
3005
3006async fn join_channel_internal(
3007    channel_id: ChannelId,
3008    response: Box<impl JoinChannelInternalResponse>,
3009    session: UserSession,
3010) -> Result<()> {
3011    let joined_room = {
3012        leave_room_for_session(&session).await?;
3013        let db = session.db().await;
3014
3015        let (joined_room, membership_updated, role) = db
3016            .join_channel(channel_id, session.user_id(), session.connection_id)
3017            .await?;
3018
3019        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3020            let (can_publish, token) = if role == ChannelRole::Guest {
3021                (
3022                    false,
3023                    live_kit
3024                        .guest_token(
3025                            &joined_room.room.live_kit_room,
3026                            &session.user_id().to_string(),
3027                        )
3028                        .trace_err()?,
3029                )
3030            } else {
3031                (
3032                    true,
3033                    live_kit
3034                        .room_token(
3035                            &joined_room.room.live_kit_room,
3036                            &session.user_id().to_string(),
3037                        )
3038                        .trace_err()?,
3039                )
3040            };
3041
3042            Some(LiveKitConnectionInfo {
3043                server_url: live_kit.url().into(),
3044                token,
3045                can_publish,
3046            })
3047        });
3048
3049        response.send(proto::JoinRoomResponse {
3050            room: Some(joined_room.room.clone()),
3051            channel_id: joined_room
3052                .channel
3053                .as_ref()
3054                .map(|channel| channel.id.to_proto()),
3055            live_kit_connection_info,
3056        })?;
3057
3058        let mut connection_pool = session.connection_pool().await;
3059        if let Some(membership_updated) = membership_updated {
3060            notify_membership_updated(
3061                &mut connection_pool,
3062                membership_updated,
3063                session.user_id(),
3064                &session.peer,
3065            );
3066        }
3067
3068        room_updated(&joined_room.room, &session.peer);
3069
3070        joined_room
3071    };
3072
3073    channel_updated(
3074        &joined_room
3075            .channel
3076            .ok_or_else(|| anyhow!("channel not returned"))?,
3077        &joined_room.room,
3078        &session.peer,
3079        &*session.connection_pool().await,
3080    );
3081
3082    update_user_contacts(session.user_id(), &session).await?;
3083    Ok(())
3084}
3085
3086/// Start editing the channel notes
3087async fn join_channel_buffer(
3088    request: proto::JoinChannelBuffer,
3089    response: Response<proto::JoinChannelBuffer>,
3090    session: UserSession,
3091) -> Result<()> {
3092    let db = session.db().await;
3093    let channel_id = ChannelId::from_proto(request.channel_id);
3094
3095    let open_response = db
3096        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3097        .await?;
3098
3099    let collaborators = open_response.collaborators.clone();
3100    response.send(open_response)?;
3101
3102    let update = UpdateChannelBufferCollaborators {
3103        channel_id: channel_id.to_proto(),
3104        collaborators: collaborators.clone(),
3105    };
3106    channel_buffer_updated(
3107        session.connection_id,
3108        collaborators
3109            .iter()
3110            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3111        &update,
3112        &session.peer,
3113    );
3114
3115    Ok(())
3116}
3117
3118/// Edit the channel notes
3119async fn update_channel_buffer(
3120    request: proto::UpdateChannelBuffer,
3121    session: UserSession,
3122) -> Result<()> {
3123    let db = session.db().await;
3124    let channel_id = ChannelId::from_proto(request.channel_id);
3125
3126    let (collaborators, non_collaborators, epoch, version) = db
3127        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3128        .await?;
3129
3130    channel_buffer_updated(
3131        session.connection_id,
3132        collaborators,
3133        &proto::UpdateChannelBuffer {
3134            channel_id: channel_id.to_proto(),
3135            operations: request.operations,
3136        },
3137        &session.peer,
3138    );
3139
3140    let pool = &*session.connection_pool().await;
3141
3142    broadcast(
3143        None,
3144        non_collaborators
3145            .iter()
3146            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3147        |peer_id| {
3148            session.peer.send(
3149                peer_id,
3150                proto::UpdateChannels {
3151                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3152                        channel_id: channel_id.to_proto(),
3153                        epoch: epoch as u64,
3154                        version: version.clone(),
3155                    }],
3156                    ..Default::default()
3157                },
3158            )
3159        },
3160    );
3161
3162    Ok(())
3163}
3164
3165/// Rejoin the channel notes after a connection blip
3166async fn rejoin_channel_buffers(
3167    request: proto::RejoinChannelBuffers,
3168    response: Response<proto::RejoinChannelBuffers>,
3169    session: UserSession,
3170) -> Result<()> {
3171    let db = session.db().await;
3172    let buffers = db
3173        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3174        .await?;
3175
3176    for rejoined_buffer in &buffers {
3177        let collaborators_to_notify = rejoined_buffer
3178            .buffer
3179            .collaborators
3180            .iter()
3181            .filter_map(|c| Some(c.peer_id?.into()));
3182        channel_buffer_updated(
3183            session.connection_id,
3184            collaborators_to_notify,
3185            &proto::UpdateChannelBufferCollaborators {
3186                channel_id: rejoined_buffer.buffer.channel_id,
3187                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3188            },
3189            &session.peer,
3190        );
3191    }
3192
3193    response.send(proto::RejoinChannelBuffersResponse {
3194        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3195    })?;
3196
3197    Ok(())
3198}
3199
3200/// Stop editing the channel notes
3201async fn leave_channel_buffer(
3202    request: proto::LeaveChannelBuffer,
3203    response: Response<proto::LeaveChannelBuffer>,
3204    session: UserSession,
3205) -> Result<()> {
3206    let db = session.db().await;
3207    let channel_id = ChannelId::from_proto(request.channel_id);
3208
3209    let left_buffer = db
3210        .leave_channel_buffer(channel_id, session.connection_id)
3211        .await?;
3212
3213    response.send(Ack {})?;
3214
3215    channel_buffer_updated(
3216        session.connection_id,
3217        left_buffer.connections,
3218        &proto::UpdateChannelBufferCollaborators {
3219            channel_id: channel_id.to_proto(),
3220            collaborators: left_buffer.collaborators,
3221        },
3222        &session.peer,
3223    );
3224
3225    Ok(())
3226}
3227
3228fn channel_buffer_updated<T: EnvelopedMessage>(
3229    sender_id: ConnectionId,
3230    collaborators: impl IntoIterator<Item = ConnectionId>,
3231    message: &T,
3232    peer: &Peer,
3233) {
3234    broadcast(Some(sender_id), collaborators, |peer_id| {
3235        peer.send(peer_id, message.clone())
3236    });
3237}
3238
3239fn send_notifications(
3240    connection_pool: &ConnectionPool,
3241    peer: &Peer,
3242    notifications: db::NotificationBatch,
3243) {
3244    for (user_id, notification) in notifications {
3245        for connection_id in connection_pool.user_connection_ids(user_id) {
3246            if let Err(error) = peer.send(
3247                connection_id,
3248                proto::AddNotification {
3249                    notification: Some(notification.clone()),
3250                },
3251            ) {
3252                tracing::error!(
3253                    "failed to send notification to {:?} {}",
3254                    connection_id,
3255                    error
3256                );
3257            }
3258        }
3259    }
3260}
3261
3262/// Send a message to the channel
3263async fn send_channel_message(
3264    request: proto::SendChannelMessage,
3265    response: Response<proto::SendChannelMessage>,
3266    session: UserSession,
3267) -> Result<()> {
3268    // Validate the message body.
3269    let body = request.body.trim().to_string();
3270    if body.len() > MAX_MESSAGE_LEN {
3271        return Err(anyhow!("message is too long"))?;
3272    }
3273    if body.is_empty() {
3274        return Err(anyhow!("message can't be blank"))?;
3275    }
3276
3277    // TODO: adjust mentions if body is trimmed
3278
3279    let timestamp = OffsetDateTime::now_utc();
3280    let nonce = request
3281        .nonce
3282        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3283
3284    let channel_id = ChannelId::from_proto(request.channel_id);
3285    let CreatedChannelMessage {
3286        message_id,
3287        participant_connection_ids,
3288        channel_members,
3289        notifications,
3290    } = session
3291        .db()
3292        .await
3293        .create_channel_message(
3294            channel_id,
3295            session.user_id(),
3296            &body,
3297            &request.mentions,
3298            timestamp,
3299            nonce.clone().into(),
3300            match request.reply_to_message_id {
3301                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3302                None => None,
3303            },
3304        )
3305        .await?;
3306
3307    let message = proto::ChannelMessage {
3308        sender_id: session.user_id().to_proto(),
3309        id: message_id.to_proto(),
3310        body,
3311        mentions: request.mentions,
3312        timestamp: timestamp.unix_timestamp() as u64,
3313        nonce: Some(nonce),
3314        reply_to_message_id: request.reply_to_message_id,
3315        edited_at: None,
3316    };
3317    broadcast(
3318        Some(session.connection_id),
3319        participant_connection_ids,
3320        |connection| {
3321            session.peer.send(
3322                connection,
3323                proto::ChannelMessageSent {
3324                    channel_id: channel_id.to_proto(),
3325                    message: Some(message.clone()),
3326                },
3327            )
3328        },
3329    );
3330    response.send(proto::SendChannelMessageResponse {
3331        message: Some(message),
3332    })?;
3333
3334    let pool = &*session.connection_pool().await;
3335    broadcast(
3336        None,
3337        channel_members
3338            .iter()
3339            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3340        |peer_id| {
3341            session.peer.send(
3342                peer_id,
3343                proto::UpdateChannels {
3344                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3345                        channel_id: channel_id.to_proto(),
3346                        message_id: message_id.to_proto(),
3347                    }],
3348                    ..Default::default()
3349                },
3350            )
3351        },
3352    );
3353    send_notifications(pool, &session.peer, notifications);
3354
3355    Ok(())
3356}
3357
3358/// Delete a channel message
3359async fn remove_channel_message(
3360    request: proto::RemoveChannelMessage,
3361    response: Response<proto::RemoveChannelMessage>,
3362    session: UserSession,
3363) -> Result<()> {
3364    let channel_id = ChannelId::from_proto(request.channel_id);
3365    let message_id = MessageId::from_proto(request.message_id);
3366    let connection_ids = session
3367        .db()
3368        .await
3369        .remove_channel_message(channel_id, message_id, session.user_id())
3370        .await?;
3371    broadcast(Some(session.connection_id), connection_ids, |connection| {
3372        session.peer.send(connection, request.clone())
3373    });
3374    response.send(proto::Ack {})?;
3375    Ok(())
3376}
3377
3378async fn update_channel_message(
3379    request: proto::UpdateChannelMessage,
3380    response: Response<proto::UpdateChannelMessage>,
3381    session: UserSession,
3382) -> Result<()> {
3383    let channel_id = ChannelId::from_proto(request.channel_id);
3384    let message_id = MessageId::from_proto(request.message_id);
3385    let updated_at = OffsetDateTime::now_utc();
3386    let UpdatedChannelMessage {
3387        message_id,
3388        participant_connection_ids,
3389        notifications,
3390        reply_to_message_id,
3391        timestamp,
3392    } = session
3393        .db()
3394        .await
3395        .update_channel_message(
3396            channel_id,
3397            message_id,
3398            session.user_id(),
3399            request.body.as_str(),
3400            &request.mentions,
3401            updated_at,
3402        )
3403        .await?;
3404
3405    let nonce = request
3406        .nonce
3407        .clone()
3408        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3409
3410    let message = proto::ChannelMessage {
3411        sender_id: session.user_id().to_proto(),
3412        id: message_id.to_proto(),
3413        body: request.body.clone(),
3414        mentions: request.mentions.clone(),
3415        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3416        nonce: Some(nonce),
3417        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3418        edited_at: Some(updated_at.unix_timestamp() as u64),
3419    };
3420
3421    response.send(proto::Ack {})?;
3422
3423    let pool = &*session.connection_pool().await;
3424    broadcast(
3425        Some(session.connection_id),
3426        participant_connection_ids,
3427        |connection| {
3428            session.peer.send(
3429                connection,
3430                proto::ChannelMessageUpdate {
3431                    channel_id: channel_id.to_proto(),
3432                    message: Some(message.clone()),
3433                },
3434            )
3435        },
3436    );
3437
3438    send_notifications(pool, &session.peer, notifications);
3439
3440    Ok(())
3441}
3442
3443/// Mark a channel message as read
3444async fn acknowledge_channel_message(
3445    request: proto::AckChannelMessage,
3446    session: UserSession,
3447) -> Result<()> {
3448    let channel_id = ChannelId::from_proto(request.channel_id);
3449    let message_id = MessageId::from_proto(request.message_id);
3450    let notifications = session
3451        .db()
3452        .await
3453        .observe_channel_message(channel_id, session.user_id(), message_id)
3454        .await?;
3455    send_notifications(
3456        &*session.connection_pool().await,
3457        &session.peer,
3458        notifications,
3459    );
3460    Ok(())
3461}
3462
3463/// Mark a buffer version as synced
3464async fn acknowledge_buffer_version(
3465    request: proto::AckBufferOperation,
3466    session: UserSession,
3467) -> Result<()> {
3468    let buffer_id = BufferId::from_proto(request.buffer_id);
3469    session
3470        .db()
3471        .await
3472        .observe_buffer_version(
3473            buffer_id,
3474            session.user_id(),
3475            request.epoch as i32,
3476            &request.version,
3477        )
3478        .await?;
3479    Ok(())
3480}
3481
3482struct CompleteWithLanguageModelRateLimit;
3483
3484impl RateLimit for CompleteWithLanguageModelRateLimit {
3485    fn capacity() -> usize {
3486        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3487            .ok()
3488            .and_then(|v| v.parse().ok())
3489            .unwrap_or(120) // Picked arbitrarily
3490    }
3491
3492    fn refill_duration() -> chrono::Duration {
3493        chrono::Duration::hours(1)
3494    }
3495
3496    fn db_name() -> &'static str {
3497        "complete-with-language-model"
3498    }
3499}
3500
3501async fn complete_with_language_model(
3502    request: proto::CompleteWithLanguageModel,
3503    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3504    session: Session,
3505    open_ai_api_key: Option<Arc<str>>,
3506    google_ai_api_key: Option<Arc<str>>,
3507) -> Result<()> {
3508    let Some(session) = session.for_user() else {
3509        return Err(anyhow!("user not found"))?;
3510    };
3511    authorize_access_to_language_models(&session).await?;
3512    session
3513        .rate_limiter
3514        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3515        .await?;
3516
3517    if request.model.starts_with("gpt") {
3518        let api_key =
3519            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3520        complete_with_open_ai(request, response, session, api_key).await?;
3521    } else if request.model.starts_with("gemini") {
3522        let api_key = google_ai_api_key
3523            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3524        complete_with_google_ai(request, response, session, api_key).await?;
3525    }
3526
3527    Ok(())
3528}
3529
3530async fn complete_with_open_ai(
3531    request: proto::CompleteWithLanguageModel,
3532    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3533    session: UserSession,
3534    api_key: Arc<str>,
3535) -> Result<()> {
3536    const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3537
3538    let mut completion_stream = open_ai::stream_completion(
3539        &session.http_client,
3540        OPEN_AI_API_URL,
3541        &api_key,
3542        crate::ai::language_model_request_to_open_ai(request)?,
3543    )
3544    .await
3545    .context("open_ai::stream_completion request failed")?;
3546
3547    while let Some(event) = completion_stream.next().await {
3548        let event = event?;
3549        response.send(proto::LanguageModelResponse {
3550            choices: event
3551                .choices
3552                .into_iter()
3553                .map(|choice| proto::LanguageModelChoiceDelta {
3554                    index: choice.index,
3555                    delta: Some(proto::LanguageModelResponseMessage {
3556                        role: choice.delta.role.map(|role| match role {
3557                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3558                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3559                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3560                        } as i32),
3561                        content: choice.delta.content,
3562                    }),
3563                    finish_reason: choice.finish_reason,
3564                })
3565                .collect(),
3566        })?;
3567    }
3568
3569    Ok(())
3570}
3571
3572async fn complete_with_google_ai(
3573    request: proto::CompleteWithLanguageModel,
3574    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3575    session: UserSession,
3576    api_key: Arc<str>,
3577) -> Result<()> {
3578    let mut stream = google_ai::stream_generate_content(
3579        &session.http_client,
3580        google_ai::API_URL,
3581        api_key.as_ref(),
3582        crate::ai::language_model_request_to_google_ai(request)?,
3583    )
3584    .await
3585    .context("google_ai::stream_generate_content request failed")?;
3586
3587    while let Some(event) = stream.next().await {
3588        let event = event?;
3589        response.send(proto::LanguageModelResponse {
3590            choices: event
3591                .candidates
3592                .unwrap_or_default()
3593                .into_iter()
3594                .map(|candidate| proto::LanguageModelChoiceDelta {
3595                    index: candidate.index as u32,
3596                    delta: Some(proto::LanguageModelResponseMessage {
3597                        role: Some(match candidate.content.role {
3598                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3599                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3600                        } as i32),
3601                        content: Some(
3602                            candidate
3603                                .content
3604                                .parts
3605                                .into_iter()
3606                                .filter_map(|part| match part {
3607                                    google_ai::Part::TextPart(part) => Some(part.text),
3608                                    google_ai::Part::InlineDataPart(_) => None,
3609                                })
3610                                .collect(),
3611                        ),
3612                    }),
3613                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3614                })
3615                .collect(),
3616        })?;
3617    }
3618
3619    Ok(())
3620}
3621
3622struct CountTokensWithLanguageModelRateLimit;
3623
3624impl RateLimit for CountTokensWithLanguageModelRateLimit {
3625    fn capacity() -> usize {
3626        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3627            .ok()
3628            .and_then(|v| v.parse().ok())
3629            .unwrap_or(600) // Picked arbitrarily
3630    }
3631
3632    fn refill_duration() -> chrono::Duration {
3633        chrono::Duration::hours(1)
3634    }
3635
3636    fn db_name() -> &'static str {
3637        "count-tokens-with-language-model"
3638    }
3639}
3640
3641async fn count_tokens_with_language_model(
3642    request: proto::CountTokensWithLanguageModel,
3643    response: Response<proto::CountTokensWithLanguageModel>,
3644    session: UserSession,
3645    google_ai_api_key: Option<Arc<str>>,
3646) -> Result<()> {
3647    authorize_access_to_language_models(&session).await?;
3648
3649    if !request.model.starts_with("gemini") {
3650        return Err(anyhow!(
3651            "counting tokens for model: {:?} is not supported",
3652            request.model
3653        ))?;
3654    }
3655
3656    session
3657        .rate_limiter
3658        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3659        .await?;
3660
3661    let api_key = google_ai_api_key
3662        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3663    let tokens_response = google_ai::count_tokens(
3664        &session.http_client,
3665        google_ai::API_URL,
3666        &api_key,
3667        crate::ai::count_tokens_request_to_google_ai(request)?,
3668    )
3669    .await?;
3670    response.send(proto::CountTokensResponse {
3671        token_count: tokens_response.total_tokens as u32,
3672    })?;
3673    Ok(())
3674}
3675
3676async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3677    let db = session.db().await;
3678    let flags = db.get_user_flags(session.user_id()).await?;
3679    if flags.iter().any(|flag| flag == "language-models") {
3680        Ok(())
3681    } else {
3682        Err(anyhow!("permission denied"))?
3683    }
3684}
3685
3686/// Start receiving chat updates for a channel
3687async fn join_channel_chat(
3688    request: proto::JoinChannelChat,
3689    response: Response<proto::JoinChannelChat>,
3690    session: UserSession,
3691) -> Result<()> {
3692    let channel_id = ChannelId::from_proto(request.channel_id);
3693
3694    let db = session.db().await;
3695    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3696        .await?;
3697    let messages = db
3698        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3699        .await?;
3700    response.send(proto::JoinChannelChatResponse {
3701        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3702        messages,
3703    })?;
3704    Ok(())
3705}
3706
3707/// Stop receiving chat updates for a channel
3708async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3709    let channel_id = ChannelId::from_proto(request.channel_id);
3710    session
3711        .db()
3712        .await
3713        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3714        .await?;
3715    Ok(())
3716}
3717
3718/// Retrieve the chat history for a channel
3719async fn get_channel_messages(
3720    request: proto::GetChannelMessages,
3721    response: Response<proto::GetChannelMessages>,
3722    session: UserSession,
3723) -> Result<()> {
3724    let channel_id = ChannelId::from_proto(request.channel_id);
3725    let messages = session
3726        .db()
3727        .await
3728        .get_channel_messages(
3729            channel_id,
3730            session.user_id(),
3731            MESSAGE_COUNT_PER_PAGE,
3732            Some(MessageId::from_proto(request.before_message_id)),
3733        )
3734        .await?;
3735    response.send(proto::GetChannelMessagesResponse {
3736        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3737        messages,
3738    })?;
3739    Ok(())
3740}
3741
3742/// Retrieve specific chat messages
3743async fn get_channel_messages_by_id(
3744    request: proto::GetChannelMessagesById,
3745    response: Response<proto::GetChannelMessagesById>,
3746    session: UserSession,
3747) -> Result<()> {
3748    let message_ids = request
3749        .message_ids
3750        .iter()
3751        .map(|id| MessageId::from_proto(*id))
3752        .collect::<Vec<_>>();
3753    let messages = session
3754        .db()
3755        .await
3756        .get_channel_messages_by_id(session.user_id(), &message_ids)
3757        .await?;
3758    response.send(proto::GetChannelMessagesResponse {
3759        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3760        messages,
3761    })?;
3762    Ok(())
3763}
3764
3765/// Retrieve the current users notifications
3766async fn get_notifications(
3767    request: proto::GetNotifications,
3768    response: Response<proto::GetNotifications>,
3769    session: UserSession,
3770) -> Result<()> {
3771    let notifications = session
3772        .db()
3773        .await
3774        .get_notifications(
3775            session.user_id(),
3776            NOTIFICATION_COUNT_PER_PAGE,
3777            request
3778                .before_id
3779                .map(|id| db::NotificationId::from_proto(id)),
3780        )
3781        .await?;
3782    response.send(proto::GetNotificationsResponse {
3783        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3784        notifications,
3785    })?;
3786    Ok(())
3787}
3788
3789/// Mark notifications as read
3790async fn mark_notification_as_read(
3791    request: proto::MarkNotificationRead,
3792    response: Response<proto::MarkNotificationRead>,
3793    session: UserSession,
3794) -> Result<()> {
3795    let database = &session.db().await;
3796    let notifications = database
3797        .mark_notification_as_read_by_id(
3798            session.user_id(),
3799            NotificationId::from_proto(request.notification_id),
3800        )
3801        .await?;
3802    send_notifications(
3803        &*session.connection_pool().await,
3804        &session.peer,
3805        notifications,
3806    );
3807    response.send(proto::Ack {})?;
3808    Ok(())
3809}
3810
3811/// Get the current users information
3812async fn get_private_user_info(
3813    _request: proto::GetPrivateUserInfo,
3814    response: Response<proto::GetPrivateUserInfo>,
3815    session: UserSession,
3816) -> Result<()> {
3817    let db = session.db().await;
3818
3819    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3820    let user = db
3821        .get_user_by_id(session.user_id())
3822        .await?
3823        .ok_or_else(|| anyhow!("user not found"))?;
3824    let flags = db.get_user_flags(session.user_id()).await?;
3825
3826    response.send(proto::GetPrivateUserInfoResponse {
3827        metrics_id,
3828        staff: user.admin,
3829        flags,
3830    })?;
3831    Ok(())
3832}
3833
3834fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3835    match message {
3836        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3837        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3838        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3839        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3840        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3841            code: frame.code.into(),
3842            reason: frame.reason,
3843        })),
3844    }
3845}
3846
3847fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3848    match message {
3849        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3850        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3851        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3852        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3853        AxumMessage::Close(frame) => {
3854            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3855                code: frame.code.into(),
3856                reason: frame.reason,
3857            }))
3858        }
3859    }
3860}
3861
3862fn notify_membership_updated(
3863    connection_pool: &mut ConnectionPool,
3864    result: MembershipUpdated,
3865    user_id: UserId,
3866    peer: &Peer,
3867) {
3868    for membership in &result.new_channels.channel_memberships {
3869        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3870    }
3871    for channel_id in &result.removed_channels {
3872        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3873    }
3874
3875    let user_channels_update = proto::UpdateUserChannels {
3876        channel_memberships: result
3877            .new_channels
3878            .channel_memberships
3879            .iter()
3880            .map(|cm| proto::ChannelMembership {
3881                channel_id: cm.channel_id.to_proto(),
3882                role: cm.role.into(),
3883            })
3884            .collect(),
3885        ..Default::default()
3886    };
3887
3888    let mut update = build_channels_update(result.new_channels, vec![]);
3889    update.delete_channels = result
3890        .removed_channels
3891        .into_iter()
3892        .map(|id| id.to_proto())
3893        .collect();
3894    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3895
3896    for connection_id in connection_pool.user_connection_ids(user_id) {
3897        peer.send(connection_id, user_channels_update.clone())
3898            .trace_err();
3899        peer.send(connection_id, update.clone()).trace_err();
3900    }
3901}
3902
3903fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3904    proto::UpdateUserChannels {
3905        channel_memberships: channels
3906            .channel_memberships
3907            .iter()
3908            .map(|m| proto::ChannelMembership {
3909                channel_id: m.channel_id.to_proto(),
3910                role: m.role.into(),
3911            })
3912            .collect(),
3913        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3914        observed_channel_message_id: channels.observed_channel_messages.clone(),
3915    }
3916}
3917
3918fn build_channels_update(
3919    channels: ChannelsForUser,
3920    channel_invites: Vec<db::Channel>,
3921) -> proto::UpdateChannels {
3922    let mut update = proto::UpdateChannels::default();
3923
3924    for channel in channels.channels {
3925        update.channels.push(channel.to_proto());
3926    }
3927
3928    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3929    update.latest_channel_message_ids = channels.latest_channel_messages;
3930
3931    for (channel_id, participants) in channels.channel_participants {
3932        update
3933            .channel_participants
3934            .push(proto::ChannelParticipants {
3935                channel_id: channel_id.to_proto(),
3936                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3937            });
3938    }
3939
3940    for channel in channel_invites {
3941        update.channel_invitations.push(channel.to_proto());
3942    }
3943    for project in channels.hosted_projects {
3944        update.hosted_projects.push(project);
3945    }
3946
3947    update
3948}
3949
3950fn build_initial_contacts_update(
3951    contacts: Vec<db::Contact>,
3952    pool: &ConnectionPool,
3953) -> proto::UpdateContacts {
3954    let mut update = proto::UpdateContacts::default();
3955
3956    for contact in contacts {
3957        match contact {
3958            db::Contact::Accepted { user_id, busy } => {
3959                update.contacts.push(contact_for_user(user_id, busy, &pool));
3960            }
3961            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3962            db::Contact::Incoming { user_id } => {
3963                update
3964                    .incoming_requests
3965                    .push(proto::IncomingContactRequest {
3966                        requester_id: user_id.to_proto(),
3967                    })
3968            }
3969        }
3970    }
3971
3972    update
3973}
3974
3975fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3976    proto::Contact {
3977        user_id: user_id.to_proto(),
3978        online: pool.is_user_online(user_id),
3979        busy,
3980    }
3981}
3982
3983fn room_updated(room: &proto::Room, peer: &Peer) {
3984    broadcast(
3985        None,
3986        room.participants
3987            .iter()
3988            .filter_map(|participant| Some(participant.peer_id?.into())),
3989        |peer_id| {
3990            peer.send(
3991                peer_id,
3992                proto::RoomUpdated {
3993                    room: Some(room.clone()),
3994                },
3995            )
3996        },
3997    );
3998}
3999
4000fn channel_updated(
4001    channel: &db::channel::Model,
4002    room: &proto::Room,
4003    peer: &Peer,
4004    pool: &ConnectionPool,
4005) {
4006    let participants = room
4007        .participants
4008        .iter()
4009        .map(|p| p.user_id)
4010        .collect::<Vec<_>>();
4011
4012    broadcast(
4013        None,
4014        pool.channel_connection_ids(channel.root_id())
4015            .filter_map(|(channel_id, role)| {
4016                role.can_see_channel(channel.visibility).then(|| channel_id)
4017            }),
4018        |peer_id| {
4019            peer.send(
4020                peer_id,
4021                proto::UpdateChannels {
4022                    channel_participants: vec![proto::ChannelParticipants {
4023                        channel_id: channel.id.to_proto(),
4024                        participant_user_ids: participants.clone(),
4025                    }],
4026                    ..Default::default()
4027                },
4028            )
4029        },
4030    );
4031}
4032
4033async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4034    let db = session.db().await;
4035
4036    let contacts = db.get_contacts(user_id).await?;
4037    let busy = db.is_user_busy(user_id).await?;
4038
4039    let pool = session.connection_pool().await;
4040    let updated_contact = contact_for_user(user_id, busy, &pool);
4041    for contact in contacts {
4042        if let db::Contact::Accepted {
4043            user_id: contact_user_id,
4044            ..
4045        } = contact
4046        {
4047            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4048                session
4049                    .peer
4050                    .send(
4051                        contact_conn_id,
4052                        proto::UpdateContacts {
4053                            contacts: vec![updated_contact.clone()],
4054                            remove_contacts: Default::default(),
4055                            incoming_requests: Default::default(),
4056                            remove_incoming_requests: Default::default(),
4057                            outgoing_requests: Default::default(),
4058                            remove_outgoing_requests: Default::default(),
4059                        },
4060                    )
4061                    .trace_err();
4062            }
4063        }
4064    }
4065    Ok(())
4066}
4067
4068async fn leave_room_for_session(session: &UserSession) -> Result<()> {
4069    let mut contacts_to_update = HashSet::default();
4070
4071    let room_id;
4072    let canceled_calls_to_user_ids;
4073    let live_kit_room;
4074    let delete_live_kit_room;
4075    let room;
4076    let channel;
4077
4078    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
4079        contacts_to_update.insert(session.user_id());
4080
4081        for project in left_room.left_projects.values() {
4082            project_left(project, session);
4083        }
4084
4085        room_id = RoomId::from_proto(left_room.room.id);
4086        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4087        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4088        delete_live_kit_room = left_room.deleted;
4089        room = mem::take(&mut left_room.room);
4090        channel = mem::take(&mut left_room.channel);
4091
4092        room_updated(&room, &session.peer);
4093    } else {
4094        return Ok(());
4095    }
4096
4097    if let Some(channel) = channel {
4098        channel_updated(
4099            &channel,
4100            &room,
4101            &session.peer,
4102            &*session.connection_pool().await,
4103        );
4104    }
4105
4106    {
4107        let pool = session.connection_pool().await;
4108        for canceled_user_id in canceled_calls_to_user_ids {
4109            for connection_id in pool.user_connection_ids(canceled_user_id) {
4110                session
4111                    .peer
4112                    .send(
4113                        connection_id,
4114                        proto::CallCanceled {
4115                            room_id: room_id.to_proto(),
4116                        },
4117                    )
4118                    .trace_err();
4119            }
4120            contacts_to_update.insert(canceled_user_id);
4121        }
4122    }
4123
4124    for contact_user_id in contacts_to_update {
4125        update_user_contacts(contact_user_id, &session).await?;
4126    }
4127
4128    if let Some(live_kit) = session.live_kit_client.as_ref() {
4129        live_kit
4130            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4131            .await
4132            .trace_err();
4133
4134        if delete_live_kit_room {
4135            live_kit.delete_room(live_kit_room).await.trace_err();
4136        }
4137    }
4138
4139    Ok(())
4140}
4141
4142async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4143    let left_channel_buffers = session
4144        .db()
4145        .await
4146        .leave_channel_buffers(session.connection_id)
4147        .await?;
4148
4149    for left_buffer in left_channel_buffers {
4150        channel_buffer_updated(
4151            session.connection_id,
4152            left_buffer.connections,
4153            &proto::UpdateChannelBufferCollaborators {
4154                channel_id: left_buffer.channel_id.to_proto(),
4155                collaborators: left_buffer.collaborators,
4156            },
4157            &session.peer,
4158        );
4159    }
4160
4161    Ok(())
4162}
4163
4164fn project_left(project: &db::LeftProject, session: &UserSession) {
4165    for connection_id in &project.connection_ids {
4166        if project.host_user_id == Some(session.user_id()) {
4167            session
4168                .peer
4169                .send(
4170                    *connection_id,
4171                    proto::UnshareProject {
4172                        project_id: project.id.to_proto(),
4173                    },
4174                )
4175                .trace_err();
4176        } else {
4177            session
4178                .peer
4179                .send(
4180                    *connection_id,
4181                    proto::RemoveProjectCollaborator {
4182                        project_id: project.id.to_proto(),
4183                        peer_id: Some(session.connection_id.into()),
4184                    },
4185                )
4186                .trace_err();
4187        }
4188    }
4189}
4190
4191pub trait ResultExt {
4192    type Ok;
4193
4194    fn trace_err(self) -> Option<Self::Ok>;
4195}
4196
4197impl<T, E> ResultExt for Result<T, E>
4198where
4199    E: std::fmt::Debug,
4200{
4201    type Ok = T;
4202
4203    fn trace_err(self) -> Option<T> {
4204        match self {
4205            Ok(value) => Some(value),
4206            Err(error) => {
4207                tracing::error!("{:?}", error);
4208                None
4209            }
4210        }
4211    }
4212}