rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self},
   5    db::{
   6        self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
   9        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, RateLimit, RateLimiter, Result,
  13};
  14use anyhow::{anyhow, Context as _};
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34
  35use futures::{
  36    channel::oneshot,
  37    future::{self, BoxFuture},
  38    stream::FuturesUnordered,
  39    FutureExt, SinkExt, StreamExt, TryStreamExt,
  40};
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  45        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48};
  49use semantic_version::SemanticVersion;
  50use serde::{Serialize, Serializer};
  51use std::{
  52    any::TypeId,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc, OnceLock,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{
  69    field::{self},
  70    info_span, instrument, Instrument,
  71};
  72use util::http::IsahcHttpClient;
  73
  74pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  75
  76// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  77pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  78
  79const MESSAGE_COUNT_PER_PAGE: usize = 100;
  80const MAX_MESSAGE_LEN: usize = 1024;
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82
  83type MessageHandler =
  84    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  85
  86struct Response<R> {
  87    peer: Arc<Peer>,
  88    receipt: Receipt<R>,
  89    responded: Arc<AtomicBool>,
  90}
  91
  92impl<R: RequestMessage> Response<R> {
  93    fn send(self, payload: R::Response) -> Result<()> {
  94        self.responded.store(true, SeqCst);
  95        self.peer.respond(self.receipt, payload)?;
  96        Ok(())
  97    }
  98}
  99
 100struct StreamingResponse<R: RequestMessage> {
 101    peer: Arc<Peer>,
 102    receipt: Receipt<R>,
 103}
 104
 105impl<R: RequestMessage> StreamingResponse<R> {
 106    fn send(&self, payload: R::Response) -> Result<()> {
 107        self.peer.respond(self.receipt, payload)?;
 108        Ok(())
 109    }
 110}
 111
 112#[derive(Clone, Debug)]
 113pub enum Principal {
 114    User(User),
 115    Impersonated { user: User, admin: User },
 116    DevServer(dev_server::Model),
 117}
 118
 119impl Principal {
 120    fn update_span(&self, span: &tracing::Span) {
 121        match &self {
 122            Principal::User(user) => {
 123                span.record("user_id", &user.id.0);
 124                span.record("login", &user.github_login);
 125            }
 126            Principal::Impersonated { user, admin } => {
 127                span.record("user_id", &user.id.0);
 128                span.record("login", &user.github_login);
 129                span.record("impersonator", &admin.github_login);
 130            }
 131            Principal::DevServer(dev_server) => {
 132                span.record("dev_server_id", &dev_server.id.0);
 133            }
 134        }
 135    }
 136}
 137
 138#[derive(Clone)]
 139struct Session {
 140    principal: Principal,
 141    connection_id: ConnectionId,
 142    db: Arc<tokio::sync::Mutex<DbHandle>>,
 143    peer: Arc<Peer>,
 144    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 145    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 146    http_client: IsahcHttpClient,
 147    rate_limiter: Arc<RateLimiter>,
 148    _executor: Executor,
 149}
 150
 151impl Session {
 152    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.db.lock().await;
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        guard
 159    }
 160
 161    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 162        #[cfg(test)]
 163        tokio::task::yield_now().await;
 164        let guard = self.connection_pool.lock();
 165        ConnectionPoolGuard {
 166            guard,
 167            _not_send: PhantomData,
 168        }
 169    }
 170
 171    fn for_user(self) -> Option<UserSession> {
 172        UserSession::new(self)
 173    }
 174
 175    fn user_id(&self) -> Option<UserId> {
 176        match &self.principal {
 177            Principal::User(user) => Some(user.id),
 178            Principal::Impersonated { user, .. } => Some(user.id),
 179            Principal::DevServer(_) => None,
 180        }
 181    }
 182}
 183
 184impl Debug for Session {
 185    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 186        let mut result = f.debug_struct("Session");
 187        match &self.principal {
 188            Principal::User(user) => {
 189                result.field("user", &user.github_login);
 190            }
 191            Principal::Impersonated { user, admin } => {
 192                result.field("user", &user.github_login);
 193                result.field("impersonator", &admin.github_login);
 194            }
 195            Principal::DevServer(dev_server) => {
 196                result.field("dev_server", &dev_server.id);
 197            }
 198        }
 199        result.field("connection_id", &self.connection_id).finish()
 200    }
 201}
 202
 203struct UserSession(Session);
 204
 205impl UserSession {
 206    pub fn new(s: Session) -> Option<Self> {
 207        s.user_id().map(|_| UserSession(s))
 208    }
 209    pub fn user_id(&self) -> UserId {
 210        self.0.user_id().unwrap()
 211    }
 212}
 213
 214impl Deref for UserSession {
 215    type Target = Session;
 216
 217    fn deref(&self) -> &Self::Target {
 218        &self.0
 219    }
 220}
 221impl DerefMut for UserSession {
 222    fn deref_mut(&mut self) -> &mut Self::Target {
 223        &mut self.0
 224    }
 225}
 226
 227fn user_handler<M: RequestMessage, Fut>(
 228    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 229) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 230where
 231    Fut: Send + Future<Output = Result<()>>,
 232{
 233    let handler = Arc::new(handler);
 234    move |message, response, session| {
 235        let handler = handler.clone();
 236        Box::pin(async move {
 237            if let Some(user_session) = session.for_user() {
 238                Ok(handler(message, response, user_session).await?)
 239            } else {
 240                Err(Error::Internal(anyhow!("must be a user")))
 241            }
 242        })
 243    }
 244}
 245
 246fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 247    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 248) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 249where
 250    InnertRetFut: Send + Future<Output = Result<()>>,
 251{
 252    let handler = Arc::new(handler);
 253    move |message, session| {
 254        let handler = handler.clone();
 255        Box::pin(async move {
 256            if let Some(user_session) = session.for_user() {
 257                Ok(handler(message, user_session).await?)
 258            } else {
 259                Err(Error::Internal(anyhow!("must be a user")))
 260            }
 261        })
 262    }
 263}
 264
 265struct DbHandle(Arc<Database>);
 266
 267impl Deref for DbHandle {
 268    type Target = Database;
 269
 270    fn deref(&self) -> &Self::Target {
 271        self.0.as_ref()
 272    }
 273}
 274
 275pub struct Server {
 276    id: parking_lot::Mutex<ServerId>,
 277    peer: Arc<Peer>,
 278    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 279    app_state: Arc<AppState>,
 280    handlers: HashMap<TypeId, MessageHandler>,
 281    teardown: watch::Sender<bool>,
 282}
 283
 284pub(crate) struct ConnectionPoolGuard<'a> {
 285    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 286    _not_send: PhantomData<Rc<()>>,
 287}
 288
 289#[derive(Serialize)]
 290pub struct ServerSnapshot<'a> {
 291    peer: &'a Peer,
 292    #[serde(serialize_with = "serialize_deref")]
 293    connection_pool: ConnectionPoolGuard<'a>,
 294}
 295
 296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 297where
 298    S: Serializer,
 299    T: Deref<Target = U>,
 300    U: Serialize,
 301{
 302    Serialize::serialize(value.deref(), serializer)
 303}
 304
 305impl Server {
 306    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 307        let mut server = Self {
 308            id: parking_lot::Mutex::new(id),
 309            peer: Peer::new(id.0 as u32),
 310            app_state: app_state.clone(),
 311            connection_pool: Default::default(),
 312            handlers: Default::default(),
 313            teardown: watch::channel(false).0,
 314        };
 315
 316        server
 317            .add_request_handler(ping)
 318            .add_request_handler(user_handler(create_room))
 319            .add_request_handler(user_handler(join_room))
 320            .add_request_handler(user_handler(rejoin_room))
 321            .add_request_handler(user_handler(leave_room))
 322            .add_request_handler(user_handler(set_room_participant_role))
 323            .add_request_handler(user_handler(call))
 324            .add_request_handler(user_handler(cancel_call))
 325            .add_message_handler(user_message_handler(decline_call))
 326            .add_request_handler(user_handler(update_participant_location))
 327            .add_request_handler(share_project)
 328            .add_message_handler(unshare_project)
 329            .add_request_handler(user_handler(join_project))
 330            .add_request_handler(user_handler(join_hosted_project))
 331            .add_message_handler(user_message_handler(leave_project))
 332            .add_request_handler(update_project)
 333            .add_request_handler(update_worktree)
 334            .add_message_handler(start_language_server)
 335            .add_message_handler(update_language_server)
 336            .add_message_handler(update_diagnostic_summary)
 337            .add_message_handler(update_worktree_settings)
 338            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 342            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 346            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 347            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 348            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 349            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 350            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 353            )
 354            .add_request_handler(
 355                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 356            )
 357            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 358            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 359            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 360            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 361            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 362            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 369            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 370            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 371            .add_message_handler(create_buffer_for_peer)
 372            .add_request_handler(update_buffer)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 374            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 375            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 376            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 377            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 378            .add_request_handler(get_users)
 379            .add_request_handler(user_handler(fuzzy_search_users))
 380            .add_request_handler(user_handler(request_contact))
 381            .add_request_handler(user_handler(remove_contact))
 382            .add_request_handler(user_handler(respond_to_contact_request))
 383            .add_request_handler(user_handler(create_channel))
 384            .add_request_handler(user_handler(delete_channel))
 385            .add_request_handler(user_handler(invite_channel_member))
 386            .add_request_handler(user_handler(remove_channel_member))
 387            .add_request_handler(user_handler(set_channel_member_role))
 388            .add_request_handler(user_handler(set_channel_visibility))
 389            .add_request_handler(user_handler(rename_channel))
 390            .add_request_handler(user_handler(join_channel_buffer))
 391            .add_request_handler(user_handler(leave_channel_buffer))
 392            .add_message_handler(user_message_handler(update_channel_buffer))
 393            .add_request_handler(user_handler(rejoin_channel_buffers))
 394            .add_request_handler(user_handler(get_channel_members))
 395            .add_request_handler(user_handler(respond_to_channel_invite))
 396            .add_request_handler(user_handler(join_channel))
 397            .add_request_handler(user_handler(join_channel_chat))
 398            .add_message_handler(user_message_handler(leave_channel_chat))
 399            .add_request_handler(user_handler(send_channel_message))
 400            .add_request_handler(user_handler(remove_channel_message))
 401            .add_request_handler(user_handler(update_channel_message))
 402            .add_request_handler(user_handler(get_channel_messages))
 403            .add_request_handler(user_handler(get_channel_messages_by_id))
 404            .add_request_handler(user_handler(get_notifications))
 405            .add_request_handler(user_handler(mark_notification_as_read))
 406            .add_request_handler(user_handler(move_channel))
 407            .add_request_handler(user_handler(follow))
 408            .add_message_handler(user_message_handler(unfollow))
 409            .add_message_handler(user_message_handler(update_followers))
 410            .add_request_handler(user_handler(get_private_user_info))
 411            .add_message_handler(user_message_handler(acknowledge_channel_message))
 412            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 413            .add_streaming_request_handler({
 414                let app_state = app_state.clone();
 415                move |request, response, session| {
 416                    complete_with_language_model(
 417                        request,
 418                        response,
 419                        session,
 420                        app_state.config.openai_api_key.clone(),
 421                        app_state.config.google_ai_api_key.clone(),
 422                        app_state.config.anthropic_api_key.clone(),
 423                    )
 424                }
 425            })
 426            .add_request_handler({
 427                let app_state = app_state.clone();
 428                user_handler(move |request, response, session| {
 429                    count_tokens_with_language_model(
 430                        request,
 431                        response,
 432                        session,
 433                        app_state.config.google_ai_api_key.clone(),
 434                    )
 435                })
 436            });
 437
 438        Arc::new(server)
 439    }
 440
 441    pub async fn start(&self) -> Result<()> {
 442        let server_id = *self.id.lock();
 443        let app_state = self.app_state.clone();
 444        let peer = self.peer.clone();
 445        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 446        let pool = self.connection_pool.clone();
 447        let live_kit_client = self.app_state.live_kit_client.clone();
 448
 449        let span = info_span!("start server");
 450        self.app_state.executor.spawn_detached(
 451            async move {
 452                tracing::info!("waiting for cleanup timeout");
 453                timeout.await;
 454                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 455                if let Some((room_ids, channel_ids)) = app_state
 456                    .db
 457                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 458                    .await
 459                    .trace_err()
 460                {
 461                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 462                    tracing::info!(
 463                        stale_channel_buffer_count = channel_ids.len(),
 464                        "retrieved stale channel buffers"
 465                    );
 466
 467                    for channel_id in channel_ids {
 468                        if let Some(refreshed_channel_buffer) = app_state
 469                            .db
 470                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 471                            .await
 472                            .trace_err()
 473                        {
 474                            for connection_id in refreshed_channel_buffer.connection_ids {
 475                                peer.send(
 476                                    connection_id,
 477                                    proto::UpdateChannelBufferCollaborators {
 478                                        channel_id: channel_id.to_proto(),
 479                                        collaborators: refreshed_channel_buffer
 480                                            .collaborators
 481                                            .clone(),
 482                                    },
 483                                )
 484                                .trace_err();
 485                            }
 486                        }
 487                    }
 488
 489                    for room_id in room_ids {
 490                        let mut contacts_to_update = HashSet::default();
 491                        let mut canceled_calls_to_user_ids = Vec::new();
 492                        let mut live_kit_room = String::new();
 493                        let mut delete_live_kit_room = false;
 494
 495                        if let Some(mut refreshed_room) = app_state
 496                            .db
 497                            .clear_stale_room_participants(room_id, server_id)
 498                            .await
 499                            .trace_err()
 500                        {
 501                            tracing::info!(
 502                                room_id = room_id.0,
 503                                new_participant_count = refreshed_room.room.participants.len(),
 504                                "refreshed room"
 505                            );
 506                            room_updated(&refreshed_room.room, &peer);
 507                            if let Some(channel) = refreshed_room.channel.as_ref() {
 508                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 509                            }
 510                            contacts_to_update
 511                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 512                            contacts_to_update
 513                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 514                            canceled_calls_to_user_ids =
 515                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 516                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 517                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 518                        }
 519
 520                        {
 521                            let pool = pool.lock();
 522                            for canceled_user_id in canceled_calls_to_user_ids {
 523                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 524                                    peer.send(
 525                                        connection_id,
 526                                        proto::CallCanceled {
 527                                            room_id: room_id.to_proto(),
 528                                        },
 529                                    )
 530                                    .trace_err();
 531                                }
 532                            }
 533                        }
 534
 535                        for user_id in contacts_to_update {
 536                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 537                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 538                            if let Some((busy, contacts)) = busy.zip(contacts) {
 539                                let pool = pool.lock();
 540                                let updated_contact = contact_for_user(user_id, busy, &pool);
 541                                for contact in contacts {
 542                                    if let db::Contact::Accepted {
 543                                        user_id: contact_user_id,
 544                                        ..
 545                                    } = contact
 546                                    {
 547                                        for contact_conn_id in
 548                                            pool.user_connection_ids(contact_user_id)
 549                                        {
 550                                            peer.send(
 551                                                contact_conn_id,
 552                                                proto::UpdateContacts {
 553                                                    contacts: vec![updated_contact.clone()],
 554                                                    remove_contacts: Default::default(),
 555                                                    incoming_requests: Default::default(),
 556                                                    remove_incoming_requests: Default::default(),
 557                                                    outgoing_requests: Default::default(),
 558                                                    remove_outgoing_requests: Default::default(),
 559                                                },
 560                                            )
 561                                            .trace_err();
 562                                        }
 563                                    }
 564                                }
 565                            }
 566                        }
 567
 568                        if let Some(live_kit) = live_kit_client.as_ref() {
 569                            if delete_live_kit_room {
 570                                live_kit.delete_room(live_kit_room).await.trace_err();
 571                            }
 572                        }
 573                    }
 574                }
 575
 576                app_state
 577                    .db
 578                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 579                    .await
 580                    .trace_err();
 581            }
 582            .instrument(span),
 583        );
 584        Ok(())
 585    }
 586
 587    pub fn teardown(&self) {
 588        self.peer.teardown();
 589        self.connection_pool.lock().reset();
 590        let _ = self.teardown.send(true);
 591    }
 592
 593    #[cfg(test)]
 594    pub fn reset(&self, id: ServerId) {
 595        self.teardown();
 596        *self.id.lock() = id;
 597        self.peer.reset(id.0 as u32);
 598        let _ = self.teardown.send(false);
 599    }
 600
 601    #[cfg(test)]
 602    pub fn id(&self) -> ServerId {
 603        *self.id.lock()
 604    }
 605
 606    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 607    where
 608        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 609        Fut: 'static + Send + Future<Output = Result<()>>,
 610        M: EnvelopedMessage,
 611    {
 612        let prev_handler = self.handlers.insert(
 613            TypeId::of::<M>(),
 614            Box::new(move |envelope, session| {
 615                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 616                let received_at = envelope.received_at;
 617                    tracing::info!(
 618                        "message received"
 619                    );
 620                let start_time = Instant::now();
 621                let future = (handler)(*envelope, session);
 622                async move {
 623                    let result = future.await;
 624                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 625                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 626                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 627                    match result {
 628                        Err(error) => {
 629                            tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
 630                        }
 631                        Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
 632                    }
 633                }
 634                .boxed()
 635            }),
 636        );
 637        if prev_handler.is_some() {
 638            panic!("registered a handler for the same message twice");
 639        }
 640        self
 641    }
 642
 643    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 644    where
 645        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 646        Fut: 'static + Send + Future<Output = Result<()>>,
 647        M: EnvelopedMessage,
 648    {
 649        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 650        self
 651    }
 652
 653    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 654    where
 655        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 656        Fut: Send + Future<Output = Result<()>>,
 657        M: RequestMessage,
 658    {
 659        let handler = Arc::new(handler);
 660        self.add_handler(move |envelope, session| {
 661            let receipt = envelope.receipt();
 662            let handler = handler.clone();
 663            async move {
 664                let peer = session.peer.clone();
 665                let responded = Arc::new(AtomicBool::default());
 666                let response = Response {
 667                    peer: peer.clone(),
 668                    responded: responded.clone(),
 669                    receipt,
 670                };
 671                match (handler)(envelope.payload, response, session).await {
 672                    Ok(()) => {
 673                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 674                            Ok(())
 675                        } else {
 676                            Err(anyhow!("handler did not send a response"))?
 677                        }
 678                    }
 679                    Err(error) => {
 680                        let proto_err = match &error {
 681                            Error::Internal(err) => err.to_proto(),
 682                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 683                        };
 684                        peer.respond_with_error(receipt, proto_err)?;
 685                        Err(error)
 686                    }
 687                }
 688            }
 689        })
 690    }
 691
 692    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 693    where
 694        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 695        Fut: Send + Future<Output = Result<()>>,
 696        M: RequestMessage,
 697    {
 698        let handler = Arc::new(handler);
 699        self.add_handler(move |envelope, session| {
 700            let receipt = envelope.receipt();
 701            let handler = handler.clone();
 702            async move {
 703                let peer = session.peer.clone();
 704                let response = StreamingResponse {
 705                    peer: peer.clone(),
 706                    receipt,
 707                };
 708                match (handler)(envelope.payload, response, session).await {
 709                    Ok(()) => {
 710                        peer.end_stream(receipt)?;
 711                        Ok(())
 712                    }
 713                    Err(error) => {
 714                        let proto_err = match &error {
 715                            Error::Internal(err) => err.to_proto(),
 716                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 717                        };
 718                        peer.respond_with_error(receipt, proto_err)?;
 719                        Err(error)
 720                    }
 721                }
 722            }
 723        })
 724    }
 725
 726    #[allow(clippy::too_many_arguments)]
 727    pub fn handle_connection(
 728        self: &Arc<Self>,
 729        connection: Connection,
 730        address: String,
 731        principal: Principal,
 732        zed_version: ZedVersion,
 733        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 734        executor: Executor,
 735    ) -> impl Future<Output = ()> {
 736        let this = self.clone();
 737        let span = info_span!("handle connection", %address, impersonator = field::Empty, connection_id = field::Empty);
 738        principal.update_span(&span);
 739
 740        let mut teardown = self.teardown.subscribe();
 741        async move {
 742            if *teardown.borrow() {
 743                tracing::error!("server is tearing down");
 744                return
 745            }
 746            let (connection_id, handle_io, mut incoming_rx) = this
 747                .peer
 748                .add_connection(connection, {
 749                    let executor = executor.clone();
 750                    move |duration| executor.sleep(duration)
 751                });
 752            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 753            tracing::info!("connection opened");
 754
 755            let http_client = match IsahcHttpClient::new() {
 756                Ok(http_client) => http_client,
 757                Err(error) => {
 758                    tracing::error!(?error, "failed to create HTTP client");
 759                    return;
 760                }
 761            };
 762
 763            let session = Session {
 764                principal: principal.clone(),
 765                connection_id,
 766                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 767                peer: this.peer.clone(),
 768                connection_pool: this.connection_pool.clone(),
 769                live_kit_client: this.app_state.live_kit_client.clone(),
 770                http_client,
 771                rate_limiter: this.app_state.rate_limiter.clone(),
 772                _executor: executor.clone(),
 773            };
 774
 775            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 776                tracing::error!(?error, "failed to send initial client update");
 777                return;
 778            }
 779
 780            let handle_io = handle_io.fuse();
 781            futures::pin_mut!(handle_io);
 782
 783            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 784            // This prevents deadlocks when e.g., client A performs a request to client B and
 785            // client B performs a request to client A. If both clients stop processing further
 786            // messages until their respective request completes, they won't have a chance to
 787            // respond to the other client's request and cause a deadlock.
 788            //
 789            // This arrangement ensures we will attempt to process earlier messages first, but fall
 790            // back to processing messages arrived later in the spirit of making progress.
 791            let mut foreground_message_handlers = FuturesUnordered::new();
 792            let concurrent_handlers = Arc::new(Semaphore::new(256));
 793            loop {
 794                let next_message = async {
 795                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 796                    let message = incoming_rx.next().await;
 797                    (permit, message)
 798                }.fuse();
 799                futures::pin_mut!(next_message);
 800                futures::select_biased! {
 801                    _ = teardown.changed().fuse() => return,
 802                    result = handle_io => {
 803                        if let Err(error) = result {
 804                            tracing::error!(?error, "error handling I/O");
 805                        }
 806                        break;
 807                    }
 808                    _ = foreground_message_handlers.next() => {}
 809                    next_message = next_message => {
 810                        let (permit, message) = next_message;
 811                        if let Some(message) = message {
 812                            let type_name = message.payload_type_name();
 813                            // note: we copy all the fields from the parent span so we can query them in the logs.
 814                            // (https://github.com/tokio-rs/tracing/issues/2670).
 815                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name);
 816                            principal.update_span(&span);
 817                            let span_enter = span.enter();
 818                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 819                                let is_background = message.is_background();
 820                                let handle_message = (handler)(message, session.clone());
 821                                drop(span_enter);
 822
 823                                let handle_message = async move {
 824                                    handle_message.await;
 825                                    drop(permit);
 826                                }.instrument(span);
 827                                if is_background {
 828                                    executor.spawn_detached(handle_message);
 829                                } else {
 830                                    foreground_message_handlers.push(handle_message);
 831                                }
 832                            } else {
 833                                tracing::error!("no message handler");
 834                            }
 835                        } else {
 836                            tracing::info!("connection closed");
 837                            break;
 838                        }
 839                    }
 840                }
 841            }
 842
 843            drop(foreground_message_handlers);
 844            tracing::info!("signing out");
 845            if let Err(error) = connection_lost(session, teardown, executor).await {
 846                tracing::error!(?error, "error signing out");
 847            }
 848
 849        }.instrument(span)
 850    }
 851
 852    async fn send_initial_client_update(
 853        &self,
 854        connection_id: ConnectionId,
 855        principal: &Principal,
 856        zed_version: ZedVersion,
 857        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 858        session: &Session,
 859    ) -> Result<()> {
 860        self.peer.send(
 861            connection_id,
 862            proto::Hello {
 863                peer_id: Some(connection_id.into()),
 864            },
 865        )?;
 866        tracing::info!("sent hello message");
 867
 868        let Principal::User(user) = principal else {
 869            return Ok(());
 870        };
 871
 872        if let Some(send_connection_id) = send_connection_id.take() {
 873            let _ = send_connection_id.send(connection_id);
 874        }
 875
 876        if !user.connected_once {
 877            self.peer.send(connection_id, proto::ShowContacts {})?;
 878            self.app_state
 879                .db
 880                .set_user_connected_once(user.id, true)
 881                .await?;
 882        }
 883
 884        let (contacts, channels_for_user, channel_invites) = future::try_join3(
 885            self.app_state.db.get_contacts(user.id),
 886            self.app_state.db.get_channels_for_user(user.id),
 887            self.app_state.db.get_channel_invites_for_user(user.id),
 888        )
 889        .await?;
 890
 891        {
 892            let mut pool = self.connection_pool.lock();
 893            pool.add_connection(connection_id, user.id, user.admin, zed_version);
 894            for membership in &channels_for_user.channel_memberships {
 895                pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
 896            }
 897            self.peer.send(
 898                connection_id,
 899                build_initial_contacts_update(contacts, &pool),
 900            )?;
 901            self.peer.send(
 902                connection_id,
 903                build_update_user_channels(&channels_for_user),
 904            )?;
 905            self.peer.send(
 906                connection_id,
 907                build_channels_update(channels_for_user, channel_invites),
 908            )?;
 909        }
 910
 911        if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
 912            self.peer.send(connection_id, incoming_call)?;
 913        }
 914
 915        update_user_contacts(user.id, &session).await?;
 916        Ok(())
 917    }
 918
 919    pub async fn invite_code_redeemed(
 920        self: &Arc<Self>,
 921        inviter_id: UserId,
 922        invitee_id: UserId,
 923    ) -> Result<()> {
 924        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 925            if let Some(code) = &user.invite_code {
 926                let pool = self.connection_pool.lock();
 927                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 928                for connection_id in pool.user_connection_ids(inviter_id) {
 929                    self.peer.send(
 930                        connection_id,
 931                        proto::UpdateContacts {
 932                            contacts: vec![invitee_contact.clone()],
 933                            ..Default::default()
 934                        },
 935                    )?;
 936                    self.peer.send(
 937                        connection_id,
 938                        proto::UpdateInviteInfo {
 939                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 940                            count: user.invite_count as u32,
 941                        },
 942                    )?;
 943                }
 944            }
 945        }
 946        Ok(())
 947    }
 948
 949    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 950        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 951            if let Some(invite_code) = &user.invite_code {
 952                let pool = self.connection_pool.lock();
 953                for connection_id in pool.user_connection_ids(user_id) {
 954                    self.peer.send(
 955                        connection_id,
 956                        proto::UpdateInviteInfo {
 957                            url: format!(
 958                                "{}{}",
 959                                self.app_state.config.invite_link_prefix, invite_code
 960                            ),
 961                            count: user.invite_count as u32,
 962                        },
 963                    )?;
 964                }
 965            }
 966        }
 967        Ok(())
 968    }
 969
 970    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 971        ServerSnapshot {
 972            connection_pool: ConnectionPoolGuard {
 973                guard: self.connection_pool.lock(),
 974                _not_send: PhantomData,
 975            },
 976            peer: &self.peer,
 977        }
 978    }
 979}
 980
 981impl<'a> Deref for ConnectionPoolGuard<'a> {
 982    type Target = ConnectionPool;
 983
 984    fn deref(&self) -> &Self::Target {
 985        &self.guard
 986    }
 987}
 988
 989impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 990    fn deref_mut(&mut self) -> &mut Self::Target {
 991        &mut self.guard
 992    }
 993}
 994
 995impl<'a> Drop for ConnectionPoolGuard<'a> {
 996    fn drop(&mut self) {
 997        #[cfg(test)]
 998        self.check_invariants();
 999    }
1000}
1001
1002fn broadcast<F>(
1003    sender_id: Option<ConnectionId>,
1004    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1005    mut f: F,
1006) where
1007    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1008{
1009    for receiver_id in receiver_ids {
1010        if Some(receiver_id) != sender_id {
1011            if let Err(error) = f(receiver_id) {
1012                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1013            }
1014        }
1015    }
1016}
1017
1018pub struct ProtocolVersion(u32);
1019
1020impl Header for ProtocolVersion {
1021    fn name() -> &'static HeaderName {
1022        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1023        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1024    }
1025
1026    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1027    where
1028        Self: Sized,
1029        I: Iterator<Item = &'i axum::http::HeaderValue>,
1030    {
1031        let version = values
1032            .next()
1033            .ok_or_else(axum::headers::Error::invalid)?
1034            .to_str()
1035            .map_err(|_| axum::headers::Error::invalid())?
1036            .parse()
1037            .map_err(|_| axum::headers::Error::invalid())?;
1038        Ok(Self(version))
1039    }
1040
1041    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1042        values.extend([self.0.to_string().parse().unwrap()]);
1043    }
1044}
1045
1046pub struct AppVersionHeader(SemanticVersion);
1047impl Header for AppVersionHeader {
1048    fn name() -> &'static HeaderName {
1049        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1050        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1051    }
1052
1053    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1054    where
1055        Self: Sized,
1056        I: Iterator<Item = &'i axum::http::HeaderValue>,
1057    {
1058        let version = values
1059            .next()
1060            .ok_or_else(axum::headers::Error::invalid)?
1061            .to_str()
1062            .map_err(|_| axum::headers::Error::invalid())?
1063            .parse()
1064            .map_err(|_| axum::headers::Error::invalid())?;
1065        Ok(Self(version))
1066    }
1067
1068    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1069        values.extend([self.0.to_string().parse().unwrap()]);
1070    }
1071}
1072
1073pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1074    Router::new()
1075        .route("/rpc", get(handle_websocket_request))
1076        .layer(
1077            ServiceBuilder::new()
1078                .layer(Extension(server.app_state.clone()))
1079                .layer(middleware::from_fn(auth::validate_header)),
1080        )
1081        .route("/metrics", get(handle_metrics))
1082        .layer(Extension(server))
1083}
1084
1085pub async fn handle_websocket_request(
1086    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1087    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1088    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1089    Extension(server): Extension<Arc<Server>>,
1090    Extension(principal): Extension<Principal>,
1091    ws: WebSocketUpgrade,
1092) -> axum::response::Response {
1093    if protocol_version != rpc::PROTOCOL_VERSION {
1094        return (
1095            StatusCode::UPGRADE_REQUIRED,
1096            "client must be upgraded".to_string(),
1097        )
1098            .into_response();
1099    }
1100
1101    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1102        return (
1103            StatusCode::UPGRADE_REQUIRED,
1104            "no version header found".to_string(),
1105        )
1106            .into_response();
1107    };
1108
1109    if !version.can_collaborate() {
1110        return (
1111            StatusCode::UPGRADE_REQUIRED,
1112            "client must be upgraded".to_string(),
1113        )
1114            .into_response();
1115    }
1116
1117    let socket_address = socket_address.to_string();
1118    ws.on_upgrade(move |socket| {
1119        let socket = socket
1120            .map_ok(to_tungstenite_message)
1121            .err_into()
1122            .with(|message| async move { Ok(to_axum_message(message)) });
1123        let connection = Connection::new(Box::pin(socket));
1124        async move {
1125            server
1126                .handle_connection(
1127                    connection,
1128                    socket_address,
1129                    principal,
1130                    version,
1131                    None,
1132                    Executor::Production,
1133                )
1134                .await;
1135        }
1136    })
1137}
1138
1139pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1140    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1141    let connections_metric = CONNECTIONS_METRIC
1142        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1143
1144    let connections = server
1145        .connection_pool
1146        .lock()
1147        .connections()
1148        .filter(|connection| !connection.admin)
1149        .count();
1150    connections_metric.set(connections as _);
1151
1152    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1153    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1154        register_int_gauge!(
1155            "shared_projects",
1156            "number of open projects with one or more guests"
1157        )
1158        .unwrap()
1159    });
1160
1161    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1162    shared_projects_metric.set(shared_projects as _);
1163
1164    let encoder = prometheus::TextEncoder::new();
1165    let metric_families = prometheus::gather();
1166    let encoded_metrics = encoder
1167        .encode_to_string(&metric_families)
1168        .map_err(|err| anyhow!("{}", err))?;
1169    Ok(encoded_metrics)
1170}
1171
1172#[instrument(err, skip(executor))]
1173async fn connection_lost(
1174    session: Session,
1175    mut teardown: watch::Receiver<bool>,
1176    executor: Executor,
1177) -> Result<()> {
1178    session.peer.disconnect(session.connection_id);
1179    session
1180        .connection_pool()
1181        .await
1182        .remove_connection(session.connection_id)?;
1183
1184    session
1185        .db()
1186        .await
1187        .connection_lost(session.connection_id)
1188        .await
1189        .trace_err();
1190
1191    futures::select_biased! {
1192        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1193            if let Some(session) = session.for_user() {
1194                log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1195                leave_room_for_session(&session).await.trace_err();
1196                leave_channel_buffers_for_session(&session)
1197                    .await
1198                    .trace_err();
1199
1200                if !session
1201                    .connection_pool()
1202                    .await
1203                    .is_user_online(session.user_id())
1204                {
1205                    let db = session.db().await;
1206                    if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1207                        room_updated(&room, &session.peer);
1208                    }
1209                }
1210
1211                update_user_contacts(session.user_id(), &session).await?;
1212            }
1213        }
1214        _ = teardown.changed().fuse() => {}
1215    }
1216
1217    Ok(())
1218}
1219
1220/// Acknowledges a ping from a client, used to keep the connection alive.
1221async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1222    response.send(proto::Ack {})?;
1223    Ok(())
1224}
1225
1226/// Creates a new room for calling (outside of channels)
1227async fn create_room(
1228    _request: proto::CreateRoom,
1229    response: Response<proto::CreateRoom>,
1230    session: UserSession,
1231) -> Result<()> {
1232    let live_kit_room = nanoid::nanoid!(30);
1233
1234    let live_kit_connection_info = util::maybe!(async {
1235        let live_kit = session.live_kit_client.as_ref();
1236        let live_kit = live_kit?;
1237        let user_id = session.user_id().to_string();
1238
1239        let token = live_kit
1240            .room_token(&live_kit_room, &user_id.to_string())
1241            .trace_err()?;
1242
1243        Some(proto::LiveKitConnectionInfo {
1244            server_url: live_kit.url().into(),
1245            token,
1246            can_publish: true,
1247        })
1248    })
1249    .await;
1250
1251    let room = session
1252        .db()
1253        .await
1254        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1255        .await?;
1256
1257    response.send(proto::CreateRoomResponse {
1258        room: Some(room.clone()),
1259        live_kit_connection_info,
1260    })?;
1261
1262    update_user_contacts(session.user_id(), &session).await?;
1263    Ok(())
1264}
1265
1266/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1267async fn join_room(
1268    request: proto::JoinRoom,
1269    response: Response<proto::JoinRoom>,
1270    session: UserSession,
1271) -> Result<()> {
1272    let room_id = RoomId::from_proto(request.id);
1273
1274    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1275
1276    if let Some(channel_id) = channel_id {
1277        return join_channel_internal(channel_id, Box::new(response), session).await;
1278    }
1279
1280    let joined_room = {
1281        let room = session
1282            .db()
1283            .await
1284            .join_room(room_id, session.user_id(), session.connection_id)
1285            .await?;
1286        room_updated(&room.room, &session.peer);
1287        room.into_inner()
1288    };
1289
1290    for connection_id in session
1291        .connection_pool()
1292        .await
1293        .user_connection_ids(session.user_id())
1294    {
1295        session
1296            .peer
1297            .send(
1298                connection_id,
1299                proto::CallCanceled {
1300                    room_id: room_id.to_proto(),
1301                },
1302            )
1303            .trace_err();
1304    }
1305
1306    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1307        if let Some(token) = live_kit
1308            .room_token(
1309                &joined_room.room.live_kit_room,
1310                &session.user_id().to_string(),
1311            )
1312            .trace_err()
1313        {
1314            Some(proto::LiveKitConnectionInfo {
1315                server_url: live_kit.url().into(),
1316                token,
1317                can_publish: true,
1318            })
1319        } else {
1320            None
1321        }
1322    } else {
1323        None
1324    };
1325
1326    response.send(proto::JoinRoomResponse {
1327        room: Some(joined_room.room),
1328        channel_id: None,
1329        live_kit_connection_info,
1330    })?;
1331
1332    update_user_contacts(session.user_id(), &session).await?;
1333    Ok(())
1334}
1335
1336/// Rejoin room is used to reconnect to a room after connection errors.
1337async fn rejoin_room(
1338    request: proto::RejoinRoom,
1339    response: Response<proto::RejoinRoom>,
1340    session: UserSession,
1341) -> Result<()> {
1342    let room;
1343    let channel;
1344    {
1345        let mut rejoined_room = session
1346            .db()
1347            .await
1348            .rejoin_room(request, session.user_id(), session.connection_id)
1349            .await?;
1350
1351        response.send(proto::RejoinRoomResponse {
1352            room: Some(rejoined_room.room.clone()),
1353            reshared_projects: rejoined_room
1354                .reshared_projects
1355                .iter()
1356                .map(|project| proto::ResharedProject {
1357                    id: project.id.to_proto(),
1358                    collaborators: project
1359                        .collaborators
1360                        .iter()
1361                        .map(|collaborator| collaborator.to_proto())
1362                        .collect(),
1363                })
1364                .collect(),
1365            rejoined_projects: rejoined_room
1366                .rejoined_projects
1367                .iter()
1368                .map(|rejoined_project| proto::RejoinedProject {
1369                    id: rejoined_project.id.to_proto(),
1370                    worktrees: rejoined_project
1371                        .worktrees
1372                        .iter()
1373                        .map(|worktree| proto::WorktreeMetadata {
1374                            id: worktree.id,
1375                            root_name: worktree.root_name.clone(),
1376                            visible: worktree.visible,
1377                            abs_path: worktree.abs_path.clone(),
1378                        })
1379                        .collect(),
1380                    collaborators: rejoined_project
1381                        .collaborators
1382                        .iter()
1383                        .map(|collaborator| collaborator.to_proto())
1384                        .collect(),
1385                    language_servers: rejoined_project.language_servers.clone(),
1386                })
1387                .collect(),
1388        })?;
1389        room_updated(&rejoined_room.room, &session.peer);
1390
1391        for project in &rejoined_room.reshared_projects {
1392            for collaborator in &project.collaborators {
1393                session
1394                    .peer
1395                    .send(
1396                        collaborator.connection_id,
1397                        proto::UpdateProjectCollaborator {
1398                            project_id: project.id.to_proto(),
1399                            old_peer_id: Some(project.old_connection_id.into()),
1400                            new_peer_id: Some(session.connection_id.into()),
1401                        },
1402                    )
1403                    .trace_err();
1404            }
1405
1406            broadcast(
1407                Some(session.connection_id),
1408                project
1409                    .collaborators
1410                    .iter()
1411                    .map(|collaborator| collaborator.connection_id),
1412                |connection_id| {
1413                    session.peer.forward_send(
1414                        session.connection_id,
1415                        connection_id,
1416                        proto::UpdateProject {
1417                            project_id: project.id.to_proto(),
1418                            worktrees: project.worktrees.clone(),
1419                        },
1420                    )
1421                },
1422            );
1423        }
1424
1425        for project in &rejoined_room.rejoined_projects {
1426            for collaborator in &project.collaborators {
1427                session
1428                    .peer
1429                    .send(
1430                        collaborator.connection_id,
1431                        proto::UpdateProjectCollaborator {
1432                            project_id: project.id.to_proto(),
1433                            old_peer_id: Some(project.old_connection_id.into()),
1434                            new_peer_id: Some(session.connection_id.into()),
1435                        },
1436                    )
1437                    .trace_err();
1438            }
1439        }
1440
1441        for project in &mut rejoined_room.rejoined_projects {
1442            for worktree in mem::take(&mut project.worktrees) {
1443                #[cfg(any(test, feature = "test-support"))]
1444                const MAX_CHUNK_SIZE: usize = 2;
1445                #[cfg(not(any(test, feature = "test-support")))]
1446                const MAX_CHUNK_SIZE: usize = 256;
1447
1448                // Stream this worktree's entries.
1449                let message = proto::UpdateWorktree {
1450                    project_id: project.id.to_proto(),
1451                    worktree_id: worktree.id,
1452                    abs_path: worktree.abs_path.clone(),
1453                    root_name: worktree.root_name,
1454                    updated_entries: worktree.updated_entries,
1455                    removed_entries: worktree.removed_entries,
1456                    scan_id: worktree.scan_id,
1457                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1458                    updated_repositories: worktree.updated_repositories,
1459                    removed_repositories: worktree.removed_repositories,
1460                };
1461                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1462                    session.peer.send(session.connection_id, update.clone())?;
1463                }
1464
1465                // Stream this worktree's diagnostics.
1466                for summary in worktree.diagnostic_summaries {
1467                    session.peer.send(
1468                        session.connection_id,
1469                        proto::UpdateDiagnosticSummary {
1470                            project_id: project.id.to_proto(),
1471                            worktree_id: worktree.id,
1472                            summary: Some(summary),
1473                        },
1474                    )?;
1475                }
1476
1477                for settings_file in worktree.settings_files {
1478                    session.peer.send(
1479                        session.connection_id,
1480                        proto::UpdateWorktreeSettings {
1481                            project_id: project.id.to_proto(),
1482                            worktree_id: worktree.id,
1483                            path: settings_file.path,
1484                            content: Some(settings_file.content),
1485                        },
1486                    )?;
1487                }
1488            }
1489
1490            for language_server in &project.language_servers {
1491                session.peer.send(
1492                    session.connection_id,
1493                    proto::UpdateLanguageServer {
1494                        project_id: project.id.to_proto(),
1495                        language_server_id: language_server.id,
1496                        variant: Some(
1497                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1498                                proto::LspDiskBasedDiagnosticsUpdated {},
1499                            ),
1500                        ),
1501                    },
1502                )?;
1503            }
1504        }
1505
1506        let rejoined_room = rejoined_room.into_inner();
1507
1508        room = rejoined_room.room;
1509        channel = rejoined_room.channel;
1510    }
1511
1512    if let Some(channel) = channel {
1513        channel_updated(
1514            &channel,
1515            &room,
1516            &session.peer,
1517            &*session.connection_pool().await,
1518        );
1519    }
1520
1521    update_user_contacts(session.user_id(), &session).await?;
1522    Ok(())
1523}
1524
1525/// leave room disconnects from the room.
1526async fn leave_room(
1527    _: proto::LeaveRoom,
1528    response: Response<proto::LeaveRoom>,
1529    session: UserSession,
1530) -> Result<()> {
1531    leave_room_for_session(&session).await?;
1532    response.send(proto::Ack {})?;
1533    Ok(())
1534}
1535
1536/// Updates the permissions of someone else in the room.
1537async fn set_room_participant_role(
1538    request: proto::SetRoomParticipantRole,
1539    response: Response<proto::SetRoomParticipantRole>,
1540    session: UserSession,
1541) -> Result<()> {
1542    let user_id = UserId::from_proto(request.user_id);
1543    let role = ChannelRole::from(request.role());
1544
1545    let (live_kit_room, can_publish) = {
1546        let room = session
1547            .db()
1548            .await
1549            .set_room_participant_role(
1550                session.user_id(),
1551                RoomId::from_proto(request.room_id),
1552                user_id,
1553                role,
1554            )
1555            .await?;
1556
1557        let live_kit_room = room.live_kit_room.clone();
1558        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1559        room_updated(&room, &session.peer);
1560        (live_kit_room, can_publish)
1561    };
1562
1563    if let Some(live_kit) = session.live_kit_client.as_ref() {
1564        live_kit
1565            .update_participant(
1566                live_kit_room.clone(),
1567                request.user_id.to_string(),
1568                live_kit_server::proto::ParticipantPermission {
1569                    can_subscribe: true,
1570                    can_publish,
1571                    can_publish_data: can_publish,
1572                    hidden: false,
1573                    recorder: false,
1574                },
1575            )
1576            .await
1577            .trace_err();
1578    }
1579
1580    response.send(proto::Ack {})?;
1581    Ok(())
1582}
1583
1584/// Call someone else into the current room
1585async fn call(
1586    request: proto::Call,
1587    response: Response<proto::Call>,
1588    session: UserSession,
1589) -> Result<()> {
1590    let room_id = RoomId::from_proto(request.room_id);
1591    let calling_user_id = session.user_id();
1592    let calling_connection_id = session.connection_id;
1593    let called_user_id = UserId::from_proto(request.called_user_id);
1594    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1595    if !session
1596        .db()
1597        .await
1598        .has_contact(calling_user_id, called_user_id)
1599        .await?
1600    {
1601        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1602    }
1603
1604    let incoming_call = {
1605        let (room, incoming_call) = &mut *session
1606            .db()
1607            .await
1608            .call(
1609                room_id,
1610                calling_user_id,
1611                calling_connection_id,
1612                called_user_id,
1613                initial_project_id,
1614            )
1615            .await?;
1616        room_updated(&room, &session.peer);
1617        mem::take(incoming_call)
1618    };
1619    update_user_contacts(called_user_id, &session).await?;
1620
1621    let mut calls = session
1622        .connection_pool()
1623        .await
1624        .user_connection_ids(called_user_id)
1625        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1626        .collect::<FuturesUnordered<_>>();
1627
1628    while let Some(call_response) = calls.next().await {
1629        match call_response.as_ref() {
1630            Ok(_) => {
1631                response.send(proto::Ack {})?;
1632                return Ok(());
1633            }
1634            Err(_) => {
1635                call_response.trace_err();
1636            }
1637        }
1638    }
1639
1640    {
1641        let room = session
1642            .db()
1643            .await
1644            .call_failed(room_id, called_user_id)
1645            .await?;
1646        room_updated(&room, &session.peer);
1647    }
1648    update_user_contacts(called_user_id, &session).await?;
1649
1650    Err(anyhow!("failed to ring user"))?
1651}
1652
1653/// Cancel an outgoing call.
1654async fn cancel_call(
1655    request: proto::CancelCall,
1656    response: Response<proto::CancelCall>,
1657    session: UserSession,
1658) -> Result<()> {
1659    let called_user_id = UserId::from_proto(request.called_user_id);
1660    let room_id = RoomId::from_proto(request.room_id);
1661    {
1662        let room = session
1663            .db()
1664            .await
1665            .cancel_call(room_id, session.connection_id, called_user_id)
1666            .await?;
1667        room_updated(&room, &session.peer);
1668    }
1669
1670    for connection_id in session
1671        .connection_pool()
1672        .await
1673        .user_connection_ids(called_user_id)
1674    {
1675        session
1676            .peer
1677            .send(
1678                connection_id,
1679                proto::CallCanceled {
1680                    room_id: room_id.to_proto(),
1681                },
1682            )
1683            .trace_err();
1684    }
1685    response.send(proto::Ack {})?;
1686
1687    update_user_contacts(called_user_id, &session).await?;
1688    Ok(())
1689}
1690
1691/// Decline an incoming call.
1692async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1693    let room_id = RoomId::from_proto(message.room_id);
1694    {
1695        let room = session
1696            .db()
1697            .await
1698            .decline_call(Some(room_id), session.user_id())
1699            .await?
1700            .ok_or_else(|| anyhow!("failed to decline call"))?;
1701        room_updated(&room, &session.peer);
1702    }
1703
1704    for connection_id in session
1705        .connection_pool()
1706        .await
1707        .user_connection_ids(session.user_id())
1708    {
1709        session
1710            .peer
1711            .send(
1712                connection_id,
1713                proto::CallCanceled {
1714                    room_id: room_id.to_proto(),
1715                },
1716            )
1717            .trace_err();
1718    }
1719    update_user_contacts(session.user_id(), &session).await?;
1720    Ok(())
1721}
1722
1723/// Updates other participants in the room with your current location.
1724async fn update_participant_location(
1725    request: proto::UpdateParticipantLocation,
1726    response: Response<proto::UpdateParticipantLocation>,
1727    session: UserSession,
1728) -> Result<()> {
1729    let room_id = RoomId::from_proto(request.room_id);
1730    let location = request
1731        .location
1732        .ok_or_else(|| anyhow!("invalid location"))?;
1733
1734    let db = session.db().await;
1735    let room = db
1736        .update_room_participant_location(room_id, session.connection_id, location)
1737        .await?;
1738
1739    room_updated(&room, &session.peer);
1740    response.send(proto::Ack {})?;
1741    Ok(())
1742}
1743
1744/// Share a project into the room.
1745async fn share_project(
1746    request: proto::ShareProject,
1747    response: Response<proto::ShareProject>,
1748    session: Session,
1749) -> Result<()> {
1750    let (project_id, room) = &*session
1751        .db()
1752        .await
1753        .share_project(
1754            RoomId::from_proto(request.room_id),
1755            session.connection_id,
1756            &request.worktrees,
1757        )
1758        .await?;
1759    response.send(proto::ShareProjectResponse {
1760        project_id: project_id.to_proto(),
1761    })?;
1762    room_updated(&room, &session.peer);
1763
1764    Ok(())
1765}
1766
1767/// Unshare a project from the room.
1768async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1769    let project_id = ProjectId::from_proto(message.project_id);
1770
1771    let (room, guest_connection_ids) = &*session
1772        .db()
1773        .await
1774        .unshare_project(project_id, session.connection_id)
1775        .await?;
1776
1777    broadcast(
1778        Some(session.connection_id),
1779        guest_connection_ids.iter().copied(),
1780        |conn_id| session.peer.send(conn_id, message.clone()),
1781    );
1782    room_updated(&room, &session.peer);
1783
1784    Ok(())
1785}
1786
1787/// Join someone elses shared project.
1788async fn join_project(
1789    request: proto::JoinProject,
1790    response: Response<proto::JoinProject>,
1791    session: UserSession,
1792) -> Result<()> {
1793    let project_id = ProjectId::from_proto(request.project_id);
1794
1795    tracing::info!(%project_id, "join project");
1796
1797    let (project, replica_id) = &mut *session
1798        .db()
1799        .await
1800        .join_project_in_room(project_id, session.connection_id)
1801        .await?;
1802
1803    join_project_internal(response, session, project, replica_id)
1804}
1805
1806trait JoinProjectInternalResponse {
1807    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1808}
1809impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1810    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1811        Response::<proto::JoinProject>::send(self, result)
1812    }
1813}
1814impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1815    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1816        Response::<proto::JoinHostedProject>::send(self, result)
1817    }
1818}
1819
1820fn join_project_internal(
1821    response: impl JoinProjectInternalResponse,
1822    session: UserSession,
1823    project: &mut Project,
1824    replica_id: &ReplicaId,
1825) -> Result<()> {
1826    let collaborators = project
1827        .collaborators
1828        .iter()
1829        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1830        .map(|collaborator| collaborator.to_proto())
1831        .collect::<Vec<_>>();
1832    let project_id = project.id;
1833    let guest_user_id = session.user_id();
1834
1835    let worktrees = project
1836        .worktrees
1837        .iter()
1838        .map(|(id, worktree)| proto::WorktreeMetadata {
1839            id: *id,
1840            root_name: worktree.root_name.clone(),
1841            visible: worktree.visible,
1842            abs_path: worktree.abs_path.clone(),
1843        })
1844        .collect::<Vec<_>>();
1845
1846    for collaborator in &collaborators {
1847        session
1848            .peer
1849            .send(
1850                collaborator.peer_id.unwrap().into(),
1851                proto::AddProjectCollaborator {
1852                    project_id: project_id.to_proto(),
1853                    collaborator: Some(proto::Collaborator {
1854                        peer_id: Some(session.connection_id.into()),
1855                        replica_id: replica_id.0 as u32,
1856                        user_id: guest_user_id.to_proto(),
1857                    }),
1858                },
1859            )
1860            .trace_err();
1861    }
1862
1863    // First, we send the metadata associated with each worktree.
1864    response.send(proto::JoinProjectResponse {
1865        project_id: project.id.0 as u64,
1866        worktrees: worktrees.clone(),
1867        replica_id: replica_id.0 as u32,
1868        collaborators: collaborators.clone(),
1869        language_servers: project.language_servers.clone(),
1870        role: project.role.into(), // todo
1871    })?;
1872
1873    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1874        #[cfg(any(test, feature = "test-support"))]
1875        const MAX_CHUNK_SIZE: usize = 2;
1876        #[cfg(not(any(test, feature = "test-support")))]
1877        const MAX_CHUNK_SIZE: usize = 256;
1878
1879        // Stream this worktree's entries.
1880        let message = proto::UpdateWorktree {
1881            project_id: project_id.to_proto(),
1882            worktree_id,
1883            abs_path: worktree.abs_path.clone(),
1884            root_name: worktree.root_name,
1885            updated_entries: worktree.entries,
1886            removed_entries: Default::default(),
1887            scan_id: worktree.scan_id,
1888            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1889            updated_repositories: worktree.repository_entries.into_values().collect(),
1890            removed_repositories: Default::default(),
1891        };
1892        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1893            session.peer.send(session.connection_id, update.clone())?;
1894        }
1895
1896        // Stream this worktree's diagnostics.
1897        for summary in worktree.diagnostic_summaries {
1898            session.peer.send(
1899                session.connection_id,
1900                proto::UpdateDiagnosticSummary {
1901                    project_id: project_id.to_proto(),
1902                    worktree_id: worktree.id,
1903                    summary: Some(summary),
1904                },
1905            )?;
1906        }
1907
1908        for settings_file in worktree.settings_files {
1909            session.peer.send(
1910                session.connection_id,
1911                proto::UpdateWorktreeSettings {
1912                    project_id: project_id.to_proto(),
1913                    worktree_id: worktree.id,
1914                    path: settings_file.path,
1915                    content: Some(settings_file.content),
1916                },
1917            )?;
1918        }
1919    }
1920
1921    for language_server in &project.language_servers {
1922        session.peer.send(
1923            session.connection_id,
1924            proto::UpdateLanguageServer {
1925                project_id: project_id.to_proto(),
1926                language_server_id: language_server.id,
1927                variant: Some(
1928                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1929                        proto::LspDiskBasedDiagnosticsUpdated {},
1930                    ),
1931                ),
1932            },
1933        )?;
1934    }
1935
1936    Ok(())
1937}
1938
1939/// Leave someone elses shared project.
1940async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1941    let sender_id = session.connection_id;
1942    let project_id = ProjectId::from_proto(request.project_id);
1943    let db = session.db().await;
1944    if db.is_hosted_project(project_id).await? {
1945        let project = db.leave_hosted_project(project_id, sender_id).await?;
1946        project_left(&project, &session);
1947        return Ok(());
1948    }
1949
1950    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1951    tracing::info!(
1952        %project_id,
1953        host_user_id = ?project.host_user_id,
1954        host_connection_id = ?project.host_connection_id,
1955        "leave project"
1956    );
1957
1958    project_left(&project, &session);
1959    room_updated(&room, &session.peer);
1960
1961    Ok(())
1962}
1963
1964async fn join_hosted_project(
1965    request: proto::JoinHostedProject,
1966    response: Response<proto::JoinHostedProject>,
1967    session: UserSession,
1968) -> Result<()> {
1969    let (mut project, replica_id) = session
1970        .db()
1971        .await
1972        .join_hosted_project(
1973            ProjectId(request.project_id as i32),
1974            session.user_id(),
1975            session.connection_id,
1976        )
1977        .await?;
1978
1979    join_project_internal(response, session, &mut project, &replica_id)
1980}
1981
1982/// Updates other participants with changes to the project
1983async fn update_project(
1984    request: proto::UpdateProject,
1985    response: Response<proto::UpdateProject>,
1986    session: Session,
1987) -> Result<()> {
1988    let project_id = ProjectId::from_proto(request.project_id);
1989    let (room, guest_connection_ids) = &*session
1990        .db()
1991        .await
1992        .update_project(project_id, session.connection_id, &request.worktrees)
1993        .await?;
1994    broadcast(
1995        Some(session.connection_id),
1996        guest_connection_ids.iter().copied(),
1997        |connection_id| {
1998            session
1999                .peer
2000                .forward_send(session.connection_id, connection_id, request.clone())
2001        },
2002    );
2003    room_updated(&room, &session.peer);
2004    response.send(proto::Ack {})?;
2005
2006    Ok(())
2007}
2008
2009/// Updates other participants with changes to the worktree
2010async fn update_worktree(
2011    request: proto::UpdateWorktree,
2012    response: Response<proto::UpdateWorktree>,
2013    session: Session,
2014) -> Result<()> {
2015    let guest_connection_ids = session
2016        .db()
2017        .await
2018        .update_worktree(&request, session.connection_id)
2019        .await?;
2020
2021    broadcast(
2022        Some(session.connection_id),
2023        guest_connection_ids.iter().copied(),
2024        |connection_id| {
2025            session
2026                .peer
2027                .forward_send(session.connection_id, connection_id, request.clone())
2028        },
2029    );
2030    response.send(proto::Ack {})?;
2031    Ok(())
2032}
2033
2034/// Updates other participants with changes to the diagnostics
2035async fn update_diagnostic_summary(
2036    message: proto::UpdateDiagnosticSummary,
2037    session: Session,
2038) -> Result<()> {
2039    let guest_connection_ids = session
2040        .db()
2041        .await
2042        .update_diagnostic_summary(&message, session.connection_id)
2043        .await?;
2044
2045    broadcast(
2046        Some(session.connection_id),
2047        guest_connection_ids.iter().copied(),
2048        |connection_id| {
2049            session
2050                .peer
2051                .forward_send(session.connection_id, connection_id, message.clone())
2052        },
2053    );
2054
2055    Ok(())
2056}
2057
2058/// Updates other participants with changes to the worktree settings
2059async fn update_worktree_settings(
2060    message: proto::UpdateWorktreeSettings,
2061    session: Session,
2062) -> Result<()> {
2063    let guest_connection_ids = session
2064        .db()
2065        .await
2066        .update_worktree_settings(&message, session.connection_id)
2067        .await?;
2068
2069    broadcast(
2070        Some(session.connection_id),
2071        guest_connection_ids.iter().copied(),
2072        |connection_id| {
2073            session
2074                .peer
2075                .forward_send(session.connection_id, connection_id, message.clone())
2076        },
2077    );
2078
2079    Ok(())
2080}
2081
2082/// Notify other participants that a  language server has started.
2083async fn start_language_server(
2084    request: proto::StartLanguageServer,
2085    session: Session,
2086) -> Result<()> {
2087    let guest_connection_ids = session
2088        .db()
2089        .await
2090        .start_language_server(&request, session.connection_id)
2091        .await?;
2092
2093    broadcast(
2094        Some(session.connection_id),
2095        guest_connection_ids.iter().copied(),
2096        |connection_id| {
2097            session
2098                .peer
2099                .forward_send(session.connection_id, connection_id, request.clone())
2100        },
2101    );
2102    Ok(())
2103}
2104
2105/// Notify other participants that a language server has changed.
2106async fn update_language_server(
2107    request: proto::UpdateLanguageServer,
2108    session: Session,
2109) -> Result<()> {
2110    let project_id = ProjectId::from_proto(request.project_id);
2111    let project_connection_ids = session
2112        .db()
2113        .await
2114        .project_connection_ids(project_id, session.connection_id)
2115        .await?;
2116    broadcast(
2117        Some(session.connection_id),
2118        project_connection_ids.iter().copied(),
2119        |connection_id| {
2120            session
2121                .peer
2122                .forward_send(session.connection_id, connection_id, request.clone())
2123        },
2124    );
2125    Ok(())
2126}
2127
2128/// forward a project request to the host. These requests should be read only
2129/// as guests are allowed to send them.
2130async fn forward_read_only_project_request<T>(
2131    request: T,
2132    response: Response<T>,
2133    session: Session,
2134) -> Result<()>
2135where
2136    T: EntityMessage + RequestMessage,
2137{
2138    let project_id = ProjectId::from_proto(request.remote_entity_id());
2139    let host_connection_id = session
2140        .db()
2141        .await
2142        .host_for_read_only_project_request(project_id, session.connection_id)
2143        .await?;
2144    let payload = session
2145        .peer
2146        .forward_request(session.connection_id, host_connection_id, request)
2147        .await?;
2148    response.send(payload)?;
2149    Ok(())
2150}
2151
2152/// forward a project request to the host. These requests are disallowed
2153/// for guests.
2154async fn forward_mutating_project_request<T>(
2155    request: T,
2156    response: Response<T>,
2157    session: Session,
2158) -> Result<()>
2159where
2160    T: EntityMessage + RequestMessage,
2161{
2162    let project_id = ProjectId::from_proto(request.remote_entity_id());
2163    let host_connection_id = session
2164        .db()
2165        .await
2166        .host_for_mutating_project_request(project_id, session.connection_id)
2167        .await?;
2168    let payload = session
2169        .peer
2170        .forward_request(session.connection_id, host_connection_id, request)
2171        .await?;
2172    response.send(payload)?;
2173    Ok(())
2174}
2175
2176/// Notify other participants that a new buffer has been created
2177async fn create_buffer_for_peer(
2178    request: proto::CreateBufferForPeer,
2179    session: Session,
2180) -> Result<()> {
2181    session
2182        .db()
2183        .await
2184        .check_user_is_project_host(
2185            ProjectId::from_proto(request.project_id),
2186            session.connection_id,
2187        )
2188        .await?;
2189    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2190    session
2191        .peer
2192        .forward_send(session.connection_id, peer_id.into(), request)?;
2193    Ok(())
2194}
2195
2196/// Notify other participants that a buffer has been updated. This is
2197/// allowed for guests as long as the update is limited to selections.
2198async fn update_buffer(
2199    request: proto::UpdateBuffer,
2200    response: Response<proto::UpdateBuffer>,
2201    session: Session,
2202) -> Result<()> {
2203    let project_id = ProjectId::from_proto(request.project_id);
2204    let mut guest_connection_ids;
2205    let mut host_connection_id = None;
2206
2207    let mut requires_write_permission = false;
2208
2209    for op in request.operations.iter() {
2210        match op.variant {
2211            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2212            Some(_) => requires_write_permission = true,
2213        }
2214    }
2215
2216    {
2217        let collaborators = session
2218            .db()
2219            .await
2220            .project_collaborators_for_buffer_update(
2221                project_id,
2222                session.connection_id,
2223                requires_write_permission,
2224            )
2225            .await?;
2226        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2227        for collaborator in collaborators.iter() {
2228            if collaborator.is_host {
2229                host_connection_id = Some(collaborator.connection_id);
2230            } else {
2231                guest_connection_ids.push(collaborator.connection_id);
2232            }
2233        }
2234    }
2235    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2236
2237    broadcast(
2238        Some(session.connection_id),
2239        guest_connection_ids,
2240        |connection_id| {
2241            session
2242                .peer
2243                .forward_send(session.connection_id, connection_id, request.clone())
2244        },
2245    );
2246    if host_connection_id != session.connection_id {
2247        session
2248            .peer
2249            .forward_request(session.connection_id, host_connection_id, request.clone())
2250            .await?;
2251    }
2252
2253    response.send(proto::Ack {})?;
2254    Ok(())
2255}
2256
2257/// Notify other participants that a project has been updated.
2258async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2259    request: T,
2260    session: Session,
2261) -> Result<()> {
2262    let project_id = ProjectId::from_proto(request.remote_entity_id());
2263    let project_connection_ids = session
2264        .db()
2265        .await
2266        .project_connection_ids(project_id, session.connection_id)
2267        .await?;
2268
2269    broadcast(
2270        Some(session.connection_id),
2271        project_connection_ids.iter().copied(),
2272        |connection_id| {
2273            session
2274                .peer
2275                .forward_send(session.connection_id, connection_id, request.clone())
2276        },
2277    );
2278    Ok(())
2279}
2280
2281/// Start following another user in a call.
2282async fn follow(
2283    request: proto::Follow,
2284    response: Response<proto::Follow>,
2285    session: UserSession,
2286) -> Result<()> {
2287    let room_id = RoomId::from_proto(request.room_id);
2288    let project_id = request.project_id.map(ProjectId::from_proto);
2289    let leader_id = request
2290        .leader_id
2291        .ok_or_else(|| anyhow!("invalid leader id"))?
2292        .into();
2293    let follower_id = session.connection_id;
2294
2295    session
2296        .db()
2297        .await
2298        .check_room_participants(room_id, leader_id, session.connection_id)
2299        .await?;
2300
2301    let response_payload = session
2302        .peer
2303        .forward_request(session.connection_id, leader_id, request)
2304        .await?;
2305    response.send(response_payload)?;
2306
2307    if let Some(project_id) = project_id {
2308        let room = session
2309            .db()
2310            .await
2311            .follow(room_id, project_id, leader_id, follower_id)
2312            .await?;
2313        room_updated(&room, &session.peer);
2314    }
2315
2316    Ok(())
2317}
2318
2319/// Stop following another user in a call.
2320async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2321    let room_id = RoomId::from_proto(request.room_id);
2322    let project_id = request.project_id.map(ProjectId::from_proto);
2323    let leader_id = request
2324        .leader_id
2325        .ok_or_else(|| anyhow!("invalid leader id"))?
2326        .into();
2327    let follower_id = session.connection_id;
2328
2329    session
2330        .db()
2331        .await
2332        .check_room_participants(room_id, leader_id, session.connection_id)
2333        .await?;
2334
2335    session
2336        .peer
2337        .forward_send(session.connection_id, leader_id, request)?;
2338
2339    if let Some(project_id) = project_id {
2340        let room = session
2341            .db()
2342            .await
2343            .unfollow(room_id, project_id, leader_id, follower_id)
2344            .await?;
2345        room_updated(&room, &session.peer);
2346    }
2347
2348    Ok(())
2349}
2350
2351/// Notify everyone following you of your current location.
2352async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2353    let room_id = RoomId::from_proto(request.room_id);
2354    let database = session.db.lock().await;
2355
2356    let connection_ids = if let Some(project_id) = request.project_id {
2357        let project_id = ProjectId::from_proto(project_id);
2358        database
2359            .project_connection_ids(project_id, session.connection_id)
2360            .await?
2361    } else {
2362        database
2363            .room_connection_ids(room_id, session.connection_id)
2364            .await?
2365    };
2366
2367    // For now, don't send view update messages back to that view's current leader.
2368    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2369        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2370        _ => None,
2371    });
2372
2373    for connection_id in connection_ids.iter().cloned() {
2374        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2375            session
2376                .peer
2377                .forward_send(session.connection_id, connection_id, request.clone())?;
2378        }
2379    }
2380    Ok(())
2381}
2382
2383/// Get public data about users.
2384async fn get_users(
2385    request: proto::GetUsers,
2386    response: Response<proto::GetUsers>,
2387    session: Session,
2388) -> Result<()> {
2389    let user_ids = request
2390        .user_ids
2391        .into_iter()
2392        .map(UserId::from_proto)
2393        .collect();
2394    let users = session
2395        .db()
2396        .await
2397        .get_users_by_ids(user_ids)
2398        .await?
2399        .into_iter()
2400        .map(|user| proto::User {
2401            id: user.id.to_proto(),
2402            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2403            github_login: user.github_login,
2404        })
2405        .collect();
2406    response.send(proto::UsersResponse { users })?;
2407    Ok(())
2408}
2409
2410/// Search for users (to invite) buy Github login
2411async fn fuzzy_search_users(
2412    request: proto::FuzzySearchUsers,
2413    response: Response<proto::FuzzySearchUsers>,
2414    session: UserSession,
2415) -> Result<()> {
2416    let query = request.query;
2417    let users = match query.len() {
2418        0 => vec![],
2419        1 | 2 => session
2420            .db()
2421            .await
2422            .get_user_by_github_login(&query)
2423            .await?
2424            .into_iter()
2425            .collect(),
2426        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2427    };
2428    let users = users
2429        .into_iter()
2430        .filter(|user| user.id != session.user_id())
2431        .map(|user| proto::User {
2432            id: user.id.to_proto(),
2433            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2434            github_login: user.github_login,
2435        })
2436        .collect();
2437    response.send(proto::UsersResponse { users })?;
2438    Ok(())
2439}
2440
2441/// Send a contact request to another user.
2442async fn request_contact(
2443    request: proto::RequestContact,
2444    response: Response<proto::RequestContact>,
2445    session: UserSession,
2446) -> Result<()> {
2447    let requester_id = session.user_id();
2448    let responder_id = UserId::from_proto(request.responder_id);
2449    if requester_id == responder_id {
2450        return Err(anyhow!("cannot add yourself as a contact"))?;
2451    }
2452
2453    let notifications = session
2454        .db()
2455        .await
2456        .send_contact_request(requester_id, responder_id)
2457        .await?;
2458
2459    // Update outgoing contact requests of requester
2460    let mut update = proto::UpdateContacts::default();
2461    update.outgoing_requests.push(responder_id.to_proto());
2462    for connection_id in session
2463        .connection_pool()
2464        .await
2465        .user_connection_ids(requester_id)
2466    {
2467        session.peer.send(connection_id, update.clone())?;
2468    }
2469
2470    // Update incoming contact requests of responder
2471    let mut update = proto::UpdateContacts::default();
2472    update
2473        .incoming_requests
2474        .push(proto::IncomingContactRequest {
2475            requester_id: requester_id.to_proto(),
2476        });
2477    let connection_pool = session.connection_pool().await;
2478    for connection_id in connection_pool.user_connection_ids(responder_id) {
2479        session.peer.send(connection_id, update.clone())?;
2480    }
2481
2482    send_notifications(&connection_pool, &session.peer, notifications);
2483
2484    response.send(proto::Ack {})?;
2485    Ok(())
2486}
2487
2488/// Accept or decline a contact request
2489async fn respond_to_contact_request(
2490    request: proto::RespondToContactRequest,
2491    response: Response<proto::RespondToContactRequest>,
2492    session: UserSession,
2493) -> Result<()> {
2494    let responder_id = session.user_id();
2495    let requester_id = UserId::from_proto(request.requester_id);
2496    let db = session.db().await;
2497    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2498        db.dismiss_contact_notification(responder_id, requester_id)
2499            .await?;
2500    } else {
2501        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2502
2503        let notifications = db
2504            .respond_to_contact_request(responder_id, requester_id, accept)
2505            .await?;
2506        let requester_busy = db.is_user_busy(requester_id).await?;
2507        let responder_busy = db.is_user_busy(responder_id).await?;
2508
2509        let pool = session.connection_pool().await;
2510        // Update responder with new contact
2511        let mut update = proto::UpdateContacts::default();
2512        if accept {
2513            update
2514                .contacts
2515                .push(contact_for_user(requester_id, requester_busy, &pool));
2516        }
2517        update
2518            .remove_incoming_requests
2519            .push(requester_id.to_proto());
2520        for connection_id in pool.user_connection_ids(responder_id) {
2521            session.peer.send(connection_id, update.clone())?;
2522        }
2523
2524        // Update requester with new contact
2525        let mut update = proto::UpdateContacts::default();
2526        if accept {
2527            update
2528                .contacts
2529                .push(contact_for_user(responder_id, responder_busy, &pool));
2530        }
2531        update
2532            .remove_outgoing_requests
2533            .push(responder_id.to_proto());
2534
2535        for connection_id in pool.user_connection_ids(requester_id) {
2536            session.peer.send(connection_id, update.clone())?;
2537        }
2538
2539        send_notifications(&pool, &session.peer, notifications);
2540    }
2541
2542    response.send(proto::Ack {})?;
2543    Ok(())
2544}
2545
2546/// Remove a contact.
2547async fn remove_contact(
2548    request: proto::RemoveContact,
2549    response: Response<proto::RemoveContact>,
2550    session: UserSession,
2551) -> Result<()> {
2552    let requester_id = session.user_id();
2553    let responder_id = UserId::from_proto(request.user_id);
2554    let db = session.db().await;
2555    let (contact_accepted, deleted_notification_id) =
2556        db.remove_contact(requester_id, responder_id).await?;
2557
2558    let pool = session.connection_pool().await;
2559    // Update outgoing contact requests of requester
2560    let mut update = proto::UpdateContacts::default();
2561    if contact_accepted {
2562        update.remove_contacts.push(responder_id.to_proto());
2563    } else {
2564        update
2565            .remove_outgoing_requests
2566            .push(responder_id.to_proto());
2567    }
2568    for connection_id in pool.user_connection_ids(requester_id) {
2569        session.peer.send(connection_id, update.clone())?;
2570    }
2571
2572    // Update incoming contact requests of responder
2573    let mut update = proto::UpdateContacts::default();
2574    if contact_accepted {
2575        update.remove_contacts.push(requester_id.to_proto());
2576    } else {
2577        update
2578            .remove_incoming_requests
2579            .push(requester_id.to_proto());
2580    }
2581    for connection_id in pool.user_connection_ids(responder_id) {
2582        session.peer.send(connection_id, update.clone())?;
2583        if let Some(notification_id) = deleted_notification_id {
2584            session.peer.send(
2585                connection_id,
2586                proto::DeleteNotification {
2587                    notification_id: notification_id.to_proto(),
2588                },
2589            )?;
2590        }
2591    }
2592
2593    response.send(proto::Ack {})?;
2594    Ok(())
2595}
2596
2597/// Creates a new channel.
2598async fn create_channel(
2599    request: proto::CreateChannel,
2600    response: Response<proto::CreateChannel>,
2601    session: UserSession,
2602) -> Result<()> {
2603    let db = session.db().await;
2604
2605    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2606    let (channel, membership) = db
2607        .create_channel(&request.name, parent_id, session.user_id())
2608        .await?;
2609
2610    let root_id = channel.root_id();
2611    let channel = Channel::from_model(channel);
2612
2613    response.send(proto::CreateChannelResponse {
2614        channel: Some(channel.to_proto()),
2615        parent_id: request.parent_id,
2616    })?;
2617
2618    let mut connection_pool = session.connection_pool().await;
2619    if let Some(membership) = membership {
2620        connection_pool.subscribe_to_channel(
2621            membership.user_id,
2622            membership.channel_id,
2623            membership.role,
2624        );
2625        let update = proto::UpdateUserChannels {
2626            channel_memberships: vec![proto::ChannelMembership {
2627                channel_id: membership.channel_id.to_proto(),
2628                role: membership.role.into(),
2629            }],
2630            ..Default::default()
2631        };
2632        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2633            session.peer.send(connection_id, update.clone())?;
2634        }
2635    }
2636
2637    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2638        if !role.can_see_channel(channel.visibility) {
2639            continue;
2640        }
2641
2642        let update = proto::UpdateChannels {
2643            channels: vec![channel.to_proto()],
2644            ..Default::default()
2645        };
2646        session.peer.send(connection_id, update.clone())?;
2647    }
2648
2649    Ok(())
2650}
2651
2652/// Delete a channel
2653async fn delete_channel(
2654    request: proto::DeleteChannel,
2655    response: Response<proto::DeleteChannel>,
2656    session: UserSession,
2657) -> Result<()> {
2658    let db = session.db().await;
2659
2660    let channel_id = request.channel_id;
2661    let (root_channel, removed_channels) = db
2662        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2663        .await?;
2664    response.send(proto::Ack {})?;
2665
2666    // Notify members of removed channels
2667    let mut update = proto::UpdateChannels::default();
2668    update
2669        .delete_channels
2670        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2671
2672    let connection_pool = session.connection_pool().await;
2673    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2674        session.peer.send(connection_id, update.clone())?;
2675    }
2676
2677    Ok(())
2678}
2679
2680/// Invite someone to join a channel.
2681async fn invite_channel_member(
2682    request: proto::InviteChannelMember,
2683    response: Response<proto::InviteChannelMember>,
2684    session: UserSession,
2685) -> Result<()> {
2686    let db = session.db().await;
2687    let channel_id = ChannelId::from_proto(request.channel_id);
2688    let invitee_id = UserId::from_proto(request.user_id);
2689    let InviteMemberResult {
2690        channel,
2691        notifications,
2692    } = db
2693        .invite_channel_member(
2694            channel_id,
2695            invitee_id,
2696            session.user_id(),
2697            request.role().into(),
2698        )
2699        .await?;
2700
2701    let update = proto::UpdateChannels {
2702        channel_invitations: vec![channel.to_proto()],
2703        ..Default::default()
2704    };
2705
2706    let connection_pool = session.connection_pool().await;
2707    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2708        session.peer.send(connection_id, update.clone())?;
2709    }
2710
2711    send_notifications(&connection_pool, &session.peer, notifications);
2712
2713    response.send(proto::Ack {})?;
2714    Ok(())
2715}
2716
2717/// remove someone from a channel
2718async fn remove_channel_member(
2719    request: proto::RemoveChannelMember,
2720    response: Response<proto::RemoveChannelMember>,
2721    session: UserSession,
2722) -> Result<()> {
2723    let db = session.db().await;
2724    let channel_id = ChannelId::from_proto(request.channel_id);
2725    let member_id = UserId::from_proto(request.user_id);
2726
2727    let RemoveChannelMemberResult {
2728        membership_update,
2729        notification_id,
2730    } = db
2731        .remove_channel_member(channel_id, member_id, session.user_id())
2732        .await?;
2733
2734    let mut connection_pool = session.connection_pool().await;
2735    notify_membership_updated(
2736        &mut connection_pool,
2737        membership_update,
2738        member_id,
2739        &session.peer,
2740    );
2741    for connection_id in connection_pool.user_connection_ids(member_id) {
2742        if let Some(notification_id) = notification_id {
2743            session
2744                .peer
2745                .send(
2746                    connection_id,
2747                    proto::DeleteNotification {
2748                        notification_id: notification_id.to_proto(),
2749                    },
2750                )
2751                .trace_err();
2752        }
2753    }
2754
2755    response.send(proto::Ack {})?;
2756    Ok(())
2757}
2758
2759/// Toggle the channel between public and private.
2760/// Care is taken to maintain the invariant that public channels only descend from public channels,
2761/// (though members-only channels can appear at any point in the hierarchy).
2762async fn set_channel_visibility(
2763    request: proto::SetChannelVisibility,
2764    response: Response<proto::SetChannelVisibility>,
2765    session: UserSession,
2766) -> Result<()> {
2767    let db = session.db().await;
2768    let channel_id = ChannelId::from_proto(request.channel_id);
2769    let visibility = request.visibility().into();
2770
2771    let channel_model = db
2772        .set_channel_visibility(channel_id, visibility, session.user_id())
2773        .await?;
2774    let root_id = channel_model.root_id();
2775    let channel = Channel::from_model(channel_model);
2776
2777    let mut connection_pool = session.connection_pool().await;
2778    for (user_id, role) in connection_pool
2779        .channel_user_ids(root_id)
2780        .collect::<Vec<_>>()
2781        .into_iter()
2782    {
2783        let update = if role.can_see_channel(channel.visibility) {
2784            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2785            proto::UpdateChannels {
2786                channels: vec![channel.to_proto()],
2787                ..Default::default()
2788            }
2789        } else {
2790            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2791            proto::UpdateChannels {
2792                delete_channels: vec![channel.id.to_proto()],
2793                ..Default::default()
2794            }
2795        };
2796
2797        for connection_id in connection_pool.user_connection_ids(user_id) {
2798            session.peer.send(connection_id, update.clone())?;
2799        }
2800    }
2801
2802    response.send(proto::Ack {})?;
2803    Ok(())
2804}
2805
2806/// Alter the role for a user in the channel.
2807async fn set_channel_member_role(
2808    request: proto::SetChannelMemberRole,
2809    response: Response<proto::SetChannelMemberRole>,
2810    session: UserSession,
2811) -> Result<()> {
2812    let db = session.db().await;
2813    let channel_id = ChannelId::from_proto(request.channel_id);
2814    let member_id = UserId::from_proto(request.user_id);
2815    let result = db
2816        .set_channel_member_role(
2817            channel_id,
2818            session.user_id(),
2819            member_id,
2820            request.role().into(),
2821        )
2822        .await?;
2823
2824    match result {
2825        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2826            let mut connection_pool = session.connection_pool().await;
2827            notify_membership_updated(
2828                &mut connection_pool,
2829                membership_update,
2830                member_id,
2831                &session.peer,
2832            )
2833        }
2834        db::SetMemberRoleResult::InviteUpdated(channel) => {
2835            let update = proto::UpdateChannels {
2836                channel_invitations: vec![channel.to_proto()],
2837                ..Default::default()
2838            };
2839
2840            for connection_id in session
2841                .connection_pool()
2842                .await
2843                .user_connection_ids(member_id)
2844            {
2845                session.peer.send(connection_id, update.clone())?;
2846            }
2847        }
2848    }
2849
2850    response.send(proto::Ack {})?;
2851    Ok(())
2852}
2853
2854/// Change the name of a channel
2855async fn rename_channel(
2856    request: proto::RenameChannel,
2857    response: Response<proto::RenameChannel>,
2858    session: UserSession,
2859) -> Result<()> {
2860    let db = session.db().await;
2861    let channel_id = ChannelId::from_proto(request.channel_id);
2862    let channel_model = db
2863        .rename_channel(channel_id, session.user_id(), &request.name)
2864        .await?;
2865    let root_id = channel_model.root_id();
2866    let channel = Channel::from_model(channel_model);
2867
2868    response.send(proto::RenameChannelResponse {
2869        channel: Some(channel.to_proto()),
2870    })?;
2871
2872    let connection_pool = session.connection_pool().await;
2873    let update = proto::UpdateChannels {
2874        channels: vec![channel.to_proto()],
2875        ..Default::default()
2876    };
2877    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2878        if role.can_see_channel(channel.visibility) {
2879            session.peer.send(connection_id, update.clone())?;
2880        }
2881    }
2882
2883    Ok(())
2884}
2885
2886/// Move a channel to a new parent.
2887async fn move_channel(
2888    request: proto::MoveChannel,
2889    response: Response<proto::MoveChannel>,
2890    session: UserSession,
2891) -> Result<()> {
2892    let channel_id = ChannelId::from_proto(request.channel_id);
2893    let to = ChannelId::from_proto(request.to);
2894
2895    let (root_id, channels) = session
2896        .db()
2897        .await
2898        .move_channel(channel_id, to, session.user_id())
2899        .await?;
2900
2901    let connection_pool = session.connection_pool().await;
2902    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2903        let channels = channels
2904            .iter()
2905            .filter_map(|channel| {
2906                if role.can_see_channel(channel.visibility) {
2907                    Some(channel.to_proto())
2908                } else {
2909                    None
2910                }
2911            })
2912            .collect::<Vec<_>>();
2913        if channels.is_empty() {
2914            continue;
2915        }
2916
2917        let update = proto::UpdateChannels {
2918            channels,
2919            ..Default::default()
2920        };
2921
2922        session.peer.send(connection_id, update.clone())?;
2923    }
2924
2925    response.send(Ack {})?;
2926    Ok(())
2927}
2928
2929/// Get the list of channel members
2930async fn get_channel_members(
2931    request: proto::GetChannelMembers,
2932    response: Response<proto::GetChannelMembers>,
2933    session: UserSession,
2934) -> Result<()> {
2935    let db = session.db().await;
2936    let channel_id = ChannelId::from_proto(request.channel_id);
2937    let members = db
2938        .get_channel_participant_details(channel_id, session.user_id())
2939        .await?;
2940    response.send(proto::GetChannelMembersResponse { members })?;
2941    Ok(())
2942}
2943
2944/// Accept or decline a channel invitation.
2945async fn respond_to_channel_invite(
2946    request: proto::RespondToChannelInvite,
2947    response: Response<proto::RespondToChannelInvite>,
2948    session: UserSession,
2949) -> Result<()> {
2950    let db = session.db().await;
2951    let channel_id = ChannelId::from_proto(request.channel_id);
2952    let RespondToChannelInvite {
2953        membership_update,
2954        notifications,
2955    } = db
2956        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2957        .await?;
2958
2959    let mut connection_pool = session.connection_pool().await;
2960    if let Some(membership_update) = membership_update {
2961        notify_membership_updated(
2962            &mut connection_pool,
2963            membership_update,
2964            session.user_id(),
2965            &session.peer,
2966        );
2967    } else {
2968        let update = proto::UpdateChannels {
2969            remove_channel_invitations: vec![channel_id.to_proto()],
2970            ..Default::default()
2971        };
2972
2973        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2974            session.peer.send(connection_id, update.clone())?;
2975        }
2976    };
2977
2978    send_notifications(&connection_pool, &session.peer, notifications);
2979
2980    response.send(proto::Ack {})?;
2981
2982    Ok(())
2983}
2984
2985/// Join the channels' room
2986async fn join_channel(
2987    request: proto::JoinChannel,
2988    response: Response<proto::JoinChannel>,
2989    session: UserSession,
2990) -> Result<()> {
2991    let channel_id = ChannelId::from_proto(request.channel_id);
2992    join_channel_internal(channel_id, Box::new(response), session).await
2993}
2994
2995trait JoinChannelInternalResponse {
2996    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2997}
2998impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2999    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3000        Response::<proto::JoinChannel>::send(self, result)
3001    }
3002}
3003impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3004    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3005        Response::<proto::JoinRoom>::send(self, result)
3006    }
3007}
3008
3009async fn join_channel_internal(
3010    channel_id: ChannelId,
3011    response: Box<impl JoinChannelInternalResponse>,
3012    session: UserSession,
3013) -> Result<()> {
3014    let joined_room = {
3015        leave_room_for_session(&session).await?;
3016        let db = session.db().await;
3017
3018        let (joined_room, membership_updated, role) = db
3019            .join_channel(channel_id, session.user_id(), session.connection_id)
3020            .await?;
3021
3022        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3023            let (can_publish, token) = if role == ChannelRole::Guest {
3024                (
3025                    false,
3026                    live_kit
3027                        .guest_token(
3028                            &joined_room.room.live_kit_room,
3029                            &session.user_id().to_string(),
3030                        )
3031                        .trace_err()?,
3032                )
3033            } else {
3034                (
3035                    true,
3036                    live_kit
3037                        .room_token(
3038                            &joined_room.room.live_kit_room,
3039                            &session.user_id().to_string(),
3040                        )
3041                        .trace_err()?,
3042                )
3043            };
3044
3045            Some(LiveKitConnectionInfo {
3046                server_url: live_kit.url().into(),
3047                token,
3048                can_publish,
3049            })
3050        });
3051
3052        response.send(proto::JoinRoomResponse {
3053            room: Some(joined_room.room.clone()),
3054            channel_id: joined_room
3055                .channel
3056                .as_ref()
3057                .map(|channel| channel.id.to_proto()),
3058            live_kit_connection_info,
3059        })?;
3060
3061        let mut connection_pool = session.connection_pool().await;
3062        if let Some(membership_updated) = membership_updated {
3063            notify_membership_updated(
3064                &mut connection_pool,
3065                membership_updated,
3066                session.user_id(),
3067                &session.peer,
3068            );
3069        }
3070
3071        room_updated(&joined_room.room, &session.peer);
3072
3073        joined_room
3074    };
3075
3076    channel_updated(
3077        &joined_room
3078            .channel
3079            .ok_or_else(|| anyhow!("channel not returned"))?,
3080        &joined_room.room,
3081        &session.peer,
3082        &*session.connection_pool().await,
3083    );
3084
3085    update_user_contacts(session.user_id(), &session).await?;
3086    Ok(())
3087}
3088
3089/// Start editing the channel notes
3090async fn join_channel_buffer(
3091    request: proto::JoinChannelBuffer,
3092    response: Response<proto::JoinChannelBuffer>,
3093    session: UserSession,
3094) -> Result<()> {
3095    let db = session.db().await;
3096    let channel_id = ChannelId::from_proto(request.channel_id);
3097
3098    let open_response = db
3099        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3100        .await?;
3101
3102    let collaborators = open_response.collaborators.clone();
3103    response.send(open_response)?;
3104
3105    let update = UpdateChannelBufferCollaborators {
3106        channel_id: channel_id.to_proto(),
3107        collaborators: collaborators.clone(),
3108    };
3109    channel_buffer_updated(
3110        session.connection_id,
3111        collaborators
3112            .iter()
3113            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3114        &update,
3115        &session.peer,
3116    );
3117
3118    Ok(())
3119}
3120
3121/// Edit the channel notes
3122async fn update_channel_buffer(
3123    request: proto::UpdateChannelBuffer,
3124    session: UserSession,
3125) -> Result<()> {
3126    let db = session.db().await;
3127    let channel_id = ChannelId::from_proto(request.channel_id);
3128
3129    let (collaborators, non_collaborators, epoch, version) = db
3130        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3131        .await?;
3132
3133    channel_buffer_updated(
3134        session.connection_id,
3135        collaborators,
3136        &proto::UpdateChannelBuffer {
3137            channel_id: channel_id.to_proto(),
3138            operations: request.operations,
3139        },
3140        &session.peer,
3141    );
3142
3143    let pool = &*session.connection_pool().await;
3144
3145    broadcast(
3146        None,
3147        non_collaborators
3148            .iter()
3149            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3150        |peer_id| {
3151            session.peer.send(
3152                peer_id,
3153                proto::UpdateChannels {
3154                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3155                        channel_id: channel_id.to_proto(),
3156                        epoch: epoch as u64,
3157                        version: version.clone(),
3158                    }],
3159                    ..Default::default()
3160                },
3161            )
3162        },
3163    );
3164
3165    Ok(())
3166}
3167
3168/// Rejoin the channel notes after a connection blip
3169async fn rejoin_channel_buffers(
3170    request: proto::RejoinChannelBuffers,
3171    response: Response<proto::RejoinChannelBuffers>,
3172    session: UserSession,
3173) -> Result<()> {
3174    let db = session.db().await;
3175    let buffers = db
3176        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3177        .await?;
3178
3179    for rejoined_buffer in &buffers {
3180        let collaborators_to_notify = rejoined_buffer
3181            .buffer
3182            .collaborators
3183            .iter()
3184            .filter_map(|c| Some(c.peer_id?.into()));
3185        channel_buffer_updated(
3186            session.connection_id,
3187            collaborators_to_notify,
3188            &proto::UpdateChannelBufferCollaborators {
3189                channel_id: rejoined_buffer.buffer.channel_id,
3190                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3191            },
3192            &session.peer,
3193        );
3194    }
3195
3196    response.send(proto::RejoinChannelBuffersResponse {
3197        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3198    })?;
3199
3200    Ok(())
3201}
3202
3203/// Stop editing the channel notes
3204async fn leave_channel_buffer(
3205    request: proto::LeaveChannelBuffer,
3206    response: Response<proto::LeaveChannelBuffer>,
3207    session: UserSession,
3208) -> Result<()> {
3209    let db = session.db().await;
3210    let channel_id = ChannelId::from_proto(request.channel_id);
3211
3212    let left_buffer = db
3213        .leave_channel_buffer(channel_id, session.connection_id)
3214        .await?;
3215
3216    response.send(Ack {})?;
3217
3218    channel_buffer_updated(
3219        session.connection_id,
3220        left_buffer.connections,
3221        &proto::UpdateChannelBufferCollaborators {
3222            channel_id: channel_id.to_proto(),
3223            collaborators: left_buffer.collaborators,
3224        },
3225        &session.peer,
3226    );
3227
3228    Ok(())
3229}
3230
3231fn channel_buffer_updated<T: EnvelopedMessage>(
3232    sender_id: ConnectionId,
3233    collaborators: impl IntoIterator<Item = ConnectionId>,
3234    message: &T,
3235    peer: &Peer,
3236) {
3237    broadcast(Some(sender_id), collaborators, |peer_id| {
3238        peer.send(peer_id, message.clone())
3239    });
3240}
3241
3242fn send_notifications(
3243    connection_pool: &ConnectionPool,
3244    peer: &Peer,
3245    notifications: db::NotificationBatch,
3246) {
3247    for (user_id, notification) in notifications {
3248        for connection_id in connection_pool.user_connection_ids(user_id) {
3249            if let Err(error) = peer.send(
3250                connection_id,
3251                proto::AddNotification {
3252                    notification: Some(notification.clone()),
3253                },
3254            ) {
3255                tracing::error!(
3256                    "failed to send notification to {:?} {}",
3257                    connection_id,
3258                    error
3259                );
3260            }
3261        }
3262    }
3263}
3264
3265/// Send a message to the channel
3266async fn send_channel_message(
3267    request: proto::SendChannelMessage,
3268    response: Response<proto::SendChannelMessage>,
3269    session: UserSession,
3270) -> Result<()> {
3271    // Validate the message body.
3272    let body = request.body.trim().to_string();
3273    if body.len() > MAX_MESSAGE_LEN {
3274        return Err(anyhow!("message is too long"))?;
3275    }
3276    if body.is_empty() {
3277        return Err(anyhow!("message can't be blank"))?;
3278    }
3279
3280    // TODO: adjust mentions if body is trimmed
3281
3282    let timestamp = OffsetDateTime::now_utc();
3283    let nonce = request
3284        .nonce
3285        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3286
3287    let channel_id = ChannelId::from_proto(request.channel_id);
3288    let CreatedChannelMessage {
3289        message_id,
3290        participant_connection_ids,
3291        channel_members,
3292        notifications,
3293    } = session
3294        .db()
3295        .await
3296        .create_channel_message(
3297            channel_id,
3298            session.user_id(),
3299            &body,
3300            &request.mentions,
3301            timestamp,
3302            nonce.clone().into(),
3303            match request.reply_to_message_id {
3304                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3305                None => None,
3306            },
3307        )
3308        .await?;
3309
3310    let message = proto::ChannelMessage {
3311        sender_id: session.user_id().to_proto(),
3312        id: message_id.to_proto(),
3313        body,
3314        mentions: request.mentions,
3315        timestamp: timestamp.unix_timestamp() as u64,
3316        nonce: Some(nonce),
3317        reply_to_message_id: request.reply_to_message_id,
3318        edited_at: None,
3319    };
3320    broadcast(
3321        Some(session.connection_id),
3322        participant_connection_ids,
3323        |connection| {
3324            session.peer.send(
3325                connection,
3326                proto::ChannelMessageSent {
3327                    channel_id: channel_id.to_proto(),
3328                    message: Some(message.clone()),
3329                },
3330            )
3331        },
3332    );
3333    response.send(proto::SendChannelMessageResponse {
3334        message: Some(message),
3335    })?;
3336
3337    let pool = &*session.connection_pool().await;
3338    broadcast(
3339        None,
3340        channel_members
3341            .iter()
3342            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3343        |peer_id| {
3344            session.peer.send(
3345                peer_id,
3346                proto::UpdateChannels {
3347                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3348                        channel_id: channel_id.to_proto(),
3349                        message_id: message_id.to_proto(),
3350                    }],
3351                    ..Default::default()
3352                },
3353            )
3354        },
3355    );
3356    send_notifications(pool, &session.peer, notifications);
3357
3358    Ok(())
3359}
3360
3361/// Delete a channel message
3362async fn remove_channel_message(
3363    request: proto::RemoveChannelMessage,
3364    response: Response<proto::RemoveChannelMessage>,
3365    session: UserSession,
3366) -> Result<()> {
3367    let channel_id = ChannelId::from_proto(request.channel_id);
3368    let message_id = MessageId::from_proto(request.message_id);
3369    let connection_ids = session
3370        .db()
3371        .await
3372        .remove_channel_message(channel_id, message_id, session.user_id())
3373        .await?;
3374    broadcast(Some(session.connection_id), connection_ids, |connection| {
3375        session.peer.send(connection, request.clone())
3376    });
3377    response.send(proto::Ack {})?;
3378    Ok(())
3379}
3380
3381async fn update_channel_message(
3382    request: proto::UpdateChannelMessage,
3383    response: Response<proto::UpdateChannelMessage>,
3384    session: UserSession,
3385) -> Result<()> {
3386    let channel_id = ChannelId::from_proto(request.channel_id);
3387    let message_id = MessageId::from_proto(request.message_id);
3388    let updated_at = OffsetDateTime::now_utc();
3389    let UpdatedChannelMessage {
3390        message_id,
3391        participant_connection_ids,
3392        notifications,
3393        reply_to_message_id,
3394        timestamp,
3395    } = session
3396        .db()
3397        .await
3398        .update_channel_message(
3399            channel_id,
3400            message_id,
3401            session.user_id(),
3402            request.body.as_str(),
3403            &request.mentions,
3404            updated_at,
3405        )
3406        .await?;
3407
3408    let nonce = request
3409        .nonce
3410        .clone()
3411        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3412
3413    let message = proto::ChannelMessage {
3414        sender_id: session.user_id().to_proto(),
3415        id: message_id.to_proto(),
3416        body: request.body.clone(),
3417        mentions: request.mentions.clone(),
3418        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3419        nonce: Some(nonce),
3420        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3421        edited_at: Some(updated_at.unix_timestamp() as u64),
3422    };
3423
3424    response.send(proto::Ack {})?;
3425
3426    let pool = &*session.connection_pool().await;
3427    broadcast(
3428        Some(session.connection_id),
3429        participant_connection_ids,
3430        |connection| {
3431            session.peer.send(
3432                connection,
3433                proto::ChannelMessageUpdate {
3434                    channel_id: channel_id.to_proto(),
3435                    message: Some(message.clone()),
3436                },
3437            )
3438        },
3439    );
3440
3441    send_notifications(pool, &session.peer, notifications);
3442
3443    Ok(())
3444}
3445
3446/// Mark a channel message as read
3447async fn acknowledge_channel_message(
3448    request: proto::AckChannelMessage,
3449    session: UserSession,
3450) -> Result<()> {
3451    let channel_id = ChannelId::from_proto(request.channel_id);
3452    let message_id = MessageId::from_proto(request.message_id);
3453    let notifications = session
3454        .db()
3455        .await
3456        .observe_channel_message(channel_id, session.user_id(), message_id)
3457        .await?;
3458    send_notifications(
3459        &*session.connection_pool().await,
3460        &session.peer,
3461        notifications,
3462    );
3463    Ok(())
3464}
3465
3466/// Mark a buffer version as synced
3467async fn acknowledge_buffer_version(
3468    request: proto::AckBufferOperation,
3469    session: UserSession,
3470) -> Result<()> {
3471    let buffer_id = BufferId::from_proto(request.buffer_id);
3472    session
3473        .db()
3474        .await
3475        .observe_buffer_version(
3476            buffer_id,
3477            session.user_id(),
3478            request.epoch as i32,
3479            &request.version,
3480        )
3481        .await?;
3482    Ok(())
3483}
3484
3485struct CompleteWithLanguageModelRateLimit;
3486
3487impl RateLimit for CompleteWithLanguageModelRateLimit {
3488    fn capacity() -> usize {
3489        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3490            .ok()
3491            .and_then(|v| v.parse().ok())
3492            .unwrap_or(120) // Picked arbitrarily
3493    }
3494
3495    fn refill_duration() -> chrono::Duration {
3496        chrono::Duration::hours(1)
3497    }
3498
3499    fn db_name() -> &'static str {
3500        "complete-with-language-model"
3501    }
3502}
3503
3504async fn complete_with_language_model(
3505    request: proto::CompleteWithLanguageModel,
3506    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3507    session: Session,
3508    open_ai_api_key: Option<Arc<str>>,
3509    google_ai_api_key: Option<Arc<str>>,
3510    anthropic_api_key: Option<Arc<str>>,
3511) -> Result<()> {
3512    let Some(session) = session.for_user() else {
3513        return Err(anyhow!("user not found"))?;
3514    };
3515    authorize_access_to_language_models(&session).await?;
3516    session
3517        .rate_limiter
3518        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3519        .await?;
3520
3521    if request.model.starts_with("gpt") {
3522        let api_key =
3523            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3524        complete_with_open_ai(request, response, session, api_key).await?;
3525    } else if request.model.starts_with("gemini") {
3526        let api_key = google_ai_api_key
3527            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3528        complete_with_google_ai(request, response, session, api_key).await?;
3529    } else if request.model.starts_with("claude") {
3530        let api_key = anthropic_api_key
3531            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
3532        complete_with_anthropic(request, response, session, api_key).await?;
3533    }
3534
3535    Ok(())
3536}
3537
3538async fn complete_with_open_ai(
3539    request: proto::CompleteWithLanguageModel,
3540    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3541    session: UserSession,
3542    api_key: Arc<str>,
3543) -> Result<()> {
3544    const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3545
3546    let mut completion_stream = open_ai::stream_completion(
3547        &session.http_client,
3548        OPEN_AI_API_URL,
3549        &api_key,
3550        crate::ai::language_model_request_to_open_ai(request)?,
3551    )
3552    .await
3553    .context("open_ai::stream_completion request failed")?;
3554
3555    while let Some(event) = completion_stream.next().await {
3556        let event = event?;
3557        response.send(proto::LanguageModelResponse {
3558            choices: event
3559                .choices
3560                .into_iter()
3561                .map(|choice| proto::LanguageModelChoiceDelta {
3562                    index: choice.index,
3563                    delta: Some(proto::LanguageModelResponseMessage {
3564                        role: choice.delta.role.map(|role| match role {
3565                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3566                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3567                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3568                        } as i32),
3569                        content: choice.delta.content,
3570                    }),
3571                    finish_reason: choice.finish_reason,
3572                })
3573                .collect(),
3574        })?;
3575    }
3576
3577    Ok(())
3578}
3579
3580async fn complete_with_google_ai(
3581    request: proto::CompleteWithLanguageModel,
3582    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3583    session: UserSession,
3584    api_key: Arc<str>,
3585) -> Result<()> {
3586    let mut stream = google_ai::stream_generate_content(
3587        &session.http_client,
3588        google_ai::API_URL,
3589        api_key.as_ref(),
3590        crate::ai::language_model_request_to_google_ai(request)?,
3591    )
3592    .await
3593    .context("google_ai::stream_generate_content request failed")?;
3594
3595    while let Some(event) = stream.next().await {
3596        let event = event?;
3597        response.send(proto::LanguageModelResponse {
3598            choices: event
3599                .candidates
3600                .unwrap_or_default()
3601                .into_iter()
3602                .map(|candidate| proto::LanguageModelChoiceDelta {
3603                    index: candidate.index as u32,
3604                    delta: Some(proto::LanguageModelResponseMessage {
3605                        role: Some(match candidate.content.role {
3606                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3607                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3608                        } as i32),
3609                        content: Some(
3610                            candidate
3611                                .content
3612                                .parts
3613                                .into_iter()
3614                                .filter_map(|part| match part {
3615                                    google_ai::Part::TextPart(part) => Some(part.text),
3616                                    google_ai::Part::InlineDataPart(_) => None,
3617                                })
3618                                .collect(),
3619                        ),
3620                    }),
3621                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3622                })
3623                .collect(),
3624        })?;
3625    }
3626
3627    Ok(())
3628}
3629
3630async fn complete_with_anthropic(
3631    request: proto::CompleteWithLanguageModel,
3632    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3633    session: UserSession,
3634    api_key: Arc<str>,
3635) -> Result<()> {
3636    let model = anthropic::Model::from_id(&request.model)?;
3637
3638    let mut system_message = String::new();
3639    let messages = request
3640        .messages
3641        .into_iter()
3642        .filter_map(|message| match message.role() {
3643            LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
3644                role: anthropic::Role::User,
3645                content: message.content,
3646            }),
3647            LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
3648                role: anthropic::Role::Assistant,
3649                content: message.content,
3650            }),
3651            // Anthropic's API breaks system instructions out as a separate field rather
3652            // than having a system message role.
3653            LanguageModelRole::LanguageModelSystem => {
3654                if !system_message.is_empty() {
3655                    system_message.push_str("\n\n");
3656                }
3657                system_message.push_str(&message.content);
3658
3659                None
3660            }
3661        })
3662        .collect();
3663
3664    let mut stream = anthropic::stream_completion(
3665        &session.http_client,
3666        "https://api.anthropic.com",
3667        &api_key,
3668        anthropic::Request {
3669            model,
3670            messages,
3671            stream: true,
3672            system: system_message,
3673            max_tokens: 4092,
3674        },
3675    )
3676    .await?;
3677
3678    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
3679
3680    while let Some(event) = stream.next().await {
3681        let event = event?;
3682
3683        match event {
3684            anthropic::ResponseEvent::MessageStart { message } => {
3685                if let Some(role) = message.role {
3686                    if role == "assistant" {
3687                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
3688                    } else if role == "user" {
3689                        current_role = proto::LanguageModelRole::LanguageModelUser;
3690                    }
3691                }
3692            }
3693            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
3694                match content_block {
3695                    anthropic::ContentBlock::Text { text } => {
3696                        if !text.is_empty() {
3697                            response.send(proto::LanguageModelResponse {
3698                                choices: vec![proto::LanguageModelChoiceDelta {
3699                                    index: 0,
3700                                    delta: Some(proto::LanguageModelResponseMessage {
3701                                        role: Some(current_role as i32),
3702                                        content: Some(text),
3703                                    }),
3704                                    finish_reason: None,
3705                                }],
3706                            })?;
3707                        }
3708                    }
3709                }
3710            }
3711            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
3712                anthropic::TextDelta::TextDelta { text } => {
3713                    response.send(proto::LanguageModelResponse {
3714                        choices: vec![proto::LanguageModelChoiceDelta {
3715                            index: 0,
3716                            delta: Some(proto::LanguageModelResponseMessage {
3717                                role: Some(current_role as i32),
3718                                content: Some(text),
3719                            }),
3720                            finish_reason: None,
3721                        }],
3722                    })?;
3723                }
3724            },
3725            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
3726                if let Some(stop_reason) = delta.stop_reason {
3727                    response.send(proto::LanguageModelResponse {
3728                        choices: vec![proto::LanguageModelChoiceDelta {
3729                            index: 0,
3730                            delta: None,
3731                            finish_reason: Some(stop_reason),
3732                        }],
3733                    })?;
3734                }
3735            }
3736            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
3737            anthropic::ResponseEvent::MessageStop {} => {}
3738            anthropic::ResponseEvent::Ping {} => {}
3739        }
3740    }
3741
3742    Ok(())
3743}
3744
3745struct CountTokensWithLanguageModelRateLimit;
3746
3747impl RateLimit for CountTokensWithLanguageModelRateLimit {
3748    fn capacity() -> usize {
3749        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3750            .ok()
3751            .and_then(|v| v.parse().ok())
3752            .unwrap_or(600) // Picked arbitrarily
3753    }
3754
3755    fn refill_duration() -> chrono::Duration {
3756        chrono::Duration::hours(1)
3757    }
3758
3759    fn db_name() -> &'static str {
3760        "count-tokens-with-language-model"
3761    }
3762}
3763
3764async fn count_tokens_with_language_model(
3765    request: proto::CountTokensWithLanguageModel,
3766    response: Response<proto::CountTokensWithLanguageModel>,
3767    session: UserSession,
3768    google_ai_api_key: Option<Arc<str>>,
3769) -> Result<()> {
3770    authorize_access_to_language_models(&session).await?;
3771
3772    if !request.model.starts_with("gemini") {
3773        return Err(anyhow!(
3774            "counting tokens for model: {:?} is not supported",
3775            request.model
3776        ))?;
3777    }
3778
3779    session
3780        .rate_limiter
3781        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3782        .await?;
3783
3784    let api_key = google_ai_api_key
3785        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3786    let tokens_response = google_ai::count_tokens(
3787        &session.http_client,
3788        google_ai::API_URL,
3789        &api_key,
3790        crate::ai::count_tokens_request_to_google_ai(request)?,
3791    )
3792    .await?;
3793    response.send(proto::CountTokensResponse {
3794        token_count: tokens_response.total_tokens as u32,
3795    })?;
3796    Ok(())
3797}
3798
3799async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3800    let db = session.db().await;
3801    let flags = db.get_user_flags(session.user_id()).await?;
3802    if flags.iter().any(|flag| flag == "language-models") {
3803        Ok(())
3804    } else {
3805        Err(anyhow!("permission denied"))?
3806    }
3807}
3808
3809/// Start receiving chat updates for a channel
3810async fn join_channel_chat(
3811    request: proto::JoinChannelChat,
3812    response: Response<proto::JoinChannelChat>,
3813    session: UserSession,
3814) -> Result<()> {
3815    let channel_id = ChannelId::from_proto(request.channel_id);
3816
3817    let db = session.db().await;
3818    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3819        .await?;
3820    let messages = db
3821        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3822        .await?;
3823    response.send(proto::JoinChannelChatResponse {
3824        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3825        messages,
3826    })?;
3827    Ok(())
3828}
3829
3830/// Stop receiving chat updates for a channel
3831async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3832    let channel_id = ChannelId::from_proto(request.channel_id);
3833    session
3834        .db()
3835        .await
3836        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3837        .await?;
3838    Ok(())
3839}
3840
3841/// Retrieve the chat history for a channel
3842async fn get_channel_messages(
3843    request: proto::GetChannelMessages,
3844    response: Response<proto::GetChannelMessages>,
3845    session: UserSession,
3846) -> Result<()> {
3847    let channel_id = ChannelId::from_proto(request.channel_id);
3848    let messages = session
3849        .db()
3850        .await
3851        .get_channel_messages(
3852            channel_id,
3853            session.user_id(),
3854            MESSAGE_COUNT_PER_PAGE,
3855            Some(MessageId::from_proto(request.before_message_id)),
3856        )
3857        .await?;
3858    response.send(proto::GetChannelMessagesResponse {
3859        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3860        messages,
3861    })?;
3862    Ok(())
3863}
3864
3865/// Retrieve specific chat messages
3866async fn get_channel_messages_by_id(
3867    request: proto::GetChannelMessagesById,
3868    response: Response<proto::GetChannelMessagesById>,
3869    session: UserSession,
3870) -> Result<()> {
3871    let message_ids = request
3872        .message_ids
3873        .iter()
3874        .map(|id| MessageId::from_proto(*id))
3875        .collect::<Vec<_>>();
3876    let messages = session
3877        .db()
3878        .await
3879        .get_channel_messages_by_id(session.user_id(), &message_ids)
3880        .await?;
3881    response.send(proto::GetChannelMessagesResponse {
3882        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3883        messages,
3884    })?;
3885    Ok(())
3886}
3887
3888/// Retrieve the current users notifications
3889async fn get_notifications(
3890    request: proto::GetNotifications,
3891    response: Response<proto::GetNotifications>,
3892    session: UserSession,
3893) -> Result<()> {
3894    let notifications = session
3895        .db()
3896        .await
3897        .get_notifications(
3898            session.user_id(),
3899            NOTIFICATION_COUNT_PER_PAGE,
3900            request
3901                .before_id
3902                .map(|id| db::NotificationId::from_proto(id)),
3903        )
3904        .await?;
3905    response.send(proto::GetNotificationsResponse {
3906        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3907        notifications,
3908    })?;
3909    Ok(())
3910}
3911
3912/// Mark notifications as read
3913async fn mark_notification_as_read(
3914    request: proto::MarkNotificationRead,
3915    response: Response<proto::MarkNotificationRead>,
3916    session: UserSession,
3917) -> Result<()> {
3918    let database = &session.db().await;
3919    let notifications = database
3920        .mark_notification_as_read_by_id(
3921            session.user_id(),
3922            NotificationId::from_proto(request.notification_id),
3923        )
3924        .await?;
3925    send_notifications(
3926        &*session.connection_pool().await,
3927        &session.peer,
3928        notifications,
3929    );
3930    response.send(proto::Ack {})?;
3931    Ok(())
3932}
3933
3934/// Get the current users information
3935async fn get_private_user_info(
3936    _request: proto::GetPrivateUserInfo,
3937    response: Response<proto::GetPrivateUserInfo>,
3938    session: UserSession,
3939) -> Result<()> {
3940    let db = session.db().await;
3941
3942    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3943    let user = db
3944        .get_user_by_id(session.user_id())
3945        .await?
3946        .ok_or_else(|| anyhow!("user not found"))?;
3947    let flags = db.get_user_flags(session.user_id()).await?;
3948
3949    response.send(proto::GetPrivateUserInfoResponse {
3950        metrics_id,
3951        staff: user.admin,
3952        flags,
3953    })?;
3954    Ok(())
3955}
3956
3957fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3958    match message {
3959        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3960        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3961        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3962        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3963        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3964            code: frame.code.into(),
3965            reason: frame.reason,
3966        })),
3967    }
3968}
3969
3970fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3971    match message {
3972        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3973        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3974        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3975        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3976        AxumMessage::Close(frame) => {
3977            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3978                code: frame.code.into(),
3979                reason: frame.reason,
3980            }))
3981        }
3982    }
3983}
3984
3985fn notify_membership_updated(
3986    connection_pool: &mut ConnectionPool,
3987    result: MembershipUpdated,
3988    user_id: UserId,
3989    peer: &Peer,
3990) {
3991    for membership in &result.new_channels.channel_memberships {
3992        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3993    }
3994    for channel_id in &result.removed_channels {
3995        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3996    }
3997
3998    let user_channels_update = proto::UpdateUserChannels {
3999        channel_memberships: result
4000            .new_channels
4001            .channel_memberships
4002            .iter()
4003            .map(|cm| proto::ChannelMembership {
4004                channel_id: cm.channel_id.to_proto(),
4005                role: cm.role.into(),
4006            })
4007            .collect(),
4008        ..Default::default()
4009    };
4010
4011    let mut update = build_channels_update(result.new_channels, vec![]);
4012    update.delete_channels = result
4013        .removed_channels
4014        .into_iter()
4015        .map(|id| id.to_proto())
4016        .collect();
4017    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4018
4019    for connection_id in connection_pool.user_connection_ids(user_id) {
4020        peer.send(connection_id, user_channels_update.clone())
4021            .trace_err();
4022        peer.send(connection_id, update.clone()).trace_err();
4023    }
4024}
4025
4026fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4027    proto::UpdateUserChannels {
4028        channel_memberships: channels
4029            .channel_memberships
4030            .iter()
4031            .map(|m| proto::ChannelMembership {
4032                channel_id: m.channel_id.to_proto(),
4033                role: m.role.into(),
4034            })
4035            .collect(),
4036        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4037        observed_channel_message_id: channels.observed_channel_messages.clone(),
4038    }
4039}
4040
4041fn build_channels_update(
4042    channels: ChannelsForUser,
4043    channel_invites: Vec<db::Channel>,
4044) -> proto::UpdateChannels {
4045    let mut update = proto::UpdateChannels::default();
4046
4047    for channel in channels.channels {
4048        update.channels.push(channel.to_proto());
4049    }
4050
4051    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4052    update.latest_channel_message_ids = channels.latest_channel_messages;
4053
4054    for (channel_id, participants) in channels.channel_participants {
4055        update
4056            .channel_participants
4057            .push(proto::ChannelParticipants {
4058                channel_id: channel_id.to_proto(),
4059                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4060            });
4061    }
4062
4063    for channel in channel_invites {
4064        update.channel_invitations.push(channel.to_proto());
4065    }
4066    for project in channels.hosted_projects {
4067        update.hosted_projects.push(project);
4068    }
4069
4070    update
4071}
4072
4073fn build_initial_contacts_update(
4074    contacts: Vec<db::Contact>,
4075    pool: &ConnectionPool,
4076) -> proto::UpdateContacts {
4077    let mut update = proto::UpdateContacts::default();
4078
4079    for contact in contacts {
4080        match contact {
4081            db::Contact::Accepted { user_id, busy } => {
4082                update.contacts.push(contact_for_user(user_id, busy, &pool));
4083            }
4084            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4085            db::Contact::Incoming { user_id } => {
4086                update
4087                    .incoming_requests
4088                    .push(proto::IncomingContactRequest {
4089                        requester_id: user_id.to_proto(),
4090                    })
4091            }
4092        }
4093    }
4094
4095    update
4096}
4097
4098fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4099    proto::Contact {
4100        user_id: user_id.to_proto(),
4101        online: pool.is_user_online(user_id),
4102        busy,
4103    }
4104}
4105
4106fn room_updated(room: &proto::Room, peer: &Peer) {
4107    broadcast(
4108        None,
4109        room.participants
4110            .iter()
4111            .filter_map(|participant| Some(participant.peer_id?.into())),
4112        |peer_id| {
4113            peer.send(
4114                peer_id,
4115                proto::RoomUpdated {
4116                    room: Some(room.clone()),
4117                },
4118            )
4119        },
4120    );
4121}
4122
4123fn channel_updated(
4124    channel: &db::channel::Model,
4125    room: &proto::Room,
4126    peer: &Peer,
4127    pool: &ConnectionPool,
4128) {
4129    let participants = room
4130        .participants
4131        .iter()
4132        .map(|p| p.user_id)
4133        .collect::<Vec<_>>();
4134
4135    broadcast(
4136        None,
4137        pool.channel_connection_ids(channel.root_id())
4138            .filter_map(|(channel_id, role)| {
4139                role.can_see_channel(channel.visibility).then(|| channel_id)
4140            }),
4141        |peer_id| {
4142            peer.send(
4143                peer_id,
4144                proto::UpdateChannels {
4145                    channel_participants: vec![proto::ChannelParticipants {
4146                        channel_id: channel.id.to_proto(),
4147                        participant_user_ids: participants.clone(),
4148                    }],
4149                    ..Default::default()
4150                },
4151            )
4152        },
4153    );
4154}
4155
4156async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4157    let db = session.db().await;
4158
4159    let contacts = db.get_contacts(user_id).await?;
4160    let busy = db.is_user_busy(user_id).await?;
4161
4162    let pool = session.connection_pool().await;
4163    let updated_contact = contact_for_user(user_id, busy, &pool);
4164    for contact in contacts {
4165        if let db::Contact::Accepted {
4166            user_id: contact_user_id,
4167            ..
4168        } = contact
4169        {
4170            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4171                session
4172                    .peer
4173                    .send(
4174                        contact_conn_id,
4175                        proto::UpdateContacts {
4176                            contacts: vec![updated_contact.clone()],
4177                            remove_contacts: Default::default(),
4178                            incoming_requests: Default::default(),
4179                            remove_incoming_requests: Default::default(),
4180                            outgoing_requests: Default::default(),
4181                            remove_outgoing_requests: Default::default(),
4182                        },
4183                    )
4184                    .trace_err();
4185            }
4186        }
4187    }
4188    Ok(())
4189}
4190
4191async fn leave_room_for_session(session: &UserSession) -> Result<()> {
4192    let mut contacts_to_update = HashSet::default();
4193
4194    let room_id;
4195    let canceled_calls_to_user_ids;
4196    let live_kit_room;
4197    let delete_live_kit_room;
4198    let room;
4199    let channel;
4200
4201    if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
4202        contacts_to_update.insert(session.user_id());
4203
4204        for project in left_room.left_projects.values() {
4205            project_left(project, session);
4206        }
4207
4208        room_id = RoomId::from_proto(left_room.room.id);
4209        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4210        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4211        delete_live_kit_room = left_room.deleted;
4212        room = mem::take(&mut left_room.room);
4213        channel = mem::take(&mut left_room.channel);
4214
4215        room_updated(&room, &session.peer);
4216    } else {
4217        return Ok(());
4218    }
4219
4220    if let Some(channel) = channel {
4221        channel_updated(
4222            &channel,
4223            &room,
4224            &session.peer,
4225            &*session.connection_pool().await,
4226        );
4227    }
4228
4229    {
4230        let pool = session.connection_pool().await;
4231        for canceled_user_id in canceled_calls_to_user_ids {
4232            for connection_id in pool.user_connection_ids(canceled_user_id) {
4233                session
4234                    .peer
4235                    .send(
4236                        connection_id,
4237                        proto::CallCanceled {
4238                            room_id: room_id.to_proto(),
4239                        },
4240                    )
4241                    .trace_err();
4242            }
4243            contacts_to_update.insert(canceled_user_id);
4244        }
4245    }
4246
4247    for contact_user_id in contacts_to_update {
4248        update_user_contacts(contact_user_id, &session).await?;
4249    }
4250
4251    if let Some(live_kit) = session.live_kit_client.as_ref() {
4252        live_kit
4253            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4254            .await
4255            .trace_err();
4256
4257        if delete_live_kit_room {
4258            live_kit.delete_room(live_kit_room).await.trace_err();
4259        }
4260    }
4261
4262    Ok(())
4263}
4264
4265async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4266    let left_channel_buffers = session
4267        .db()
4268        .await
4269        .leave_channel_buffers(session.connection_id)
4270        .await?;
4271
4272    for left_buffer in left_channel_buffers {
4273        channel_buffer_updated(
4274            session.connection_id,
4275            left_buffer.connections,
4276            &proto::UpdateChannelBufferCollaborators {
4277                channel_id: left_buffer.channel_id.to_proto(),
4278                collaborators: left_buffer.collaborators,
4279            },
4280            &session.peer,
4281        );
4282    }
4283
4284    Ok(())
4285}
4286
4287fn project_left(project: &db::LeftProject, session: &UserSession) {
4288    for connection_id in &project.connection_ids {
4289        if project.host_user_id == Some(session.user_id()) {
4290            session
4291                .peer
4292                .send(
4293                    *connection_id,
4294                    proto::UnshareProject {
4295                        project_id: project.id.to_proto(),
4296                    },
4297                )
4298                .trace_err();
4299        } else {
4300            session
4301                .peer
4302                .send(
4303                    *connection_id,
4304                    proto::RemoveProjectCollaborator {
4305                        project_id: project.id.to_proto(),
4306                        peer_id: Some(session.connection_id.into()),
4307                    },
4308                )
4309                .trace_err();
4310        }
4311    }
4312}
4313
4314pub trait ResultExt {
4315    type Ok;
4316
4317    fn trace_err(self) -> Option<Self::Ok>;
4318}
4319
4320impl<T, E> ResultExt for Result<T, E>
4321where
4322    E: std::fmt::Debug,
4323{
4324    type Ok = T;
4325
4326    fn trace_err(self) -> Option<T> {
4327        match self {
4328            Ok(value) => Some(value),
4329            Err(error) => {
4330                tracing::error!("{:?}", error);
4331                None
4332            }
4333        }
4334    }
4335}