rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self},
   5    db::{
   6        self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
   9        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, RateLimit, RateLimiter, Result,
  13};
  14use anyhow::{anyhow, Context as _};
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34
  35use futures::{
  36    channel::oneshot,
  37    future::{self, BoxFuture},
  38    stream::FuturesUnordered,
  39    FutureExt, SinkExt, StreamExt, TryStreamExt,
  40};
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  45        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48};
  49use semantic_version::SemanticVersion;
  50use serde::{Serialize, Serializer};
  51use std::{
  52    any::TypeId,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc, OnceLock,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{
  69    field::{self},
  70    info_span, instrument, Instrument,
  71};
  72use util::http::IsahcHttpClient;
  73
  74pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  75
  76// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  77pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  78
  79const MESSAGE_COUNT_PER_PAGE: usize = 100;
  80const MAX_MESSAGE_LEN: usize = 1024;
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82
  83type MessageHandler =
  84    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  85
  86struct Response<R> {
  87    peer: Arc<Peer>,
  88    receipt: Receipt<R>,
  89    responded: Arc<AtomicBool>,
  90}
  91
  92impl<R: RequestMessage> Response<R> {
  93    fn send(self, payload: R::Response) -> Result<()> {
  94        self.responded.store(true, SeqCst);
  95        self.peer.respond(self.receipt, payload)?;
  96        Ok(())
  97    }
  98}
  99
 100struct StreamingResponse<R: RequestMessage> {
 101    peer: Arc<Peer>,
 102    receipt: Receipt<R>,
 103}
 104
 105impl<R: RequestMessage> StreamingResponse<R> {
 106    fn send(&self, payload: R::Response) -> Result<()> {
 107        self.peer.respond(self.receipt, payload)?;
 108        Ok(())
 109    }
 110}
 111
 112#[derive(Clone, Debug)]
 113pub enum Principal {
 114    User(User),
 115    Impersonated { user: User, admin: User },
 116    DevServer(dev_server::Model),
 117}
 118
 119impl Principal {
 120    fn update_span(&self, span: &tracing::Span) {
 121        match &self {
 122            Principal::User(user) => {
 123                span.record("user_id", &user.id.0);
 124                span.record("login", &user.github_login);
 125            }
 126            Principal::Impersonated { user, admin } => {
 127                span.record("user_id", &user.id.0);
 128                span.record("login", &user.github_login);
 129                span.record("impersonator", &admin.github_login);
 130            }
 131            Principal::DevServer(dev_server) => {
 132                span.record("dev_server_id", &dev_server.id.0);
 133            }
 134        }
 135    }
 136}
 137
 138#[derive(Clone)]
 139struct Session {
 140    principal: Principal,
 141    connection_id: ConnectionId,
 142    db: Arc<tokio::sync::Mutex<DbHandle>>,
 143    peer: Arc<Peer>,
 144    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 145    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 146    http_client: IsahcHttpClient,
 147    rate_limiter: Arc<RateLimiter>,
 148    _executor: Executor,
 149}
 150
 151impl Session {
 152    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.db.lock().await;
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        guard
 159    }
 160
 161    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 162        #[cfg(test)]
 163        tokio::task::yield_now().await;
 164        let guard = self.connection_pool.lock();
 165        ConnectionPoolGuard {
 166            guard,
 167            _not_send: PhantomData,
 168        }
 169    }
 170
 171    fn for_user(self) -> Option<UserSession> {
 172        UserSession::new(self)
 173    }
 174
 175    fn user_id(&self) -> Option<UserId> {
 176        match &self.principal {
 177            Principal::User(user) => Some(user.id),
 178            Principal::Impersonated { user, .. } => Some(user.id),
 179            Principal::DevServer(_) => None,
 180        }
 181    }
 182}
 183
 184impl Debug for Session {
 185    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 186        let mut result = f.debug_struct("Session");
 187        match &self.principal {
 188            Principal::User(user) => {
 189                result.field("user", &user.github_login);
 190            }
 191            Principal::Impersonated { user, admin } => {
 192                result.field("user", &user.github_login);
 193                result.field("impersonator", &admin.github_login);
 194            }
 195            Principal::DevServer(dev_server) => {
 196                result.field("dev_server", &dev_server.id);
 197            }
 198        }
 199        result.field("connection_id", &self.connection_id).finish()
 200    }
 201}
 202
 203struct UserSession(Session);
 204
 205impl UserSession {
 206    pub fn new(s: Session) -> Option<Self> {
 207        s.user_id().map(|_| UserSession(s))
 208    }
 209    pub fn user_id(&self) -> UserId {
 210        self.0.user_id().unwrap()
 211    }
 212}
 213
 214impl Deref for UserSession {
 215    type Target = Session;
 216
 217    fn deref(&self) -> &Self::Target {
 218        &self.0
 219    }
 220}
 221impl DerefMut for UserSession {
 222    fn deref_mut(&mut self) -> &mut Self::Target {
 223        &mut self.0
 224    }
 225}
 226
 227fn user_handler<M: RequestMessage, Fut>(
 228    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 229) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 230where
 231    Fut: Send + Future<Output = Result<()>>,
 232{
 233    let handler = Arc::new(handler);
 234    move |message, response, session| {
 235        let handler = handler.clone();
 236        Box::pin(async move {
 237            if let Some(user_session) = session.for_user() {
 238                Ok(handler(message, response, user_session).await?)
 239            } else {
 240                Err(Error::Internal(anyhow!("must be a user")))
 241            }
 242        })
 243    }
 244}
 245
 246fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 247    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 248) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 249where
 250    InnertRetFut: Send + Future<Output = Result<()>>,
 251{
 252    let handler = Arc::new(handler);
 253    move |message, session| {
 254        let handler = handler.clone();
 255        Box::pin(async move {
 256            if let Some(user_session) = session.for_user() {
 257                Ok(handler(message, user_session).await?)
 258            } else {
 259                Err(Error::Internal(anyhow!("must be a user")))
 260            }
 261        })
 262    }
 263}
 264
 265struct DbHandle(Arc<Database>);
 266
 267impl Deref for DbHandle {
 268    type Target = Database;
 269
 270    fn deref(&self) -> &Self::Target {
 271        self.0.as_ref()
 272    }
 273}
 274
 275pub struct Server {
 276    id: parking_lot::Mutex<ServerId>,
 277    peer: Arc<Peer>,
 278    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 279    app_state: Arc<AppState>,
 280    handlers: HashMap<TypeId, MessageHandler>,
 281    teardown: watch::Sender<bool>,
 282}
 283
 284pub(crate) struct ConnectionPoolGuard<'a> {
 285    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 286    _not_send: PhantomData<Rc<()>>,
 287}
 288
 289#[derive(Serialize)]
 290pub struct ServerSnapshot<'a> {
 291    peer: &'a Peer,
 292    #[serde(serialize_with = "serialize_deref")]
 293    connection_pool: ConnectionPoolGuard<'a>,
 294}
 295
 296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 297where
 298    S: Serializer,
 299    T: Deref<Target = U>,
 300    U: Serialize,
 301{
 302    Serialize::serialize(value.deref(), serializer)
 303}
 304
 305impl Server {
 306    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 307        let mut server = Self {
 308            id: parking_lot::Mutex::new(id),
 309            peer: Peer::new(id.0 as u32),
 310            app_state: app_state.clone(),
 311            connection_pool: Default::default(),
 312            handlers: Default::default(),
 313            teardown: watch::channel(false).0,
 314        };
 315
 316        server
 317            .add_request_handler(ping)
 318            .add_request_handler(user_handler(create_room))
 319            .add_request_handler(user_handler(join_room))
 320            .add_request_handler(user_handler(rejoin_room))
 321            .add_request_handler(user_handler(leave_room))
 322            .add_request_handler(user_handler(set_room_participant_role))
 323            .add_request_handler(user_handler(call))
 324            .add_request_handler(user_handler(cancel_call))
 325            .add_message_handler(user_message_handler(decline_call))
 326            .add_request_handler(user_handler(update_participant_location))
 327            .add_request_handler(share_project)
 328            .add_message_handler(unshare_project)
 329            .add_request_handler(user_handler(join_project))
 330            .add_request_handler(user_handler(join_hosted_project))
 331            .add_message_handler(user_message_handler(leave_project))
 332            .add_request_handler(update_project)
 333            .add_request_handler(update_worktree)
 334            .add_message_handler(start_language_server)
 335            .add_message_handler(update_language_server)
 336            .add_message_handler(update_diagnostic_summary)
 337            .add_message_handler(update_worktree_settings)
 338            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 342            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 346            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 347            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 348            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 349            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 350            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 353            )
 354            .add_request_handler(
 355                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 356            )
 357            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 358            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 359            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 360            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 361            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 362            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 369            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 370            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 371            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 372            .add_message_handler(create_buffer_for_peer)
 373            .add_request_handler(update_buffer)
 374            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 375            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 376            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 377            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 378            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 379            .add_request_handler(get_users)
 380            .add_request_handler(user_handler(fuzzy_search_users))
 381            .add_request_handler(user_handler(request_contact))
 382            .add_request_handler(user_handler(remove_contact))
 383            .add_request_handler(user_handler(respond_to_contact_request))
 384            .add_request_handler(user_handler(create_channel))
 385            .add_request_handler(user_handler(delete_channel))
 386            .add_request_handler(user_handler(invite_channel_member))
 387            .add_request_handler(user_handler(remove_channel_member))
 388            .add_request_handler(user_handler(set_channel_member_role))
 389            .add_request_handler(user_handler(set_channel_visibility))
 390            .add_request_handler(user_handler(rename_channel))
 391            .add_request_handler(user_handler(join_channel_buffer))
 392            .add_request_handler(user_handler(leave_channel_buffer))
 393            .add_message_handler(user_message_handler(update_channel_buffer))
 394            .add_request_handler(user_handler(rejoin_channel_buffers))
 395            .add_request_handler(user_handler(get_channel_members))
 396            .add_request_handler(user_handler(respond_to_channel_invite))
 397            .add_request_handler(user_handler(join_channel))
 398            .add_request_handler(user_handler(join_channel_chat))
 399            .add_message_handler(user_message_handler(leave_channel_chat))
 400            .add_request_handler(user_handler(send_channel_message))
 401            .add_request_handler(user_handler(remove_channel_message))
 402            .add_request_handler(user_handler(update_channel_message))
 403            .add_request_handler(user_handler(get_channel_messages))
 404            .add_request_handler(user_handler(get_channel_messages_by_id))
 405            .add_request_handler(user_handler(get_notifications))
 406            .add_request_handler(user_handler(mark_notification_as_read))
 407            .add_request_handler(user_handler(move_channel))
 408            .add_request_handler(user_handler(follow))
 409            .add_message_handler(user_message_handler(unfollow))
 410            .add_message_handler(user_message_handler(update_followers))
 411            .add_request_handler(user_handler(get_private_user_info))
 412            .add_message_handler(user_message_handler(acknowledge_channel_message))
 413            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 414            .add_streaming_request_handler({
 415                let app_state = app_state.clone();
 416                move |request, response, session| {
 417                    complete_with_language_model(
 418                        request,
 419                        response,
 420                        session,
 421                        app_state.config.openai_api_key.clone(),
 422                        app_state.config.google_ai_api_key.clone(),
 423                        app_state.config.anthropic_api_key.clone(),
 424                    )
 425                }
 426            })
 427            .add_request_handler({
 428                let app_state = app_state.clone();
 429                user_handler(move |request, response, session| {
 430                    count_tokens_with_language_model(
 431                        request,
 432                        response,
 433                        session,
 434                        app_state.config.google_ai_api_key.clone(),
 435                    )
 436                })
 437            });
 438
 439        Arc::new(server)
 440    }
 441
 442    pub async fn start(&self) -> Result<()> {
 443        let server_id = *self.id.lock();
 444        let app_state = self.app_state.clone();
 445        let peer = self.peer.clone();
 446        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 447        let pool = self.connection_pool.clone();
 448        let live_kit_client = self.app_state.live_kit_client.clone();
 449
 450        let span = info_span!("start server");
 451        self.app_state.executor.spawn_detached(
 452            async move {
 453                tracing::info!("waiting for cleanup timeout");
 454                timeout.await;
 455                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 456                if let Some((room_ids, channel_ids)) = app_state
 457                    .db
 458                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 459                    .await
 460                    .trace_err()
 461                {
 462                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 463                    tracing::info!(
 464                        stale_channel_buffer_count = channel_ids.len(),
 465                        "retrieved stale channel buffers"
 466                    );
 467
 468                    for channel_id in channel_ids {
 469                        if let Some(refreshed_channel_buffer) = app_state
 470                            .db
 471                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 472                            .await
 473                            .trace_err()
 474                        {
 475                            for connection_id in refreshed_channel_buffer.connection_ids {
 476                                peer.send(
 477                                    connection_id,
 478                                    proto::UpdateChannelBufferCollaborators {
 479                                        channel_id: channel_id.to_proto(),
 480                                        collaborators: refreshed_channel_buffer
 481                                            .collaborators
 482                                            .clone(),
 483                                    },
 484                                )
 485                                .trace_err();
 486                            }
 487                        }
 488                    }
 489
 490                    for room_id in room_ids {
 491                        let mut contacts_to_update = HashSet::default();
 492                        let mut canceled_calls_to_user_ids = Vec::new();
 493                        let mut live_kit_room = String::new();
 494                        let mut delete_live_kit_room = false;
 495
 496                        if let Some(mut refreshed_room) = app_state
 497                            .db
 498                            .clear_stale_room_participants(room_id, server_id)
 499                            .await
 500                            .trace_err()
 501                        {
 502                            tracing::info!(
 503                                room_id = room_id.0,
 504                                new_participant_count = refreshed_room.room.participants.len(),
 505                                "refreshed room"
 506                            );
 507                            room_updated(&refreshed_room.room, &peer);
 508                            if let Some(channel) = refreshed_room.channel.as_ref() {
 509                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 510                            }
 511                            contacts_to_update
 512                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 513                            contacts_to_update
 514                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 515                            canceled_calls_to_user_ids =
 516                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 517                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 518                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 519                        }
 520
 521                        {
 522                            let pool = pool.lock();
 523                            for canceled_user_id in canceled_calls_to_user_ids {
 524                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 525                                    peer.send(
 526                                        connection_id,
 527                                        proto::CallCanceled {
 528                                            room_id: room_id.to_proto(),
 529                                        },
 530                                    )
 531                                    .trace_err();
 532                                }
 533                            }
 534                        }
 535
 536                        for user_id in contacts_to_update {
 537                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 538                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 539                            if let Some((busy, contacts)) = busy.zip(contacts) {
 540                                let pool = pool.lock();
 541                                let updated_contact = contact_for_user(user_id, busy, &pool);
 542                                for contact in contacts {
 543                                    if let db::Contact::Accepted {
 544                                        user_id: contact_user_id,
 545                                        ..
 546                                    } = contact
 547                                    {
 548                                        for contact_conn_id in
 549                                            pool.user_connection_ids(contact_user_id)
 550                                        {
 551                                            peer.send(
 552                                                contact_conn_id,
 553                                                proto::UpdateContacts {
 554                                                    contacts: vec![updated_contact.clone()],
 555                                                    remove_contacts: Default::default(),
 556                                                    incoming_requests: Default::default(),
 557                                                    remove_incoming_requests: Default::default(),
 558                                                    outgoing_requests: Default::default(),
 559                                                    remove_outgoing_requests: Default::default(),
 560                                                },
 561                                            )
 562                                            .trace_err();
 563                                        }
 564                                    }
 565                                }
 566                            }
 567                        }
 568
 569                        if let Some(live_kit) = live_kit_client.as_ref() {
 570                            if delete_live_kit_room {
 571                                live_kit.delete_room(live_kit_room).await.trace_err();
 572                            }
 573                        }
 574                    }
 575                }
 576
 577                app_state
 578                    .db
 579                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 580                    .await
 581                    .trace_err();
 582            }
 583            .instrument(span),
 584        );
 585        Ok(())
 586    }
 587
 588    pub fn teardown(&self) {
 589        self.peer.teardown();
 590        self.connection_pool.lock().reset();
 591        let _ = self.teardown.send(true);
 592    }
 593
 594    #[cfg(test)]
 595    pub fn reset(&self, id: ServerId) {
 596        self.teardown();
 597        *self.id.lock() = id;
 598        self.peer.reset(id.0 as u32);
 599        let _ = self.teardown.send(false);
 600    }
 601
 602    #[cfg(test)]
 603    pub fn id(&self) -> ServerId {
 604        *self.id.lock()
 605    }
 606
 607    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 608    where
 609        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 610        Fut: 'static + Send + Future<Output = Result<()>>,
 611        M: EnvelopedMessage,
 612    {
 613        let prev_handler = self.handlers.insert(
 614            TypeId::of::<M>(),
 615            Box::new(move |envelope, session| {
 616                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 617                let received_at = envelope.received_at;
 618                    tracing::info!(
 619                        "message received"
 620                    );
 621                let start_time = Instant::now();
 622                let future = (handler)(*envelope, session);
 623                async move {
 624                    let result = future.await;
 625                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 626                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 627                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 628                    match result {
 629                        Err(error) => {
 630                            tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
 631                        }
 632                        Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
 633                    }
 634                }
 635                .boxed()
 636            }),
 637        );
 638        if prev_handler.is_some() {
 639            panic!("registered a handler for the same message twice");
 640        }
 641        self
 642    }
 643
 644    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 645    where
 646        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 647        Fut: 'static + Send + Future<Output = Result<()>>,
 648        M: EnvelopedMessage,
 649    {
 650        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 651        self
 652    }
 653
 654    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 655    where
 656        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 657        Fut: Send + Future<Output = Result<()>>,
 658        M: RequestMessage,
 659    {
 660        let handler = Arc::new(handler);
 661        self.add_handler(move |envelope, session| {
 662            let receipt = envelope.receipt();
 663            let handler = handler.clone();
 664            async move {
 665                let peer = session.peer.clone();
 666                let responded = Arc::new(AtomicBool::default());
 667                let response = Response {
 668                    peer: peer.clone(),
 669                    responded: responded.clone(),
 670                    receipt,
 671                };
 672                match (handler)(envelope.payload, response, session).await {
 673                    Ok(()) => {
 674                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 675                            Ok(())
 676                        } else {
 677                            Err(anyhow!("handler did not send a response"))?
 678                        }
 679                    }
 680                    Err(error) => {
 681                        let proto_err = match &error {
 682                            Error::Internal(err) => err.to_proto(),
 683                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 684                        };
 685                        peer.respond_with_error(receipt, proto_err)?;
 686                        Err(error)
 687                    }
 688                }
 689            }
 690        })
 691    }
 692
 693    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 694    where
 695        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 696        Fut: Send + Future<Output = Result<()>>,
 697        M: RequestMessage,
 698    {
 699        let handler = Arc::new(handler);
 700        self.add_handler(move |envelope, session| {
 701            let receipt = envelope.receipt();
 702            let handler = handler.clone();
 703            async move {
 704                let peer = session.peer.clone();
 705                let response = StreamingResponse {
 706                    peer: peer.clone(),
 707                    receipt,
 708                };
 709                match (handler)(envelope.payload, response, session).await {
 710                    Ok(()) => {
 711                        peer.end_stream(receipt)?;
 712                        Ok(())
 713                    }
 714                    Err(error) => {
 715                        let proto_err = match &error {
 716                            Error::Internal(err) => err.to_proto(),
 717                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 718                        };
 719                        peer.respond_with_error(receipt, proto_err)?;
 720                        Err(error)
 721                    }
 722                }
 723            }
 724        })
 725    }
 726
 727    #[allow(clippy::too_many_arguments)]
 728    pub fn handle_connection(
 729        self: &Arc<Self>,
 730        connection: Connection,
 731        address: String,
 732        principal: Principal,
 733        zed_version: ZedVersion,
 734        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 735        executor: Executor,
 736    ) -> impl Future<Output = ()> {
 737        let this = self.clone();
 738        let span = info_span!("handle connection", %address,
 739            connection_id=field::Empty,
 740            user_id=field::Empty,
 741            login=field::Empty,
 742            impersonator=field::Empty,
 743            dev_server_id=field::Empty
 744        );
 745        principal.update_span(&span);
 746
 747        let mut teardown = self.teardown.subscribe();
 748        async move {
 749            if *teardown.borrow() {
 750                tracing::error!("server is tearing down");
 751                return
 752            }
 753            let (connection_id, handle_io, mut incoming_rx) = this
 754                .peer
 755                .add_connection(connection, {
 756                    let executor = executor.clone();
 757                    move |duration| executor.sleep(duration)
 758                });
 759            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 760            tracing::info!("connection opened");
 761
 762            let http_client = match IsahcHttpClient::new() {
 763                Ok(http_client) => http_client,
 764                Err(error) => {
 765                    tracing::error!(?error, "failed to create HTTP client");
 766                    return;
 767                }
 768            };
 769
 770            let session = Session {
 771                principal: principal.clone(),
 772                connection_id,
 773                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 774                peer: this.peer.clone(),
 775                connection_pool: this.connection_pool.clone(),
 776                live_kit_client: this.app_state.live_kit_client.clone(),
 777                http_client,
 778                rate_limiter: this.app_state.rate_limiter.clone(),
 779                _executor: executor.clone(),
 780            };
 781
 782            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 783                tracing::error!(?error, "failed to send initial client update");
 784                return;
 785            }
 786
 787            let handle_io = handle_io.fuse();
 788            futures::pin_mut!(handle_io);
 789
 790            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 791            // This prevents deadlocks when e.g., client A performs a request to client B and
 792            // client B performs a request to client A. If both clients stop processing further
 793            // messages until their respective request completes, they won't have a chance to
 794            // respond to the other client's request and cause a deadlock.
 795            //
 796            // This arrangement ensures we will attempt to process earlier messages first, but fall
 797            // back to processing messages arrived later in the spirit of making progress.
 798            let mut foreground_message_handlers = FuturesUnordered::new();
 799            let concurrent_handlers = Arc::new(Semaphore::new(256));
 800            loop {
 801                let next_message = async {
 802                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 803                    let message = incoming_rx.next().await;
 804                    (permit, message)
 805                }.fuse();
 806                futures::pin_mut!(next_message);
 807                futures::select_biased! {
 808                    _ = teardown.changed().fuse() => return,
 809                    result = handle_io => {
 810                        if let Err(error) = result {
 811                            tracing::error!(?error, "error handling I/O");
 812                        }
 813                        break;
 814                    }
 815                    _ = foreground_message_handlers.next() => {}
 816                    next_message = next_message => {
 817                        let (permit, message) = next_message;
 818                        if let Some(message) = message {
 819                            let type_name = message.payload_type_name();
 820                            // note: we copy all the fields from the parent span so we can query them in the logs.
 821                            // (https://github.com/tokio-rs/tracing/issues/2670).
 822                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 823                                user_id=field::Empty,
 824                                login=field::Empty,
 825                                impersonator=field::Empty,
 826                                dev_server_id=field::Empty
 827                            );
 828                            principal.update_span(&span);
 829                            let span_enter = span.enter();
 830                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 831                                let is_background = message.is_background();
 832                                let handle_message = (handler)(message, session.clone());
 833                                drop(span_enter);
 834
 835                                let handle_message = async move {
 836                                    handle_message.await;
 837                                    drop(permit);
 838                                }.instrument(span);
 839                                if is_background {
 840                                    executor.spawn_detached(handle_message);
 841                                } else {
 842                                    foreground_message_handlers.push(handle_message);
 843                                }
 844                            } else {
 845                                tracing::error!("no message handler");
 846                            }
 847                        } else {
 848                            tracing::info!("connection closed");
 849                            break;
 850                        }
 851                    }
 852                }
 853            }
 854
 855            drop(foreground_message_handlers);
 856            tracing::info!("signing out");
 857            if let Err(error) = connection_lost(session, teardown, executor).await {
 858                tracing::error!(?error, "error signing out");
 859            }
 860
 861        }.instrument(span)
 862    }
 863
 864    async fn send_initial_client_update(
 865        &self,
 866        connection_id: ConnectionId,
 867        principal: &Principal,
 868        zed_version: ZedVersion,
 869        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 870        session: &Session,
 871    ) -> Result<()> {
 872        self.peer.send(
 873            connection_id,
 874            proto::Hello {
 875                peer_id: Some(connection_id.into()),
 876            },
 877        )?;
 878        tracing::info!("sent hello message");
 879
 880        let Principal::User(user) = principal else {
 881            return Ok(());
 882        };
 883
 884        if let Some(send_connection_id) = send_connection_id.take() {
 885            let _ = send_connection_id.send(connection_id);
 886        }
 887
 888        if !user.connected_once {
 889            self.peer.send(connection_id, proto::ShowContacts {})?;
 890            self.app_state
 891                .db
 892                .set_user_connected_once(user.id, true)
 893                .await?;
 894        }
 895
 896        let (contacts, channels_for_user, channel_invites) = future::try_join3(
 897            self.app_state.db.get_contacts(user.id),
 898            self.app_state.db.get_channels_for_user(user.id),
 899            self.app_state.db.get_channel_invites_for_user(user.id),
 900        )
 901        .await?;
 902
 903        {
 904            let mut pool = self.connection_pool.lock();
 905            pool.add_connection(connection_id, user.id, user.admin, zed_version);
 906            for membership in &channels_for_user.channel_memberships {
 907                pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
 908            }
 909            self.peer.send(
 910                connection_id,
 911                build_initial_contacts_update(contacts, &pool),
 912            )?;
 913            self.peer.send(
 914                connection_id,
 915                build_update_user_channels(&channels_for_user),
 916            )?;
 917            self.peer.send(
 918                connection_id,
 919                build_channels_update(channels_for_user, channel_invites),
 920            )?;
 921        }
 922
 923        if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
 924            self.peer.send(connection_id, incoming_call)?;
 925        }
 926
 927        update_user_contacts(user.id, &session).await?;
 928        Ok(())
 929    }
 930
 931    pub async fn invite_code_redeemed(
 932        self: &Arc<Self>,
 933        inviter_id: UserId,
 934        invitee_id: UserId,
 935    ) -> Result<()> {
 936        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 937            if let Some(code) = &user.invite_code {
 938                let pool = self.connection_pool.lock();
 939                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 940                for connection_id in pool.user_connection_ids(inviter_id) {
 941                    self.peer.send(
 942                        connection_id,
 943                        proto::UpdateContacts {
 944                            contacts: vec![invitee_contact.clone()],
 945                            ..Default::default()
 946                        },
 947                    )?;
 948                    self.peer.send(
 949                        connection_id,
 950                        proto::UpdateInviteInfo {
 951                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 952                            count: user.invite_count as u32,
 953                        },
 954                    )?;
 955                }
 956            }
 957        }
 958        Ok(())
 959    }
 960
 961    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 962        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 963            if let Some(invite_code) = &user.invite_code {
 964                let pool = self.connection_pool.lock();
 965                for connection_id in pool.user_connection_ids(user_id) {
 966                    self.peer.send(
 967                        connection_id,
 968                        proto::UpdateInviteInfo {
 969                            url: format!(
 970                                "{}{}",
 971                                self.app_state.config.invite_link_prefix, invite_code
 972                            ),
 973                            count: user.invite_count as u32,
 974                        },
 975                    )?;
 976                }
 977            }
 978        }
 979        Ok(())
 980    }
 981
 982    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 983        ServerSnapshot {
 984            connection_pool: ConnectionPoolGuard {
 985                guard: self.connection_pool.lock(),
 986                _not_send: PhantomData,
 987            },
 988            peer: &self.peer,
 989        }
 990    }
 991}
 992
 993impl<'a> Deref for ConnectionPoolGuard<'a> {
 994    type Target = ConnectionPool;
 995
 996    fn deref(&self) -> &Self::Target {
 997        &self.guard
 998    }
 999}
