rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self},
   5    db::{
   6        self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
   9        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, RateLimit, RateLimiter, Result,
  13};
  14use anyhow::{anyhow, Context as _};
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34
  35use futures::{
  36    channel::oneshot,
  37    future::{self, BoxFuture},
  38    stream::FuturesUnordered,
  39    FutureExt, SinkExt, StreamExt, TryStreamExt,
  40};
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  45        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48};
  49use serde::{Serialize, Serializer};
  50use std::{
  51    any::TypeId,
  52    future::Future,
  53    marker::PhantomData,
  54    mem,
  55    net::SocketAddr,
  56    ops::{Deref, DerefMut},
  57    rc::Rc,
  58    sync::{
  59        atomic::{AtomicBool, Ordering::SeqCst},
  60        Arc, OnceLock,
  61    },
  62    time::{Duration, Instant},
  63};
  64use time::OffsetDateTime;
  65use tokio::sync::{watch, Semaphore};
  66use tower::ServiceBuilder;
  67use tracing::{
  68    field::{self},
  69    info_span, instrument, Instrument,
  70};
  71use util::{http::IsahcHttpClient, SemanticVersion};
  72
  73pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  74
  75// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  76pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  77
  78const MESSAGE_COUNT_PER_PAGE: usize = 100;
  79const MAX_MESSAGE_LEN: usize = 1024;
  80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  81
  82type MessageHandler =
  83    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  84
  85struct Response<R> {
  86    peer: Arc<Peer>,
  87    receipt: Receipt<R>,
  88    responded: Arc<AtomicBool>,
  89}
  90
  91impl<R: RequestMessage> Response<R> {
  92    fn send(self, payload: R::Response) -> Result<()> {
  93        self.responded.store(true, SeqCst);
  94        self.peer.respond(self.receipt, payload)?;
  95        Ok(())
  96    }
  97}
  98
  99struct StreamingResponse<R: RequestMessage> {
 100    peer: Arc<Peer>,
 101    receipt: Receipt<R>,
 102}
 103
 104impl<R: RequestMessage> StreamingResponse<R> {
 105    fn send(&self, payload: R::Response) -> Result<()> {
 106        self.peer.respond(self.receipt, payload)?;
 107        Ok(())
 108    }
 109}
 110
 111#[derive(Clone, Debug)]
 112pub enum Principal {
 113    User(User),
 114    Impersonated { user: User, admin: User },
 115    DevServer(dev_server::Model),
 116}
 117
 118impl Principal {
 119    fn update_span(&self, span: &tracing::Span) {
 120        match &self {
 121            Principal::User(user) => {
 122                span.record("user_id", &user.id.0);
 123                span.record("login", &user.github_login);
 124            }
 125            Principal::Impersonated { user, admin } => {
 126                span.record("user_id", &user.id.0);
 127                span.record("login", &user.github_login);
 128                span.record("impersonator", &admin.github_login);
 129            }
 130            Principal::DevServer(dev_server) => {
 131                span.record("dev_server_id", &dev_server.id.0);
 132            }
 133        }
 134    }
 135}
 136
 137#[derive(Clone)]
 138struct Session {
 139    principal: Principal,
 140    connection_id: ConnectionId,
 141    db: Arc<tokio::sync::Mutex<DbHandle>>,
 142    peer: Arc<Peer>,
 143    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 144    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 145    http_client: IsahcHttpClient,
 146    rate_limiter: Arc<RateLimiter>,
 147    _executor: Executor,
 148}
 149
 150impl Session {
 151    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 152        #[cfg(test)]
 153        tokio::task::yield_now().await;
 154        let guard = self.db.lock().await;
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        guard
 158    }
 159
 160    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 161        #[cfg(test)]
 162        tokio::task::yield_now().await;
 163        let guard = self.connection_pool.lock();
 164        ConnectionPoolGuard {
 165            guard,
 166            _not_send: PhantomData,
 167        }
 168    }
 169
 170    fn for_user(self) -> Option<UserSession> {
 171        UserSession::new(self)
 172    }
 173
 174    fn user_id(&self) -> Option<UserId> {
 175        match &self.principal {
 176            Principal::User(user) => Some(user.id),
 177            Principal::Impersonated { user, .. } => Some(user.id),
 178            Principal::DevServer(_) => None,
 179        }
 180    }
 181}
 182
 183impl Debug for Session {
 184    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 185        let mut result = f.debug_struct("Session");
 186        match &self.principal {
 187            Principal::User(user) => {
 188                result.field("user", &user.github_login);
 189            }
 190            Principal::Impersonated { user, admin } => {
 191                result.field("user", &user.github_login);
 192                result.field("impersonator", &admin.github_login);
 193            }
 194            Principal::DevServer(dev_server) => {
 195                result.field("dev_server", &dev_server.id);
 196            }
 197        }
 198        result.field("connection_id", &self.connection_id).finish()
 199    }
 200}
 201
 202struct UserSession(Session);
 203
 204impl UserSession {
 205    pub fn new(s: Session) -> Option<Self> {
 206        s.user_id().map(|_| UserSession(s))
 207    }
 208    pub fn user_id(&self) -> UserId {
 209        self.0.user_id().unwrap()
 210    }
 211}
 212
 213impl Deref for UserSession {
 214    type Target = Session;
 215
 216    fn deref(&self) -> &Self::Target {
 217        &self.0
 218    }
 219}
 220impl DerefMut for UserSession {
 221    fn deref_mut(&mut self) -> &mut Self::Target {
 222        &mut self.0
 223    }
 224}
 225
 226fn user_handler<M: RequestMessage, Fut>(
 227    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 228) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 229where
 230    Fut: Send + Future<Output = Result<()>>,
 231{
 232    let handler = Arc::new(handler);
 233    move |message, response, session| {
 234        let handler = handler.clone();
 235        Box::pin(async move {
 236            if let Some(user_session) = session.for_user() {
 237                Ok(handler(message, response, user_session).await?)
 238            } else {
 239                Err(Error::Internal(anyhow!("must be a user")))
 240            }
 241        })
 242    }
 243}
 244
 245fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 246    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 247) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 248where
 249    InnertRetFut: Send + Future<Output = Result<()>>,
 250{
 251    let handler = Arc::new(handler);
 252    move |message, session| {
 253        let handler = handler.clone();
 254        Box::pin(async move {
 255            if let Some(user_session) = session.for_user() {
 256                Ok(handler(message, user_session).await?)
 257            } else {
 258                Err(Error::Internal(anyhow!("must be a user")))
 259            }
 260        })
 261    }
 262}
 263
 264struct DbHandle(Arc<Database>);
 265
 266impl Deref for DbHandle {
 267    type Target = Database;
 268
 269    fn deref(&self) -> &Self::Target {
 270        self.0.as_ref()
 271    }
 272}
 273
 274pub struct Server {
 275    id: parking_lot::Mutex<ServerId>,
 276    peer: Arc<Peer>,
 277    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 278    app_state: Arc<AppState>,
 279    handlers: HashMap<TypeId, MessageHandler>,
 280    teardown: watch::Sender<bool>,
 281}
 282
 283pub(crate) struct ConnectionPoolGuard<'a> {
 284    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 285    _not_send: PhantomData<Rc<()>>,
 286}
 287
 288#[derive(Serialize)]
 289pub struct ServerSnapshot<'a> {
 290    peer: &'a Peer,
 291    #[serde(serialize_with = "serialize_deref")]
 292    connection_pool: ConnectionPoolGuard<'a>,
 293}
 294
 295pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 296where
 297    S: Serializer,
 298    T: Deref<Target = U>,
 299    U: Serialize,
 300{
 301    Serialize::serialize(value.deref(), serializer)
 302}
 303
 304impl Server {
 305    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 306        let mut server = Self {
 307            id: parking_lot::Mutex::new(id),
 308            peer: Peer::new(id.0 as u32),
 309            app_state: app_state.clone(),
 310            connection_pool: Default::default(),
 311            handlers: Default::default(),
 312            teardown: watch::channel(false).0,
 313        };
 314
 315        server
 316            .add_request_handler(ping)
 317            .add_request_handler(user_handler(create_room))
 318            .add_request_handler(user_handler(join_room))
 319            .add_request_handler(user_handler(rejoin_room))
 320            .add_request_handler(user_handler(leave_room))
 321            .add_request_handler(user_handler(set_room_participant_role))
 322            .add_request_handler(user_handler(call))
 323            .add_request_handler(user_handler(cancel_call))
 324            .add_message_handler(user_message_handler(decline_call))
 325            .add_request_handler(user_handler(update_participant_location))
 326            .add_request_handler(share_project)
 327            .add_message_handler(unshare_project)
 328            .add_request_handler(user_handler(join_project))
 329            .add_request_handler(user_handler(join_hosted_project))
 330            .add_message_handler(user_message_handler(leave_project))
 331            .add_request_handler(update_project)
 332            .add_request_handler(update_worktree)
 333            .add_message_handler(start_language_server)
 334            .add_message_handler(update_language_server)
 335            .add_message_handler(update_diagnostic_summary)
 336            .add_message_handler(update_worktree_settings)
 337            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 338            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 341            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 342            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 344            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 346            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 347            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 348            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 349            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 350            .add_request_handler(
 351                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 352            )
 353            .add_request_handler(
 354                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 355            )
 356            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 357            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 358            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 359            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 360            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 361            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 362            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 363            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 368            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 369            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 370            .add_message_handler(create_buffer_for_peer)
 371            .add_request_handler(update_buffer)
 372            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 374            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 375            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 376            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 377            .add_request_handler(get_users)
 378            .add_request_handler(user_handler(fuzzy_search_users))
 379            .add_request_handler(user_handler(request_contact))
 380            .add_request_handler(user_handler(remove_contact))
 381            .add_request_handler(user_handler(respond_to_contact_request))
 382            .add_request_handler(user_handler(create_channel))
 383            .add_request_handler(user_handler(delete_channel))
 384            .add_request_handler(user_handler(invite_channel_member))
 385            .add_request_handler(user_handler(remove_channel_member))
 386            .add_request_handler(user_handler(set_channel_member_role))
 387            .add_request_handler(user_handler(set_channel_visibility))
 388            .add_request_handler(user_handler(rename_channel))
 389            .add_request_handler(user_handler(join_channel_buffer))
 390            .add_request_handler(user_handler(leave_channel_buffer))
 391            .add_message_handler(user_message_handler(update_channel_buffer))
 392            .add_request_handler(user_handler(rejoin_channel_buffers))
 393            .add_request_handler(user_handler(get_channel_members))
 394            .add_request_handler(user_handler(respond_to_channel_invite))
 395            .add_request_handler(user_handler(join_channel))
 396            .add_request_handler(user_handler(join_channel_chat))
 397            .add_message_handler(user_message_handler(leave_channel_chat))
 398            .add_request_handler(user_handler(send_channel_message))
 399            .add_request_handler(user_handler(remove_channel_message))
 400            .add_request_handler(user_handler(update_channel_message))
 401            .add_request_handler(user_handler(get_channel_messages))
 402            .add_request_handler(user_handler(get_channel_messages_by_id))
 403            .add_request_handler(user_handler(get_notifications))
 404            .add_request_handler(user_handler(mark_notification_as_read))
 405            .add_request_handler(user_handler(move_channel))
 406            .add_request_handler(user_handler(follow))
 407            .add_message_handler(user_message_handler(unfollow))
 408            .add_message_handler(user_message_handler(update_followers))
 409            .add_request_handler(user_handler(get_private_user_info))
 410            .add_message_handler(user_message_handler(acknowledge_channel_message))
 411            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 412            .add_streaming_request_handler({
 413                let app_state = app_state.clone();
 414                move |request, response, session| {
 415                    complete_with_language_model(
 416                        request,
 417                        response,
 418                        session,
 419                        app_state.config.openai_api_key.clone(),
 420                        app_state.config.google_ai_api_key.clone(),
 421                    )
 422                }
 423            })
 424            .add_request_handler({
 425                let app_state = app_state.clone();
 426                user_handler(move |request, response, session| {
 427                    count_tokens_with_language_model(
 428                        request,
 429                        response,
 430                        session,
 431                        app_state.config.google_ai_api_key.clone(),
 432                    )
 433                })
 434            });
 435
 436        Arc::new(server)
 437    }
 438
 439    pub async fn start(&self) -> Result<()> {
 440        let server_id = *self.id.lock();
 441        let app_state = self.app_state.clone();
 442        let peer = self.peer.clone();
 443        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 444        let pool = self.connection_pool.clone();
 445        let live_kit_client = self.app_state.live_kit_client.clone();
 446
 447        let span = info_span!("start server");
 448        self.app_state.executor.spawn_detached(
 449            async move {
 450                tracing::info!("waiting for cleanup timeout");
 451                timeout.await;
 452                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 453                if let Some((room_ids, channel_ids)) = app_state
 454                    .db
 455                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 456                    .await
 457                    .trace_err()
 458                {
 459                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 460                    tracing::info!(
 461                        stale_channel_buffer_count = channel_ids.len(),
 462                        "retrieved stale channel buffers"
 463                    );
 464
 465                    for channel_id in channel_ids {
 466                        if let Some(refreshed_channel_buffer) = app_state
 467                            .db
 468                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 469                            .await
 470                            .trace_err()
 471                        {
 472                            for connection_id in refreshed_channel_buffer.connection_ids {
 473                                peer.send(
 474                                    connection_id,
 475                                    proto::UpdateChannelBufferCollaborators {
 476                                        channel_id: channel_id.to_proto(),
 477                                        collaborators: refreshed_channel_buffer
 478                                            .collaborators
 479                                            .clone(),
 480                                    },
 481                                )
 482                                .trace_err();
 483                            }
 484                        }
 485                    }
 486
 487                    for room_id in room_ids {
 488                        let mut contacts_to_update = HashSet::default();
 489                        let mut canceled_calls_to_user_ids = Vec::new();
 490                        let mut live_kit_room = String::new();
 491                        let mut delete_live_kit_room = false;
 492
 493                        if let Some(mut refreshed_room) = app_state
 494                            .db
 495                            .clear_stale_room_participants(room_id, server_id)
 496                            .await
 497                            .trace_err()
 498                        {
 499                            tracing::info!(
 500                                room_id = room_id.0,
 501                                new_participant_count = refreshed_room.room.participants.len(),
 502                                "refreshed room"
 503                            );
 504                            room_updated(&refreshed_room.room, &peer);
 505                            if let Some(channel) = refreshed_room.channel.as_ref() {
 506                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 507                            }
 508                            contacts_to_update
 509                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 510                            contacts_to_update
 511                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 512                            canceled_calls_to_user_ids =
 513                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 514                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 515                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 516                        }
 517
 518                        {
 519                            let pool = pool.lock();
 520                            for canceled_user_id in canceled_calls_to_user_ids {
 521                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 522                                    peer.send(
 523                                        connection_id,
 524                                        proto::CallCanceled {
 525                                            room_id: room_id.to_proto(),
 526                                        },
 527                                    )
 528                                    .trace_err();
 529                                }
 530                            }
 531                        }
 532
 533                        for user_id in contacts_to_update {
 534                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 535                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 536                            if let Some((busy, contacts)) = busy.zip(contacts) {
 537                                let pool = pool.lock();
 538                                let updated_contact = contact_for_user(user_id, busy, &pool);
 539                                for contact in contacts {
 540                                    if let db::Contact::Accepted {
 541                                        user_id: contact_user_id,
 542                                        ..
 543                                    } = contact
 544                                    {
 545                                        for contact_conn_id in
 546                                            pool.user_connection_ids(contact_user_id)
 547                                        {
 548                                            peer.send(
 549                                                contact_conn_id,
 550                                                proto::UpdateContacts {
 551                                                    contacts: vec![updated_contact.clone()],
 552                                                    remove_contacts: Default::default(),
 553                                                    incoming_requests: Default::default(),
 554                                                    remove_incoming_requests: Default::default(),
 555                                                    outgoing_requests: Default::default(),
 556                                                    remove_outgoing_requests: Default::default(),
 557                                                },
 558                                            )
 559                                            .trace_err();
 560                                        }
 561                                    }
 562                                }
 563                            }
 564                        }
 565
 566                        if let Some(live_kit) = live_kit_client.as_ref() {
 567                            if delete_live_kit_room {
 568                                live_kit.delete_room(live_kit_room).await.trace_err();
 569                            }
 570                        }
 571                    }
 572                }
 573
 574                app_state
 575                    .db
 576                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 577                    .await
 578                    .trace_err();
 579            }
 580            .instrument(span),
 581        );
 582        Ok(())
 583    }
 584
 585    pub fn teardown(&self) {
 586        self.peer.teardown();
 587        self.connection_pool.lock().reset();
 588        let _ = self.teardown.send(true);
 589    }
 590
 591    #[cfg(test)]
 592    pub fn reset(&self, id: ServerId) {
 593        self.teardown();
 594        *self.id.lock() = id;
 595        self.peer.reset(id.0 as u32);
 596        let _ = self.teardown.send(false);
 597    }
 598
 599    #[cfg(test)]
 600    pub fn id(&self) -> ServerId {
 601        *self.id.lock()
 602    }
 603
 604    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 605    where
 606        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 607        Fut: 'static + Send + Future<Output = Result<()>>,
 608        M: EnvelopedMessage,
 609    {
 610        let prev_handler = self.handlers.insert(
 611            TypeId::of::<M>(),
 612            Box::new(move |envelope, session| {
 613                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 614                let received_at = envelope.received_at;
 615                    tracing::info!(
 616                        "message received"
 617                    );
 618                let start_time = Instant::now();
 619                let future = (handler)(*envelope, session);
 620                async move {
 621                    let result = future.await;
 622                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 623                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 624                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 625                    match result {
 626                        Err(error) => {
 627                            tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
 628                        }
 629                        Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
 630                    }
 631                }
 632                .boxed()
 633            }),
 634        );
 635        if prev_handler.is_some() {
 636            panic!("registered a handler for the same message twice");
 637        }
 638        self
 639    }
 640
 641    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 642    where
 643        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 644        Fut: 'static + Send + Future<Output = Result<()>>,
 645        M: EnvelopedMessage,
 646    {
 647        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 648        self
 649    }
 650
 651    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 652    where
 653        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 654        Fut: Send + Future<Output = Result<()>>,
 655        M: RequestMessage,
 656    {
 657        let handler = Arc::new(handler);
 658        self.add_handler(move |envelope, session| {
 659            let receipt = envelope.receipt();
 660            let handler = handler.clone();
 661            async move {
 662                let peer = session.peer.clone();
 663                let responded = Arc::new(AtomicBool::default());
 664                let response = Response {
 665                    peer: peer.clone(),
 666                    responded: responded.clone(),
 667                    receipt,
 668                };
 669                match (handler)(envelope.payload, response, session).await {
 670                    Ok(()) => {
 671                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 672                            Ok(())
 673                        } else {
 674                            Err(anyhow!("handler did not send a response"))?
 675                        }
 676                    }
 677                    Err(error) => {
 678                        let proto_err = match &error {
 679                            Error::Internal(err) => err.to_proto(),
 680                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 681                        };
 682                        peer.respond_with_error(receipt, proto_err)?;
 683                        Err(error)
 684                    }
 685                }
 686            }
 687        })
 688    }
 689
 690    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 691    where
 692        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 693        Fut: Send + Future<Output = Result<()>>,
 694        M: RequestMessage,
 695    {
 696        let handler = Arc::new(handler);
 697        self.add_handler(move |envelope, session| {
 698            let receipt = envelope.receipt();
 699            let handler = handler.clone();
 700            async move {
 701                let peer = session.peer.clone();
 702                let response = StreamingResponse {
 703                    peer: peer.clone(),
 704                    receipt,
 705                };
 706                match (handler)(envelope.payload, response, session).await {
 707                    Ok(()) => {
 708                        peer.end_stream(receipt)?;
 709                        Ok(())
 710                    }
 711                    Err(error) => {
 712                        let proto_err = match &error {
 713                            Error::Internal(err) => err.to_proto(),
 714                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 715                        };
 716                        peer.respond_with_error(receipt, proto_err)?;
 717                        Err(error)
 718                    }
 719                }
 720            }
 721        })
 722    }
 723
 724    #[allow(clippy::too_many_arguments)]
 725    pub fn handle_connection(
 726        self: &Arc<Self>,
 727        connection: Connection,
 728        address: String,
 729        principal: Principal,
 730        zed_version: ZedVersion,
 731        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 732        executor: Executor,
 733    ) -> impl Future<Output = ()> {
 734        let this = self.clone();
 735        let span = info_span!("handle connection", %address, impersonator = field::Empty, connection_id = field::Empty);
 736        principal.update_span(&span);
 737
 738        let mut teardown = self.teardown.subscribe();
 739        async move {
 740            if *teardown.borrow() {
 741                tracing::error!("server is tearing down");
 742                return
 743            }
 744            let (connection_id, handle_io, mut incoming_rx) = this
 745                .peer
 746                .add_connection(connection, {
 747                    let executor = executor.clone();
 748                    move |duration| executor.sleep(duration)
 749                });
 750            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 751            tracing::info!("connection opened");
 752
 753            let http_client = match IsahcHttpClient::new() {
 754                Ok(http_client) => http_client,
 755                Err(error) => {
 756                    tracing::error!(?error, "failed to create HTTP client");
 757                    return;
 758                }
 759            };
 760
 761            let session = Session {
 762                principal: principal.clone(),
 763                connection_id,
 764                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 765                peer: this.peer.clone(),
 766                connection_pool: this.connection_pool.clone(),
 767                live_kit_client: this.app_state.live_kit_client.clone(),
 768                http_client,
 769                rate_limiter: this.app_state.rate_limiter.clone(),
 770                _executor: executor.clone(),
 771            };
 772
 773            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 774                tracing::error!(?error, "failed to send initial client update");
 775                return;
 776            }
 777
 778            let handle_io = handle_io.fuse();
 779            futures::pin_mut!(handle_io);
 780
 781            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 782            // This prevents deadlocks when e.g., client A performs a request to client B and
 783            // client B performs a request to client A. If both clients stop processing further
 784            // messages until their respective request completes, they won't have a chance to
 785            // respond to the other client's request and cause a deadlock.
 786            //
 787            // This arrangement ensures we will attempt to process earlier messages first, but fall
 788            // back to processing messages arrived later in the spirit of making progress.
 789            let mut foreground_message_handlers = FuturesUnordered::new();
 790            let concurrent_handlers = Arc::new(Semaphore::new(256));
 791            loop {
 792                let next_message = async {
 793                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 794                    let message = incoming_rx.next().await;
 795                    (permit, message)
 796                }.fuse();
 797                futures::pin_mut!(next_message);
 798                futures::select_biased! {
 799                    _ = teardown.changed().fuse() => return,
 800                    result = handle_io => {
 801                        if let Err(error) = result {
 802                            tracing::error!(?error, "error handling I/O");
 803                        }
 804                        break;
 805                    }
 806                    _ = foreground_message_handlers.next() => {}
 807                    next_message = next_message => {
 808                        let (permit, message) = next_message;
 809                        if let Some(message) = message {
 810                            let type_name = message.payload_type_name();
 811                            // note: we copy all the fields from the parent span so we can query them in the logs.
 812                            // (https://github.com/tokio-rs/tracing/issues/2670).
 813                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name);
 814                            principal.update_span(&span);
 815                            let span_enter = span.enter();
 816                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 817                                let is_background = message.is_background();
 818                                let handle_message = (handler)(message, session.clone());
 819                                drop(span_enter);
 820
 821                                let handle_message = async move {
 822                                    handle_message.await;
 823                                    drop(permit);
 824                                }.instrument(span);
 825                                if is_background {
 826                                    executor.spawn_detached(handle_message);
 827                                } else {
 828                                    foreground_message_handlers.push(handle_message);
 829                                }
 830                            } else {
 831                                tracing::error!("no message handler");
 832                            }
 833                        } else {
 834                            tracing::info!("connection closed");
 835                            break;
 836                        }
 837                    }
 838                }
 839            }
 840
 841            drop(foreground_message_handlers);
 842            tracing::info!("signing out");
 843            if let Err(error) = connection_lost(session, teardown, executor).await {
 844                tracing::error!(?error, "error signing out");
 845            }
 846
 847        }.instrument(span)
 848    }
 849
 850    async fn send_initial_client_update(
 851        &self,
 852        connection_id: ConnectionId,
 853        principal: &Principal,
 854        zed_version: ZedVersion,
 855        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 856        session: &Session,
 857    ) -> Result<()> {
 858        self.peer.send(
 859            connection_id,
 860            proto::Hello {
 861                peer_id: Some(connection_id.into()),
 862            },
 863        )?;
 864        tracing::info!("sent hello message");
 865
 866        let Principal::User(user) = principal else {
 867            return Ok(());
 868        };
 869
 870        if let Some(send_connection_id) = send_connection_id.take() {
 871            let _ = send_connection_id.send(connection_id);
 872        }
 873
 874        if !user.connected_once {
 875            self.peer.send(connection_id, proto::ShowContacts {})?;
 876            self.app_state
 877                .db
 878                .set_user_connected_once(user.id, true)
 879                .await?;
 880        }
 881
 882        let (contacts, channels_for_user, channel_invites) = future::try_join3(
 883            self.app_state.db.get_contacts(user.id),
 884            self.app_state.db.get_channels_for_user(user.id),
 885            self.app_state.db.get_channel_invites_for_user(user.id),
 886        )
 887        .await?;
 888
 889        {
 890            let mut pool = self.connection_pool.lock();
 891            pool.add_connection(connection_id, user.id, user.admin, zed_version);
 892            for membership in &channels_for_user.channel_memberships {
 893                pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
 894            }
 895            self.peer.send(
 896                connection_id,
 897                build_initial_contacts_update(contacts, &pool),
 898            )?;
 899            self.peer.send(
 900                connection_id,
 901                build_update_user_channels(&channels_for_user),
 902            )?;
 903            self.peer.send(
 904                connection_id,
 905                build_channels_update(channels_for_user, channel_invites),
 906            )?;
 907        }
 908
 909        if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
 910            self.peer.send(connection_id, incoming_call)?;
 911        }
 912
 913        update_user_contacts(user.id, &session).await?;
 914        Ok(())
 915    }
 916
 917    pub async fn invite_code_redeemed(
 918        self: &Arc<Self>,
 919        inviter_id: UserId,
 920        invitee_id: UserId,
 921    ) -> Result<()> {
 922        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 923            if let Some(code) = &user.invite_code {
 924                let pool = self.connection_pool.lock();
 925                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 926                for connection_id in pool.user_connection_ids(inviter_id) {
 927                    self.peer.send(
 928                        connection_id,
 929                        proto::UpdateContacts {
 930                            contacts: vec![invitee_contact.clone()],
 931                            ..Default::default()
 932                        },
 933                    )?;
 934                    self.peer.send(
 935                        connection_id,
 936                        proto::UpdateInviteInfo {
 937                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 938                            count: user.invite_count as u32,
 939                        },
 940                    )?;
 941                }
 942            }
 943        }
 944        Ok(())
 945    }
 946
 947    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 948        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 949            if let Some(invite_code) = &user.invite_code {
 950                let pool = self.connection_pool.lock();
 951                for connection_id in pool.user_connection_ids(user_id) {
 952                    self.peer.send(
 953                        connection_id,
 954                        proto::UpdateInviteInfo {
 955                            url: format!(
 956                                "{}{}",
 957                                self.app_state.config.invite_link_prefix, invite_code
 958                            ),
 959                            count: user.invite_count as u32,
 960                        },
 961                    )?;
 962                }
 963            }
 964        }
 965        Ok(())
 966    }
 967
 968    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 969        ServerSnapshot {
 970            connection_pool: ConnectionPoolGuard {
 971                guard: self.connection_pool.lock(),
 972                _not_send: PhantomData,
 973            },
 974            peer: &self.peer,
 975        }
 976    }
 977}
 978
 979impl<'a> Deref for ConnectionPoolGuard<'a> {
 980    type Target = ConnectionPool;
 981
 982    fn deref(&self) -> &Self::Target {
 983        &self.guard
 984    }
 985}
 986
 987impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 988    fn deref_mut(&mut self) -> &mut Self::Target {
 989        &mut self.guard
 990    }
 991}
 992
 993impl<'a> Drop for ConnectionPoolGuard<'a> {
 994    fn drop(&mut self) {
 995        #[cfg(test)]
 996        self.check_invariants();
 997    }
 998}
 999
1000fn broadcast<F>(
1001    sender_id: Option<ConnectionId>,
1002    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1003    mut f: F,
1004) where
1005    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1006{
1007    for receiver_id in receiver_ids {
1008        if Some(receiver_id) != sender_id {
1009            if let Err(error) = f(receiver_id) {
1010                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1011            }
1012        }
1013    }
1014}
1015
1016pub struct ProtocolVersion(u32);
1017
1018impl Header for ProtocolVersion {
1019    fn name() -> &'static HeaderName {
1020        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1021        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1022    }
1023
1024    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1025    where
1026        Self: Sized,
1027        I: Iterator<Item = &'i axum::http::HeaderValue>,
1028    {
1029        let version = values
1030            .next()
1031            .ok_or_else(axum::headers::Error::invalid)?
1032            .to_str()
1033            .map_err(|_| axum::headers::Error::invalid())?
1034            .parse()
1035            .map_err(|_| axum::headers::Error::invalid())?;
1036        Ok(Self(version))
1037    }
1038
1039    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1040        values.extend([self.0.to_string().parse().unwrap()]);
1041    }
1042}
1043
1044pub struct AppVersionHeader(SemanticVersion);
1045impl Header for AppVersionHeader {
1046    fn name() -> &'static HeaderName {
1047        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1048        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1049    }
1050
1051    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1052    where
1053        Self: Sized,
1054        I: Iterator<Item = &'i axum::http::HeaderValue>,
1055    {
1056        let version = values
1057            .next()
1058            .ok_or_else(axum::headers::Error::invalid)?
1059            .to_str()
1060            .map_err(|_| axum::headers::Error::invalid())?
1061            .parse()
1062            .map_err(|_| axum::headers::Error::invalid())?;
1063        Ok(Self(version))
1064    }
1065
1066    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1067        values.extend([self.0.to_string().parse().unwrap()]);
1068    }
1069}
1070
1071pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1072    Router::new()
1073        .route("/rpc", get(handle_websocket_request))
1074        .layer(
1075            ServiceBuilder::new()
1076                .layer(Extension(server.app_state.clone()))
1077                .layer(middleware::from_fn(auth::validate_header)),
1078        )
1079        .route("/metrics", get(handle_metrics))
1080        .layer(Extension(server))
1081}
1082
1083pub async fn handle_websocket_request(
1084    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1085    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1086    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1087    Extension(server): Extension<Arc<Server>>,
1088    Extension(principal): Extension<Principal>,
1089    ws: WebSocketUpgrade,
1090) -> axum::response::Response {
1091    if protocol_version != rpc::PROTOCOL_VERSION {
1092        return (
1093            StatusCode::UPGRADE_REQUIRED,
1094            "client must be upgraded".to_string(),
1095        )
1096            .into_response();
1097    }
1098
1099    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1100        return (
1101            StatusCode::UPGRADE_REQUIRED,
1102            "no version header found".to_string(),
1103        )
1104            .into_response();
1105    };
1106
1107    if !version.can_collaborate() {
1108        return (
1109            StatusCode::UPGRADE_REQUIRED,
1110            "client must be upgraded".to_string(),
1111        )
1112            .into_response();
1113    }
1114
1115    let socket_address = socket_address.to_string();
1116    ws.on_upgrade(move |socket| {
1117        let socket = socket
1118            .map_ok(to_tungstenite_message)
1119            .err_into()
1120            .with(|message| async move { Ok(to_axum_message(message)) });
1121        let connection = Connection::new(Box::pin(socket));
1122        async move {
1123            server
1124                .handle_connection(
1125                    connection,
1126                    socket_address,
1127                    principal,
1128                    version,
1129                    None,
1130                    Executor::Production,
1131                )
1132                .await;
1133        }
1134    })
1135}
1136
1137pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1138    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1139    let connections_metric = CONNECTIONS_METRIC
1140        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1141
1142    let connections = server
1143        .connection_pool
1144        .lock()
1145        .connections()
1146        .filter(|connection| !connection.admin)
1147        .count();
1148    connections_metric.set(connections as _);
1149
1150    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1151    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1152        register_int_gauge!(
1153            "shared_projects",
1154            "number of open projects with one or more guests"
1155        )
1156        .unwrap()
1157    });
1158
1159    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1160    shared_projects_metric.set(shared_projects as _);
1161
1162    let encoder = prometheus::TextEncoder::new();
1163    let metric_families = prometheus::gather();
1164    let encoded_metrics = encoder
1165        .encode_to_string(&metric_families)
1166        .map_err(|err| anyhow!("{}", err))?;
1167    Ok(encoded_metrics)
1168}
1169
1170#[instrument(err, skip(executor))]
1171async fn connection_lost(
1172    session: Session,
1173    mut teardown: watch::Receiver<bool>,
1174    executor: Executor,
1175) -> Result<()> {
1176    session.peer.disconnect(session.connection_id);
1177    session
1178        .connection_pool()
1179        .await
1180        .remove_connection(session.connection_id)?;
1181
1182    session
1183        .db()
1184        .await
1185        .connection_lost(session.connection_id)
1186        .await
1187        .trace_err();
1188
1189    futures::select_biased! {
1190        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1191            if let Some(session) = session.for_user() {
1192                log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1193                leave_room_for_session(&session).await.trace_err();
1194                leave_channel_buffers_for_session(&session)
1195                    .await
1196                    .trace_err();
1197
1198                if !session
1199                    .connection_pool()
1200                    .await
1201                    .is_user_online(session.user_id())
1202                {
1203                    let db = session.db().await;
1204                    if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1205                        room_updated(&room, &session.peer);
1206                    }
1207                }
1208
1209                update_user_contacts(session.user_id(), &session).await?;
1210            }
1211        }
1212        _ = teardown.changed().fuse() => {}
1213    }
1214
1215    Ok(())
1216}
1217
1218/// Acknowledges a ping from a client, used to keep the connection alive.
1219async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1220    response.send(proto::Ack {})?;
1221    Ok(())
1222}
1223
1224/// Creates a new room for calling (outside of channels)
1225async fn create_room(
1226    _request: proto::CreateRoom,
1227    response: Response<proto::CreateRoom>,
1228    session: UserSession,
1229) -> Result<()> {
1230    let live_kit_room = nanoid::nanoid!(30);
1231
1232    let live_kit_connection_info = util::maybe!(async {
1233        let live_kit = session.live_kit_client.as_ref();
1234        let live_kit = live_kit?;
1235        let user_id = session.user_id().to_string();
1236
1237        let token = live_kit
1238            .room_token(&live_kit_room, &user_id.to_string())
1239            .trace_err()?;
1240
1241        Some(proto::LiveKitConnectionInfo {
1242            server_url: live_kit.url().into(),
1243            token,
1244            can_publish: true,
1245        })
1246    })
1247    .await;
1248
1249    let room = session
1250        .db()
1251        .await
1252        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1253        .await?;
1254
1255    response.send(proto::CreateRoomResponse {
1256        room: Some(room.clone()),
1257        live_kit_connection_info,
1258    })?;
1259
1260    update_user_contacts(session.user_id(), &session).await?;
1261    Ok(())
1262}
1263
1264/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1265async fn join_room(
1266    request: proto::JoinRoom,
1267    response: Response<proto::JoinRoom>,
1268    session: UserSession,
1269) -> Result<()> {
1270    let room_id = RoomId::from_proto(request.id);
1271
1272    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1273
1274    if let Some(channel_id) = channel_id {
1275        return join_channel_internal(channel_id, Box::new(response), session).await;
1276    }
1277
1278    let joined_room = {
1279        let room = session
1280            .db()
1281            .await
1282            .join_room(room_id, session.user_id(), session.connection_id)
1283            .await?;
1284        room_updated(&room.room, &session.peer);
1285        room.into_inner()
1286    };
1287
1288    for connection_id in session
1289        .connection_pool()
1290        .await
1291        .user_connection_ids(session.user_id())
1292    {
1293        session
1294            .peer
1295            .send(
1296                connection_id,
1297                proto::CallCanceled {
1298                    room_id: room_id.to_proto(),
1299                },
1300            )
1301            .trace_err();
1302    }
1303
1304    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1305        if let Some(token) = live_kit
1306            .room_token(
1307                &joined_room.room.live_kit_room,
1308                &session.user_id().to_string(),
1309            )
1310            .trace_err()
1311        {
1312            Some(proto::LiveKitConnectionInfo {
1313                server_url: live_kit.url().into(),
1314                token,
1315                can_publish: true,
1316            })
1317        } else {
1318            None
1319        }
1320    } else {
1321        None
1322    };
1323
1324    response.send(proto::JoinRoomResponse {
1325        room: Some(joined_room.room),
1326        channel_id: None,
1327        live_kit_connection_info,
1328    })?;
1329
1330    update_user_contacts(session.user_id(), &session).await?;
1331    Ok(())
1332}
1333
1334/// Rejoin room is used to reconnect to a room after connection errors.
1335async fn rejoin_room(
1336    request: proto::RejoinRoom,
1337    response: Response<proto::RejoinRoom>,
1338    session: UserSession,
1339) -> Result<()> {
1340    let room;
1341    let channel;
1342    {
1343        let mut rejoined_room = session
1344            .db()
1345            .await
1346            .rejoin_room(request, session.user_id(), session.connection_id)
1347            .await?;
1348
1349        response.send(proto::RejoinRoomResponse {
1350            room: Some(rejoined_room.room.clone()),
1351            reshared_projects: rejoined_room
1352                .reshared_projects
1353                .iter()
1354                .map(|project| proto::ResharedProject {
1355                    id: project.id.to_proto(),
1356                    collaborators: project
1357                        .collaborators
1358                        .iter()
1359                        .map(|collaborator| collaborator.to_proto())
1360                        .collect(),
1361                })
1362                .collect(),
1363            rejoined_projects: rejoined_room
1364                .rejoined_projects
1365                .iter()
1366                .map(|rejoined_project| proto::RejoinedProject {
1367                    id: rejoined_project.id.to_proto(),
1368                    worktrees: rejoined_project
1369                        .worktrees
1370                        .iter()
1371                        .map(|worktree| proto::WorktreeMetadata {
1372                            id: worktree.id,
1373                            root_name: worktree.root_name.clone(),
1374                            visible: worktree.visible,
1375                            abs_path: worktree.abs_path.clone(),
1376                        })
1377                        .collect(),
1378                    collaborators: rejoined_project
1379                        .collaborators
1380                        .iter()
1381                        .map(|collaborator| collaborator.to_proto())
1382                        .collect(),
1383                    language_servers: rejoined_project.language_servers.clone(),
1384                })
1385                .collect(),
1386        })?;
1387        room_updated(&rejoined_room.room, &session.peer);
1388
1389        for project in &rejoined_room.reshared_projects {
1390            for collaborator in &project.collaborators {
1391                session
1392                    .peer
1393                    .send(
1394                        collaborator.connection_id,
1395                        proto::UpdateProjectCollaborator {
1396                            project_id: project.id.to_proto(),
1397                            old_peer_id: Some(project.old_connection_id.into()),
1398                            new_peer_id: Some(session.connection_id.into()),
1399                        },
1400                    )
1401                    .trace_err();
1402            }
1403
1404            broadcast(
1405                Some(session.connection_id),
1406                project
1407                    .collaborators
1408                    .iter()
1409                    .map(|collaborator| collaborator.connection_id),
1410                |connection_id| {
1411                    session.peer.forward_send(
1412                        session.connection_id,
1413                        connection_id,
1414                        proto::UpdateProject {
1415                            project_id: project.id.to_proto(),
1416                            worktrees: project.worktrees.clone(),
1417                        },
1418                    )
1419                },
1420            );
1421        }
1422
1423        for project in &rejoined_room.rejoined_projects {
1424            for collaborator in &project.collaborators {
1425                session
1426                    .peer
1427                    .send(
1428                        collaborator.connection_id,
1429                        proto::UpdateProjectCollaborator {
1430                            project_id: project.id.to_proto(),
1431                            old_peer_id: Some(project.old_connection_id.into()),
1432                            new_peer_id: Some(session.connection_id.into()),
1433                        },
1434                    )
1435                    .trace_err();
1436            }
1437        }
1438
1439        for project in &mut rejoined_room.rejoined_projects {
1440            for worktree in mem::take(&mut project.worktrees) {
1441                #[cfg(any(test, feature = "test-support"))]
1442                const MAX_CHUNK_SIZE: usize = 2;
1443                #[cfg(not(any(test, feature = "test-support")))]
1444                const MAX_CHUNK_SIZE: usize = 256;
1445
1446                // Stream this worktree's entries.
1447                let message = proto::UpdateWorktree {
1448                    project_id: project.id.to_proto(),
1449                    worktree_id: worktree.id,
1450                    abs_path: worktree.abs_path.clone(),
1451                    root_name: worktree.root_name,
1452                    updated_entries: worktree.updated_entries,
1453                    removed_entries: worktree.removed_entries,
1454                    scan_id: worktree.scan_id,
1455                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1456                    updated_repositories: worktree.updated_repositories,
1457                    removed_repositories: worktree.removed_repositories,
1458                };
1459                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1460                    session.peer.send(session.connection_id, update.clone())?;
1461                }
1462
1463                // Stream this worktree's diagnostics.
1464                for summary in worktree.diagnostic_summaries {
1465                    session.peer.send(
1466                        session.connection_id,
1467                        proto::UpdateDiagnosticSummary {
1468                            project_id: project.id.to_proto(),
1469                            worktree_id: worktree.id,
1470                            summary: Some(summary),
1471                        },
1472                    )?;
1473                }
1474
1475                for settings_file in worktree.settings_files {
1476                    session.peer.send(
1477                        session.connection_id,
1478                        proto::UpdateWorktreeSettings {
1479                            project_id: project.id.to_proto(),
1480                            worktree_id: worktree.id,
1481                            path: settings_file.path,
1482                            content: Some(settings_file.content),
1483                        },
1484                    )?;
1485                }
1486            }
1487
1488            for language_server in &project.language_servers {
1489                session.peer.send(
1490                    session.connection_id,
1491                    proto::UpdateLanguageServer {
1492                        project_id: project.id.to_proto(),
1493                        language_server_id: language_server.id,
1494                        variant: Some(
1495                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1496                                proto::LspDiskBasedDiagnosticsUpdated {},
1497                            ),
1498                        ),
1499                    },
1500                )?;
1501            }
1502        }
1503
1504        let rejoined_room = rejoined_room.into_inner();
1505
1506        room = rejoined_room.room;
1507        channel = rejoined_room.channel;
1508    }
1509
1510    if let Some(channel) = channel {
1511        channel_updated(
1512            &channel,
1513            &room,
1514            &session.peer,
1515            &*session.connection_pool().await,
1516        );
1517    }
1518
1519    update_user_contacts(session.user_id(), &session).await?;
1520    Ok(())
1521}
1522
1523/// leave room disconnects from the room.
1524async fn leave_room(
1525    _: proto::LeaveRoom,
1526    response: Response<proto::LeaveRoom>,
1527    session: UserSession,
1528) -> Result<()> {
1529    leave_room_for_session(&session).await?;
1530    response.send(proto::Ack {})?;
1531    Ok(())
1532}
1533
1534/// Updates the permissions of someone else in the room.
1535async fn set_room_participant_role(
1536    request: proto::SetRoomParticipantRole,
1537    response: Response<proto::SetRoomParticipantRole>,
1538    session: UserSession,
1539) -> Result<()> {
1540    let user_id = UserId::from_proto(request.user_id);
1541    let role = ChannelRole::from(request.role());
1542
1543    let (live_kit_room, can_publish) = {
1544        let room = session
1545            .db()
1546            .await
1547            .set_room_participant_role(
1548                session.user_id(),
1549                RoomId::from_proto(request.room_id),
1550                user_id,
1551                role,
1552            )
1553            .await?;
1554
1555        let live_kit_room = room.live_kit_room.clone();
1556        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1557        room_updated(&room, &session.peer);
1558        (live_kit_room, can_publish)
1559    };
1560
1561    if let Some(live_kit) = session.live_kit_client.as_ref() {
1562        live_kit
1563            .update_participant(
1564                live_kit_room.clone(),
1565                request.user_id.to_string(),
1566                live_kit_server::proto::ParticipantPermission {
1567                    can_subscribe: true,
1568                    can_publish,
1569                    can_publish_data: can_publish,
1570                    hidden: false,
1571                    recorder: false,
1572                },
1573            )
1574            .await
1575            .trace_err();
1576    }
1577
1578    response.send(proto::Ack {})?;
1579    Ok(())
1580}
1581
1582/// Call someone else into the current room
1583async fn call(
1584    request: proto::Call,
1585    response: Response<proto::Call>,
1586    session: UserSession,
1587) -> Result<()> {
1588    let room_id = RoomId::from_proto(request.room_id);
1589    let calling_user_id = session.user_id();
1590    let calling_connection_id = session.connection_id;
1591    let called_user_id = UserId::from_proto(request.called_user_id);
1592    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1593    if !session
1594        .db()
1595        .await
1596        .has_contact(calling_user_id, called_user_id)
1597        .await?
1598    {
1599        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1600    }
1601
1602    let incoming_call = {
1603        let (room, incoming_call) = &mut *session
1604            .db()
1605            .await
1606            .call(
1607                room_id,
1608                calling_user_id,
1609                calling_connection_id,
1610                called_user_id,
1611                initial_project_id,
1612            )
1613            .await?;
1614        room_updated(&room, &session.peer);
1615        mem::take(incoming_call)
1616    };
1617    update_user_contacts(called_user_id, &session).await?;
1618
1619    let mut calls = session
1620        .connection_pool()
1621        .await
1622        .user_connection_ids(called_user_id)
1623        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1624        .collect::<FuturesUnordered<_>>();
1625
1626    while let Some(call_response) = calls.next().await {
1627        match call_response.as_ref() {
1628            Ok(_) => {
1629                response.send(proto::Ack {})?;
1630                return Ok(());
1631            }
1632            Err(_) => {
1633                call_response.trace_err();
1634            }
1635        }
1636    }
1637
1638    {
1639        let room = session
1640            .db()
1641            .await
1642            .call_failed(room_id, called_user_id)
1643            .await?;
1644        room_updated(&room, &session.peer);
1645    }
1646    update_user_contacts(called_user_id, &session).await?;
1647
1648    Err(anyhow!("failed to ring user"))?
1649}
1650
1651/// Cancel an outgoing call.
1652async fn cancel_call(
1653    request: proto::CancelCall,
1654    response: Response<proto::CancelCall>,
1655    session: UserSession,
1656) -> Result<()> {
1657    let called_user_id = UserId::from_proto(request.called_user_id);
1658    let room_id = RoomId::from_proto(request.room_id);
1659    {
1660        let room = session
1661            .db()
1662            .await
1663            .cancel_call(room_id, session.connection_id, called_user_id)
1664            .await?;
1665        room_updated(&room, &session.peer);
1666    }
1667
1668    for connection_id in session
1669        .connection_pool()
1670        .await
1671        .user_connection_ids(called_user_id)
1672    {
1673        session
1674            .peer
1675            .send(
1676                connection_id,
1677                proto::CallCanceled {
1678                    room_id: room_id.to_proto(),
1679                },
1680            )
1681            .trace_err();
1682    }
1683    response.send(proto::Ack {})?;
1684
1685    update_user_contacts(called_user_id, &session).await?;
1686    Ok(())
1687}
1688
1689/// Decline an incoming call.
1690async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1691    let room_id = RoomId::from_proto(message.room_id);
1692    {
1693        let room = session
1694            .db()
1695            .await
1696            .decline_call(Some(room_id), session.user_id())
1697            .await?
1698            .ok_or_else(|| anyhow!("failed to decline call"))?;
1699        room_updated(&room, &session.peer);
1700    }
1701
1702    for connection_id in session
1703        .connection_pool()
1704        .await
1705        .user_connection_ids(session.user_id())
1706    {
1707        session
1708            .peer
1709            .send(
1710                connection_id,
1711                proto::CallCanceled {
1712                    room_id: room_id.to_proto(),
1713                },
1714            )
1715            .trace_err();
1716    }
1717    update_user_contacts(session.user_id(), &session).await?;
1718    Ok(())
1719}
1720
1721/// Updates other participants in the room with your current location.
1722async fn update_participant_location(
1723    request: proto::UpdateParticipantLocation,
1724    response: Response<proto::UpdateParticipantLocation>,
1725    session: UserSession,
1726) -> Result<()> {
1727    let room_id = RoomId::from_proto(request.room_id);
1728    let location = request
1729        .location
1730        .ok_or_else(|| anyhow!("invalid location"))?;
1731
1732    let db = session.db().await;
1733    let room = db
1734        .update_room_participant_location(room_id, session.connection_id, location)
1735        .await?;
1736
1737    room_updated(&room, &session.peer);
1738    response.send(proto::Ack {})?;
1739    Ok(())
1740}
1741
1742/// Share a project into the room.
1743async fn share_project(
1744    request: proto::ShareProject,
1745    response: Response<proto::ShareProject>,
1746    session: Session,
1747) -> Result<()> {
1748    let (project_id, room) = &*session
1749        .db()
1750        .await
1751        .share_project(
1752            RoomId::from_proto(request.room_id),
1753            session.connection_id,
1754            &request.worktrees,
1755        )
1756        .await?;
1757    response.send(proto::ShareProjectResponse {
1758        project_id: project_id.to_proto(),
1759    })?;
1760    room_updated(&room, &session.peer);
1761
1762    Ok(())
1763}
1764
1765/// Unshare a project from the room.
1766async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1767    let project_id = ProjectId::from_proto(message.project_id);
1768
1769    let (room, guest_connection_ids) = &*session
1770        .db()
1771        .await
1772        .unshare_project(project_id, session.connection_id)
1773        .await?;
1774
1775    broadcast(
1776        Some(session.connection_id),
1777        guest_connection_ids.iter().copied(),
1778        |conn_id| session.peer.send(conn_id, message.clone()),
1779    );
1780    room_updated(&room, &session.peer);
1781
1782    Ok(())
1783}
1784
1785/// Join someone elses shared project.
1786async fn join_project(
1787    request: proto::JoinProject,
1788    response: Response<proto::JoinProject>,
1789    session: UserSession,
1790) -> Result<()> {
1791    let project_id = ProjectId::from_proto(request.project_id);
1792
1793    tracing::info!(%project_id, "join project");
1794
1795    let (project, replica_id) = &mut *session
1796        .db()
1797        .await
1798        .join_project_in_room(project_id, session.connection_id)
1799        .await?;
1800
1801    join_project_internal(response, session, project, replica_id)
1802}
1803
1804trait JoinProjectInternalResponse {
1805    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1806}
1807impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1808    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1809        Response::<proto::JoinProject>::send(self, result)
1810    }
1811}
1812impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1813    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1814        Response::<proto::JoinHostedProject>::send(self, result)
1815    }
1816}
1817
1818fn join_project_internal(
1819    response: impl JoinProjectInternalResponse,
1820    session: UserSession,
1821    project: &mut Project,
1822    replica_id: &ReplicaId,
1823) -> Result<()> {
1824    let collaborators = project
1825        .collaborators
1826        .iter()
1827        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1828        .map(|collaborator| collaborator.to_proto())
1829        .collect::<Vec<_>>();
1830    let project_id = project.id;
1831    let guest_user_id = session.user_id();
1832
1833    let worktrees = project
1834        .worktrees
1835        .iter()
1836        .map(|(id, worktree)| proto::WorktreeMetadata {
1837            id: *id,
1838            root_name: worktree.root_name.clone(),
1839            visible: worktree.visible,
1840            abs_path: worktree.abs_path.clone(),
1841        })
1842        .collect::<Vec<_>>();
1843
1844    for collaborator in &collaborators {
1845        session
1846            .peer
1847            .send(
1848                collaborator.peer_id.unwrap().into(),
1849                proto::AddProjectCollaborator {
1850                    project_id: project_id.to_proto(),
1851                    collaborator: Some(proto::Collaborator {
1852                        peer_id: Some(session.connection_id.into()),
1853                        replica_id: replica_id.0 as u32,
1854                        user_id: guest_user_id.to_proto(),
1855                    }),
1856                },
1857            )
1858            .trace_err();
1859    }
1860
1861    // First, we send the metadata associated with each worktree.
1862    response.send(proto::JoinProjectResponse {
1863        project_id: project.id.0 as u64,
1864        worktrees: worktrees.clone(),
1865        replica_id: replica_id.0 as u32,
1866        collaborators: collaborators.clone(),
1867        language_servers: project.language_servers.clone(),
1868        role: project.role.into(), // todo
1869    })?;
1870
1871    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1872        #[cfg(any(test, feature = "test-support"))]
1873        const MAX_CHUNK_SIZE: usize = 2;
1874        #[cfg(not(any(test, feature = "test-support")))]
1875        const MAX_CHUNK_SIZE: usize = 256;
1876
1877        // Stream this worktree's entries.
1878        let message = proto::UpdateWorktree {
1879            project_id: project_id.to_proto(),
1880            worktree_id,
1881            abs_path: worktree.abs_path.clone(),
1882            root_name: worktree.root_name,
1883            updated_entries: worktree.entries,
1884            removed_entries: Default::default(),
1885            scan_id: worktree.scan_id,
1886            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1887            updated_repositories: worktree.repository_entries.into_values().collect(),
1888            removed_repositories: Default::default(),
1889        };
1890        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1891            session.peer.send(session.connection_id, update.clone())?;
1892        }
1893
1894        // Stream this worktree's diagnostics.
1895        for summary in worktree.diagnostic_summaries {
1896            session.peer.send(
1897                session.connection_id,
1898                proto::UpdateDiagnosticSummary {
1899                    project_id: project_id.to_proto(),
1900                    worktree_id: worktree.id,
1901                    summary: Some(summary),
1902                },
1903            )?;
1904        }
1905
1906        for settings_file in worktree.settings_files {
1907            session.peer.send(
1908                session.connection_id,
1909                proto::UpdateWorktreeSettings {
1910                    project_id: project_id.to_proto(),
1911                    worktree_id: worktree.id,
1912                    path: settings_file.path,
1913                    content: Some(settings_file.content),
1914                },
1915            )?;
1916        }
1917    }
1918
1919    for language_server in &project.language_servers {
1920        session.peer.send(
1921            session.connection_id,
1922            proto::UpdateLanguageServer {
1923                project_id: project_id.to_proto(),
1924                language_server_id: language_server.id,
1925                variant: Some(
1926                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1927                        proto::LspDiskBasedDiagnosticsUpdated {},
1928                    ),
1929                ),
1930            },
1931        )?;
1932    }
1933
1934    Ok(())
1935}
1936
1937/// Leave someone elses shared project.
1938async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1939    let sender_id = session.connection_id;
1940    let project_id = ProjectId::from_proto(request.project_id);
1941    let db = session.db().await;
1942    if db.is_hosted_project(project_id).await? {
1943        let project = db.leave_hosted_project(project_id, sender_id).await?;
1944        project_left(&project, &session);
1945        return Ok(());
1946    }
1947
1948    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1949    tracing::info!(
1950        %project_id,
1951        host_user_id = ?project.host_user_id,
1952        host_connection_id = ?project.host_connection_id,
1953        "leave project"
1954    );
1955
1956    project_left(&project, &session);
1957    room_updated(&room, &session.peer);
1958
1959    Ok(())
1960}
1961
1962async fn join_hosted_project(
1963    request: proto::JoinHostedProject,
1964    response: Response<proto::JoinHostedProject>,
1965    session: UserSession,
1966) -> Result<()> {
1967    let (mut project, replica_id) = session
1968        .db()
1969        .await
1970        .join_hosted_project(
1971            ProjectId(request.project_id as i32),
1972            session.user_id(),
1973            session.connection_id,
1974        )
1975        .await?;
1976
1977    join_project_internal(response, session, &mut project, &replica_id)
1978}
1979
1980/// Updates other participants with changes to the project
1981async fn update_project(
1982    request: proto::UpdateProject,
1983    response: Response<proto::UpdateProject>,
1984    session: Session,
1985) -> Result<()> {
1986    let project_id = ProjectId::from_proto(request.project_id);
1987    let (room, guest_connection_ids) = &*session
1988        .db()
1989        .await
1990        .update_project(project_id, session.connection_id, &request.worktrees)
1991        .await?;
1992    broadcast(
1993        Some(session.connection_id),
1994        guest_connection_ids.iter().copied(),
1995        |connection_id| {
1996            session
1997                .peer
1998                .forward_send(session.connection_id, connection_id, request.clone())
1999        },
2000    );
2001    room_updated(&room, &session.peer);
2002    response.send(proto::Ack {})?;
2003
2004    Ok(())
2005}
2006
2007/// Updates other participants with changes to the worktree
2008async fn update_worktree(
2009    request: proto::UpdateWorktree,
2010    response: Response<proto::UpdateWorktree>,
2011    session: Session,
2012) -> Result<()> {
2013    let guest_connection_ids = session
2014        .db()
2015        .await
2016        .update_worktree(&request, session.connection_id)
2017        .await?;
2018
2019    broadcast(
2020        Some(session.connection_id),
2021        guest_connection_ids.iter().copied(),
2022        |connection_id| {
2023            session
2024                .peer
2025                .forward_send(session.connection_id, connection_id, request.clone())
2026        },
2027    );
2028    response.send(proto::Ack {})?;
2029    Ok(())
2030}
2031
2032/// Updates other participants with changes to the diagnostics
2033async fn update_diagnostic_summary(
2034    message: proto::UpdateDiagnosticSummary,
2035    session: Session,
2036) -> Result<()> {
2037    let guest_connection_ids = session
2038        .db()
2039        .await
2040        .update_diagnostic_summary(&message, session.connection_id)
2041        .await?;
2042
2043    broadcast(
2044        Some(session.connection_id),
2045        guest_connection_ids.iter().copied(),
2046        |connection_id| {
2047            session
2048                .peer
2049                .forward_send(session.connection_id, connection_id, message.clone())
2050        },
2051    );
2052
2053    Ok(())
2054}
2055
2056/// Updates other participants with changes to the worktree settings
2057async fn update_worktree_settings(
2058    message: proto::UpdateWorktreeSettings,
2059    session: Session,
2060) -> Result<()> {
2061    let guest_connection_ids = session
2062        .db()
2063        .await
2064        .update_worktree_settings(&message, session.connection_id)
2065        .await?;
2066
2067    broadcast(
2068        Some(session.connection_id),
2069        guest_connection_ids.iter().copied(),
2070        |connection_id| {
2071            session
2072                .peer
2073                .forward_send(session.connection_id, connection_id, message.clone())
2074        },
2075    );
2076
2077    Ok(())
2078}
2079
2080/// Notify other participants that a  language server has started.
2081async fn start_language_server(
2082    request: proto::StartLanguageServer,
2083    session: Session,
2084) -> Result<()> {
2085    let guest_connection_ids = session
2086        .db()
2087        .await
2088        .start_language_server(&request, session.connection_id)
2089        .await?;
2090
2091    broadcast(
2092        Some(session.connection_id),
2093        guest_connection_ids.iter().copied(),
2094        |connection_id| {
2095            session
2096                .peer
2097                .forward_send(session.connection_id, connection_id, request.clone())
2098        },
2099    );
2100    Ok(())
2101}
2102
2103/// Notify other participants that a language server has changed.
2104async fn update_language_server(
2105    request: proto::UpdateLanguageServer,
2106    session: Session,
2107) -> Result<()> {
2108    let project_id = ProjectId::from_proto(request.project_id);
2109    let project_connection_ids = session
2110        .db()
2111        .await
2112        .project_connection_ids(project_id, session.connection_id)
2113        .await?;
2114    broadcast(
2115        Some(session.connection_id),
2116        project_connection_ids.iter().copied(),
2117        |connection_id| {
2118            session
2119                .peer
2120                .forward_send(session.connection_id, connection_id, request.clone())
2121        },
2122    );
2123    Ok(())
2124}
2125
2126/// forward a project request to the host. These requests should be read only
2127/// as guests are allowed to send them.
2128async fn forward_read_only_project_request<T>(
2129    request: T,
2130    response: Response<T>,
2131    session: Session,
2132) -> Result<()>
2133where
2134    T: EntityMessage + RequestMessage,
2135{
2136    let project_id = ProjectId::from_proto(request.remote_entity_id());
2137    let host_connection_id = session
2138        .db()
2139        .await
2140        .host_for_read_only_project_request(project_id, session.connection_id)
2141        .await?;
2142    let payload = session
2143        .peer
2144        .forward_request(session.connection_id, host_connection_id, request)
2145        .await?;
2146    response.send(payload)?;
2147    Ok(())
2148}
2149
2150/// forward a project request to the host. These requests are disallowed
2151/// for guests.
2152async fn forward_mutating_project_request<T>(
2153    request: T,
2154    response: Response<T>,
2155    session: Session,
2156) -> Result<()>
2157where
2158    T: EntityMessage + RequestMessage,
2159{
2160    let project_id = ProjectId::from_proto(request.remote_entity_id());
2161    let host_connection_id = session
2162        .db()
2163        .await
2164        .host_for_mutating_project_request(project_id, session.connection_id)
2165        .await?;
2166    let payload = session
2167        .peer
2168        .forward_request(session.connection_id, host_connection_id, request)
2169        .await?;
2170    response.send(payload)?;
2171    Ok(())
2172}
2173
2174/// Notify other participants that a new buffer has been created
2175async fn create_buffer_for_peer(
2176    request: proto::CreateBufferForPeer,
2177    session: Session,
2178) -> Result<()> {
2179    session
2180        .db()
2181        .await
2182        .check_user_is_project_host(
2183            ProjectId::from_proto(request.project_id),
2184            session.connection_id,
2185        )
2186        .await?;
2187    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2188    session
2189        .peer
2190        .forward_send(session.connection_id, peer_id.into(), request)?;
2191    Ok(())
2192}
2193
2194/// Notify other participants that a buffer has been updated. This is
2195/// allowed for guests as long as the update is limited to selections.
2196async fn update_buffer(
2197    request: proto::UpdateBuffer,
2198    response: Response<proto::UpdateBuffer>,
2199    session: Session,
2200) -> Result<()> {
2201    let project_id = ProjectId::from_proto(request.project_id);
2202    let mut guest_connection_ids;
2203    let mut host_connection_id = None;
2204
2205    let mut requires_write_permission = false;
2206
2207    for op in request.operations.iter() {
2208        match op.variant {
2209            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2210            Some(_) => requires_write_permission = true,
2211        }
2212    }
2213
2214    {
2215        let collaborators = session
2216            .db()
2217            .await
2218            .project_collaborators_for_buffer_update(
2219                project_id,
2220                session.connection_id,
2221                requires_write_permission,
2222            )
2223            .await?;
2224        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2225        for collaborator in collaborators.iter() {
2226            if collaborator.is_host {
2227                host_connection_id = Some(collaborator.connection_id);
2228            } else {
2229                guest_connection_ids.push(collaborator.connection_id);
2230            }
2231        }
2232    }
2233    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2234
2235    broadcast(
2236        Some(session.connection_id),
2237        guest_connection_ids,
2238        |connection_id| {
2239            session
2240                .peer
2241                .forward_send(session.connection_id, connection_id, request.clone())
2242        },
2243    );
2244    if host_connection_id != session.connection_id {
2245        session
2246            .peer
2247            .forward_request(session.connection_id, host_connection_id, request.clone())
2248            .await?;
2249    }
2250
2251    response.send(proto::Ack {})?;
2252    Ok(())
2253}
2254
2255/// Notify other participants that a project has been updated.
2256async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2257    request: T,
2258    session: Session,
2259) -> Result<()> {
2260    let project_id = ProjectId::from_proto(request.remote_entity_id());
2261    let project_connection_ids = session
2262        .db()
2263        .await
2264        .project_connection_ids(project_id, session.connection_id)
2265        .await?;
2266
2267    broadcast(
2268        Some(session.connection_id),
2269        project_connection_ids.iter().copied(),
2270        |connection_id| {
2271            session
2272                .peer
2273                .forward_send(session.connection_id, connection_id, request.clone())
2274        },
2275    );
2276    Ok(())
2277}
2278
2279/// Start following another user in a call.
2280async fn follow(
2281    request: proto::Follow,
2282    response: Response<proto::Follow>,
2283    session: UserSession,
2284) -> Result<()> {
2285    let room_id = RoomId::from_proto(request.room_id);
2286    let project_id = request.project_id.map(ProjectId::from_proto);
2287    let leader_id = request
2288        .leader_id
2289        .ok_or_else(|| anyhow!("invalid leader id"))?
2290        .into();
2291    let follower_id = session.connection_id;
2292
2293    session
2294        .db()
2295        .await
2296        .check_room_participants(room_id, leader_id, session.connection_id)
2297        .await?;
2298
2299    let response_payload = session
2300        .peer
2301        .forward_request(session.connection_id, leader_id, request)
2302        .await?;
2303    response.send(response_payload)?;
2304
2305    if let Some(project_id) = project_id {
2306        let room = session
2307            .db()
2308            .await
2309            .follow(room_id, project_id, leader_id, follower_id)
2310            .await?;
2311        room_updated(&room, &session.peer);
2312    }
2313
2314    Ok(())
2315}
2316
2317/// Stop following another user in a call.
2318async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2319    let room_id = RoomId::from_proto(request.room_id);
2320    let project_id = request.project_id.map(ProjectId::from_proto);
2321    let leader_id = request
2322        .leader_id
2323        .ok_or_else(|| anyhow!("invalid leader id"))?
2324        .into();
2325    let follower_id = session.connection_id;
2326
2327    session
2328        .db()
2329        .await
2330        .check_room_participants(room_id, leader_id, session.connection_id)
2331        .await?;
2332
2333    session
2334        .peer
2335        .forward_send(session.connection_id, leader_id, request)?;
2336
2337    if let Some(project_id) = project_id {
2338        let room = session
2339            .db()
2340            .await
2341            .unfollow(room_id, project_id, leader_id, follower_id)
2342            .await?;
2343        room_updated(&room, &session.peer);
2344    }
2345
2346    Ok(())
2347}
2348
2349/// Notify everyone following you of your current location.
2350async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2351    let room_id = RoomId::from_proto(request.room_id);
2352    let database = session.db.lock().await;
2353
2354    let connection_ids = if let Some(project_id) = request.project_id {
2355        let project_id = ProjectId::from_proto(project_id);
2356        database
2357            .project_connection_ids(project_id, session.connection_id)
2358            .await?
2359    } else {
2360        database
2361            .room_connection_ids(room_id, session.connection_id)
2362            .await?
2363    };
2364
2365    // For now, don't send view update messages back to that view's current leader.
2366    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2367        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2368        _ => None,
2369    });
2370
2371    for connection_id in connection_ids.iter().cloned() {
2372        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2373            session
2374                .peer
2375                .forward_send(session.connection_id, connection_id, request.clone())?;
2376        }
2377    }
2378    Ok(())
2379}
2380
2381/// Get public data about users.
2382async fn get_users(
2383    request: proto::GetUsers,
2384    response: Response<proto::GetUsers>,
2385    session: Session,
2386) -> Result<()> {
2387    let user_ids = request
2388        .user_ids
2389        .into_iter()
2390        .map(UserId::from_proto)
2391        .collect();
2392    let users = session
2393        .db()
2394        .await
2395        .get_users_by_ids(user_ids)
2396        .await?
2397        .into_iter()
2398        .map(|user| proto::User {
2399            id: user.id.to_proto(),
2400            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2401            github_login: user.github_login,
2402        })
2403        .collect();
2404    response.send(proto::UsersResponse { users })?;
2405    Ok(())
2406}
2407
2408/// Search for users (to invite) buy Github login
2409async fn fuzzy_search_users(
2410    request: proto::FuzzySearchUsers,
2411    response: Response<proto::FuzzySearchUsers>,
2412    session: UserSession,
2413) -> Result<()> {
2414    let query = request.query;
2415    let users = match query.len() {
2416        0 => vec![],
2417        1 | 2 => session
2418            .db()
2419            .await
2420            .get_user_by_github_login(&query)
2421            .await?
2422            .into_iter()
2423            .collect(),
2424        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2425    };
2426    let users = users
2427        .into_iter()
2428        .filter(|user| user.id != session.user_id())
2429        .map(|user| proto::User {
2430            id: user.id.to_proto(),
2431            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2432            github_login: user.github_login,
2433        })
2434        .collect();
2435    response.send(proto::UsersResponse { users })?;
2436    Ok(())
2437}
2438
2439/// Send a contact request to another user.
2440async fn request_contact(
2441    request: proto::RequestContact,
2442    response: Response<proto::RequestContact>,
2443    session: UserSession,
2444) -> Result<()> {
2445    let requester_id = session.user_id();
2446    let responder_id = UserId::from_proto(request.responder_id);
2447    if requester_id == responder_id {
2448        return Err(anyhow!("cannot add yourself as a contact"))?;
2449    }
2450
2451    let notifications = session
2452        .db()
2453        .await
2454        .send_contact_request(requester_id, responder_id)
2455        .await?;
2456
2457    // Update outgoing contact requests of requester
2458    let mut update = proto::UpdateContacts::default();
2459    update.outgoing_requests.push(responder_id.to_proto());
2460    for connection_id in session
2461        .connection_pool()
2462        .await
2463        .user_connection_ids(requester_id)
2464    {
2465        session.peer.send(connection_id, update.clone())?;
2466    }
2467
2468    // Update incoming contact requests of responder
2469    let mut update = proto::UpdateContacts::default();
2470    update
2471        .incoming_requests
2472        .push(proto::IncomingContactRequest {
2473            requester_id: requester_id.to_proto(),
2474        });
2475    let connection_pool = session.connection_pool().await;
2476    for connection_id in connection_pool.user_connection_ids(responder_id) {
2477        session.peer.send(connection_id, update.clone())?;
2478    }
2479
2480    send_notifications(&connection_pool, &session.peer, notifications);
2481
2482    response.send(proto::Ack {})?;
2483    Ok(())
2484}
2485
2486/// Accept or decline a contact request
2487async fn respond_to_contact_request(
2488    request: proto::RespondToContactRequest,
2489    response: Response<proto::RespondToContactRequest>,
2490    session: UserSession,
2491) -> Result<()> {
2492    let responder_id = session.user_id();
2493    let requester_id = UserId::from_proto(request.requester_id);
2494    let db = session.db().await;
2495    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2496        db.dismiss_contact_notification(responder_id, requester_id)
2497            .await?;
2498    } else {
2499        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2500
2501        let notifications = db
2502            .respond_to_contact_request(responder_id, requester_id, accept)
2503            .await?;
2504        let requester_busy = db.is_user_busy(requester_id).await?;
2505        let responder_busy = db.is_user_busy(responder_id).await?;
2506
2507        let pool = session.connection_pool().await;
2508        // Update responder with new contact
2509        let mut update = proto::UpdateContacts::default();
2510        if accept {
2511            update
2512                .contacts
2513                .push(contact_for_user(requester_id, requester_busy, &pool));
2514        }
2515        update
2516            .remove_incoming_requests
2517            .push(requester_id.to_proto());
2518        for connection_id in pool.user_connection_ids(responder_id) {
2519            session.peer.send(connection_id, update.clone())?;
2520        }
2521
2522        // Update requester with new contact
2523        let mut update = proto::UpdateContacts::default();
2524        if accept {
2525            update
2526                .contacts
2527                .push(contact_for_user(responder_id, responder_busy, &pool));
2528        }
2529        update
2530            .remove_outgoing_requests
2531            .push(responder_id.to_proto());
2532
2533        for connection_id in pool.user_connection_ids(requester_id) {
2534            session.peer.send(connection_id, update.clone())?;
2535        }
2536
2537        send_notifications(&pool, &session.peer, notifications);
2538    }
2539
2540    response.send(proto::Ack {})?;
2541    Ok(())
2542}
2543
2544/// Remove a contact.
2545async fn remove_contact(
2546    request: proto::RemoveContact,
2547    response: Response<proto::RemoveContact>,
2548    session: UserSession,
2549) -> Result<()> {
2550    let requester_id = session.user_id();
2551    let responder_id = UserId::from_proto(request.user_id);
2552    let db = session.db().await;
2553    let (contact_accepted, deleted_notification_id) =
2554        db.remove_contact(requester_id, responder_id).await?;
2555
2556    let pool = session.connection_pool().await;
2557    // Update outgoing contact requests of requester
2558    let mut update = proto::UpdateContacts::default();
2559    if contact_accepted {
2560        update.remove_contacts.push(responder_id.to_proto());
2561    } else {
2562        update
2563            .remove_outgoing_requests
2564            .push(responder_id.to_proto());
2565    }
2566    for connection_id in pool.user_connection_ids(requester_id) {
2567        session.peer.send(connection_id, update.clone())?;
2568    }
2569
2570    // Update incoming contact requests of responder
2571    let mut update = proto::UpdateContacts::default();
2572    if contact_accepted {
2573        update.remove_contacts.push(requester_id.to_proto());
2574    } else {
2575        update
2576            .remove_incoming_requests
2577            .push(requester_id.to_proto());
2578    }
2579    for connection_id in pool.user_connection_ids(responder_id) {
2580        session.peer.send(connection_id, update.clone())?;
2581        if let Some(notification_id) = deleted_notification_id {
2582            session.peer.send(
2583                connection_id,
2584                proto::DeleteNotification {
2585                    notification_id: notification_id.to_proto(),
2586                },
2587            )?;
2588        }
2589    }
2590
2591    response.send(proto::Ack {})?;
2592    Ok(())
2593}
2594
2595/// Creates a new channel.
2596async fn create_channel(
2597    request: proto::CreateChannel,
2598    response: Response<proto::CreateChannel>,
2599    session: UserSession,
2600) -> Result<()> {
2601    let db = session.db().await;
2602
2603    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2604    let (channel, membership) = db
2605        .create_channel(&request.name, parent_id, session.user_id())
2606        .await?;
2607
2608    let root_id = channel.root_id();
2609    let channel = Channel::from_model(channel);
2610
2611    response.send(proto::CreateChannelResponse {
2612        channel: Some(channel.to_proto()),
2613        parent_id: request.parent_id,
2614    })?;
2615
2616    let mut connection_pool = session.connection_pool().await;
2617    if let Some(membership) = membership {
2618        connection_pool.subscribe_to_channel(
2619            membership.user_id,
2620            membership.channel_id,
2621            membership.role,
2622        );
2623        let update = proto::UpdateUserChannels {
2624            channel_memberships: vec![proto::ChannelMembership {
2625                channel_id: membership.channel_id.to_proto(),
2626                role: membership.role.into(),
2627            }],
2628            ..Default::default()
2629        };
2630        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2631            session.peer.send(connection_id, update.clone())?;
2632        }
2633    }
2634
2635    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2636        if !role.can_see_channel(channel.visibility) {
2637            continue;
2638        }
2639
2640        let update = proto::UpdateChannels {
2641            channels: vec![channel.to_proto()],
2642            ..Default::default()
2643        };
2644        session.peer.send(connection_id, update.clone())?;
2645    }
2646
2647    Ok(())
2648}
2649
2650/// Delete a channel
2651async fn delete_channel(
2652    request: proto::DeleteChannel,
2653    response: Response<proto::DeleteChannel>,
2654    session: UserSession,
2655) -> Result<()> {
2656    let db = session.db().await;
2657
2658    let channel_id = request.channel_id;
2659    let (root_channel, removed_channels) = db
2660        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2661        .await?;
2662    response.send(proto::Ack {})?;
2663
2664    // Notify members of removed channels
2665    let mut update = proto::UpdateChannels::default();
2666    update
2667        .delete_channels
2668        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2669
2670    let connection_pool = session.connection_pool().await;
2671    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2672        session.peer.send(connection_id, update.clone())?;
2673    }
2674
2675    Ok(())
2676}
2677
2678/// Invite someone to join a channel.
2679async fn invite_channel_member(
2680    request: proto::InviteChannelMember,
2681    response: Response<proto::InviteChannelMember>,
2682    session: UserSession,
2683) -> Result<()> {
2684    let db = session.db().await;
2685    let channel_id = ChannelId::from_proto(request.channel_id);
2686    let invitee_id = UserId::from_proto(request.user_id);
2687    let InviteMemberResult {
2688        channel,
2689        notifications,
2690    } = db
2691        .invite_channel_member(
2692            channel_id,
2693            invitee_id,
2694            session.user_id(),
2695            request.role().into(),
2696        )
2697        .await?;
2698
2699    let update = proto::UpdateChannels {
2700        channel_invitations: vec![channel.to_proto()],
2701        ..Default::default()
2702    };
2703
2704    let connection_pool = session.connection_pool().await;
2705    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2706        session.peer.send(connection_id, update.clone())?;
2707    }
2708
2709    send_notifications(&connection_pool, &session.peer, notifications);
2710
2711    response.send(proto::Ack {})?;
2712    Ok(())
2713}
2714
2715/// remove someone from a channel
2716async fn remove_channel_member(
2717    request: proto::RemoveChannelMember,
2718    response: Response<proto::RemoveChannelMember>,
2719    session: UserSession,
2720) -> Result<()> {
2721    let db = session.db().await;
2722    let channel_id = ChannelId::from_proto(request.channel_id);
2723    let member_id = UserId::from_proto(request.user_id);
2724
2725    let RemoveChannelMemberResult {
2726        membership_update,
2727        notification_id,
2728    } = db
2729        .remove_channel_member(channel_id, member_id, session.user_id())
2730        .await?;
2731
2732    let mut connection_pool = session.connection_pool().await;
2733    notify_membership_updated(
2734        &mut connection_pool,
2735        membership_update,
2736        member_id,
2737        &session.peer,
2738    );
2739    for connection_id in connection_pool.user_connection_ids(member_id) {
2740        if let Some(notification_id) = notification_id {
2741            session
2742                .peer
2743                .send(
2744                    connection_id,
2745                    proto::DeleteNotification {
2746                        notification_id: notification_id.to_proto(),
2747                    },
2748                )
2749                .trace_err();
2750        }
2751    }
2752
2753    response.send(proto::Ack {})?;
2754    Ok(())
2755}
2756
2757/// Toggle the channel between public and private.
2758/// Care is taken to maintain the invariant that public channels only descend from public channels,
2759/// (though members-only channels can appear at any point in the hierarchy).
2760async fn set_channel_visibility(
2761    request: proto::SetChannelVisibility,
2762    response: Response<proto::SetChannelVisibility>,
2763    session: UserSession,
2764) -> Result<()> {
2765    let db = session.db().await;
2766    let channel_id = ChannelId::from_proto(request.channel_id);
2767    let visibility = request.visibility().into();
2768
2769    let channel_model = db
2770        .set_channel_visibility(channel_id, visibility, session.user_id())
2771        .await?;
2772    let root_id = channel_model.root_id();
2773    let channel = Channel::from_model(channel_model);
2774
2775    let mut connection_pool = session.connection_pool().await;
2776    for (user_id, role) in connection_pool
2777        .channel_user_ids(root_id)
2778        .collect::<Vec<_>>()
2779        .into_iter()
2780    {
2781        let update = if role.can_see_channel(channel.visibility) {
2782            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2783            proto::UpdateChannels {
2784                channels: vec![channel.to_proto()],
2785                ..Default::default()
2786            }
2787        } else {
2788            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2789            proto::UpdateChannels {
2790                delete_channels: vec![channel.id.to_proto()],
2791                ..Default::default()
2792            }
2793        };
2794
2795        for connection_id in connection_pool.user_connection_ids(user_id) {
2796            session.peer.send(connection_id, update.clone())?;
2797        }
2798    }
2799
2800    response.send(proto::Ack {})?;
2801    Ok(())
2802}
2803
2804/// Alter the role for a user in the channel.
2805async fn set_channel_member_role(
2806    request: proto::SetChannelMemberRole,
2807    response: Response<proto::SetChannelMemberRole>,
2808    session: UserSession,
2809) -> Result<()> {
2810    let db = session.db().await;
2811    let channel_id = ChannelId::from_proto(request.channel_id);
2812    let member_id = UserId::from_proto(request.user_id);
2813    let result = db
2814        .set_channel_member_role(
2815            channel_id,
2816            session.user_id(),
2817            member_id,
2818            request.role().into(),
2819        )
2820        .await?;
2821
2822    match result {
2823        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2824            let mut connection_pool = session.connection_pool().await;
2825            notify_membership_updated(
2826                &mut connection_pool,
2827                membership_update,
2828                member_id,
2829                &session.peer,
2830            )
2831        }
2832        db::SetMemberRoleResult::InviteUpdated(channel) => {
2833            let update = proto::UpdateChannels {
2834                channel_invitations: vec![channel.to_proto()],
2835                ..Default::default()
2836            };
2837
2838            for connection_id in session
2839                .connection_pool()
2840                .await
2841                .user_connection_ids(member_id)
2842            {
2843                session.peer.send(connection_id, update.clone())?;
2844            }
2845        }
2846    }
2847
2848    response.send(proto::Ack {})?;
2849    Ok(())
2850}
2851
2852/// Change the name of a channel
2853async fn rename_channel(
2854    request: proto::RenameChannel,
2855    response: Response<proto::RenameChannel>,
2856    session: UserSession,
2857) -> Result<()> {
2858    let db = session.db().await;
2859    let channel_id = ChannelId::from_proto(request.channel_id);
2860    let channel_model = db
2861        .rename_channel(channel_id, session.user_id(), &request.name)
2862        .await?;
2863    let root_id = channel_model.root_id();
2864    let channel = Channel::from_model(channel_model);
2865
2866    response.send(proto::RenameChannelResponse {
2867        channel: Some(channel.to_proto()),
2868    })?;
2869
2870    let connection_pool = session.connection_pool().await;
2871    let update = proto::UpdateChannels {
2872        channels: vec![channel.to_proto()],
2873        ..Default::default()
2874    };
2875    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2876        if role.can_see_channel(channel.visibility) {
2877            session.peer.send(connection_id, update.clone())?;
2878        }
2879    }
2880
2881    Ok(())
2882}
2883
2884/// Move a channel to a new parent.
2885async fn move_channel(
2886    request: proto::MoveChannel,
2887    response: Response<proto::MoveChannel>,
2888    session: UserSession,
2889) -> Result<()> {
2890    let channel_id = ChannelId::from_proto(request.channel_id);
2891    let to = ChannelId::from_proto(request.to);
2892
2893    let (root_id, channels) = session
2894        .db()
2895        .await
2896        .move_channel(channel_id, to, session.user_id())
2897        .await?;
2898
2899    let connection_pool = session.connection_pool().await;
2900    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2901        let channels = channels
2902            .iter()
2903            .filter_map(|channel| {
2904                if role.can_see_channel(channel.visibility) {
2905                    Some(channel.to_proto())
2906                } else {
2907                    None
2908                }
2909            })
2910            .collect::<Vec<_>>();
2911        if channels.is_empty() {
2912            continue;
2913        }
2914
2915        let update = proto::UpdateChannels {
2916            channels,
2917            ..Default::default()
2918        };
2919
2920        session.peer.send(connection_id, update.clone())?;
2921    }
2922
2923    response.send(Ack {})?;
2924    Ok(())
2925}
2926
2927/// Get the list of channel members
2928async fn get_channel_members(
2929    request: proto::GetChannelMembers,
2930    response: Response<proto::GetChannelMembers>,
2931    session: UserSession,
2932) -> Result<()> {
2933    let db = session.db().await;
2934    let channel_id = ChannelId::from_proto(request.channel_id);
2935    let members = db
2936        .get_channel_participant_details(channel_id, session.user_id())
2937        .await?;
2938    response.send(proto::GetChannelMembersResponse { members })?;
2939    Ok(())
2940}
2941
2942/// Accept or decline a channel invitation.
2943async fn respond_to_channel_invite(
2944    request: proto::RespondToChannelInvite,
2945    response: Response<proto::RespondToChannelInvite>,
2946    session: UserSession,
2947) -> Result<()> {
2948    let db = session.db().await;
2949    let channel_id = ChannelId::from_proto(request.channel_id);
2950    let RespondToChannelInvite {
2951        membership_update,
2952        notifications,
2953    } = db
2954        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2955        .await?;
2956
2957    let mut connection_pool = session.connection_pool().await;
2958    if let Some(membership_update) = membership_update {
2959        notify_membership_updated(
2960            &mut connection_pool,
2961            membership_update,
2962            session.user_id(),
2963            &session.peer,
2964        );
2965    } else {
2966        let update = proto::UpdateChannels {
2967            remove_channel_invitations: vec![channel_id.to_proto()],
2968            ..Default::default()
2969        };
2970
2971        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2972            session.peer.send(connection_id, update.clone())?;
2973        }
2974    };
2975
2976    send_notifications(&connection_pool, &session.peer, notifications);
2977
2978    response.send(proto::Ack {})?;
2979
2980    Ok(())
2981}
2982
2983/// Join the channels' room
2984async fn join_channel(
2985    request: proto::JoinChannel,
2986    response: Response<proto::JoinChannel>,
2987    session: UserSession,
2988) -> Result<()> {
2989    let channel_id = ChannelId::from_proto(request.channel_id);
2990    join_channel_internal(channel_id, Box::new(response), session).await
2991}
2992
2993trait JoinChannelInternalResponse {
2994    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2995}
2996impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2997    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2998        Response::<proto::JoinChannel>::send(self, result)
2999    }
3000}
3001impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3002    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3003        Response::<proto::JoinRoom>::send(self, result)
3004    }
3005}
3006
3007async fn join_channel_internal(
3008    channel_id: ChannelId,
3009    response: Box<impl JoinChannelInternalResponse>,
3010    session: UserSession,
3011) -> Result<()> {
3012    let joined_room = {
3013        leave_room_for_session(&session).await?;
3014        let db = session.db().await;
3015
3016        let (joined_room, membership_updated, role) = db
3017            .join_channel(channel_id, session.user_id(), session.connection_id)
3018            .await?;
3019
3020        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3021            let (can_publish, token) = if role == ChannelRole::Guest {
3022                (
3023                    false,
3024                    live_kit
3025                        .guest_token(
3026                            &joined_room.room.live_kit_room,
3027                            &session.user_id().to_string(),
3028                        )
3029                        .trace_err()?,
3030                )
3031            } else {
3032                (
3033                    true,
3034                    live_kit
3035                        .room_token(
3036                            &joined_room.room.live_kit_room,
3037                            &session.user_id().to_string(),
3038                        )
3039                        .trace_err()?,
3040                )
3041            };
3042
3043            Some(LiveKitConnectionInfo {
3044                server_url: live_kit.url().into(),
3045                token,
3046                can_publish,
3047            })
3048        });
3049
3050        response.send(proto::JoinRoomResponse {
3051            room: Some(joined_room.room.clone()),
3052            channel_id: joined_room
3053                .channel
3054                .as_ref()
3055                .map(|channel| channel.id.to_proto()),
3056            live_kit_connection_info,
3057        })?;
3058
3059        let mut connection_pool = session.connection_pool().await;
3060        if let Some(membership_updated) = membership_updated {
3061            notify_membership_updated(
3062                &mut connection_pool,
3063                membership_updated,
3064                session.user_id(),
3065                &session.peer,
3066            );
3067        }
3068
3069        room_updated(&joined_room.room, &session.peer);
3070
3071        joined_room
3072    };
3073
3074    channel_updated(
3075        &joined_room
3076            .channel
3077            .ok_or_else(|| anyhow!("channel not returned"))?,
3078        &joined_room.room,
3079        &session.peer,
3080        &*session.connection_pool().await,
3081    );
3082
3083    update_user_contacts(session.user_id(), &session).await?;
3084    Ok(())
3085}
3086
3087/// Start editing the channel notes
3088async fn join_channel_buffer(
3089    request: proto::JoinChannelBuffer,
3090    response: Response<proto::JoinChannelBuffer>,
3091    session: UserSession,
3092) -> Result<()> {
3093    let db = session.db().await;
3094    let channel_id = ChannelId::from_proto(request.channel_id);
3095
3096    let open_response = db
3097        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3098        .await?;
3099
3100    let collaborators = open_response.collaborators.clone();
3101    response.send(open_response)?;
3102
3103    let update = UpdateChannelBufferCollaborators {
3104        channel_id: channel_id.to_proto(),
3105        collaborators: collaborators.clone(),
3106    };
3107    channel_buffer_updated(
3108        session.connection_id,
3109        collaborators
3110            .iter()
3111            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3112        &update,
3113        &session.peer,
3114    );
3115
3116    Ok(())
3117}
3118
3119/// Edit the channel notes
3120async fn update_channel_buffer(
3121    request: proto::UpdateChannelBuffer,
3122    session: UserSession,
3123) -> Result<()> {
3124    let db = session.db().await;
3125    let channel_id = ChannelId::from_proto(request.channel_id);
3126
3127    let (collaborators, non_collaborators, epoch, version) = db
3128        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3129        .await?;
3130
3131    channel_buffer_updated(
3132        session.connection_id,
3133        collaborators,
3134        &proto::UpdateChannelBuffer {
3135            channel_id: channel_id.to_proto(),
3136            operations: request.operations,
3137        },
3138        &session.peer,
3139    );
3140
3141    let pool = &*session.connection_pool().await;
3142
3143    broadcast(
3144        None,
3145        non_collaborators
3146            .iter()
3147            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3148        |peer_id| {
3149            session.peer.send(
3150                peer_id,
3151                proto::UpdateChannels {
3152                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3153                        channel_id: channel_id.to_proto(),
3154                        epoch: epoch as u64,
3155                        version: version.clone(),
3156                    }],
3157                    ..Default::default()
3158                },
3159            )
3160        },
3161    );
3162
3163    Ok(())
3164}
3165
3166/// Rejoin the channel notes after a connection blip
3167async fn rejoin_channel_buffers(
3168    request: proto::RejoinChannelBuffers,
3169    response: Response<proto::RejoinChannelBuffers>,
3170    session: UserSession,
3171) -> Result<()> {
3172    let db = session.db().await;
3173    let buffers = db
3174        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3175        .await?;
3176
3177    for rejoined_buffer in &buffers {
3178        let collaborators_to_notify = rejoined_buffer
3179            .buffer
3180            .collaborators
3181            .iter()
3182            .filter_map(|c| Some(c.peer_id?.into()));
3183        channel_buffer_updated(
3184            session.connection_id,
3185            collaborators_to_notify,
3186            &proto::UpdateChannelBufferCollaborators {
3187                channel_id: rejoined_buffer.buffer.channel_id,
3188                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3189            },
3190            &session.peer,
3191        );
3192    }
3193
3194    response.send(proto::RejoinChannelBuffersResponse {
3195        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3196    })?;
3197
3198    Ok(())
3199}
3200
3201/// Stop editing the channel notes
3202async fn leave_channel_buffer(
3203    request: proto::LeaveChannelBuffer,
3204    response: Response<proto::LeaveChannelBuffer>,
3205    session: UserSession,
3206) -> Result<()> {
3207    let db = session.db().await;
3208    let channel_id = ChannelId::from_proto(request.channel_id);
3209
3210    let left_buffer = db
3211        .leave_channel_buffer(channel_id, session.connection_id)
3212        .await?;
3213
3214    response.send(Ack {})?;
3215
3216    channel_buffer_updated(
3217        session.connection_id,
3218        left_buffer.connections,
3219        &proto::UpdateChannelBufferCollaborators {
3220            channel_id: channel_id.to_proto(),
3221            collaborators: left_buffer.collaborators,
3222        },
3223        &session.peer,
3224    );
3225
3226    Ok(())
3227}
3228
3229fn channel_buffer_updated<T: EnvelopedMessage>(
3230    sender_id: ConnectionId,
3231    collaborators: impl IntoIterator<Item = ConnectionId>,
3232    message: &T,
3233    peer: &Peer,
3234) {
3235    broadcast(Some(sender_id), collaborators, |peer_id| {
3236        peer.send(peer_id, message.clone())
3237    });
3238}
3239
3240fn send_notifications(
3241    connection_pool: &ConnectionPool,
3242    peer: &Peer,
3243    notifications: db::NotificationBatch,
3244) {
3245    for (user_id, notification) in notifications {
3246        for connection_id in connection_pool.user_connection_ids(user_id) {
3247            if let Err(error) = peer.send(
3248                connection_id,
3249                proto::AddNotification {
3250                    notification: Some(notification.clone()),
3251                },
3252            ) {
3253                tracing::error!(
3254                    "failed to send notification to {:?} {}",
3255                    connection_id,
3256                    error
3257                );
3258            }
3259        }
3260    }
3261}
3262
3263/// Send a message to the channel
3264async fn send_channel_message(
3265    request: proto::SendChannelMessage,
3266    response: Response<proto::SendChannelMessage>,
3267    session: UserSession,
3268) -> Result<()> {
3269    // Validate the message body.
3270    let body = request.body.trim().to_string();
3271    if body.len() > MAX_MESSAGE_LEN {
3272        return Err(anyhow!("message is too long"))?;
3273    }
3274    if body.is_empty() {
3275        return Err(anyhow!("message can't be blank"))?;
3276    }
3277
3278    // TODO: adjust mentions if body is trimmed
3279
3280    let timestamp = OffsetDateTime::now_utc();
3281    let nonce = request
3282        .nonce
3283        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3284
3285    let channel_id = ChannelId::from_proto(request.channel_id);
3286    let CreatedChannelMessage {
3287        message_id,
3288        participant_connection_ids,
3289        channel_members,
3290        notifications,
3291    } = session
3292        .db()
3293        .await
3294        .create_channel_message(
3295            channel_id,
3296            session.user_id(),
3297            &body,
3298            &request.mentions,
3299            timestamp,
3300            nonce.clone().into(),
3301            match request.reply_to_message_id {
3302                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3303                None => None,
3304            },
3305        )
3306        .await?;
3307
3308    let message = proto::ChannelMessage {
3309        sender_id: session.user_id().to_proto(),
3310        id: message_id.to_proto(),
3311        body,
3312        mentions: request.mentions,
3313        timestamp: timestamp.unix_timestamp() as u64,
3314        nonce: Some(nonce),
3315        reply_to_message_id: request.reply_to_message_id,
3316        edited_at: None,
3317    };
3318    broadcast(
3319        Some(session.connection_id),
3320        participant_connection_ids,
3321        |connection| {
3322            session.peer.send(
3323                connection,
3324                proto::ChannelMessageSent {
3325                    channel_id: channel_id.to_proto(),
3326                    message: Some(message.clone()),
3327                },
3328            )
3329        },
3330    );
3331    response.send(proto::SendChannelMessageResponse {
3332        message: Some(message),
3333    })?;
3334
3335    let pool = &*session.connection_pool().await;
3336    broadcast(
3337        None,
3338        channel_members
3339            .iter()
3340            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3341        |peer_id| {
3342            session.peer.send(
3343                peer_id,
3344                proto::UpdateChannels {
3345                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3346                        channel_id: channel_id.to_proto(),
3347                        message_id: message_id.to_proto(),
3348                    }],
3349                    ..Default::default()
3350                },
3351            )
3352        },
3353    );
3354    send_notifications(pool, &session.peer, notifications);
3355
3356    Ok(())
3357}
3358
3359/// Delete a channel message
3360async fn remove_channel_message(
3361    request: proto::RemoveChannelMessage,
3362    response: Response<proto::RemoveChannelMessage>,
3363    session: UserSession,
3364) -> Result<()> {
3365    let channel_id = ChannelId::from_proto(request.channel_id);
3366    let message_id = MessageId::from_proto(request.message_id);
3367    let connection_ids = session
3368        .db()
3369        .await
3370        .remove_channel_message(channel_id, message_id, session.user_id())
3371        .await?;
3372    broadcast(Some(session.connection_id), connection_ids, |connection| {
3373        session.peer.send(connection, request.clone())
3374    });
3375    response.send(proto::Ack {})?;
3376    Ok(())
3377}
3378
3379async fn update_channel_message(
3380    request: proto::UpdateChannelMessage,
3381    response: Response<proto::UpdateChannelMessage>,
3382    session: UserSession,
3383) -> Result<()> {
3384    let channel_id = ChannelId::from_proto(request.channel_id);
3385    let message_id = MessageId::from_proto(request.message_id);
3386    let updated_at = OffsetDateTime::now_utc();
3387    let UpdatedChannelMessage {
3388        message_id,
3389        participant_connection_ids,
3390        notifications,
3391        reply_to_message_id,
3392        timestamp,
3393    } = session
3394        .db()
3395        .await
3396        .update_channel_message(
3397            channel_id,
3398            message_id,
3399            session.user_id(),
3400            request.body.as_str(),
3401            &request.mentions,
3402            updated_at,
3403        )
3404        .await?;
3405
3406    let nonce = request
3407        .nonce
3408        .clone()
3409        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3410
3411    let message = proto::ChannelMessage {
3412        sender_id: session.user_id().to_proto(),
3413        id: message_id.to_proto(),
3414        body: request.body.clone(),
3415        mentions: request.mentions.clone(),
3416        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3417        nonce: Some(nonce),
3418        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3419        edited_at: Some(updated_at.unix_timestamp() as u64),
3420    };
3421
3422    response.send(proto::Ack {})?;
3423
3424    let pool = &*session.connection_pool().await;
3425    broadcast(
3426        Some(session.connection_id),
3427        participant_connection_ids,
3428        |connection| {
3429            session.peer.send(
3430                connection,
3431                proto::ChannelMessageUpdate {
3432                    channel_id: channel_id.to_proto(),
3433                    message: Some(message.clone()),
3434                },
3435            )
3436        },
3437    );
3438
3439    send_notifications(pool, &session.peer, notifications);
3440
3441    Ok(())
3442}
3443
3444/// Mark a channel message as read
3445async fn acknowledge_channel_message(
3446    request: proto::AckChannelMessage,
3447    session: UserSession,
3448) -> Result<()> {
3449    let channel_id = ChannelId::from_proto(request.channel_id);
3450    let message_id = MessageId::from_proto(request.message_id);
3451    let notifications = session
3452        .db()
3453        .await
3454        .observe_channel_message(channel_id, session.user_id(), message_id)
3455        .await?;
3456    send_notifications(
3457        &*session.connection_pool().await,
3458        &session.peer,
3459        notifications,
3460    );
3461    Ok(())
3462}
3463
3464/// Mark a buffer version as synced
3465async fn acknowledge_buffer_version(
3466    request: proto::AckBufferOperation,
3467    session: UserSession,
3468) -> Result<()> {
3469    let buffer_id = BufferId::from_proto(request.buffer_id);
3470    session
3471        .db()
3472        .await
3473        .observe_buffer_version(
3474            buffer_id,
3475            session.user_id(),
3476            request.epoch as i32,
3477            &request.version,
3478        )
3479        .await?;
3480    Ok(())
3481}
3482
3483struct CompleteWithLanguageModelRateLimit;
3484
3485impl RateLimit for CompleteWithLanguageModelRateLimit {
3486    fn capacity() -> usize {
3487        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3488            .ok()
3489            .and_then(|v| v.parse().ok())
3490            .unwrap_or(120) // Picked arbitrarily
3491    }
3492
3493    fn refill_duration() -> chrono::Duration {
3494        chrono::Duration::hours(1)
3495    }
3496
3497    fn db_name() -> &'static str {
3498        "complete-with-language-model"
3499    }
3500}
3501
3502async fn complete_with_language_model(
3503    request: proto::CompleteWithLanguageModel,
3504    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3505    session: Session,
3506    open_ai_api_key: Option<Arc<str>>,
3507    google_ai_api_key: Option<Arc<str>>,
3508) -> Result<()> {
3509    let Some(session) = session.for_user() else {
3510        return Err(anyhow!("user not found"))?;
3511    };
3512    authorize_access_to_language_models(&session).await?;
3513    session
3514        .rate_limiter
3515        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3516        .await?;
3517
3518    if request.model.starts_with("gpt") {
3519        let api_key =
3520            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3521        complete_with_open_ai(request, response, session, api_key).await?;
3522    } else if request.model.starts_with("gemini") {
3523        let api_key = google_ai_api_key
3524            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3525        complete_with_google_ai(request, response, session, api_key).await?;
3526    }
3527
3528    Ok(())
3529}
3530
3531async fn complete_with_open_ai(
3532    request: proto::CompleteWithLanguageModel,
3533    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3534    session: UserSession,
3535    api_key: Arc<str>,
3536) -> Result<()> {
3537    const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3538
3539    let mut completion_stream = open_ai::stream_completion(
3540        &session.http_client,
3541        OPEN_AI_API_URL,
3542        &api_key,
3543        crate::ai::language_model_request_to_open_ai(request)?,
3544    )
3545    .await
3546    .context("open_ai::stream_completion request failed")?;
3547
3548    while let Some(event) = completion_stream.next().await {
3549        let event = event?;
3550        response.send(proto::LanguageModelResponse {
3551            choices: event
3552                .choices
3553                .into_iter()
3554                .map(|choice| proto::LanguageModelChoiceDelta {
3555                    index: choice.index,
3556                    delta: Some(proto::LanguageModelResponseMessage {
3557                        role: choice.delta.role.map(|role| match role {
3558                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3559                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3560                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3561                        } as i32),
3562                        content: choice.delta.content,
3563                    }),
3564                    finish_reason: choice.finish_reason,
3565                })
3566                .collect(),
3567        })?;
3568    }
3569
3570    Ok(())
3571}
3572
3573async fn complete_with_google_ai(
3574    request: proto::CompleteWithLanguageModel,
3575    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3576    session: UserSession,
3577    api_key: Arc<str>,
3578) -> Result<()> {
3579    let mut stream = google_ai::stream_generate_content(
3580        &session.http_client,
3581        google_ai::API_URL,
3582        api_key.as_ref(),
3583        crate::ai::language_model_request_to_google_ai(request)?,
3584    )
3585    .await
3586    .context("google_ai::stream_generate_content request failed")?;
3587
3588    while let Some(event) = stream.next().await {
3589        let event = event?;
3590        response.send(proto::LanguageModelResponse {
3591            choices: event
3592                .candidates
3593                .unwrap_or_default()
3594                .into_iter()
3595                .map(|candidate| proto::LanguageModelChoiceDelta {
3596                    index: candidate.index as u32,
3597                    delta: Some(proto::LanguageModelResponseMessage {
3598                        role: Some(match candidate.content.role {
3599                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3600                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3601                        } as i32),
3602                        content: Some(
3603                            candidate
3604                                .content
3605                                .parts
3606                                .into_iter()
3607                                .filter_map(|part| match part {
3608                                    google_ai::Part::TextPart(part) => Some(part.text),
3609                                    google_ai::Part::InlineDataPart(_) => None,
3610                                })
3611                                .collect(),
3612                        ),
3613                    }),
3614                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3615                })
3616                .collect(),
3617        })?;
3618    }
3619
3620    Ok(())
3621}
3622
3623struct CountTokensWithLanguageModelRateLimit;
3624
3625impl RateLimit for CountTokensWithLanguageModelRateLimit {
3626    fn capacity() -> usize {
3627        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3628            .ok()
3629            .and_then(|v| v.parse().ok())
3630            .unwrap_or(600) // Picked arbitrarily
3631    }
3632
3633    fn refill_duration() -> chrono::Duration {
3634        chrono::Duration::hours(1)
3635    }
3636
3637    fn db_name() -> &'static str {
3638        "count-tokens-with-language-model"
3639    }
3640}
3641
3642async fn count_tokens_with_language_model(
3643    request: proto::CountTokensWithLanguageModel,
3644    response: Response<proto::CountTokensWithLanguageModel>,
3645    session: UserSession,
3646    google_ai_api_key: Option<Arc<str>>,
3647) -> Result<()> {
3648    authorize_access_to_language_models(&session).await?;
3649
3650    if !request.model.starts_with("gemini") {
3651        return Err(anyhow!(
3652            "counting tokens for model: {:?} is not supported",
3653            request.model
3654        ))?;
3655    }
3656
3657    session
3658        .rate_limiter
3659        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3660        .await?;
3661
3662    let api_key = google_ai_api_key
3663        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3664    let tokens_response = google_ai::count_tokens(
3665        &session.http_client,
3666        google_ai::API_URL,
3667        &api_key,
3668        crate::ai::count_tokens_request_to_google_ai(request)?,
3669    )
3670    .await?;
3671    response.send(proto::CountTokensResponse {
3672        token_count: tokens_response.total_tokens as u32,
3673    })?;
3674    Ok(())
3675}
3676
3677async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3678    let db = session.db().await;
3679    let flags = db.get_user_flags(session.user_id()).await?;
3680    if flags.iter().any(|flag| flag == "language-models") {
3681        Ok(())
3682    } else {
3683        Err(anyhow!("permission denied"))?
3684    }
3685}
3686
3687/// Start receiving chat updates for a channel
3688async fn join_channel_chat(
3689    request: proto::JoinChannelChat,
3690    response: Response<proto::JoinChannelChat>,
3691    session: UserSession,
3692) -> Result<()> {
3693    let channel_id = ChannelId::from_proto(request.channel_id);
3694
3695    let db = session.db().await;
3696    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3697        .await?;
3698    let messages = db
3699        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3700        .await?;
3701    response.send(proto::JoinChannelChatResponse {
3702        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3703        messages,
3704    })?;
3705    Ok(())
3706}
3707
3708/// Stop receiving chat updates for a channel
3709async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3710    let channel_id = ChannelId::from_proto(request.channel_id);
3711    session
3712        .db()
3713        .await
3714        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3715        .await?;
3716    Ok(())
3717}
3718
3719/// Retrieve the chat history for a channel
3720async fn get_channel_messages(
3721    request: proto::GetChannelMessages,
3722    response: Response<proto::GetChannelMessages>,
3723    session: UserSession,
3724) -> Result<()> {
3725    let channel_id = ChannelId::from_proto(request.channel_id);
3726    let messages = session
3727        .db()
3728        .await
3729        .get_channel_messages(
3730            channel_id,
3731            session.user_id(),
3732            MESSAGE_COUNT_PER_PAGE,
3733            Some(MessageId::from_proto(request.before_message_id)),
3734        )
3735        .await?;
3736    response.send(proto::GetChannelMessagesResponse {
3737        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3738        messages,
3739    })?;
3740    Ok(())
3741}
3742
3743/// Retrieve specific chat messages
3744async fn get_channel_messages_by_id(
3745    request: proto::GetChannelMessagesById,
3746    response: Response<proto::GetChannelMessagesById>,
3747    session: UserSession,
3748) -> Result<()> {
3749    let message_ids = request
3750        .message_ids
3751        .iter()
3752        .map(|id| MessageId::from_proto(*id))
3753        .collect::<Vec<_>>();
3754    let messages = session
3755        .db()
3756        .await
3757        .get_channel_messages_by_id(session.user_id(), &message_ids)
3758        .await?;
3759    response.send(proto::GetChannelMessagesResponse {
3760        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3761        messages,
3762    })?;
3763    Ok(())
3764}
3765
3766/// Retrieve the current users notifications
3767async fn get_notifications(
3768    request: proto::GetNotifications,
3769    response: Response<proto::GetNotifications>,
3770    session: UserSession,
3771) -> Result<()> {
3772    let notifications = session
3773        .db()
3774        .await
3775        .get_notifications(
3776            session.user_id(),
3777            NOTIFICATION_COUNT_PER_PAGE,
3778            request
3779                .before_id
3780                .map(|id| db::NotificationId::from_proto(id)),
3781        )
3782        .await?;
3783    response.send(proto::GetNotificationsResponse {
3784        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3785        notifications,
3786    })?;
3787    Ok(())
3788}
3789
3790/// Mark notifications as read
3791async fn mark_notification_as_read(
3792    request: proto::MarkNotificationRead,
3793    response: Response<proto::MarkNotificationRead>,
3794    session: UserSession,
3795) -> Result<()> {
3796    let database = &session.db().await;
3797    let notifications = database
3798        .mark_notification_as_read_by_id(
3799            session.user_id(),
3800            NotificationId::from_proto(request.notification_id),
3801        )
3802        .await?;
3803    send_notifications(
3804        &*session.connection_pool().await,
3805        &session.peer,
3806        notifications,
3807    );
3808    response.send(proto::Ack {})?;
3809    Ok(())
3810}
3811
3812/// Get the current users information
3813async fn get_private_user_info(
3814    _request: proto::GetPrivateUserInfo,
3815    response: Response<proto::GetPrivateUserInfo>,
3816    session: UserSession,
3817) -> Result<()> {
3818    let db = session.db().await;
3819
3820    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3821    let user = db
3822        .get_user_by_id(session.user_id())
3823        .await?
3824        .ok_or_else(|| anyhow!("user not found"))?;
3825    let flags = db.get_user_flags(session.user_id()).await?;
3826
3827    response.send(proto::GetPrivateUserInfoResponse {
3828        metrics_id,
3829        staff: user.admin,
3830        flags,
3831    })?;
3832    Ok(())
3833}
3834
3835fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3836    match message {
3837        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3838        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3839        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3840        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3841        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3842            code: frame.code.into(),
3843            reason: frame.reason,
3844        })),
3845    }
3846}
3847
3848fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3849    match message {
3850        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3851        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3852        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3853        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3854        AxumMessage::Close(frame) => {
3855            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3856                code: frame.code.into(),
3857                reason: frame.reason,
3858            }))
3859        }
3860    }
3861}
3862
3863fn notify_membership_updated(
3864    connection_pool: &mut ConnectionPool,
3865    result: MembershipUpdated,
3866    user_id: UserId,
3867    peer: &Peer,
3868) {
3869    for membership in &result.new_channels.channel_memberships {
3870        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3871    }
3872    for channel_id in &result.removed_channels {
3873        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3874    }
3875
3876    let user_channels_update = proto::UpdateUserChannels {
3877        channel_memberships: result
3878            .new_channels
3879            .channel_memberships
3880            .iter()
3881            .map(|cm| proto::ChannelMembership {
3882                channel_id: cm.channel_id.to_proto(),
3883                role: cm.role.into(),
3884            })
3885            .collect(),
3886        ..Default::default()
3887    };
3888
3889    let mut update = build_channels_update(result.new_channels, vec![]);
3890    update.delete_channels = result
3891        .removed_channels
3892        .into_iter()
3893        .map(|id| id.to_proto())
3894        .collect();
3895    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3896
3897    for connection_id in connection_pool.user_connection_ids(user_id) {
3898        peer.send(connection_id, user_channels_update.clone())
3899            .trace_err();
3900        peer.send(connection_id, update.clone()).trace_err();
3901    }
3902}
3903
3904fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3905    proto::UpdateUserChannels {
3906        channel_memberships: channels
3907            .channel_memberships
3908            .iter()
3909            .map(|m| proto::ChannelMembership {
3910                channel_id: m.channel_id.to_proto(),
3911                role: m.role.into(),
3912            })
3913            .collect(),
3914        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3915        observed_channel_message_id: channels.observed_channel_messages.clone(),
3916    }
3917}
3918
3919fn build_channels_update(
3920    channels: ChannelsForUser,
3921    channel_invites: Vec<db::Channel>,
3922) -> proto::UpdateChannels {
3923    let mut update = proto::UpdateChannels::default();
3924
3925    for channel in channels.channels {
3926        update.channels.push(channel.to_proto());
3927    }
3928
3929    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3930    update.latest_channel_message_ids = channels.latest_channel_messages;
3931
3932    for (channel_id, participants) in channels.channel_participants {
3933        update
3934            .channel_participants
3935            .push(proto::ChannelParticipants {
3936                channel_id: channel_id.to_proto(),
3937                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3938            });
3939    }
3940
3941    for channel in channel_invites {
3942        update.channel_invitations.push(channel.to_proto());
3943    }
3944    for project in channels.hosted_projects {
3945        update.hosted_projects.push(project);
3946    }
3947
3948    update
3949}
3950
3951fn build_initial_contacts_update(
3952    contacts: Vec<db::Contact>,
3953    pool: &ConnectionPool,
3954) -> proto::UpdateContacts {
3955    let mut update = proto::UpdateContacts::default();
3956
3957    for contact in contacts {
3958        match contact {
3959            db::Contact::Accepted { user_id, busy } => {
3960                update.contacts.push(contact_for_user(user_id, busy, &pool));
3961            }
3962            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3963            db::Contact::Incoming { user_id } => {
3964                update
3965                    .incoming_requests
3966                    .push(proto::IncomingContactRequest {
3967                        requester_id: user_id.to_proto(),
3968                    })
3969            }
3970        }
3971    }
3972
3973    update
3974}
3975
3976fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3977    proto::Contact {
3978        user_id: user_id.to_proto(),
3979        online: pool.is_user_online(user_id),
3980        busy,
3981    }
3982}
3983
3984fn room_updated(room: &proto::Room, peer: &Peer) {
3985    broadcast(
3986        None,
3987        room.participants
3988            .iter()
3989            .filter_map(|participant| Some(participant.peer_id?.into())),
3990        |peer_id| {
3991            peer.send(
3992                peer_id,
3993                proto::RoomUpdated {
3994                    room: Some(room.clone()),
3995                },
3996            )
3997        },
3998    );
3999}
4000
4001fn channel_updated(
4002    channel: &db::channel::Model,
4003    room: &proto::Room,
4004    peer: &Peer,
4005    pool: &ConnectionPool,
4006) {
4007    let participants = room
4008        .participants
4009        .iter()
4010        .map(|p| p.user_id)
4011        .collect::<Vec<_>>();
4012
4013    broadcast(
4014        None,
4015        pool.channel_connection_ids(channel.root_id())
4016            .filter_map(|(channel_id, role)| {
4017                role.can_see_channel(channel.visibility).then(|| channel_id)
4018            }),
4019        |peer_id| {
4020            peer.send(
4021                peer_id,
4022                proto::UpdateChannels {
4023                    channel_participants: vec![proto::ChannelParticipants {
4024                        channel_id: channel.id.to_proto(),
4025                        participant_user_ids: participants.clone(),
4026                    }],
4027                    ..Default::default()
4028                },
4029            )
4030        },
4031    );
4032}
4033
4034async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4035    let db = session.db().await;
4036
4037    let contacts = db.get_contacts(user_id).await?;
4038    let busy = db.is_user_busy(user_id).await?;
4039
4040    let pool = session.connection_pool().await;
4041    let updated_contact = contact_for_user(user_id, busy, &pool);
4042    for contact in contacts {
4043        if let db::Contact::Accepted {
4044            user_id: contact_user_id,
4045            ..
4046        } = contact
4047        {
4048            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4049                session
4050                    .peer
4051                    .send(
4052                        contact_conn_id,
4053                        proto::UpdateContacts {
4054                            contacts: vec![updated_contact.clone()],
4055                            remove_contacts: Default::default(),
4056                            incoming_requests: Default::default(),
4057                            remove_incoming_requests: Default::default(),
4058                            outgoing_requests: Default::default(),
4059                            remove_outgoing_requests: Default::default(),
4060                        },
4061                    )
4062                    .trace_err();
4063            }
4064        }
4065    }
4066    Ok(())
4067}
4068
4069async fn leave_room_for_session(session: &UserSession) -> Result<()> {
4070    let mut contacts_to_update = HashSet::default();
4071
4072    let room_id;
4073    let canceled_calls_to_user_ids;
4074    let live_kit_room;
4075    let delete_live_kit_room;
4076    let room;
4077    let channel;
4078
4079    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
4080        contacts_to_update.insert(session.user_id());
4081
4082        for project in left_room.left_projects.values() {
4083            project_left(project, session);
4084        }
4085
4086        room_id = RoomId::from_proto(left_room.room.id);
4087        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4088        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4089        delete_live_kit_room = left_room.deleted;
4090        room = mem::take(&mut left_room.room);
4091        channel = mem::take(&mut left_room.channel);
4092
4093        room_updated(&room, &session.peer);
4094    } else {
4095        return Ok(());
4096    }
4097
4098    if let Some(channel) = channel {
4099        channel_updated(
4100            &channel,
4101            &room,
4102            &session.peer,
4103            &*session.connection_pool().await,
4104        );
4105    }
4106
4107    {
4108        let pool = session.connection_pool().await;
4109        for canceled_user_id in canceled_calls_to_user_ids {
4110            for connection_id in pool.user_connection_ids(canceled_user_id) {
4111                session
4112                    .peer
4113                    .send(
4114                        connection_id,
4115                        proto::CallCanceled {
4116                            room_id: room_id.to_proto(),
4117                        },
4118                    )
4119                    .trace_err();
4120            }
4121            contacts_to_update.insert(canceled_user_id);
4122        }
4123    }
4124
4125    for contact_user_id in contacts_to_update {
4126        update_user_contacts(contact_user_id, &session).await?;
4127    }
4128
4129    if let Some(live_kit) = session.live_kit_client.as_ref() {
4130        live_kit
4131            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4132            .await
4133            .trace_err();
4134
4135        if delete_live_kit_room {
4136            live_kit.delete_room(live_kit_room).await.trace_err();
4137        }
4138    }
4139
4140    Ok(())
4141}
4142
4143async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4144    let left_channel_buffers = session
4145        .db()
4146        .await
4147        .leave_channel_buffers(session.connection_id)
4148        .await?;
4149
4150    for left_buffer in left_channel_buffers {
4151        channel_buffer_updated(
4152            session.connection_id,
4153            left_buffer.connections,
4154            &proto::UpdateChannelBufferCollaborators {
4155                channel_id: left_buffer.channel_id.to_proto(),
4156                collaborators: left_buffer.collaborators,
4157            },
4158            &session.peer,
4159        );
4160    }
4161
4162    Ok(())
4163}
4164
4165fn project_left(project: &db::LeftProject, session: &UserSession) {
4166    for connection_id in &project.connection_ids {
4167        if project.host_user_id == Some(session.user_id()) {
4168            session
4169                .peer
4170                .send(
4171                    *connection_id,
4172                    proto::UnshareProject {
4173                        project_id: project.id.to_proto(),
4174                    },
4175                )
4176                .trace_err();
4177        } else {
4178            session
4179                .peer
4180                .send(
4181                    *connection_id,
4182                    proto::RemoveProjectCollaborator {
4183                        project_id: project.id.to_proto(),
4184                        peer_id: Some(session.connection_id.into()),
4185                    },
4186                )
4187                .trace_err();
4188        }
4189    }
4190}
4191
4192pub trait ResultExt {
4193    type Ok;
4194
4195    fn trace_err(self) -> Option<Self::Ok>;
4196}
4197
4198impl<T, E> ResultExt for Result<T, E>
4199where
4200    E: std::fmt::Debug,
4201{
4202    type Ok = T;
4203
4204    fn trace_err(self) -> Option<T> {
4205        match self {
4206            Ok(value) => Some(value),
4207            Err(error) => {
4208                tracing::error!("{:?}", error);
4209                None
4210            }
4211        }
4212    }
4213}