1000
1001impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1002    fn deref_mut(&mut self) -> &mut Self::Target {
1003        &mut self.guard
1004    }
1005}
1006
1007impl<'a> Drop for ConnectionPoolGuard<'a> {
1008    fn drop(&mut self) {
1009        #[cfg(test)]
1010        self.check_invariants();
1011    }
1012}
1013
1014fn broadcast<F>(
1015    sender_id: Option<ConnectionId>,
1016    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1017    mut f: F,
1018) where
1019    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1020{
1021    for receiver_id in receiver_ids {
1022        if Some(receiver_id) != sender_id {
1023            if let Err(error) = f(receiver_id) {
1024                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1025            }
1026        }
1027    }
1028}
1029
1030pub struct ProtocolVersion(u32);
1031
1032impl Header for ProtocolVersion {
1033    fn name() -> &'static HeaderName {
1034        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1035        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1036    }
1037
1038    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1039    where
1040        Self: Sized,
1041        I: Iterator<Item = &'i axum::http::HeaderValue>,
1042    {
1043        let version = values
1044            .next()
1045            .ok_or_else(axum::headers::Error::invalid)?
1046            .to_str()
1047            .map_err(|_| axum::headers::Error::invalid())?
1048            .parse()
1049            .map_err(|_| axum::headers::Error::invalid())?;
1050        Ok(Self(version))
1051    }
1052
1053    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1054        values.extend([self.0.to_string().parse().unwrap()]);
1055    }
1056}
1057
1058pub struct AppVersionHeader(SemanticVersion);
1059impl Header for AppVersionHeader {
1060    fn name() -> &'static HeaderName {
1061        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1062        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1063    }
1064
1065    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1066    where
1067        Self: Sized,
1068        I: Iterator<Item = &'i axum::http::HeaderValue>,
1069    {
1070        let version = values
1071            .next()
1072            .ok_or_else(axum::headers::Error::invalid)?
1073            .to_str()
1074            .map_err(|_| axum::headers::Error::invalid())?
1075            .parse()
1076            .map_err(|_| axum::headers::Error::invalid())?;
1077        Ok(Self(version))
1078    }
1079
1080    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1081        values.extend([self.0.to_string().parse().unwrap()]);
1082    }
1083}
1084
1085pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1086    Router::new()
1087        .route("/rpc", get(handle_websocket_request))
1088        .layer(
1089            ServiceBuilder::new()
1090                .layer(Extension(server.app_state.clone()))
1091                .layer(middleware::from_fn(auth::validate_header)),
1092        )
1093        .route("/metrics", get(handle_metrics))
1094        .layer(Extension(server))
1095}
1096
1097pub async fn handle_websocket_request(
1098    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1099    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1100    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1101    Extension(server): Extension<Arc<Server>>,
1102    Extension(principal): Extension<Principal>,
1103    ws: WebSocketUpgrade,
1104) -> axum::response::Response {
1105    if protocol_version != rpc::PROTOCOL_VERSION {
1106        return (
1107            StatusCode::UPGRADE_REQUIRED,
1108            "client must be upgraded".to_string(),
1109        )
1110            .into_response();
1111    }
1112
1113    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1114        return (
1115            StatusCode::UPGRADE_REQUIRED,
1116            "no version header found".to_string(),
1117        )
1118            .into_response();
1119    };
1120
1121    if !version.can_collaborate() {
1122        return (
1123            StatusCode::UPGRADE_REQUIRED,
1124            "client must be upgraded".to_string(),
1125        )
1126            .into_response();
1127    }
1128
1129    let socket_address = socket_address.to_string();
1130    ws.on_upgrade(move |socket| {
1131        let socket = socket
1132            .map_ok(to_tungstenite_message)
1133            .err_into()
1134            .with(|message| async move { Ok(to_axum_message(message)) });
1135        let connection = Connection::new(Box::pin(socket));
1136        async move {
1137            server
1138                .handle_connection(
1139                    connection,
1140                    socket_address,
1141                    principal,
1142                    version,
1143                    None,
1144                    Executor::Production,
1145                )
1146                .await;
1147        }
1148    })
1149}
1150
1151pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1152    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1153    let connections_metric = CONNECTIONS_METRIC
1154        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1155
1156    let connections = server
1157        .connection_pool
1158        .lock()
1159        .connections()
1160        .filter(|connection| !connection.admin)
1161        .count();
1162    connections_metric.set(connections as _);
1163
1164    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1165    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1166        register_int_gauge!(
1167            "shared_projects",
1168            "number of open projects with one or more guests"
1169        )
1170        .unwrap()
1171    });
1172
1173    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1174    shared_projects_metric.set(shared_projects as _);
1175
1176    let encoder = prometheus::TextEncoder::new();
1177    let metric_families = prometheus::gather();
1178    let encoded_metrics = encoder
1179        .encode_to_string(&metric_families)
1180        .map_err(|err| anyhow!("{}", err))?;
1181    Ok(encoded_metrics)
1182}
1183
1184#[instrument(err, skip(executor))]
1185async fn connection_lost(
1186    session: Session,
1187    mut teardown: watch::Receiver<bool>,
1188    executor: Executor,
1189) -> Result<()> {
1190    session.peer.disconnect(session.connection_id);
1191    session
1192        .connection_pool()
1193        .await
1194        .remove_connection(session.connection_id)?;
1195
1196    session
1197        .db()
1198        .await
1199        .connection_lost(session.connection_id)
1200        .await
1201        .trace_err();
1202
1203    futures::select_biased! {
1204        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1205            if let Some(session) = session.for_user() {
1206                log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1207                leave_room_for_session(&session, session.connection_id).await.trace_err();
1208                leave_channel_buffers_for_session(&session)
1209                    .await
1210                    .trace_err();
1211
1212                if !session
1213                    .connection_pool()
1214                    .await
1215                    .is_user_online(session.user_id())
1216                {
1217                    let db = session.db().await;
1218                    if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1219                        room_updated(&room, &session.peer);
1220                    }
1221                }
1222
1223                update_user_contacts(session.user_id(), &session).await?;
1224            }
1225        }
1226        _ = teardown.changed().fuse() => {}
1227    }
1228
1229    Ok(())
1230}
1231
1232/// Acknowledges a ping from a client, used to keep the connection alive.
1233async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1234    response.send(proto::Ack {})?;
1235    Ok(())
1236}
1237
1238/// Creates a new room for calling (outside of channels)
1239async fn create_room(
1240    _request: proto::CreateRoom,
1241    response: Response<proto::CreateRoom>,
1242    session: UserSession,
1243) -> Result<()> {
1244    let live_kit_room = nanoid::nanoid!(30);
1245
1246    let live_kit_connection_info = util::maybe!(async {
1247        let live_kit = session.live_kit_client.as_ref();
1248        let live_kit = live_kit?;
1249        let user_id = session.user_id().to_string();
1250
1251        let token = live_kit
1252            .room_token(&live_kit_room, &user_id.to_string())
1253            .trace_err()?;
1254
1255        Some(proto::LiveKitConnectionInfo {
1256            server_url: live_kit.url().into(),
1257            token,
1258            can_publish: true,
1259        })
1260    })
1261    .await;
1262
1263    let room = session
1264        .db()
1265        .await
1266        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1267        .await?;
1268
1269    response.send(proto::CreateRoomResponse {
1270        room: Some(room.clone()),
1271        live_kit_connection_info,
1272    })?;
1273
1274    update_user_contacts(session.user_id(), &session).await?;
1275    Ok(())
1276}
1277
1278/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1279async fn join_room(
1280    request: proto::JoinRoom,
1281    response: Response<proto::JoinRoom>,
1282    session: UserSession,
1283) -> Result<()> {
1284    let room_id = RoomId::from_proto(request.id);
1285
1286    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1287
1288    if let Some(channel_id) = channel_id {
1289        return join_channel_internal(channel_id, Box::new(response), session).await;
1290    }
1291
1292    let joined_room = {
1293        let room = session
1294            .db()
1295            .await
1296            .join_room(room_id, session.user_id(), session.connection_id)
1297            .await?;
1298        room_updated(&room.room, &session.peer);
1299        room.into_inner()
1300    };
1301
1302    for connection_id in session
1303        .connection_pool()
1304        .await
1305        .user_connection_ids(session.user_id())
1306    {
1307        session
1308            .peer
1309            .send(
1310                connection_id,
1311                proto::CallCanceled {
1312                    room_id: room_id.to_proto(),
1313                },
1314            )
1315            .trace_err();
1316    }
1317
1318    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1319        if let Some(token) = live_kit
1320            .room_token(
1321                &joined_room.room.live_kit_room,
1322                &session.user_id().to_string(),
1323            )
1324            .trace_err()
1325        {
1326            Some(proto::LiveKitConnectionInfo {
1327                server_url: live_kit.url().into(),
1328                token,
1329                can_publish: true,
1330            })
1331        } else {
1332            None
1333        }
1334    } else {
1335        None
1336    };
1337
1338    response.send(proto::JoinRoomResponse {
1339        room: Some(joined_room.room),
1340        channel_id: None,
1341        live_kit_connection_info,
1342    })?;
1343
1344    update_user_contacts(session.user_id(), &session).await?;
1345    Ok(())
1346}
1347
1348/// Rejoin room is used to reconnect to a room after connection errors.
1349async fn rejoin_room(
1350    request: proto::RejoinRoom,
1351    response: Response<proto::RejoinRoom>,
1352    session: UserSession,
1353) -> Result<()> {
1354    let room;
1355    let channel;
1356    {
1357        let mut rejoined_room = session
1358            .db()
1359            .await
1360            .rejoin_room(request, session.user_id(), session.connection_id)
1361            .await?;
1362
1363        response.send(proto::RejoinRoomResponse {
1364            room: Some(rejoined_room.room.clone()),
1365            reshared_projects: rejoined_room
1366                .reshared_projects
1367                .iter()
1368                .map(|project| proto::ResharedProject {
1369                    id: project.id.to_proto(),
1370                    collaborators: project
1371                        .collaborators
1372                        .iter()
1373                        .map(|collaborator| collaborator.to_proto())
1374                        .collect(),
1375                })
1376                .collect(),
1377            rejoined_projects: rejoined_room
1378                .rejoined_projects
1379                .iter()
1380                .map(|rejoined_project| proto::RejoinedProject {
1381                    id: rejoined_project.id.to_proto(),
1382                    worktrees: rejoined_project
1383                        .worktrees
1384                        .iter()
1385                        .map(|worktree| proto::WorktreeMetadata {
1386                            id: worktree.id,
1387                            root_name: worktree.root_name.clone(),
1388                            visible: worktree.visible,
1389                            abs_path: worktree.abs_path.clone(),
1390                        })
1391                        .collect(),
1392                    collaborators: rejoined_project
1393                        .collaborators
1394                        .iter()
1395                        .map(|collaborator| collaborator.to_proto())
1396                        .collect(),
1397                    language_servers: rejoined_project.language_servers.clone(),
1398                })
1399                .collect(),
1400        })?;
1401        room_updated(&rejoined_room.room, &session.peer);
1402
1403        for project in &rejoined_room.reshared_projects {
1404            for collaborator in &project.collaborators {
1405                session
1406                    .peer
1407                    .send(
1408                        collaborator.connection_id,
1409                        proto::UpdateProjectCollaborator {
1410                            project_id: project.id.to_proto(),
1411                            old_peer_id: Some(project.old_connection_id.into()),
1412                            new_peer_id: Some(session.connection_id.into()),
1413                        },
1414                    )
1415                    .trace_err();
1416            }
1417
1418            broadcast(
1419                Some(session.connection_id),
1420                project
1421                    .collaborators
1422                    .iter()
1423                    .map(|collaborator| collaborator.connection_id),
1424                |connection_id| {
1425                    session.peer.forward_send(
1426                        session.connection_id,
1427                        connection_id,
1428                        proto::UpdateProject {
1429                            project_id: project.id.to_proto(),
1430                            worktrees: project.worktrees.clone(),
1431                        },
1432                    )
1433                },
1434            );
1435        }
1436
1437        for project in &rejoined_room.rejoined_projects {
1438            for collaborator in &project.collaborators {
1439                session
1440                    .peer
1441                    .send(
1442                        collaborator.connection_id,
1443                        proto::UpdateProjectCollaborator {
1444                            project_id: project.id.to_proto(),
1445                            old_peer_id: Some(project.old_connection_id.into()),
1446                            new_peer_id: Some(session.connection_id.into()),
1447                        },
1448                    )
1449                    .trace_err();
1450            }
1451        }
1452
1453        for project in &mut rejoined_room.rejoined_projects {
1454            for worktree in mem::take(&mut project.worktrees) {
1455                #[cfg(any(test, feature = "test-support"))]
1456                const MAX_CHUNK_SIZE: usize = 2;
1457                #[cfg(not(any(test, feature = "test-support")))]
1458                const MAX_CHUNK_SIZE: usize = 256;
1459
1460                // Stream this worktree's entries.
1461                let message = proto::UpdateWorktree {
1462                    project_id: project.id.to_proto(),
1463                    worktree_id: worktree.id,
1464                    abs_path: worktree.abs_path.clone(),
1465                    root_name: worktree.root_name,
1466                    updated_entries: worktree.updated_entries,
1467                    removed_entries: worktree.removed_entries,
1468                    scan_id: worktree.scan_id,
1469                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1470                    updated_repositories: worktree.updated_repositories,
1471                    removed_repositories: worktree.removed_repositories,
1472                };
1473                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1474                    session.peer.send(session.connection_id, update.clone())?;
1475                }
1476
1477                // Stream this worktree's diagnostics.
1478                for summary in worktree.diagnostic_summaries {
1479                    session.peer.send(
1480                        session.connection_id,
1481                        proto::UpdateDiagnosticSummary {
1482                            project_id: project.id.to_proto(),
1483                            worktree_id: worktree.id,
1484                            summary: Some(summary),
1485                        },
1486                    )?;
1487                }
1488
1489                for settings_file in worktree.settings_files {
1490                    session.peer.send(
1491                        session.connection_id,
1492                        proto::UpdateWorktreeSettings {
1493                            project_id: project.id.to_proto(),
1494                            worktree_id: worktree.id,
1495                            path: settings_file.path,
1496                            content: Some(settings_file.content),
1497                        },
1498                    )?;
1499                }
1500            }
1501
1502            for language_server in &project.language_servers {
1503                session.peer.send(
1504                    session.connection_id,
1505                    proto::UpdateLanguageServer {
1506                        project_id: project.id.to_proto(),
1507                        language_server_id: language_server.id,
1508                        variant: Some(
1509                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1510                                proto::LspDiskBasedDiagnosticsUpdated {},
1511                            ),
1512                        ),
1513                    },
1514                )?;
1515            }
1516        }
1517
1518        let rejoined_room = rejoined_room.into_inner();
1519
1520        room = rejoined_room.room;
1521        channel = rejoined_room.channel;
1522    }
1523
1524    if let Some(channel) = channel {
1525        channel_updated(
1526            &channel,
1527            &room,
1528            &session.peer,
1529            &*session.connection_pool().await,
1530        );
1531    }
1532
1533    update_user_contacts(session.user_id(), &session).await?;
1534    Ok(())
1535}
1536
1537/// leave room disconnects from the room.
1538async fn leave_room(
1539    _: proto::LeaveRoom,
1540    response: Response<proto::LeaveRoom>,
1541    session: UserSession,
1542) -> Result<()> {
1543    leave_room_for_session(&session, session.connection_id).await?;
1544    response.send(proto::Ack {})?;
1545    Ok(())
1546}
1547
1548/// Updates the permissions of someone else in the room.
1549async fn set_room_participant_role(
1550    request: proto::SetRoomParticipantRole,
1551    response: Response<proto::SetRoomParticipantRole>,
1552    session: UserSession,
1553) -> Result<()> {
1554    let user_id = UserId::from_proto(request.user_id);
1555    let role = ChannelRole::from(request.role());
1556
1557    let (live_kit_room, can_publish) = {
1558        let room = session
1559            .db()
1560            .await
1561            .set_room_participant_role(
1562                session.user_id(),
1563                RoomId::from_proto(request.room_id),
1564                user_id,
1565                role,
1566            )
1567            .await?;
1568
1569        let live_kit_room = room.live_kit_room.clone();
1570        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1571        room_updated(&room, &session.peer);
1572        (live_kit_room, can_publish)
1573    };
1574
1575    if let Some(live_kit) = session.live_kit_client.as_ref() {
1576        live_kit
1577            .update_participant(
1578                live_kit_room.clone(),
1579                request.user_id.to_string(),
1580                live_kit_server::proto::ParticipantPermission {
1581                    can_subscribe: true,
1582                    can_publish,
1583                    can_publish_data: can_publish,
1584                    hidden: false,
1585                    recorder: false,
1586                },
1587            )
1588            .await
1589            .trace_err();
1590    }
1591
1592    response.send(proto::Ack {})?;
1593    Ok(())
1594}
1595
1596/// Call someone else into the current room
1597async fn call(
1598    request: proto::Call,
1599    response: Response<proto::Call>,
1600    session: UserSession,
1601) -> Result<()> {
1602    let room_id = RoomId::from_proto(request.room_id);
1603    let calling_user_id = session.user_id();
1604    let calling_connection_id = session.connection_id;
1605    let called_user_id = UserId::from_proto(request.called_user_id);
1606    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1607    if !session
1608        .db()
1609        .await
1610        .has_contact(calling_user_id, called_user_id)
1611        .await?
1612    {
1613        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1614    }
1615
1616    let incoming_call = {
1617        let (room, incoming_call) = &mut *session
1618            .db()
1619            .await
1620            .call(
1621                room_id,
1622                calling_user_id,
1623                calling_connection_id,
1624                called_user_id,
1625                initial_project_id,
1626            )
1627            .await?;
1628        room_updated(&room, &session.peer);
1629        mem::take(incoming_call)
1630    };
1631    update_user_contacts(called_user_id, &session).await?;
1632
1633    let mut calls = session
1634        .connection_pool()
1635        .await
1636        .user_connection_ids(called_user_id)
1637        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1638        .collect::<FuturesUnordered<_>>();
1639
1640    while let Some(call_response) = calls.next().await {
1641        match call_response.as_ref() {
1642            Ok(_) => {
1643                response.send(proto::Ack {})?;
1644                return Ok(());
1645            }
1646            Err(_) => {
1647                call_response.trace_err();
1648            }
1649        }
1650    }
1651
1652    {
1653        let room = session
1654            .db()
1655            .await
1656            .call_failed(room_id, called_user_id)
1657            .await?;
1658        room_updated(&room, &session.peer);
1659    }
1660    update_user_contacts(called_user_id, &session).await?;
1661
1662    Err(anyhow!("failed to ring user"))?
1663}
1664
1665/// Cancel an outgoing call.
1666async fn cancel_call(
1667    request: proto::CancelCall,
1668    response: Response<proto::CancelCall>,
1669    session: UserSession,
1670) -> Result<()> {
1671    let called_user_id = UserId::from_proto(request.called_user_id);
1672    let room_id = RoomId::from_proto(request.room_id);
1673    {
1674        let room = session
1675            .db()
1676            .await
1677            .cancel_call(room_id, session.connection_id, called_user_id)
1678            .await?;
1679        room_updated(&room, &session.peer);
1680    }
1681
1682    for connection_id in session
1683        .connection_pool()
1684        .await
1685        .user_connection_ids(called_user_id)
1686    {
1687        session
1688            .peer
1689            .send(
1690                connection_id,
1691                proto::CallCanceled {
1692                    room_id: room_id.to_proto(),
1693                },
1694            )
1695            .trace_err();
1696    }
1697    response.send(proto::Ack {})?;
1698
1699    update_user_contacts(called_user_id, &session).await?;
1700    Ok(())
1701}
1702
1703/// Decline an incoming call.
1704async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1705    let room_id = RoomId::from_proto(message.room_id);
1706    {
1707        let room = session
1708            .db()
1709            .await
1710            .decline_call(Some(room_id), session.user_id())
1711            .await?
1712            .ok_or_else(|| anyhow!("failed to decline call"))?;
1713        room_updated(&room, &session.peer);
1714    }
1715
1716    for connection_id in session
1717        .connection_pool()
1718        .await
1719        .user_connection_ids(session.user_id())
1720    {
1721        session
1722            .peer
1723            .send(
1724                connection_id,
1725                proto::CallCanceled {
1726                    room_id: room_id.to_proto(),
1727                },
1728            )
1729            .trace_err();
1730    }
1731    update_user_contacts(session.user_id(), &session).await?;
1732    Ok(())
1733}
1734
1735/// Updates other participants in the room with your current location.
1736async fn update_participant_location(
1737    request: proto::UpdateParticipantLocation,
1738    response: Response<proto::UpdateParticipantLocation>,
1739    session: UserSession,
1740) -> Result<()> {
1741    let room_id = RoomId::from_proto(request.room_id);
1742    let location = request
1743        .location
1744        .ok_or_else(|| anyhow!("invalid location"))?;
1745
1746    let db = session.db().await;
1747    let room = db
1748        .update_room_participant_location(room_id, session.connection_id, location)
1749        .await?;
1750
1751    room_updated(&room, &session.peer);
1752    response.send(proto::Ack {})?;
1753    Ok(())
1754}
1755
1756/// Share a project into the room.
1757async fn share_project(
1758    request: proto::ShareProject,
1759    response: Response<proto::ShareProject>,
1760    session: Session,
1761) -> Result<()> {
1762    let (project_id, room) = &*session
1763        .db()
1764        .await
1765        .share_project(
1766            RoomId::from_proto(request.room_id),
1767            session.connection_id,
1768            &request.worktrees,
1769        )
1770        .await?;
1771    response.send(proto::ShareProjectResponse {
1772        project_id: project_id.to_proto(),
1773    })?;
1774    room_updated(&room, &session.peer);
1775
1776    Ok(())
1777}
1778
1779/// Unshare a project from the room.
1780async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1781    let project_id = ProjectId::from_proto(message.project_id);
1782
1783    let (room, guest_connection_ids) = &*session
1784        .db()
1785        .await
1786        .unshare_project(project_id, session.connection_id)
1787        .await?;
1788
1789    broadcast(
1790        Some(session.connection_id),
1791        guest_connection_ids.iter().copied(),
1792        |conn_id| session.peer.send(conn_id, message.clone()),
1793    );
1794    room_updated(&room, &session.peer);
1795
1796    Ok(())
1797}
1798
1799/// Join someone elses shared project.
1800async fn join_project(
1801    request: proto::JoinProject,
1802    response: Response<proto::JoinProject>,
1803    session: UserSession,
1804) -> Result<()> {
1805    let project_id = ProjectId::from_proto(request.project_id);
1806
1807    tracing::info!(%project_id, "join project");
1808
1809    let (project, replica_id) = &mut *session
1810        .db()
1811        .await
1812        .join_project_in_room(project_id, session.connection_id)
1813        .await?;
1814
1815    join_project_internal(response, session, project, replica_id)
1816}
1817
1818trait JoinProjectInternalResponse {
1819    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1820}
1821impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1822    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1823        Response::<proto::JoinProject>::send(self, result)
1824    }
1825}
1826impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1827    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1828        Response::<proto::JoinHostedProject>::send(self, result)
1829    }
1830}
1831
1832fn join_project_internal(
1833    response: impl JoinProjectInternalResponse,
1834    session: UserSession,
1835    project: &mut Project,
1836    replica_id: &ReplicaId,
1837) -> Result<()> {
1838    let collaborators = project
1839        .collaborators
1840        .iter()
1841        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1842        .map(|collaborator| collaborator.to_proto())
1843        .collect::<Vec<_>>();
1844    let project_id = project.id;
1845    let guest_user_id = session.user_id();
1846
1847    let worktrees = project
1848        .worktrees
1849        .iter()
1850        .map(|(id, worktree)| proto::WorktreeMetadata {
1851            id: *id,
1852            root_name: worktree.root_name.clone(),
1853            visible: worktree.visible,
1854            abs_path: worktree.abs_path.clone(),
1855        })
1856        .collect::<Vec<_>>();
1857
1858    for collaborator in &collaborators {
1859        session
1860            .peer
1861            .send(
1862                collaborator.peer_id.unwrap().into(),
1863                proto::AddProjectCollaborator {
1864                    project_id: project_id.to_proto(),
1865                    collaborator: Some(proto::Collaborator {
1866                        peer_id: Some(session.connection_id.into()),
1867                        replica_id: replica_id.0 as u32,
1868                        user_id: guest_user_id.to_proto(),
1869                    }),
1870                },
1871            )
1872            .trace_err();
1873    }
1874
1875    // First, we send the metadata associated with each worktree.
1876    response.send(proto::JoinProjectResponse {
1877        project_id: project.id.0 as u64,
1878        worktrees: worktrees.clone(),
1879        replica_id: replica_id.0 as u32,
1880        collaborators: collaborators.clone(),
1881        language_servers: project.language_servers.clone(),
1882        role: project.role.into(), // todo
1883    })?;
1884
1885    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1886        #[cfg(any(test, feature = "test-support"))]
1887        const MAX_CHUNK_SIZE: usize = 2;
1888        #[cfg(not(any(test, feature = "test-support")))]
1889        const MAX_CHUNK_SIZE: usize = 256;
1890
1891        // Stream this worktree's entries.
1892        let message = proto::UpdateWorktree {
1893            project_id: project_id.to_proto(),
1894            worktree_id,
1895            abs_path: worktree.abs_path.clone(),
1896            root_name: worktree.root_name,
1897            updated_entries: worktree.entries,
1898            removed_entries: Default::default(),
1899            scan_id: worktree.scan_id,
1900            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1901            updated_repositories: worktree.repository_entries.into_values().collect(),
1902            removed_repositories: Default::default(),
1903        };
1904        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1905            session.peer.send(session.connection_id, update.clone())?;
1906        }
1907
1908        // Stream this worktree's diagnostics.
1909        for summary in worktree.diagnostic_summaries {
1910            session.peer.send(
1911                session.connection_id,
1912                proto::UpdateDiagnosticSummary {
1913                    project_id: project_id.to_proto(),
1914                    worktree_id: worktree.id,
1915                    summary: Some(summary),
1916                },
1917            )?;
1918        }
1919
1920        for settings_file in worktree.settings_files {
1921            session.peer.send(
1922                session.connection_id,
1923                proto::UpdateWorktreeSettings {
1924                    project_id: project_id.to_proto(),
1925                    worktree_id: worktree.id,
1926                    path: settings_file.path,
1927                    content: Some(settings_file.content),
1928                },
1929            )?;
1930        }
1931    }
1932
1933    for language_server in &project.language_servers {
1934        session.peer.send(
1935            session.connection_id,
1936            proto::UpdateLanguageServer {
1937                project_id: project_id.to_proto(),
1938                language_server_id: language_server.id,
1939                variant: Some(
1940                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1941                        proto::LspDiskBasedDiagnosticsUpdated {},
1942                    ),
1943                ),
1944            },
1945        )?;
1946    }
1947
1948    Ok(())
1949}
1950
1951/// Leave someone elses shared project.
1952async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1953    let sender_id = session.connection_id;
1954    let project_id = ProjectId::from_proto(request.project_id);
1955    let db = session.db().await;
1956    if db.is_hosted_project(project_id).await? {
1957        let project = db.leave_hosted_project(project_id, sender_id).await?;
1958        project_left(&project, &session);
1959        return Ok(());
1960    }
1961
1962    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1963    tracing::info!(
1964        %project_id,
1965        host_user_id = ?project.host_user_id,
1966        host_connection_id = ?project.host_connection_id,
1967        "leave project"
1968    );
1969
1970    project_left(&project, &session);
1971    room_updated(&room, &session.peer);
1972
1973    Ok(())
1974}
1975
1976async fn join_hosted_project(
1977    request: proto::JoinHostedProject,
1978    response: Response<proto::JoinHostedProject>,
1979    session: UserSession,
1980) -> Result<()> {
1981    let (mut project, replica_id) = session
1982        .db()
1983        .await
1984        .join_hosted_project(
1985            ProjectId(request.project_id as i32),
1986            session.user_id(),
1987            session.connection_id,
1988        )
1989        .await?;
1990
1991    join_project_internal(response, session, &mut project, &replica_id)
1992}
1993
1994/// Updates other participants with changes to the project
1995async fn update_project(
1996    request: proto::UpdateProject,
1997    response: Response<proto::UpdateProject>,
1998    session: Session,
1999) -> Result<()> {
2000    let project_id = ProjectId::from_proto(request.project_id);
2001    let (room, guest_connection_ids) = &*session
2002        .db()
2003        .await
2004        .update_project(project_id, session.connection_id, &request.worktrees)
2005        .await?;
2006    broadcast(
2007        Some(session.connection_id),
2008        guest_connection_ids.iter().copied(),
2009        |connection_id| {
2010            session
2011                .peer
2012                .forward_send(session.connection_id, connection_id, request.clone())
2013        },
2014    );
2015    room_updated(&room, &session.peer);
2016    response.send(proto::Ack {})?;
2017
2018    Ok(())
2019}
2020
2021/// Updates other participants with changes to the worktree
2022async fn update_worktree(
2023    request: proto::UpdateWorktree,
2024    response: Response<proto::UpdateWorktree>,
2025    session: Session,
2026) -> Result<()> {
2027    let guest_connection_ids = session
2028        .db()
2029        .await
2030        .update_worktree(&request, session.connection_id)
2031        .await?;
2032
2033    broadcast(
2034        Some(session.connection_id),
2035        guest_connection_ids.iter().copied(),
2036        |connection_id| {
2037            session
2038                .peer
2039                .forward_send(session.connection_id, connection_id, request.clone())
2040        },
2041    );
2042    response.send(proto::Ack {})?;
2043    Ok(())
2044}
2045
2046/// Updates other participants with changes to the diagnostics
2047async fn update_diagnostic_summary(
2048    message: proto::UpdateDiagnosticSummary,
2049    session: Session,
2050) -> Result<()> {
2051    let guest_connection_ids = session
2052        .db()
2053        .await
2054        .update_diagnostic_summary(&message, session.connection_id)
2055        .await?;
2056
2057    broadcast(
2058        Some(session.connection_id),
2059        guest_connection_ids.iter().copied(),
2060        |connection_id| {
2061            session
2062                .peer
2063                .forward_send(session.connection_id, connection_id, message.clone())
2064        },
2065    );
2066
2067    Ok(())
2068}
2069
2070/// Updates other participants with changes to the worktree settings
2071async fn update_worktree_settings(
2072    message: proto::UpdateWorktreeSettings,
2073    session: Session,
2074) -> Result<()> {
2075    let guest_connection_ids = session
2076        .db()
2077        .await
2078        .update_worktree_settings(&message, session.connection_id)
2079        .await?;
2080
2081    broadcast(
2082        Some(session.connection_id),
2083        guest_connection_ids.iter().copied(),
2084        |connection_id| {
2085            session
2086                .peer
2087                .forward_send(session.connection_id, connection_id, message.clone())
2088        },
2089    );
2090
2091    Ok(())
2092}
2093
2094/// Notify other participants that a  language server has started.
2095async fn start_language_server(
2096    request: proto::StartLanguageServer,
2097    session: Session,
2098) -> Result<()> {
2099    let guest_connection_ids = session
2100        .db()
2101        .await
2102        .start_language_server(&request, session.connection_id)
2103        .await?;
2104
2105    broadcast(
2106        Some(session.connection_id),
2107        guest_connection_ids.iter().copied(),
2108        |connection_id| {
2109            session
2110                .peer
2111                .forward_send(session.connection_id, connection_id, request.clone())
2112        },
2113    );
2114    Ok(())
2115}
2116
2117/// Notify other participants that a language server has changed.
2118async fn update_language_server(
2119    request: proto::UpdateLanguageServer,
2120    session: Session,
2121) -> Result<()> {
2122    let project_id = ProjectId::from_proto(request.project_id);
2123    let project_connection_ids = session
2124        .db()
2125        .await
2126        .project_connection_ids(project_id, session.connection_id)
2127        .await?;
2128    broadcast(
2129        Some(session.connection_id),
2130        project_connection_ids.iter().copied(),
2131        |connection_id| {
2132            session
2133                .peer
2134                .forward_send(session.connection_id, connection_id, request.clone())
2135        },
2136    );
2137    Ok(())
2138}
2139
2140/// forward a project request to the host. These requests should be read only
2141/// as guests are allowed to send them.
2142async fn forward_read_only_project_request<T>(
2143    request: T,
2144    response: Response<T>,
2145    session: Session,
2146) -> Result<()>
2147where
2148    T: EntityMessage + RequestMessage,
2149{
2150    let project_id = ProjectId::from_proto(request.remote_entity_id());
2151    let host_connection_id = session
2152        .db()
2153        .await
2154        .host_for_read_only_project_request(project_id, session.connection_id)
2155        .await?;
2156    let payload = session
2157        .peer
2158        .forward_request(session.connection_id, host_connection_id, request)
2159        .await?;
2160    response.send(payload)?;
2161    Ok(())
2162}
2163
2164/// forward a project request to the host. These requests are disallowed
2165/// for guests.
2166async fn forward_mutating_project_request<T>(
2167    request: T,
2168    response: Response<T>,
2169    session: Session,
2170) -> Result<()>
2171where
2172    T: EntityMessage + RequestMessage,
2173{
2174    let project_id = ProjectId::from_proto(request.remote_entity_id());
2175    let host_connection_id = session
2176        .db()
2177        .await
2178        .host_for_mutating_project_request(project_id, session.connection_id)
2179        .await?;
2180    let payload = session
2181        .peer
2182        .forward_request(session.connection_id, host_connection_id, request)
2183        .await?;
2184    response.send(payload)?;
2185    Ok(())
2186}
2187
2188/// Notify other participants that a new buffer has been created
2189async fn create_buffer_for_peer(
2190    request: proto::CreateBufferForPeer,
2191    session: Session,
2192) -> Result<()> {
2193    session
2194        .db()
2195        .await
2196        .check_user_is_project_host(
2197            ProjectId::from_proto(request.project_id),
2198            session.connection_id,
2199        )
2200        .await?;
2201    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2202    session
2203        .peer
2204        .forward_send(session.connection_id, peer_id.into(), request)?;
2205    Ok(())
2206}
2207
2208/// Notify other participants that a buffer has been updated. This is
2209/// allowed for guests as long as the update is limited to selections.
2210async fn update_buffer(
2211    request: proto::UpdateBuffer,
2212    response: Response<proto::UpdateBuffer>,
2213    session: Session,
2214) -> Result<()> {
2215    let project_id = ProjectId::from_proto(request.project_id);
2216    let mut guest_connection_ids;
2217    let mut host_connection_id = None;
2218
2219    let mut requires_write_permission = false;
2220
2221    for op in request.operations.iter() {
2222        match op.variant {
2223            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2224            Some(_) => requires_write_permission = true,
2225        }
2226    }
2227
2228    {
2229        let collaborators = session
2230            .db()
2231            .await
2232            .project_collaborators_for_buffer_update(
2233                project_id,
2234                session.connection_id,
2235                requires_write_permission,
2236            )
2237            .await?;
2238        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2239        for collaborator in collaborators.iter() {
2240            if collaborator.is_host {
2241                host_connection_id = Some(collaborator.connection_id);
2242            } else {
2243                guest_connection_ids.push(collaborator.connection_id);
2244            }
2245        }
2246    }
2247    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2248
2249    broadcast(
2250        Some(session.connection_id),
2251        guest_connection_ids,
2252        |connection_id| {
2253            session
2254                .peer
2255                .forward_send(session.connection_id, connection_id, request.clone())
2256        },
2257    );
2258    if host_connection_id != session.connection_id {
2259        session
2260            .peer
2261            .forward_request(session.connection_id, host_connection_id, request.clone())
2262            .await?;
2263    }
2264
2265    response.send(proto::Ack {})?;
2266    Ok(())
2267}
2268
2269/// Notify other participants that a project has been updated.
2270async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2271    request: T,
2272    session: Session,
2273) -> Result<()> {
2274    let project_id = ProjectId::from_proto(request.remote_entity_id());
2275    let project_connection_ids = session
2276        .db()
2277        .await
2278        .project_connection_ids(project_id, session.connection_id)
2279        .await?;
2280
2281    broadcast(
2282        Some(session.connection_id),
2283        project_connection_ids.iter().copied(),
2284        |connection_id| {
2285            session
2286                .peer
2287                .forward_send(session.connection_id, connection_id, request.clone())
2288        },
2289    );
2290    Ok(())
2291}
2292
2293/// Start following another user in a call.
2294async fn follow(
2295    request: proto::Follow,
2296    response: Response<proto::Follow>,
2297    session: UserSession,
2298) -> Result<()> {
2299    let room_id = RoomId::from_proto(request.room_id);
2300    let project_id = request.project_id.map(ProjectId::from_proto);
2301    let leader_id = request
2302        .leader_id
2303        .ok_or_else(|| anyhow!("invalid leader id"))?
2304        .into();
2305    let follower_id = session.connection_id;
2306
2307    session
2308        .db()
2309        .await
2310        .check_room_participants(room_id, leader_id, session.connection_id)
2311        .await?;
2312
2313    let response_payload = session
2314        .peer
2315        .forward_request(session.connection_id, leader_id, request)
2316        .await?;
2317    response.send(response_payload)?;
2318
2319    if let Some(project_id) = project_id {
2320        let room = session
2321            .db()
2322            .await
2323            .follow(room_id, project_id, leader_id, follower_id)
2324            .await?;
2325        room_updated(&room, &session.peer);
2326    }
2327
2328    Ok(())
2329}
2330
2331/// Stop following another user in a call.
2332async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2333    let room_id = RoomId::from_proto(request.room_id);
2334    let project_id = request.project_id.map(ProjectId::from_proto);
2335    let leader_id = request
2336        .leader_id
2337        .ok_or_else(|| anyhow!("invalid leader id"))?
2338        .into();
2339    let follower_id = session.connection_id;
2340
2341    session
2342        .db()
2343        .await
2344        .check_room_participants(room_id, leader_id, session.connection_id)
2345        .await?;
2346
2347    session
2348        .peer
2349        .forward_send(session.connection_id, leader_id, request)?;
2350
2351    if let Some(project_id) = project_id {
2352        let room = session
2353            .db()
2354            .await
2355            .unfollow(room_id, project_id, leader_id, follower_id)
2356            .await?;
2357        room_updated(&room, &session.peer);
2358    }
2359
2360    Ok(())
2361}
2362
2363/// Notify everyone following you of your current location.
2364async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2365    let room_id = RoomId::from_proto(request.room_id);
2366    let database = session.db.lock().await;
2367
2368    let connection_ids = if let Some(project_id) = request.project_id {
2369        let project_id = ProjectId::from_proto(project_id);
2370        database
2371            .project_connection_ids(project_id, session.connection_id)
2372            .await?
2373    } else {
2374        database
2375            .room_connection_ids(room_id, session.connection_id)
2376            .await?
2377    };
2378
2379    // For now, don't send view update messages back to that view's current leader.
2380    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2381        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2382        _ => None,
2383    });
2384
2385    for connection_id in connection_ids.iter().cloned() {
2386        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2387            session
2388                .peer
2389                .forward_send(session.connection_id, connection_id, request.clone())?;
2390        }
2391    }
2392    Ok(())
2393}
2394
2395/// Get public data about users.
2396async fn get_users(
2397    request: proto::GetUsers,
2398    response: Response<proto::GetUsers>,
2399    session: Session,
2400) -> Result<()> {
2401    let user_ids = request
2402        .user_ids
2403        .into_iter()
2404        .map(UserId::from_proto)
2405        .collect();
2406    let users = session
2407        .db()
2408        .await
2409        .get_users_by_ids(user_ids)
2410        .await?
2411        .into_iter()
2412        .map(|user| proto::User {
2413            id: user.id.to_proto(),
2414            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2415            github_login: user.github_login,
2416        })
2417        .collect();
2418    response.send(proto::UsersResponse { users })?;
2419    Ok(())
2420}
2421
2422/// Search for users (to invite) buy Github login
2423async fn fuzzy_search_users(
2424    request: proto::FuzzySearchUsers,
2425    response: Response<proto::FuzzySearchUsers>,
2426    session: UserSession,
2427) -> Result<()> {
2428    let query = request.query;
2429    let users = match query.len() {
2430        0 => vec![],
2431        1 | 2 => session
2432            .db()
2433            .await
2434            .get_user_by_github_login(&query)
2435            .await?
2436            .into_iter()
2437            .collect(),
2438        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2439    };
2440    let users = users
2441        .into_iter()
2442        .filter(|user| user.id != session.user_id())
2443        .map(|user| proto::User {
2444            id: user.id.to_proto(),
2445            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2446            github_login: user.github_login,
2447        })
2448        .collect();
2449    response.send(proto::UsersResponse { users })?;
2450    Ok(())
2451}
2452
2453/// Send a contact request to another user.
2454async fn request_contact(
2455    request: proto::RequestContact,
2456    response: Response<proto::RequestContact>,
2457    session: UserSession,
2458) -> Result<()> {
2459    let requester_id = session.user_id();
2460    let responder_id = UserId::from_proto(request.responder_id);
2461    if requester_id == responder_id {
2462        return Err(anyhow!("cannot add yourself as a contact"))?;
2463    }
2464
2465    let notifications = session
2466        .db()
2467        .await
2468        .send_contact_request(requester_id, responder_id)
2469        .await?;
2470
2471    // Update outgoing contact requests of requester
2472    let mut update = proto::UpdateContacts::default();
2473    update.outgoing_requests.push(responder_id.to_proto());
2474    for connection_id in session
2475        .connection_pool()
2476        .await
2477        .user_connection_ids(requester_id)
2478    {
2479        session.peer.send(connection_id, update.clone())?;
2480    }
2481
2482    // Update incoming contact requests of responder
2483    let mut update = proto::UpdateContacts::default();
2484    update
2485        .incoming_requests
2486        .push(proto::IncomingContactRequest {
2487            requester_id: requester_id.to_proto(),
2488        });
2489    let connection_pool = session.connection_pool().await;
2490    for connection_id in connection_pool.user_connection_ids(responder_id) {
2491        session.peer.send(connection_id, update.clone())?;
2492    }
2493
2494    send_notifications(&connection_pool, &session.peer, notifications);
2495
2496    response.send(proto::Ack {})?;
2497    Ok(())
2498}
2499
2500/// Accept or decline a contact request
2501async fn respond_to_contact_request(
2502    request: proto::RespondToContactRequest,
2503    response: Response<proto::RespondToContactRequest>,
2504    session: UserSession,
2505) -> Result<()> {
2506    let responder_id = session.user_id();
2507    let requester_id = UserId::from_proto(request.requester_id);
2508    let db = session.db().await;
2509    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2510        db.dismiss_contact_notification(responder_id, requester_id)
2511            .await?;
2512    } else {
2513        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2514
2515        let notifications = db
2516            .respond_to_contact_request(responder_id, requester_id, accept)
2517            .await?;
2518        let requester_busy = db.is_user_busy(requester_id).await?;
2519        let responder_busy = db.is_user_busy(responder_id).await?;
2520
2521        let pool = session.connection_pool().await;
2522        // Update responder with new contact
2523        let mut update = proto::UpdateContacts::default();
2524        if accept {
2525            update
2526                .contacts
2527                .push(contact_for_user(requester_id, requester_busy, &pool));
2528        }
2529        update
2530            .remove_incoming_requests
2531            .push(requester_id.to_proto());
2532        for connection_id in pool.user_connection_ids(responder_id) {
2533            session.peer.send(connection_id, update.clone())?;
2534        }
2535
2536        // Update requester with new contact
2537        let mut update = proto::UpdateContacts::default();
2538        if accept {
2539            update
2540                .contacts
2541                .push(contact_for_user(responder_id, responder_busy, &pool));
2542        }
2543        update
2544            .remove_outgoing_requests
2545            .push(responder_id.to_proto());
2546
2547        for connection_id in pool.user_connection_ids(requester_id) {
2548            session.peer.send(connection_id, update.clone())?;
2549        }
2550
2551        send_notifications(&pool, &session.peer, notifications);
2552    }
2553
2554    response.send(proto::Ack {})?;
2555    Ok(())
2556}
2557
2558/// Remove a contact.
2559async fn remove_contact(
2560    request: proto::RemoveContact,
2561    response: Response<proto::RemoveContact>,
2562    session: UserSession,
2563) -> Result<()> {
2564    let requester_id = session.user_id();
2565    let responder_id = UserId::from_proto(request.user_id);
2566    let db = session.db().await;
2567    let (contact_accepted, deleted_notification_id) =
2568        db.remove_contact(requester_id, responder_id).await?;
2569
2570    let pool = session.connection_pool().await;
2571    // Update outgoing contact requests of requester
2572    let mut update = proto::UpdateContacts::default();
2573    if contact_accepted {
2574        update.remove_contacts.push(responder_id.to_proto());
2575    } else {
2576        update
2577            .remove_outgoing_requests
2578            .push(responder_id.to_proto());
2579    }
2580    for connection_id in pool.user_connection_ids(requester_id) {
2581        session.peer.send(connection_id, update.clone())?;
2582    }
2583
2584    // Update incoming contact requests of responder
2585    let mut update = proto::UpdateContacts::default();
2586    if contact_accepted {
2587        update.remove_contacts.push(requester_id.to_proto());
2588    } else {
2589        update
2590            .remove_incoming_requests
2591            .push(requester_id.to_proto());
2592    }
2593    for connection_id in pool.user_connection_ids(responder_id) {
2594        session.peer.send(connection_id, update.clone())?;
2595        if let Some(notification_id) = deleted_notification_id {
2596            session.peer.send(
2597                connection_id,
2598                proto::DeleteNotification {
2599                    notification_id: notification_id.to_proto(),
2600                },
2601            )?;
2602        }
2603    }
2604
2605    response.send(proto::Ack {})?;
2606    Ok(())
2607}
2608
2609/// Creates a new channel.
2610async fn create_channel(
2611    request: proto::CreateChannel,
2612    response: Response<proto::CreateChannel>,
2613    session: UserSession,
2614) -> Result<()> {
2615    let db = session.db().await;
2616
2617    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2618    let (channel, membership) = db
2619        .create_channel(&request.name, parent_id, session.user_id())
2620        .await?;
2621
2622    let root_id = channel.root_id();
2623    let channel = Channel::from_model(channel);
2624
2625    response.send(proto::CreateChannelResponse {
2626        channel: Some(channel.to_proto()),
2627        parent_id: request.parent_id,
2628    })?;
2629
2630    let mut connection_pool = session.connection_pool().await;
2631    if let Some(membership) = membership {
2632        connection_pool.subscribe_to_channel(
2633            membership.user_id,
2634            membership.channel_id,
2635            membership.role,
2636        );
2637        let update = proto::UpdateUserChannels {
2638            channel_memberships: vec![proto::ChannelMembership {
2639                channel_id: membership.channel_id.to_proto(),
2640                role: membership.role.into(),
2641            }],
2642            ..Default::default()
2643        };
2644        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2645            session.peer.send(connection_id, update.clone())?;
2646        }
2647    }
2648
2649    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2650        if !role.can_see_channel(channel.visibility) {
2651            continue;
2652        }
2653
2654        let update = proto::UpdateChannels {
2655            channels: vec![channel.to_proto()],
2656            ..Default::default()
2657        };
2658        session.peer.send(connection_id, update.clone())?;
2659    }
2660
2661    Ok(())
2662}
2663
2664/// Delete a channel
2665async fn delete_channel(
2666    request: proto::DeleteChannel,
2667    response: Response<proto::DeleteChannel>,
2668    session: UserSession,
2669) -> Result<()> {
2670    let db = session.db().await;
2671
2672    let channel_id = request.channel_id;
2673    let (root_channel, removed_channels) = db
2674        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2675        .await?;
2676    response.send(proto::Ack {})?;
2677
2678    // Notify members of removed channels
2679    let mut update = proto::UpdateChannels::default();
2680    update
2681        .delete_channels
2682        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2683
2684    let connection_pool = session.connection_pool().await;
2685    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2686        session.peer.send(connection_id, update.clone())?;
2687    }
2688
2689    Ok(())
2690}
2691
2692/// Invite someone to join a channel.
2693async fn invite_channel_member(
2694    request: proto::InviteChannelMember,
2695    response: Response<proto::InviteChannelMember>,
2696    session: UserSession,
2697) -> Result<()> {
2698    let db = session.db().await;
2699    let channel_id = ChannelId::from_proto(request.channel_id);
2700    let invitee_id = UserId::from_proto(request.user_id);
2701    let InviteMemberResult {
2702        channel,
2703        notifications,
2704    } = db
2705        .invite_channel_member(
2706            channel_id,
2707            invitee_id,
2708            session.user_id(),
2709            request.role().into(),
2710        )
2711        .await?;
2712
2713    let update = proto::UpdateChannels {
2714        channel_invitations: vec![channel.to_proto()],
2715        ..Default::default()
2716    };
2717
2718    let connection_pool = session.connection_pool().await;
2719    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2720        session.peer.send(connection_id, update.clone())?;
2721    }
2722
2723    send_notifications(&connection_pool, &session.peer, notifications);
2724
2725    response.send(proto::Ack {})?;
2726    Ok(())
2727}
2728
2729/// remove someone from a channel
2730async fn remove_channel_member(
2731    request: proto::RemoveChannelMember,
2732    response: Response<proto::RemoveChannelMember>,
2733    session: UserSession,
2734) -> Result<()> {
2735    let db = session.db().await;
2736    let channel_id = ChannelId::from_proto(request.channel_id);
2737    let member_id = UserId::from_proto(request.user_id);
2738
2739    let RemoveChannelMemberResult {
2740        membership_update,
2741        notification_id,
2742    } = db
2743        .remove_channel_member(channel_id, member_id, session.user_id())
2744        .await?;
2745
2746    let mut connection_pool = session.connection_pool().await;
2747    notify_membership_updated(
2748        &mut connection_pool,
2749        membership_update,
2750        member_id,
2751        &session.peer,
2752    );
2753    for connection_id in connection_pool.user_connection_ids(member_id) {
2754        if let Some(notification_id) = notification_id {
2755            session
2756                .peer
2757                .send(
2758                    connection_id,
2759                    proto::DeleteNotification {
2760                        notification_id: notification_id.to_proto(),
2761                    },
2762                )
2763                .trace_err();
2764        }
2765    }
2766
2767    response.send(proto::Ack {})?;
2768    Ok(())
2769}
2770
2771/// Toggle the channel between public and private.
2772/// Care is taken to maintain the invariant that public channels only descend from public channels,
2773/// (though members-only channels can appear at any point in the hierarchy).
2774async fn set_channel_visibility(
2775    request: proto::SetChannelVisibility,
2776    response: Response<proto::SetChannelVisibility>,
2777    session: UserSession,
2778) -> Result<()> {
2779    let db = session.db().await;
2780    let channel_id = ChannelId::from_proto(request.channel_id);
2781    let visibility = request.visibility().into();
2782
2783    let channel_model = db
2784        .set_channel_visibility(channel_id, visibility, session.user_id())
2785        .await?;
2786    let root_id = channel_model.root_id();
2787    let channel = Channel::from_model(channel_model);
2788
2789    let mut connection_pool = session.connection_pool().await;
2790    for (user_id, role) in connection_pool
2791        .channel_user_ids(root_id)
2792        .collect::<Vec<_>>()
2793        .into_iter()
2794    {
2795        let update = if role.can_see_channel(channel.visibility) {
2796            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2797            proto::UpdateChannels {
2798                channels: vec![channel.to_proto()],
2799                ..Default::default()
2800            }
2801        } else {
2802            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2803            proto::UpdateChannels {
2804                delete_channels: vec![channel.id.to_proto()],
2805                ..Default::default()
2806            }
2807        };
2808
2809        for connection_id in connection_pool.user_connection_ids(user_id) {
2810            session.peer.send(connection_id, update.clone())?;
2811        }
2812    }
2813
2814    response.send(proto::Ack {})?;
2815    Ok(())
2816}
2817
2818/// Alter the role for a user in the channel.
2819async fn set_channel_member_role(
2820    request: proto::SetChannelMemberRole,
2821    response: Response<proto::SetChannelMemberRole>,
2822    session: UserSession,
2823) -> Result<()> {
2824    let db = session.db().await;
2825    let channel_id = ChannelId::from_proto(request.channel_id);
2826    let member_id = UserId::from_proto(request.user_id);
2827    let result = db
2828        .set_channel_member_role(
2829            channel_id,
2830            session.user_id(),
2831            member_id,
2832            request.role().into(),
2833        )
2834        .await?;
2835
2836    match result {
2837        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2838            let mut connection_pool = session.connection_pool().await;
2839            notify_membership_updated(
2840                &mut connection_pool,
2841                membership_update,
2842                member_id,
2843                &session.peer,
2844            )
2845        }
2846        db::SetMemberRoleResult::InviteUpdated(channel) => {
2847            let update = proto::UpdateChannels {
2848                channel_invitations: vec![channel.to_proto()],
2849                ..Default::default()
2850            };
2851
2852            for connection_id in session
2853                .connection_pool()
2854                .await
2855                .user_connection_ids(member_id)
2856            {
2857                session.peer.send(connection_id, update.clone())?;
2858            }
2859        }
2860    }
2861
2862    response.send(proto::Ack {})?;
2863    Ok(())
2864}
2865
2866/// Change the name of a channel
2867async fn rename_channel(
2868    request: proto::RenameChannel,
2869    response: Response<proto::RenameChannel>,
2870    session: UserSession,
2871) -> Result<()> {
2872    let db = session.db().await;
2873    let channel_id = ChannelId::from_proto(request.channel_id);
2874    let channel_model = db
2875        .rename_channel(channel_id, session.user_id(), &request.name)
2876        .await?;
2877    let root_id = channel_model.root_id();
2878    let channel = Channel::from_model(channel_model);
2879
2880    response.send(proto::RenameChannelResponse {
2881        channel: Some(channel.to_proto()),
2882    })?;
2883
2884    let connection_pool = session.connection_pool().await;
2885    let update = proto::UpdateChannels {
2886        channels: vec![channel.to_proto()],
2887        ..Default::default()
2888    };
2889    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2890        if role.can_see_channel(channel.visibility) {
2891            session.peer.send(connection_id, update.clone())?;
2892        }
2893    }
2894
2895    Ok(())
2896}
2897
2898/// Move a channel to a new parent.
2899async fn move_channel(
2900    request: proto::MoveChannel,
2901    response: Response<proto::MoveChannel>,
2902    session: UserSession,
2903) -> Result<()> {
2904    let channel_id = ChannelId::from_proto(request.channel_id);
2905    let to = ChannelId::from_proto(request.to);
2906
2907    let (root_id, channels) = session
2908        .db()
2909        .await
2910        .move_channel(channel_id, to, session.user_id())
2911        .await?;
2912
2913    let connection_pool = session.connection_pool().await;
2914    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2915        let channels = channels
2916            .iter()
2917            .filter_map(|channel| {
2918                if role.can_see_channel(channel.visibility) {
2919                    Some(channel.to_proto())
2920                } else {
2921                    None
2922                }
2923            })
2924            .collect::<Vec<_>>();
2925        if channels.is_empty() {
2926            continue;
2927        }
2928
2929        let update = proto::UpdateChannels {
2930            channels,
2931            ..Default::default()
2932        };
2933
2934        session.peer.send(connection_id, update.clone())?;
2935    }
2936
2937    response.send(Ack {})?;
2938    Ok(())
2939}
2940
2941/// Get the list of channel members
2942async fn get_channel_members(
2943    request: proto::GetChannelMembers,
2944    response: Response<proto::GetChannelMembers>,
2945    session: UserSession,
2946) -> Result<()> {
2947    let db = session.db().await;
2948    let channel_id = ChannelId::from_proto(request.channel_id);
2949    let members = db
2950        .get_channel_participant_details(channel_id, session.user_id())
2951        .await?;
2952    response.send(proto::GetChannelMembersResponse { members })?;
2953    Ok(())
2954}
2955
2956/// Accept or decline a channel invitation.
2957async fn respond_to_channel_invite(
2958    request: proto::RespondToChannelInvite,
2959    response: Response<proto::RespondToChannelInvite>,
2960    session: UserSession,
2961) -> Result<()> {
2962    let db = session.db().await;
2963    let channel_id = ChannelId::from_proto(request.channel_id);
2964    let RespondToChannelInvite {
2965        membership_update,
2966        notifications,
2967    } = db
2968        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2969        .await?;
2970
2971    let mut connection_pool = session.connection_pool().await;
2972    if let Some(membership_update) = membership_update {
2973        notify_membership_updated(
2974            &mut connection_pool,
2975            membership_update,
2976            session.user_id(),
2977            &session.peer,
2978        );
2979    } else {
2980        let update = proto::UpdateChannels {
2981            remove_channel_invitations: vec![channel_id.to_proto()],
2982            ..Default::default()
2983        };
2984
2985        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2986            session.peer.send(connection_id, update.clone())?;
2987        }
2988    };
2989
2990    send_notifications(&connection_pool, &session.peer, notifications);
2991
2992    response.send(proto::Ack {})?;
2993
2994    Ok(())
2995}
2996
2997/// Join the channels' room
2998async fn join_channel(
2999    request: proto::JoinChannel,
3000    response: Response<proto::JoinChannel>,
3001    session: UserSession,
3002) -> Result<()> {
3003    let channel_id = ChannelId::from_proto(request.channel_id);
3004    join_channel_internal(channel_id, Box::new(response), session).await
3005}
3006
3007trait JoinChannelInternalResponse {
3008    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3009}
3010impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3011    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3012        Response::<proto::JoinChannel>::send(self, result)
3013    }
3014}
3015impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3016    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3017        Response::<proto::JoinRoom>::send(self, result)
3018    }
3019}
3020
3021async fn join_channel_internal(
3022    channel_id: ChannelId,
3023    response: Box<impl JoinChannelInternalResponse>,
3024    session: UserSession,
3025) -> Result<()> {
3026    let joined_room = {
3027        let mut db = session.db().await;
3028        // If zed quits without leaving the room, and the user re-opens zed before the
3029        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3030        // room they were in.
3031        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3032            tracing::info!(
3033                stale_connection_id = %connection,
3034                "cleaning up stale connection",
3035            );
3036            drop(db);
3037            leave_room_for_session(&session, connection).await?;
3038            db = session.db().await;
3039        }
3040
3041        let (joined_room, membership_updated, role) = db
3042            .join_channel(channel_id, session.user_id(), session.connection_id)
3043            .await?;
3044
3045        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3046            let (can_publish, token) = if role == ChannelRole::Guest {
3047                (
3048                    false,
3049                    live_kit
3050                        .guest_token(
3051                            &joined_room.room.live_kit_room,
3052                            &session.user_id().to_string(),
3053                        )
3054                        .trace_err()?,
3055                )
3056            } else {
3057                (
3058                    true,
3059                    live_kit
3060                        .room_token(
3061                            &joined_room.room.live_kit_room,
3062                            &session.user_id().to_string(),
3063                        )
3064                        .trace_err()?,
3065                )
3066            };
3067
3068            Some(LiveKitConnectionInfo {
3069                server_url: live_kit.url().into(),
3070                token,
3071                can_publish,
3072            })
3073        });
3074
3075        response.send(proto::JoinRoomResponse {
3076            room: Some(joined_room.room.clone()),
3077            channel_id: joined_room
3078                .channel
3079                .as_ref()
3080                .map(|channel| channel.id.to_proto()),
3081            live_kit_connection_info,
3082        })?;
3083
3084        let mut connection_pool = session.connection_pool().await;
3085        if let Some(membership_updated) = membership_updated {
3086            notify_membership_updated(
3087                &mut connection_pool,
3088                membership_updated,
3089                session.user_id(),
3090                &session.peer,
3091            );
3092        }
3093
3094        room_updated(&joined_room.room, &session.peer);
3095
3096        joined_room
3097    };
3098
3099    channel_updated(
3100        &joined_room
3101            .channel
3102            .ok_or_else(|| anyhow!("channel not returned"))?,
3103        &joined_room.room,
3104        &session.peer,
3105        &*session.connection_pool().await,
3106    );
3107
3108    update_user_contacts(session.user_id(), &session).await?;
3109    Ok(())
3110}
3111
3112/// Start editing the channel notes
3113async fn join_channel_buffer(
3114    request: proto::JoinChannelBuffer,
3115    response: Response<proto::JoinChannelBuffer>,
3116    session: UserSession,
3117) -> Result<()> {
3118    let db = session.db().await;
3119    let channel_id = ChannelId::from_proto(request.channel_id);
3120
3121    let open_response = db
3122        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3123        .await?;
3124
3125    let collaborators = open_response.collaborators.clone();
3126    response.send(open_response)?;
3127
3128    let update = UpdateChannelBufferCollaborators {
3129        channel_id: channel_id.to_proto(),
3130        collaborators: collaborators.clone(),
3131    };
3132    channel_buffer_updated(
3133        session.connection_id,
3134        collaborators
3135            .iter()
3136            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3137        &update,
3138        &session.peer,
3139    );
3140
3141    Ok(())
3142}
3143
3144/// Edit the channel notes
3145async fn update_channel_buffer(
3146    request: proto::UpdateChannelBuffer,
3147    session: UserSession,
3148) -> Result<()> {
3149    let db = session.db().await;
3150    let channel_id = ChannelId::from_proto(request.channel_id);
3151
3152    let (collaborators, non_collaborators, epoch, version) = db
3153        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3154        .await?;
3155
3156    channel_buffer_updated(
3157        session.connection_id,
3158        collaborators,
3159        &proto::UpdateChannelBuffer {
3160            channel_id: channel_id.to_proto(),
3161            operations: request.operations,
3162        },
3163        &session.peer,
3164    );
3165
3166    let pool = &*session.connection_pool().await;
3167
3168    broadcast(
3169        None,
3170        non_collaborators
3171            .iter()
3172            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3173        |peer_id| {
3174            session.peer.send(
3175                peer_id,
3176                proto::UpdateChannels {
3177                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3178                        channel_id: channel_id.to_proto(),
3179                        epoch: epoch as u64,
3180                        version: version.clone(),
3181                    }],
3182                    ..Default::default()
3183                },
3184            )
3185        },
3186    );
3187
3188    Ok(())
3189}
3190
3191/// Rejoin the channel notes after a connection blip
3192async fn rejoin_channel_buffers(
3193    request: proto::RejoinChannelBuffers,
3194    response: Response<proto::RejoinChannelBuffers>,
3195    session: UserSession,
3196) -> Result<()> {
3197    let db = session.db().await;
3198    let buffers = db
3199        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3200        .await?;
3201
3202    for rejoined_buffer in &buffers {
3203        let collaborators_to_notify = rejoined_buffer
3204            .buffer
3205            .collaborators
3206            .iter()
3207            .filter_map(|c| Some(c.peer_id?.into()));
3208        channel_buffer_updated(
3209            session.connection_id,
3210            collaborators_to_notify,
3211            &proto::UpdateChannelBufferCollaborators {
3212                channel_id: rejoined_buffer.buffer.channel_id,
3213                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3214            },
3215            &session.peer,
3216        );
3217    }
3218
3219    response.send(proto::RejoinChannelBuffersResponse {
3220        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3221    })?;
3222
3223    Ok(())
3224}
3225
3226/// Stop editing the channel notes
3227async fn leave_channel_buffer(
3228    request: proto::LeaveChannelBuffer,
3229    response: Response<proto::LeaveChannelBuffer>,
3230    session: UserSession,
3231) -> Result<()> {
3232    let db = session.db().await;
3233    let channel_id = ChannelId::from_proto(request.channel_id);
3234
3235    let left_buffer = db
3236        .leave_channel_buffer(channel_id, session.connection_id)
3237        .await?;
3238
3239    response.send(Ack {})?;
3240
3241    channel_buffer_updated(
3242        session.connection_id,
3243        left_buffer.connections,
3244        &proto::UpdateChannelBufferCollaborators {
3245            channel_id: channel_id.to_proto(),
3246            collaborators: left_buffer.collaborators,
3247        },
3248        &session.peer,
3249    );
3250
3251    Ok(())
3252}
3253
3254fn channel_buffer_updated<T: EnvelopedMessage>(
3255    sender_id: ConnectionId,
3256    collaborators: impl IntoIterator<Item = ConnectionId>,
3257    message: &T,
3258    peer: &Peer,
3259) {
3260    broadcast(Some(sender_id), collaborators, |peer_id| {
3261        peer.send(peer_id, message.clone())
3262    });
3263}
3264
3265fn send_notifications(
3266    connection_pool: &ConnectionPool,
3267    peer: &Peer,
3268    notifications: db::NotificationBatch,
3269) {
3270    for (user_id, notification) in notifications {
3271        for connection_id in connection_pool.user_connection_ids(user_id) {
3272            if let Err(error) = peer.send(
3273                connection_id,
3274                proto::AddNotification {
3275                    notification: Some(notification.clone()),
3276                },
3277            ) {
3278                tracing::error!(
3279                    "failed to send notification to {:?} {}",
3280                    connection_id,
3281                    error
3282                );
3283            }
3284        }
3285    }
3286}
3287
3288/// Send a message to the channel
3289async fn send_channel_message(
3290    request: proto::SendChannelMessage,
3291    response: Response<proto::SendChannelMessage>,
3292    session: UserSession,
3293) -> Result<()> {
3294    // Validate the message body.
3295    let body = request.body.trim().to_string();
3296    if body.len() > MAX_MESSAGE_LEN {
3297        return Err(anyhow!("message is too long"))?;
3298    }
3299    if body.is_empty() {
3300        return Err(anyhow!("message can't be blank"))?;
3301    }
3302
3303    // TODO: adjust mentions if body is trimmed
3304
3305    let timestamp = OffsetDateTime::now_utc();
3306    let nonce = request
3307        .nonce
3308        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3309
3310    let channel_id = ChannelId::from_proto(request.channel_id);
3311    let CreatedChannelMessage {
3312        message_id,
3313        participant_connection_ids,
3314        channel_members,
3315        notifications,
3316    } = session
3317        .db()
3318        .await
3319        .create_channel_message(
3320            channel_id,
3321            session.user_id(),
3322            &body,
3323            &request.mentions,
3324            timestamp,
3325            nonce.clone().into(),
3326            match request.reply_to_message_id {
3327                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3328                None => None,
3329            },
3330        )
3331        .await?;
3332
3333    let message = proto::ChannelMessage {
3334        sender_id: session.user_id().to_proto(),
3335        id: message_id.to_proto(),
3336        body,
3337        mentions: request.mentions,
3338        timestamp: timestamp.unix_timestamp() as u64,
3339        nonce: Some(nonce),
3340        reply_to_message_id: request.reply_to_message_id,
3341        edited_at: None,
3342    };
3343    broadcast(
3344        Some(session.connection_id),
3345        participant_connection_ids,
3346        |connection| {
3347            session.peer.send(
3348                connection,
3349                proto::ChannelMessageSent {
3350                    channel_id: channel_id.to_proto(),
3351                    message: Some(message.clone()),
3352                },
3353            )
3354        },
3355    );
3356    response.send(proto::SendChannelMessageResponse {
3357        message: Some(message),
3358    })?;
3359
3360    let pool = &*session.connection_pool().await;
3361    broadcast(
3362        None,
3363        channel_members
3364            .iter()
3365            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3366        |peer_id| {
3367            session.peer.send(
3368                peer_id,
3369                proto::UpdateChannels {
3370                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3371                        channel_id: channel_id.to_proto(),
3372                        message_id: message_id.to_proto(),
3373                    }],
3374                    ..Default::default()
3375                },
3376            )
3377        },
3378    );
3379    send_notifications(pool, &session.peer, notifications);
3380
3381    Ok(())
3382}
3383
3384/// Delete a channel message
3385async fn remove_channel_message(
3386    request: proto::RemoveChannelMessage,
3387    response: Response<proto::RemoveChannelMessage>,
3388    session: UserSession,
3389) -> Result<()> {
3390    let channel_id = ChannelId::from_proto(request.channel_id);
3391    let message_id = MessageId::from_proto(request.message_id);
3392    let (connection_ids, existing_notification_ids) = session
3393        .db()
3394        .await
3395        .remove_channel_message(channel_id, message_id, session.user_id())
3396        .await?;
3397
3398    broadcast(
3399        Some(session.connection_id),
3400        connection_ids,
3401        move |connection| {
3402            session.peer.send(connection, request.clone())?;
3403
3404            for notification_id in &existing_notification_ids {
3405                session.peer.send(
3406                    connection,
3407                    proto::DeleteNotification {
3408                        notification_id: (*notification_id).to_proto(),
3409                    },
3410                )?;
3411            }
3412
3413            Ok(())
3414        },
3415    );
3416    response.send(proto::Ack {})?;
3417    Ok(())
3418}
3419
3420async fn update_channel_message(
3421    request: proto::UpdateChannelMessage,
3422    response: Response<proto::UpdateChannelMessage>,
3423    session: UserSession,
3424) -> Result<()> {
3425    let channel_id = ChannelId::from_proto(request.channel_id);
3426    let message_id = MessageId::from_proto(request.message_id);
3427    let updated_at = OffsetDateTime::now_utc();
3428    let UpdatedChannelMessage {
3429        message_id,
3430        participant_connection_ids,
3431        notifications,
3432        reply_to_message_id,
3433        timestamp,
3434        deleted_mention_notification_ids,
3435        updated_mention_notifications,
3436    } = session
3437        .db()
3438        .await
3439        .update_channel_message(
3440            channel_id,
3441            message_id,
3442            session.user_id(),
3443            request.body.as_str(),
3444            &request.mentions,
3445            updated_at,
3446        )
3447        .await?;
3448
3449    let nonce = request
3450        .nonce
3451        .clone()
3452        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3453
3454    let message = proto::ChannelMessage {
3455        sender_id: session.user_id().to_proto(),
3456        id: message_id.to_proto(),
3457        body: request.body.clone(),
3458        mentions: request.mentions.clone(),
3459        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3460        nonce: Some(nonce),
3461        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3462        edited_at: Some(updated_at.unix_timestamp() as u64),
3463    };
3464
3465    response.send(proto::Ack {})?;
3466
3467    let pool = &*session.connection_pool().await;
3468    broadcast(
3469        Some(session.connection_id),
3470        participant_connection_ids,
3471        |connection| {
3472            session.peer.send(
3473                connection,
3474                proto::ChannelMessageUpdate {
3475                    channel_id: channel_id.to_proto(),
3476                    message: Some(message.clone()),
3477                },
3478            )?;
3479
3480            for notification_id in &deleted_mention_notification_ids {
3481                session.peer.send(
3482                    connection,
3483                    proto::DeleteNotification {
3484                        notification_id: (*notification_id).to_proto(),
3485                    },
3486                )?;
3487            }
3488
3489            for notification in &updated_mention_notifications {
3490                session.peer.send(
3491                    connection,
3492                    proto::UpdateNotification {
3493                        notification: Some(notification.clone()),
3494                    },
3495                )?;
3496            }
3497
3498            Ok(())
3499        },
3500    );
3501
3502    send_notifications(pool, &session.peer, notifications);
3503
3504    Ok(())
3505}
3506
3507/// Mark a channel message as read
3508async fn acknowledge_channel_message(
3509    request: proto::AckChannelMessage,
3510    session: UserSession,
3511) -> Result<()> {
3512    let channel_id = ChannelId::from_proto(request.channel_id);
3513    let message_id = MessageId::from_proto(request.message_id);
3514    let notifications = session
3515        .db()
3516        .await
3517        .observe_channel_message(channel_id, session.user_id(), message_id)
3518        .await?;
3519    send_notifications(
3520        &*session.connection_pool().await,
3521        &session.peer,
3522        notifications,
3523    );
3524    Ok(())
3525}
3526
3527/// Mark a buffer version as synced
3528async fn acknowledge_buffer_version(
3529    request: proto::AckBufferOperation,
3530    session: UserSession,
3531) -> Result<()> {
3532    let buffer_id = BufferId::from_proto(request.buffer_id);
3533    session
3534        .db()
3535        .await
3536        .observe_buffer_version(
3537            buffer_id,
3538            session.user_id(),
3539            request.epoch as i32,
3540            &request.version,
3541        )
3542        .await?;
3543    Ok(())
3544}
3545
3546struct CompleteWithLanguageModelRateLimit;
3547
3548impl RateLimit for CompleteWithLanguageModelRateLimit {
3549    fn capacity() -> usize {
3550        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3551            .ok()
3552            .and_then(|v| v.parse().ok())
3553            .unwrap_or(120) // Picked arbitrarily
3554    }
3555
3556    fn refill_duration() -> chrono::Duration {
3557        chrono::Duration::hours(1)
3558    }
3559
3560    fn db_name() -> &'static str {
3561        "complete-with-language-model"
3562    }
3563}
3564
3565async fn complete_with_language_model(
3566    request: proto::CompleteWithLanguageModel,
3567    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3568    session: Session,
3569    open_ai_api_key: Option<Arc<str>>,
3570    google_ai_api_key: Option<Arc<str>>,
3571    anthropic_api_key: Option<Arc<str>>,
3572) -> Result<()> {
3573    let Some(session) = session.for_user() else {
3574        return Err(anyhow!("user not found"))?;
3575    };
3576    authorize_access_to_language_models(&session).await?;
3577    session
3578        .rate_limiter
3579        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3580        .await?;
3581
3582    if request.model.starts_with("gpt") {
3583        let api_key =
3584            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3585        complete_with_open_ai(request, response, session, api_key).await?;
3586    } else if request.model.starts_with("gemini") {
3587        let api_key = google_ai_api_key
3588            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3589        complete_with_google_ai(request, response, session, api_key).await?;
3590    } else if request.model.starts_with("claude") {
3591        let api_key = anthropic_api_key
3592            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
3593        complete_with_anthropic(request, response, session, api_key).await?;
3594    }
3595
3596    Ok(())
3597}
3598
3599async fn complete_with_open_ai(
3600    request: proto::CompleteWithLanguageModel,
3601    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3602    session: UserSession,
3603    api_key: Arc<str>,
3604) -> Result<()> {
3605    const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3606
3607    let mut completion_stream = open_ai::stream_completion(
3608        &session.http_client,
3609        OPEN_AI_API_URL,
3610        &api_key,
3611        crate::ai::language_model_request_to_open_ai(request)?,
3612    )
3613    .await
3614    .context("open_ai::stream_completion request failed")?;
3615
3616    while let Some(event) = completion_stream.next().await {
3617        let event = event?;
3618        response.send(proto::LanguageModelResponse {
3619            choices: event
3620                .choices
3621                .into_iter()
3622                .map(|choice| proto::LanguageModelChoiceDelta {
3623                    index: choice.index,
3624                    delta: Some(proto::LanguageModelResponseMessage {
3625                        role: choice.delta.role.map(|role| match role {
3626                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3627                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3628                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3629                        } as i32),
3630                        content: choice.delta.content,
3631                    }),
3632                    finish_reason: choice.finish_reason,
3633                })
3634                .collect(),
3635        })?;
3636    }
3637
3638    Ok(())
3639}
3640
3641async fn complete_with_google_ai(
3642    request: proto::CompleteWithLanguageModel,
3643    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3644    session: UserSession,
3645    api_key: Arc<str>,
3646) -> Result<()> {
3647    let mut stream = google_ai::stream_generate_content(
3648        &session.http_client,
3649        google_ai::API_URL,
3650        api_key.as_ref(),
3651        crate::ai::language_model_request_to_google_ai(request)?,
3652    )
3653    .await
3654    .context("google_ai::stream_generate_content request failed")?;
3655
3656    while let Some(event) = stream.next().await {
3657        let event = event?;
3658        response.send(proto::LanguageModelResponse {
3659            choices: event
3660                .candidates
3661                .unwrap_or_default()
3662                .into_iter()
3663                .map(|candidate| proto::LanguageModelChoiceDelta {
3664                    index: candidate.index as u32,
3665                    delta: Some(proto::LanguageModelResponseMessage {
3666                        role: Some(match candidate.content.role {
3667                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3668                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3669                        } as i32),
3670                        content: Some(
3671                            candidate
3672                                .content
3673                                .parts
3674                                .into_iter()
3675                                .filter_map(|part| match part {
3676                                    google_ai::Part::TextPart(part) => Some(part.text),
3677                                    google_ai::Part::InlineDataPart(_) => None,
3678                                })
3679                                .collect(),
3680                        ),
3681                    }),
3682                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3683                })
3684                .collect(),
3685        })?;
3686    }
3687
3688    Ok(())
3689}
3690
3691async fn complete_with_anthropic(
3692    request: proto::CompleteWithLanguageModel,
3693    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3694    session: UserSession,
3695    api_key: Arc<str>,
3696) -> Result<()> {
3697    let model = anthropic::Model::from_id(&request.model)?;
3698
3699    let mut system_message = String::new();
3700    let messages = request
3701        .messages
3702        .into_iter()
3703        .filter_map(|message| match message.role() {
3704            LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
3705                role: anthropic::Role::User,
3706                content: message.content,
3707            }),
3708            LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
3709                role: anthropic::Role::Assistant,
3710                content: message.content,
3711            }),
3712            // Anthropic's API breaks system instructions out as a separate field rather
3713            // than having a system message role.
3714            LanguageModelRole::LanguageModelSystem => {
3715                if !system_message.is_empty() {
3716                    system_message.push_str("\n\n");
3717                }
3718                system_message.push_str(&message.content);
3719
3720                None
3721            }
3722        })
3723        .collect();
3724
3725    let mut stream = anthropic::stream_completion(
3726        &session.http_client,
3727        "https://api.anthropic.com",
3728        &api_key,
3729        anthropic::Request {
3730            model,
3731            messages,
3732            stream: true,
3733            system: system_message,
3734            max_tokens: 4092,
3735        },
3736    )
3737    .await?;
3738
3739    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
3740
3741    while let Some(event) = stream.next().await {
3742        let event = event?;
3743
3744        match event {
3745            anthropic::ResponseEvent::MessageStart { message } => {
3746                if let Some(role) = message.role {
3747                    if role == "assistant" {
3748                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
3749                    } else if role == "user" {
3750                        current_role = proto::LanguageModelRole::LanguageModelUser;
3751                    }
3752                }
3753            }
3754            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
3755                match content_block {
3756                    anthropic::ContentBlock::Text { text } => {
3757                        if !text.is_empty() {
3758                            response.send(proto::LanguageModelResponse {
3759                                choices: vec![proto::LanguageModelChoiceDelta {
3760                                    index: 0,
3761                                    delta: Some(proto::LanguageModelResponseMessage {
3762                                        role: Some(current_role as i32),
3763                                        content: Some(text),
3764                                    }),
3765                                    finish_reason: None,
3766                                }],
3767                            })?;
3768                        }
3769                    }
3770                }
3771            }
3772            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
3773                anthropic::TextDelta::TextDelta { text } => {
3774                    response.send(proto::LanguageModelResponse {
3775                        choices: vec![proto::LanguageModelChoiceDelta {
3776                            index: 0,
3777                            delta: Some(proto::LanguageModelResponseMessage {
3778                                role: Some(current_role as i32),
3779                                content: Some(text),
3780                            }),
3781                            finish_reason: None,
3782                        }],
3783                    })?;
3784                }
3785            },
3786            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
3787                if let Some(stop_reason) = delta.stop_reason {
3788                    response.send(proto::LanguageModelResponse {
3789                        choices: vec![proto::LanguageModelChoiceDelta {
3790                            index: 0,
3791                            delta: None,
3792                            finish_reason: Some(stop_reason),
3793                        }],
3794                    })?;
3795                }
3796            }
3797            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
3798            anthropic::ResponseEvent::MessageStop {} => {}
3799            anthropic::ResponseEvent::Ping {} => {}
3800        }
3801    }
3802
3803    Ok(())
3804}
3805
3806struct CountTokensWithLanguageModelRateLimit;
3807
3808impl RateLimit for CountTokensWithLanguageModelRateLimit {
3809    fn capacity() -> usize {
3810        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3811            .ok()
3812            .and_then(|v| v.parse().ok())
3813            .unwrap_or(600) // Picked arbitrarily
3814    }
3815
3816    fn refill_duration() -> chrono::Duration {
3817        chrono::Duration::hours(1)
3818    }
3819
3820    fn db_name() -> &'static str {
3821        "count-tokens-with-language-model"
3822    }
3823}
3824
3825async fn count_tokens_with_language_model(
3826    request: proto::CountTokensWithLanguageModel,
3827    response: Response<proto::CountTokensWithLanguageModel>,
3828    session: UserSession,
3829    google_ai_api_key: Option<Arc<str>>,
3830) -> Result<()> {
3831    authorize_access_to_language_models(&session).await?;
3832
3833    if !request.model.starts_with("gemini") {
3834        return Err(anyhow!(
3835            "counting tokens for model: {:?} is not supported",
3836            request.model
3837        ))?;
3838    }
3839
3840    session
3841        .rate_limiter
3842        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3843        .await?;
3844
3845    let api_key = google_ai_api_key
3846        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3847    let tokens_response = google_ai::count_tokens(
3848        &session.http_client,
3849        google_ai::API_URL,
3850        &api_key,
3851        crate::ai::count_tokens_request_to_google_ai(request)?,
3852    )
3853    .await?;
3854    response.send(proto::CountTokensResponse {
3855        token_count: tokens_response.total_tokens as u32,
3856    })?;
3857    Ok(())
3858}
3859
3860async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3861    let db = session.db().await;
3862    let flags = db.get_user_flags(session.user_id()).await?;
3863    if flags.iter().any(|flag| flag == "language-models") {
3864        Ok(())
3865    } else {
3866        Err(anyhow!("permission denied"))?
3867    }
3868}
3869
3870/// Start receiving chat updates for a channel
3871async fn join_channel_chat(
3872    request: proto::JoinChannelChat,
3873    response: Response<proto::JoinChannelChat>,
3874    session: UserSession,
3875) -> Result<()> {
3876    let channel_id = ChannelId::from_proto(request.channel_id);
3877
3878    let db = session.db().await;
3879    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3880        .await?;
3881    let messages = db
3882        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3883        .await?;
3884    response.send(proto::JoinChannelChatResponse {
3885        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3886        messages,
3887    })?;
3888    Ok(())
3889}
3890
3891/// Stop receiving chat updates for a channel
3892async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3893    let channel_id = ChannelId::from_proto(request.channel_id);
3894    session
3895        .db()
3896        .await
3897        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3898        .await?;
3899    Ok(())
3900}
3901
3902/// Retrieve the chat history for a channel
3903async fn get_channel_messages(
3904    request: proto::GetChannelMessages,
3905    response: Response<proto::GetChannelMessages>,
3906    session: UserSession,
3907) -> Result<()> {
3908    let channel_id = ChannelId::from_proto(request.channel_id);
3909    let messages = session
3910        .db()
3911        .await
3912        .get_channel_messages(
3913            channel_id,
3914            session.user_id(),
3915            MESSAGE_COUNT_PER_PAGE,
3916            Some(MessageId::from_proto(request.before_message_id)),
3917        )
3918        .await?;
3919    response.send(proto::GetChannelMessagesResponse {
3920        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3921        messages,
3922    })?;
3923    Ok(())
3924}
3925
3926/// Retrieve specific chat messages
3927async fn get_channel_messages_by_id(
3928    request: proto::GetChannelMessagesById,
3929    response: Response<proto::GetChannelMessagesById>,
3930    session: UserSession,
3931) -> Result<()> {
3932    let message_ids = request
3933        .message_ids
3934        .iter()
3935        .map(|id| MessageId::from_proto(*id))
3936        .collect::<Vec<_>>();
3937    let messages = session
3938        .db()
3939        .await
3940        .get_channel_messages_by_id(session.user_id(), &message_ids)
3941        .await?;
3942    response.send(proto::GetChannelMessagesResponse {
3943        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3944        messages,
3945    })?;
3946    Ok(())
3947}
3948
3949/// Retrieve the current users notifications
3950async fn get_notifications(
3951    request: proto::GetNotifications,
3952    response: Response<proto::GetNotifications>,
3953    session: UserSession,
3954) -> Result<()> {
3955    let notifications = session
3956        .db()
3957        .await
3958        .get_notifications(
3959            session.user_id(),
3960            NOTIFICATION_COUNT_PER_PAGE,
3961            request
3962                .before_id
3963                .map(|id| db::NotificationId::from_proto(id)),
3964        )
3965        .await?;
3966    response.send(proto::GetNotificationsResponse {
3967        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3968        notifications,
3969    })?;
3970    Ok(())
3971}
3972
3973/// Mark notifications as read
3974async fn mark_notification_as_read(
3975    request: proto::MarkNotificationRead,
3976    response: Response<proto::MarkNotificationRead>,
3977    session: UserSession,
3978) -> Result<()> {
3979    let database = &session.db().await;
3980    let notifications = database
3981        .mark_notification_as_read_by_id(
3982            session.user_id(),
3983            NotificationId::from_proto(request.notification_id),
3984        )
3985        .await?;
3986    send_notifications(
3987        &*session.connection_pool().await,
3988        &session.peer,
3989        notifications,
3990    );
3991    response.send(proto::Ack {})?;
3992    Ok(())
3993}
3994
3995/// Get the current users information
3996async fn get_private_user_info(
3997    _request: proto::GetPrivateUserInfo,
3998    response: Response<proto::GetPrivateUserInfo>,
3999    session: UserSession,
4000) -> Result<()> {
4001    let db = session.db().await;
4002
4003    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4004    let user = db
4005        .get_user_by_id(session.user_id())
4006        .await?
4007        .ok_or_else(|| anyhow!("user not found"))?;
4008    let flags = db.get_user_flags(session.user_id()).await?;
4009
4010    response.send(proto::GetPrivateUserInfoResponse {
4011        metrics_id,
4012        staff: user.admin,
4013        flags,
4014    })?;
4015    Ok(())
4016}
4017
4018fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4019    match message {
4020        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4021        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4022        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4023        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4024        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4025            code: frame.code.into(),
4026            reason: frame.reason,
4027        })),
4028    }
4029}
4030
4031fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4032    match message {
4033        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4034        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4035        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4036        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4037        AxumMessage::Close(frame) => {
4038            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4039                code: frame.code.into(),
4040                reason: frame.reason,
4041            }))
4042        }
4043    }
4044}
4045
4046fn notify_membership_updated(
4047    connection_pool: &mut ConnectionPool,
4048    result: MembershipUpdated,
4049    user_id: UserId,
4050    peer: &Peer,
4051) {
4052    for membership in &result.new_channels.channel_memberships {
4053        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4054    }
4055    for channel_id in &result.removed_channels {
4056        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4057    }
4058
4059    let user_channels_update = proto::UpdateUserChannels {
4060        channel_memberships: result
4061            .new_channels
4062            .channel_memberships
4063            .iter()
4064            .map(|cm| proto::ChannelMembership {
4065                channel_id: cm.channel_id.to_proto(),
4066                role: cm.role.into(),
4067            })
4068            .collect(),
4069        ..Default::default()
4070    };
4071
4072    let mut update = build_channels_update(result.new_channels, vec![]);
4073    update.delete_channels = result
4074        .removed_channels
4075        .into_iter()
4076        .map(|id| id.to_proto())
4077        .collect();
4078    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4079
4080    for connection_id in connection_pool.user_connection_ids(user_id) {
4081        peer.send(connection_id, user_channels_update.clone())
4082            .trace_err();
4083        peer.send(connection_id, update.clone()).trace_err();
4084    }
4085}
4086
4087fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4088    proto::UpdateUserChannels {
4089        channel_memberships: channels
4090            .channel_memberships
4091            .iter()
4092            .map(|m| proto::ChannelMembership {
4093                channel_id: m.channel_id.to_proto(),
4094                role: m.role.into(),
4095            })
4096            .collect(),
4097        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4098        observed_channel_message_id: channels.observed_channel_messages.clone(),
4099    }
4100}
4101
4102fn build_channels_update(
4103    channels: ChannelsForUser,
4104    channel_invites: Vec<db::Channel>,
4105) -> proto::UpdateChannels {
4106    let mut update = proto::UpdateChannels::default();
4107
4108    for channel in channels.channels {
4109        update.channels.push(channel.to_proto());
4110    }
4111
4112    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4113    update.latest_channel_message_ids = channels.latest_channel_messages;
4114
4115    for (channel_id, participants) in channels.channel_participants {
4116        update
4117            .channel_participants
4118            .push(proto::ChannelParticipants {
4119                channel_id: channel_id.to_proto(),
4120                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4121            });
4122    }
4123
4124    for channel in channel_invites {
4125        update.channel_invitations.push(channel.to_proto());
4126    }
4127    for project in channels.hosted_projects {
4128        update.hosted_projects.push(project);
4129    }
4130
4131    update
4132}
4133
4134fn build_initial_contacts_update(
4135    contacts: Vec<db::Contact>,
4136    pool: &ConnectionPool,
4137) -> proto::UpdateContacts {
4138    let mut update = proto::UpdateContacts::default();
4139
4140    for contact in contacts {
4141        match contact {
4142            db::Contact::Accepted { user_id, busy } => {
4143                update.contacts.push(contact_for_user(user_id, busy, &pool));
4144            }
4145            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4146            db::Contact::Incoming { user_id } => {
4147                update
4148                    .incoming_requests
4149                    .push(proto::IncomingContactRequest {
4150                        requester_id: user_id.to_proto(),
4151                    })
4152            }
4153        }
4154    }
4155
4156    update
4157}
4158
4159fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4160    proto::Contact {
4161        user_id: user_id.to_proto(),
4162        online: pool.is_user_online(user_id),
4163        busy,
4164    }
4165}
4166
4167fn room_updated(room: &proto::Room, peer: &Peer) {
4168    broadcast(
4169        None,
4170        room.participants
4171            .iter()
4172            .filter_map(|participant| Some(participant.peer_id?.into())),
4173        |peer_id| {
4174            peer.send(
4175                peer_id,
4176                proto::RoomUpdated {
4177                    room: Some(room.clone()),
4178                },
4179            )
4180        },
4181    );
4182}
4183
4184fn channel_updated(
4185    channel: &db::channel::Model,
4186    room: &proto::Room,
4187    peer: &Peer,
4188    pool: &ConnectionPool,
4189) {
4190    let participants = room
4191        .participants
4192        .iter()
4193        .map(|p| p.user_id)
4194        .collect::<Vec<_>>();
4195
4196    broadcast(
4197        None,
4198        pool.channel_connection_ids(channel.root_id())
4199            .filter_map(|(channel_id, role)| {
4200                role.can_see_channel(channel.visibility).then(|| channel_id)
4201            }),
4202        |peer_id| {
4203            peer.send(
4204                peer_id,
4205                proto::UpdateChannels {
4206                    channel_participants: vec![proto::ChannelParticipants {
4207                        channel_id: channel.id.to_proto(),
4208                        participant_user_ids: participants.clone(),
4209                    }],
4210                    ..Default::default()
4211                },
4212            )
4213        },
4214    );
4215}
4216
4217async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4218    let db = session.db().await;
4219
4220    let contacts = db.get_contacts(user_id).await?;
4221    let busy = db.is_user_busy(user_id).await?;
4222
4223    let pool = session.connection_pool().await;
4224    let updated_contact = contact_for_user(user_id, busy, &pool);
4225    for contact in contacts {
4226        if let db::Contact::Accepted {
4227            user_id: contact_user_id,
4228            ..
4229        } = contact
4230        {
4231            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4232                session
4233                    .peer
4234                    .send(
4235                        contact_conn_id,
4236                        proto::UpdateContacts {
4237                            contacts: vec![updated_contact.clone()],
4238                            remove_contacts: Default::default(),
4239                            incoming_requests: Default::default(),
4240                            remove_incoming_requests: Default::default(),
4241                            outgoing_requests: Default::default(),
4242                            remove_outgoing_requests: Default::default(),
4243                        },
4244                    )
4245                    .trace_err();
4246            }
4247        }
4248    }
4249    Ok(())
4250}
4251
4252async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4253    let mut contacts_to_update = HashSet::default();
4254
4255    let room_id;
4256    let canceled_calls_to_user_ids;
4257    let live_kit_room;
4258    let delete_live_kit_room;
4259    let room;
4260    let channel;
4261
4262    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4263        contacts_to_update.insert(session.user_id());
4264
4265        for project in left_room.left_projects.values() {
4266            project_left(project, session);
4267        }
4268
4269        room_id = RoomId::from_proto(left_room.room.id);
4270        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4271        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4272        delete_live_kit_room = left_room.deleted;
4273        room = mem::take(&mut left_room.room);
4274        channel = mem::take(&mut left_room.channel);
4275
4276        room_updated(&room, &session.peer);
4277    } else {
4278        return Ok(());
4279    }
4280
4281    if let Some(channel) = channel {
4282        channel_updated(
4283            &channel,
4284            &room,
4285            &session.peer,
4286            &*session.connection_pool().await,
4287        );
4288    }
4289
4290    {
4291        let pool = session.connection_pool().await;
4292        for canceled_user_id in canceled_calls_to_user_ids {
4293            for connection_id in pool.user_connection_ids(canceled_user_id) {
4294                session
4295                    .peer
4296                    .send(
4297                        connection_id,
4298                        proto::CallCanceled {
4299                            room_id: room_id.to_proto(),
4300                        },
4301                    )
4302                    .trace_err();
4303            }
4304            contacts_to_update.insert(canceled_user_id);
4305        }
4306    }
4307
4308    for contact_user_id in contacts_to_update {
4309        update_user_contacts(contact_user_id, &session).await?;
4310    }
4311
4312    if let Some(live_kit) = session.live_kit_client.as_ref() {
4313        live_kit
4314            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4315            .await
4316            .trace_err();
4317
4318        if delete_live_kit_room {
4319            live_kit.delete_room(live_kit_room).await.trace_err();
4320        }
4321    }
4322
4323    Ok(())
4324}
4325
4326async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4327    let left_channel_buffers = session
4328        .db()
4329        .await
4330        .leave_channel_buffers(session.connection_id)
4331        .await?;
4332
4333    for left_buffer in left_channel_buffers {
4334        channel_buffer_updated(
4335            session.connection_id,
4336            left_buffer.connections,
4337            &proto::UpdateChannelBufferCollaborators {
4338                channel_id: left_buffer.channel_id.to_proto(),
4339                collaborators: left_buffer.collaborators,
4340            },
4341            &session.peer,
4342        );
4343    }
4344
4345    Ok(())
4346}
4347
4348fn project_left(project: &db::LeftProject, session: &UserSession) {
4349    for connection_id in &project.connection_ids {
4350        if project.host_user_id == Some(session.user_id()) {
4351            session
4352                .peer
4353                .send(
4354                    *connection_id,
4355                    proto::UnshareProject {
4356                        project_id: project.id.to_proto(),
4357                    },
4358                )
4359                .trace_err();
4360        } else {
4361            session
4362                .peer
4363                .send(
4364                    *connection_id,
4365                    proto::RemoveProjectCollaborator {
4366                        project_id: project.id.to_proto(),
4367                        peer_id: Some(session.connection_id.into()),
4368                    },
4369                )
4370                .trace_err();
4371        }
4372    }
4373}
4374
4375pub trait ResultExt {
4376    type Ok;
4377
4378    fn trace_err(self) -> Option<Self::Ok>;
4379}
4380
4381impl<T, E> ResultExt for Result<T, E>
4382where
4383    E: std::fmt::Debug,
4384{
4385    type Ok = T;
4386
4387    fn trace_err(self) -> Option<T> {
4388        match self {
4389            Ok(value) => Some(value),
4390            Err(error) => {
4391                tracing::error!("{:?}", error);
4392                None
4393            }
4394        }
4395    }
4396}