rpc.rs

   1mod connection_pool;
   2
   3use crate::{
   4    auth::{self},
   5    db::{
   6        self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
   7        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
   8        NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
   9        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  10    },
  11    executor::Executor,
  12    AppState, Error, RateLimit, RateLimiter, Result,
  13};
  14use anyhow::{anyhow, Context as _};
  15use async_tungstenite::tungstenite::{
  16    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  17};
  18use axum::{
  19    body::Body,
  20    extract::{
  21        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  22        ConnectInfo, WebSocketUpgrade,
  23    },
  24    headers::{Header, HeaderName},
  25    http::StatusCode,
  26    middleware,
  27    response::IntoResponse,
  28    routing::get,
  29    Extension, Router, TypedHeader,
  30};
  31use collections::{HashMap, HashSet};
  32pub use connection_pool::{ConnectionPool, ZedVersion};
  33use core::fmt::{self, Debug, Formatter};
  34
  35use futures::{
  36    channel::oneshot,
  37    future::{self, BoxFuture},
  38    stream::FuturesUnordered,
  39    FutureExt, SinkExt, StreamExt, TryStreamExt,
  40};
  41use prometheus::{register_int_gauge, IntGauge};
  42use rpc::{
  43    proto::{
  44        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
  45        LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  46    },
  47    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  48};
  49use semantic_version::SemanticVersion;
  50use serde::{Serialize, Serializer};
  51use std::{
  52    any::TypeId,
  53    future::Future,
  54    marker::PhantomData,
  55    mem,
  56    net::SocketAddr,
  57    ops::{Deref, DerefMut},
  58    rc::Rc,
  59    sync::{
  60        atomic::{AtomicBool, Ordering::SeqCst},
  61        Arc, OnceLock,
  62    },
  63    time::{Duration, Instant},
  64};
  65use time::OffsetDateTime;
  66use tokio::sync::{watch, Semaphore};
  67use tower::ServiceBuilder;
  68use tracing::{
  69    field::{self},
  70    info_span, instrument, Instrument,
  71};
  72use util::http::IsahcHttpClient;
  73
  74pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  75
  76// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  77pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  78
  79const MESSAGE_COUNT_PER_PAGE: usize = 100;
  80const MAX_MESSAGE_LEN: usize = 1024;
  81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  82
  83type MessageHandler =
  84    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  85
  86struct Response<R> {
  87    peer: Arc<Peer>,
  88    receipt: Receipt<R>,
  89    responded: Arc<AtomicBool>,
  90}
  91
  92impl<R: RequestMessage> Response<R> {
  93    fn send(self, payload: R::Response) -> Result<()> {
  94        self.responded.store(true, SeqCst);
  95        self.peer.respond(self.receipt, payload)?;
  96        Ok(())
  97    }
  98}
  99
 100struct StreamingResponse<R: RequestMessage> {
 101    peer: Arc<Peer>,
 102    receipt: Receipt<R>,
 103}
 104
 105impl<R: RequestMessage> StreamingResponse<R> {
 106    fn send(&self, payload: R::Response) -> Result<()> {
 107        self.peer.respond(self.receipt, payload)?;
 108        Ok(())
 109    }
 110}
 111
 112#[derive(Clone, Debug)]
 113pub enum Principal {
 114    User(User),
 115    Impersonated { user: User, admin: User },
 116    DevServer(dev_server::Model),
 117}
 118
 119impl Principal {
 120    fn update_span(&self, span: &tracing::Span) {
 121        match &self {
 122            Principal::User(user) => {
 123                span.record("user_id", &user.id.0);
 124                span.record("login", &user.github_login);
 125            }
 126            Principal::Impersonated { user, admin } => {
 127                span.record("user_id", &user.id.0);
 128                span.record("login", &user.github_login);
 129                span.record("impersonator", &admin.github_login);
 130            }
 131            Principal::DevServer(dev_server) => {
 132                span.record("dev_server_id", &dev_server.id.0);
 133            }
 134        }
 135    }
 136}
 137
 138#[derive(Clone)]
 139struct Session {
 140    principal: Principal,
 141    connection_id: ConnectionId,
 142    db: Arc<tokio::sync::Mutex<DbHandle>>,
 143    peer: Arc<Peer>,
 144    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 145    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 146    http_client: IsahcHttpClient,
 147    rate_limiter: Arc<RateLimiter>,
 148    _executor: Executor,
 149}
 150
 151impl Session {
 152    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 153        #[cfg(test)]
 154        tokio::task::yield_now().await;
 155        let guard = self.db.lock().await;
 156        #[cfg(test)]
 157        tokio::task::yield_now().await;
 158        guard
 159    }
 160
 161    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 162        #[cfg(test)]
 163        tokio::task::yield_now().await;
 164        let guard = self.connection_pool.lock();
 165        ConnectionPoolGuard {
 166            guard,
 167            _not_send: PhantomData,
 168        }
 169    }
 170
 171    fn for_user(self) -> Option<UserSession> {
 172        UserSession::new(self)
 173    }
 174
 175    fn user_id(&self) -> Option<UserId> {
 176        match &self.principal {
 177            Principal::User(user) => Some(user.id),
 178            Principal::Impersonated { user, .. } => Some(user.id),
 179            Principal::DevServer(_) => None,
 180        }
 181    }
 182}
 183
 184impl Debug for Session {
 185    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 186        let mut result = f.debug_struct("Session");
 187        match &self.principal {
 188            Principal::User(user) => {
 189                result.field("user", &user.github_login);
 190            }
 191            Principal::Impersonated { user, admin } => {
 192                result.field("user", &user.github_login);
 193                result.field("impersonator", &admin.github_login);
 194            }
 195            Principal::DevServer(dev_server) => {
 196                result.field("dev_server", &dev_server.id);
 197            }
 198        }
 199        result.field("connection_id", &self.connection_id).finish()
 200    }
 201}
 202
 203struct UserSession(Session);
 204
 205impl UserSession {
 206    pub fn new(s: Session) -> Option<Self> {
 207        s.user_id().map(|_| UserSession(s))
 208    }
 209    pub fn user_id(&self) -> UserId {
 210        self.0.user_id().unwrap()
 211    }
 212}
 213
 214impl Deref for UserSession {
 215    type Target = Session;
 216
 217    fn deref(&self) -> &Self::Target {
 218        &self.0
 219    }
 220}
 221impl DerefMut for UserSession {
 222    fn deref_mut(&mut self) -> &mut Self::Target {
 223        &mut self.0
 224    }
 225}
 226
 227fn user_handler<M: RequestMessage, Fut>(
 228    handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
 229) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
 230where
 231    Fut: Send + Future<Output = Result<()>>,
 232{
 233    let handler = Arc::new(handler);
 234    move |message, response, session| {
 235        let handler = handler.clone();
 236        Box::pin(async move {
 237            if let Some(user_session) = session.for_user() {
 238                Ok(handler(message, response, user_session).await?)
 239            } else {
 240                Err(Error::Internal(anyhow!("must be a user")))
 241            }
 242        })
 243    }
 244}
 245
 246fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
 247    handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
 248) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
 249where
 250    InnertRetFut: Send + Future<Output = Result<()>>,
 251{
 252    let handler = Arc::new(handler);
 253    move |message, session| {
 254        let handler = handler.clone();
 255        Box::pin(async move {
 256            if let Some(user_session) = session.for_user() {
 257                Ok(handler(message, user_session).await?)
 258            } else {
 259                Err(Error::Internal(anyhow!("must be a user")))
 260            }
 261        })
 262    }
 263}
 264
 265struct DbHandle(Arc<Database>);
 266
 267impl Deref for DbHandle {
 268    type Target = Database;
 269
 270    fn deref(&self) -> &Self::Target {
 271        self.0.as_ref()
 272    }
 273}
 274
 275pub struct Server {
 276    id: parking_lot::Mutex<ServerId>,
 277    peer: Arc<Peer>,
 278    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 279    app_state: Arc<AppState>,
 280    handlers: HashMap<TypeId, MessageHandler>,
 281    teardown: watch::Sender<bool>,
 282}
 283
 284pub(crate) struct ConnectionPoolGuard<'a> {
 285    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 286    _not_send: PhantomData<Rc<()>>,
 287}
 288
 289#[derive(Serialize)]
 290pub struct ServerSnapshot<'a> {
 291    peer: &'a Peer,
 292    #[serde(serialize_with = "serialize_deref")]
 293    connection_pool: ConnectionPoolGuard<'a>,
 294}
 295
 296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 297where
 298    S: Serializer,
 299    T: Deref<Target = U>,
 300    U: Serialize,
 301{
 302    Serialize::serialize(value.deref(), serializer)
 303}
 304
 305impl Server {
 306    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 307        let mut server = Self {
 308            id: parking_lot::Mutex::new(id),
 309            peer: Peer::new(id.0 as u32),
 310            app_state: app_state.clone(),
 311            connection_pool: Default::default(),
 312            handlers: Default::default(),
 313            teardown: watch::channel(false).0,
 314        };
 315
 316        server
 317            .add_request_handler(ping)
 318            .add_request_handler(user_handler(create_room))
 319            .add_request_handler(user_handler(join_room))
 320            .add_request_handler(user_handler(rejoin_room))
 321            .add_request_handler(user_handler(leave_room))
 322            .add_request_handler(user_handler(set_room_participant_role))
 323            .add_request_handler(user_handler(call))
 324            .add_request_handler(user_handler(cancel_call))
 325            .add_message_handler(user_message_handler(decline_call))
 326            .add_request_handler(user_handler(update_participant_location))
 327            .add_request_handler(share_project)
 328            .add_message_handler(unshare_project)
 329            .add_request_handler(user_handler(join_project))
 330            .add_request_handler(user_handler(join_hosted_project))
 331            .add_message_handler(user_message_handler(leave_project))
 332            .add_request_handler(update_project)
 333            .add_request_handler(update_worktree)
 334            .add_message_handler(start_language_server)
 335            .add_message_handler(update_language_server)
 336            .add_message_handler(update_diagnostic_summary)
 337            .add_message_handler(update_worktree_settings)
 338            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 339            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 340            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 341            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 342            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
 343            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 344            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 345            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 346            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 347            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 348            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 349            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 350            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 351            .add_request_handler(
 352                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 353            )
 354            .add_request_handler(
 355                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 356            )
 357            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 358            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 359            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 360            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 361            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 362            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 363            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 364            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 365            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 366            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 367            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 368            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 369            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 370            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 371            .add_message_handler(create_buffer_for_peer)
 372            .add_request_handler(update_buffer)
 373            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 374            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 375            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 376            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 377            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
 378            .add_request_handler(get_users)
 379            .add_request_handler(user_handler(fuzzy_search_users))
 380            .add_request_handler(user_handler(request_contact))
 381            .add_request_handler(user_handler(remove_contact))
 382            .add_request_handler(user_handler(respond_to_contact_request))
 383            .add_request_handler(user_handler(create_channel))
 384            .add_request_handler(user_handler(delete_channel))
 385            .add_request_handler(user_handler(invite_channel_member))
 386            .add_request_handler(user_handler(remove_channel_member))
 387            .add_request_handler(user_handler(set_channel_member_role))
 388            .add_request_handler(user_handler(set_channel_visibility))
 389            .add_request_handler(user_handler(rename_channel))
 390            .add_request_handler(user_handler(join_channel_buffer))
 391            .add_request_handler(user_handler(leave_channel_buffer))
 392            .add_message_handler(user_message_handler(update_channel_buffer))
 393            .add_request_handler(user_handler(rejoin_channel_buffers))
 394            .add_request_handler(user_handler(get_channel_members))
 395            .add_request_handler(user_handler(respond_to_channel_invite))
 396            .add_request_handler(user_handler(join_channel))
 397            .add_request_handler(user_handler(join_channel_chat))
 398            .add_message_handler(user_message_handler(leave_channel_chat))
 399            .add_request_handler(user_handler(send_channel_message))
 400            .add_request_handler(user_handler(remove_channel_message))
 401            .add_request_handler(user_handler(update_channel_message))
 402            .add_request_handler(user_handler(get_channel_messages))
 403            .add_request_handler(user_handler(get_channel_messages_by_id))
 404            .add_request_handler(user_handler(get_notifications))
 405            .add_request_handler(user_handler(mark_notification_as_read))
 406            .add_request_handler(user_handler(move_channel))
 407            .add_request_handler(user_handler(follow))
 408            .add_message_handler(user_message_handler(unfollow))
 409            .add_message_handler(user_message_handler(update_followers))
 410            .add_request_handler(user_handler(get_private_user_info))
 411            .add_message_handler(user_message_handler(acknowledge_channel_message))
 412            .add_message_handler(user_message_handler(acknowledge_buffer_version))
 413            .add_streaming_request_handler({
 414                let app_state = app_state.clone();
 415                move |request, response, session| {
 416                    complete_with_language_model(
 417                        request,
 418                        response,
 419                        session,
 420                        app_state.config.openai_api_key.clone(),
 421                        app_state.config.google_ai_api_key.clone(),
 422                        app_state.config.anthropic_api_key.clone(),
 423                    )
 424                }
 425            })
 426            .add_request_handler({
 427                let app_state = app_state.clone();
 428                user_handler(move |request, response, session| {
 429                    count_tokens_with_language_model(
 430                        request,
 431                        response,
 432                        session,
 433                        app_state.config.google_ai_api_key.clone(),
 434                    )
 435                })
 436            });
 437
 438        Arc::new(server)
 439    }
 440
 441    pub async fn start(&self) -> Result<()> {
 442        let server_id = *self.id.lock();
 443        let app_state = self.app_state.clone();
 444        let peer = self.peer.clone();
 445        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 446        let pool = self.connection_pool.clone();
 447        let live_kit_client = self.app_state.live_kit_client.clone();
 448
 449        let span = info_span!("start server");
 450        self.app_state.executor.spawn_detached(
 451            async move {
 452                tracing::info!("waiting for cleanup timeout");
 453                timeout.await;
 454                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 455                if let Some((room_ids, channel_ids)) = app_state
 456                    .db
 457                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 458                    .await
 459                    .trace_err()
 460                {
 461                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 462                    tracing::info!(
 463                        stale_channel_buffer_count = channel_ids.len(),
 464                        "retrieved stale channel buffers"
 465                    );
 466
 467                    for channel_id in channel_ids {
 468                        if let Some(refreshed_channel_buffer) = app_state
 469                            .db
 470                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 471                            .await
 472                            .trace_err()
 473                        {
 474                            for connection_id in refreshed_channel_buffer.connection_ids {
 475                                peer.send(
 476                                    connection_id,
 477                                    proto::UpdateChannelBufferCollaborators {
 478                                        channel_id: channel_id.to_proto(),
 479                                        collaborators: refreshed_channel_buffer
 480                                            .collaborators
 481                                            .clone(),
 482                                    },
 483                                )
 484                                .trace_err();
 485                            }
 486                        }
 487                    }
 488
 489                    for room_id in room_ids {
 490                        let mut contacts_to_update = HashSet::default();
 491                        let mut canceled_calls_to_user_ids = Vec::new();
 492                        let mut live_kit_room = String::new();
 493                        let mut delete_live_kit_room = false;
 494
 495                        if let Some(mut refreshed_room) = app_state
 496                            .db
 497                            .clear_stale_room_participants(room_id, server_id)
 498                            .await
 499                            .trace_err()
 500                        {
 501                            tracing::info!(
 502                                room_id = room_id.0,
 503                                new_participant_count = refreshed_room.room.participants.len(),
 504                                "refreshed room"
 505                            );
 506                            room_updated(&refreshed_room.room, &peer);
 507                            if let Some(channel) = refreshed_room.channel.as_ref() {
 508                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 509                            }
 510                            contacts_to_update
 511                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 512                            contacts_to_update
 513                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 514                            canceled_calls_to_user_ids =
 515                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 516                            live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
 517                            delete_live_kit_room = refreshed_room.room.participants.is_empty();
 518                        }
 519
 520                        {
 521                            let pool = pool.lock();
 522                            for canceled_user_id in canceled_calls_to_user_ids {
 523                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 524                                    peer.send(
 525                                        connection_id,
 526                                        proto::CallCanceled {
 527                                            room_id: room_id.to_proto(),
 528                                        },
 529                                    )
 530                                    .trace_err();
 531                                }
 532                            }
 533                        }
 534
 535                        for user_id in contacts_to_update {
 536                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 537                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 538                            if let Some((busy, contacts)) = busy.zip(contacts) {
 539                                let pool = pool.lock();
 540                                let updated_contact = contact_for_user(user_id, busy, &pool);
 541                                for contact in contacts {
 542                                    if let db::Contact::Accepted {
 543                                        user_id: contact_user_id,
 544                                        ..
 545                                    } = contact
 546                                    {
 547                                        for contact_conn_id in
 548                                            pool.user_connection_ids(contact_user_id)
 549                                        {
 550                                            peer.send(
 551                                                contact_conn_id,
 552                                                proto::UpdateContacts {
 553                                                    contacts: vec![updated_contact.clone()],
 554                                                    remove_contacts: Default::default(),
 555                                                    incoming_requests: Default::default(),
 556                                                    remove_incoming_requests: Default::default(),
 557                                                    outgoing_requests: Default::default(),
 558                                                    remove_outgoing_requests: Default::default(),
 559                                                },
 560                                            )
 561                                            .trace_err();
 562                                        }
 563                                    }
 564                                }
 565                            }
 566                        }
 567
 568                        if let Some(live_kit) = live_kit_client.as_ref() {
 569                            if delete_live_kit_room {
 570                                live_kit.delete_room(live_kit_room).await.trace_err();
 571                            }
 572                        }
 573                    }
 574                }
 575
 576                app_state
 577                    .db
 578                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 579                    .await
 580                    .trace_err();
 581            }
 582            .instrument(span),
 583        );
 584        Ok(())
 585    }
 586
 587    pub fn teardown(&self) {
 588        self.peer.teardown();
 589        self.connection_pool.lock().reset();
 590        let _ = self.teardown.send(true);
 591    }
 592
 593    #[cfg(test)]
 594    pub fn reset(&self, id: ServerId) {
 595        self.teardown();
 596        *self.id.lock() = id;
 597        self.peer.reset(id.0 as u32);
 598        let _ = self.teardown.send(false);
 599    }
 600
 601    #[cfg(test)]
 602    pub fn id(&self) -> ServerId {
 603        *self.id.lock()
 604    }
 605
 606    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 607    where
 608        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 609        Fut: 'static + Send + Future<Output = Result<()>>,
 610        M: EnvelopedMessage,
 611    {
 612        let prev_handler = self.handlers.insert(
 613            TypeId::of::<M>(),
 614            Box::new(move |envelope, session| {
 615                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 616                let received_at = envelope.received_at;
 617                    tracing::info!(
 618                        "message received"
 619                    );
 620                let start_time = Instant::now();
 621                let future = (handler)(*envelope, session);
 622                async move {
 623                    let result = future.await;
 624                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 625                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 626                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 627                    match result {
 628                        Err(error) => {
 629                            tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
 630                        }
 631                        Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
 632                    }
 633                }
 634                .boxed()
 635            }),
 636        );
 637        if prev_handler.is_some() {
 638            panic!("registered a handler for the same message twice");
 639        }
 640        self
 641    }
 642
 643    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 644    where
 645        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 646        Fut: 'static + Send + Future<Output = Result<()>>,
 647        M: EnvelopedMessage,
 648    {
 649        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 650        self
 651    }
 652
 653    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 654    where
 655        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 656        Fut: Send + Future<Output = Result<()>>,
 657        M: RequestMessage,
 658    {
 659        let handler = Arc::new(handler);
 660        self.add_handler(move |envelope, session| {
 661            let receipt = envelope.receipt();
 662            let handler = handler.clone();
 663            async move {
 664                let peer = session.peer.clone();
 665                let responded = Arc::new(AtomicBool::default());
 666                let response = Response {
 667                    peer: peer.clone(),
 668                    responded: responded.clone(),
 669                    receipt,
 670                };
 671                match (handler)(envelope.payload, response, session).await {
 672                    Ok(()) => {
 673                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 674                            Ok(())
 675                        } else {
 676                            Err(anyhow!("handler did not send a response"))?
 677                        }
 678                    }
 679                    Err(error) => {
 680                        let proto_err = match &error {
 681                            Error::Internal(err) => err.to_proto(),
 682                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 683                        };
 684                        peer.respond_with_error(receipt, proto_err)?;
 685                        Err(error)
 686                    }
 687                }
 688            }
 689        })
 690    }
 691
 692    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 693    where
 694        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
 695        Fut: Send + Future<Output = Result<()>>,
 696        M: RequestMessage,
 697    {
 698        let handler = Arc::new(handler);
 699        self.add_handler(move |envelope, session| {
 700            let receipt = envelope.receipt();
 701            let handler = handler.clone();
 702            async move {
 703                let peer = session.peer.clone();
 704                let response = StreamingResponse {
 705                    peer: peer.clone(),
 706                    receipt,
 707                };
 708                match (handler)(envelope.payload, response, session).await {
 709                    Ok(()) => {
 710                        peer.end_stream(receipt)?;
 711                        Ok(())
 712                    }
 713                    Err(error) => {
 714                        let proto_err = match &error {
 715                            Error::Internal(err) => err.to_proto(),
 716                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 717                        };
 718                        peer.respond_with_error(receipt, proto_err)?;
 719                        Err(error)
 720                    }
 721                }
 722            }
 723        })
 724    }
 725
 726    #[allow(clippy::too_many_arguments)]
 727    pub fn handle_connection(
 728        self: &Arc<Self>,
 729        connection: Connection,
 730        address: String,
 731        principal: Principal,
 732        zed_version: ZedVersion,
 733        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 734        executor: Executor,
 735    ) -> impl Future<Output = ()> {
 736        let this = self.clone();
 737        let span = info_span!("handle connection", %address,
 738            connection_id=field::Empty,
 739            user_id=field::Empty,
 740            login=field::Empty,
 741            impersonator=field::Empty,
 742            dev_server_id=field::Empty
 743        );
 744        principal.update_span(&span);
 745
 746        let mut teardown = self.teardown.subscribe();
 747        async move {
 748            if *teardown.borrow() {
 749                tracing::error!("server is tearing down");
 750                return
 751            }
 752            let (connection_id, handle_io, mut incoming_rx) = this
 753                .peer
 754                .add_connection(connection, {
 755                    let executor = executor.clone();
 756                    move |duration| executor.sleep(duration)
 757                });
 758            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 759            tracing::info!("connection opened");
 760
 761            let http_client = match IsahcHttpClient::new() {
 762                Ok(http_client) => http_client,
 763                Err(error) => {
 764                    tracing::error!(?error, "failed to create HTTP client");
 765                    return;
 766                }
 767            };
 768
 769            let session = Session {
 770                principal: principal.clone(),
 771                connection_id,
 772                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 773                peer: this.peer.clone(),
 774                connection_pool: this.connection_pool.clone(),
 775                live_kit_client: this.app_state.live_kit_client.clone(),
 776                http_client,
 777                rate_limiter: this.app_state.rate_limiter.clone(),
 778                _executor: executor.clone(),
 779            };
 780
 781            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 782                tracing::error!(?error, "failed to send initial client update");
 783                return;
 784            }
 785
 786            let handle_io = handle_io.fuse();
 787            futures::pin_mut!(handle_io);
 788
 789            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 790            // This prevents deadlocks when e.g., client A performs a request to client B and
 791            // client B performs a request to client A. If both clients stop processing further
 792            // messages until their respective request completes, they won't have a chance to
 793            // respond to the other client's request and cause a deadlock.
 794            //
 795            // This arrangement ensures we will attempt to process earlier messages first, but fall
 796            // back to processing messages arrived later in the spirit of making progress.
 797            let mut foreground_message_handlers = FuturesUnordered::new();
 798            let concurrent_handlers = Arc::new(Semaphore::new(256));
 799            loop {
 800                let next_message = async {
 801                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 802                    let message = incoming_rx.next().await;
 803                    (permit, message)
 804                }.fuse();
 805                futures::pin_mut!(next_message);
 806                futures::select_biased! {
 807                    _ = teardown.changed().fuse() => return,
 808                    result = handle_io => {
 809                        if let Err(error) = result {
 810                            tracing::error!(?error, "error handling I/O");
 811                        }
 812                        break;
 813                    }
 814                    _ = foreground_message_handlers.next() => {}
 815                    next_message = next_message => {
 816                        let (permit, message) = next_message;
 817                        if let Some(message) = message {
 818                            let type_name = message.payload_type_name();
 819                            // note: we copy all the fields from the parent span so we can query them in the logs.
 820                            // (https://github.com/tokio-rs/tracing/issues/2670).
 821                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 822                                user_id=field::Empty,
 823                                login=field::Empty,
 824                                impersonator=field::Empty,
 825                                dev_server_id=field::Empty
 826                            );
 827                            principal.update_span(&span);
 828                            let span_enter = span.enter();
 829                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 830                                let is_background = message.is_background();
 831                                let handle_message = (handler)(message, session.clone());
 832                                drop(span_enter);
 833
 834                                let handle_message = async move {
 835                                    handle_message.await;
 836                                    drop(permit);
 837                                }.instrument(span);
 838                                if is_background {
 839                                    executor.spawn_detached(handle_message);
 840                                } else {
 841                                    foreground_message_handlers.push(handle_message);
 842                                }
 843                            } else {
 844                                tracing::error!("no message handler");
 845                            }
 846                        } else {
 847                            tracing::info!("connection closed");
 848                            break;
 849                        }
 850                    }
 851                }
 852            }
 853
 854            drop(foreground_message_handlers);
 855            tracing::info!("signing out");
 856            if let Err(error) = connection_lost(session, teardown, executor).await {
 857                tracing::error!(?error, "error signing out");
 858            }
 859
 860        }.instrument(span)
 861    }
 862
 863    async fn send_initial_client_update(
 864        &self,
 865        connection_id: ConnectionId,
 866        principal: &Principal,
 867        zed_version: ZedVersion,
 868        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 869        session: &Session,
 870    ) -> Result<()> {
 871        self.peer.send(
 872            connection_id,
 873            proto::Hello {
 874                peer_id: Some(connection_id.into()),
 875            },
 876        )?;
 877        tracing::info!("sent hello message");
 878
 879        let Principal::User(user) = principal else {
 880            return Ok(());
 881        };
 882
 883        if let Some(send_connection_id) = send_connection_id.take() {
 884            let _ = send_connection_id.send(connection_id);
 885        }
 886
 887        if !user.connected_once {
 888            self.peer.send(connection_id, proto::ShowContacts {})?;
 889            self.app_state
 890                .db
 891                .set_user_connected_once(user.id, true)
 892                .await?;
 893        }
 894
 895        let (contacts, channels_for_user, channel_invites) = future::try_join3(
 896            self.app_state.db.get_contacts(user.id),
 897            self.app_state.db.get_channels_for_user(user.id),
 898            self.app_state.db.get_channel_invites_for_user(user.id),
 899        )
 900        .await?;
 901
 902        {
 903            let mut pool = self.connection_pool.lock();
 904            pool.add_connection(connection_id, user.id, user.admin, zed_version);
 905            for membership in &channels_for_user.channel_memberships {
 906                pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
 907            }
 908            self.peer.send(
 909                connection_id,
 910                build_initial_contacts_update(contacts, &pool),
 911            )?;
 912            self.peer.send(
 913                connection_id,
 914                build_update_user_channels(&channels_for_user),
 915            )?;
 916            self.peer.send(
 917                connection_id,
 918                build_channels_update(channels_for_user, channel_invites),
 919            )?;
 920        }
 921
 922        if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
 923            self.peer.send(connection_id, incoming_call)?;
 924        }
 925
 926        update_user_contacts(user.id, &session).await?;
 927        Ok(())
 928    }
 929
 930    pub async fn invite_code_redeemed(
 931        self: &Arc<Self>,
 932        inviter_id: UserId,
 933        invitee_id: UserId,
 934    ) -> Result<()> {
 935        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 936            if let Some(code) = &user.invite_code {
 937                let pool = self.connection_pool.lock();
 938                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 939                for connection_id in pool.user_connection_ids(inviter_id) {
 940                    self.peer.send(
 941                        connection_id,
 942                        proto::UpdateContacts {
 943                            contacts: vec![invitee_contact.clone()],
 944                            ..Default::default()
 945                        },
 946                    )?;
 947                    self.peer.send(
 948                        connection_id,
 949                        proto::UpdateInviteInfo {
 950                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 951                            count: user.invite_count as u32,
 952                        },
 953                    )?;
 954                }
 955            }
 956        }
 957        Ok(())
 958    }
 959
 960    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 961        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 962            if let Some(invite_code) = &user.invite_code {
 963                let pool = self.connection_pool.lock();
 964                for connection_id in pool.user_connection_ids(user_id) {
 965                    self.peer.send(
 966                        connection_id,
 967                        proto::UpdateInviteInfo {
 968                            url: format!(
 969                                "{}{}",
 970                                self.app_state.config.invite_link_prefix, invite_code
 971                            ),
 972                            count: user.invite_count as u32,
 973                        },
 974                    )?;
 975                }
 976            }
 977        }
 978        Ok(())
 979    }
 980
 981    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 982        ServerSnapshot {
 983            connection_pool: ConnectionPoolGuard {
 984                guard: self.connection_pool.lock(),
 985                _not_send: PhantomData,
 986            },
 987            peer: &self.peer,
 988        }
 989    }
 990}
 991
 992impl<'a> Deref for ConnectionPoolGuard<'a> {
 993    type Target = ConnectionPool;
 994
 995    fn deref(&self) -> &Self::Target {
 996        &self.guard
 997    }
 998}
 999
1000impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1001    fn deref_mut(&mut self) -> &mut Self::Target {
1002        &mut self.guard
1003    }
1004}
1005
1006impl<'a> Drop for ConnectionPoolGuard<'a> {
1007    fn drop(&mut self) {
1008        #[cfg(test)]
1009        self.check_invariants();
1010    }
1011}
1012
1013fn broadcast<F>(
1014    sender_id: Option<ConnectionId>,
1015    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1016    mut f: F,
1017) where
1018    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1019{
1020    for receiver_id in receiver_ids {
1021        if Some(receiver_id) != sender_id {
1022            if let Err(error) = f(receiver_id) {
1023                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1024            }
1025        }
1026    }
1027}
1028
1029pub struct ProtocolVersion(u32);
1030
1031impl Header for ProtocolVersion {
1032    fn name() -> &'static HeaderName {
1033        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1034        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1035    }
1036
1037    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1038    where
1039        Self: Sized,
1040        I: Iterator<Item = &'i axum::http::HeaderValue>,
1041    {
1042        let version = values
1043            .next()
1044            .ok_or_else(axum::headers::Error::invalid)?
1045            .to_str()
1046            .map_err(|_| axum::headers::Error::invalid())?
1047            .parse()
1048            .map_err(|_| axum::headers::Error::invalid())?;
1049        Ok(Self(version))
1050    }
1051
1052    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1053        values.extend([self.0.to_string().parse().unwrap()]);
1054    }
1055}
1056
1057pub struct AppVersionHeader(SemanticVersion);
1058impl Header for AppVersionHeader {
1059    fn name() -> &'static HeaderName {
1060        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1061        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1062    }
1063
1064    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1065    where
1066        Self: Sized,
1067        I: Iterator<Item = &'i axum::http::HeaderValue>,
1068    {
1069        let version = values
1070            .next()
1071            .ok_or_else(axum::headers::Error::invalid)?
1072            .to_str()
1073            .map_err(|_| axum::headers::Error::invalid())?
1074            .parse()
1075            .map_err(|_| axum::headers::Error::invalid())?;
1076        Ok(Self(version))
1077    }
1078
1079    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1080        values.extend([self.0.to_string().parse().unwrap()]);
1081    }
1082}
1083
1084pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1085    Router::new()
1086        .route("/rpc", get(handle_websocket_request))
1087        .layer(
1088            ServiceBuilder::new()
1089                .layer(Extension(server.app_state.clone()))
1090                .layer(middleware::from_fn(auth::validate_header)),
1091        )
1092        .route("/metrics", get(handle_metrics))
1093        .layer(Extension(server))
1094}
1095
1096pub async fn handle_websocket_request(
1097    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1098    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1099    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1100    Extension(server): Extension<Arc<Server>>,
1101    Extension(principal): Extension<Principal>,
1102    ws: WebSocketUpgrade,
1103) -> axum::response::Response {
1104    if protocol_version != rpc::PROTOCOL_VERSION {
1105        return (
1106            StatusCode::UPGRADE_REQUIRED,
1107            "client must be upgraded".to_string(),
1108        )
1109            .into_response();
1110    }
1111
1112    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1113        return (
1114            StatusCode::UPGRADE_REQUIRED,
1115            "no version header found".to_string(),
1116        )
1117            .into_response();
1118    };
1119
1120    if !version.can_collaborate() {
1121        return (
1122            StatusCode::UPGRADE_REQUIRED,
1123            "client must be upgraded".to_string(),
1124        )
1125            .into_response();
1126    }
1127
1128    let socket_address = socket_address.to_string();
1129    ws.on_upgrade(move |socket| {
1130        let socket = socket
1131            .map_ok(to_tungstenite_message)
1132            .err_into()
1133            .with(|message| async move { Ok(to_axum_message(message)) });
1134        let connection = Connection::new(Box::pin(socket));
1135        async move {
1136            server
1137                .handle_connection(
1138                    connection,
1139                    socket_address,
1140                    principal,
1141                    version,
1142                    None,
1143                    Executor::Production,
1144                )
1145                .await;
1146        }
1147    })
1148}
1149
1150pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1151    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1152    let connections_metric = CONNECTIONS_METRIC
1153        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1154
1155    let connections = server
1156        .connection_pool
1157        .lock()
1158        .connections()
1159        .filter(|connection| !connection.admin)
1160        .count();
1161    connections_metric.set(connections as _);
1162
1163    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1164    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1165        register_int_gauge!(
1166            "shared_projects",
1167            "number of open projects with one or more guests"
1168        )
1169        .unwrap()
1170    });
1171
1172    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1173    shared_projects_metric.set(shared_projects as _);
1174
1175    let encoder = prometheus::TextEncoder::new();
1176    let metric_families = prometheus::gather();
1177    let encoded_metrics = encoder
1178        .encode_to_string(&metric_families)
1179        .map_err(|err| anyhow!("{}", err))?;
1180    Ok(encoded_metrics)
1181}
1182
1183#[instrument(err, skip(executor))]
1184async fn connection_lost(
1185    session: Session,
1186    mut teardown: watch::Receiver<bool>,
1187    executor: Executor,
1188) -> Result<()> {
1189    session.peer.disconnect(session.connection_id);
1190    session
1191        .connection_pool()
1192        .await
1193        .remove_connection(session.connection_id)?;
1194
1195    session
1196        .db()
1197        .await
1198        .connection_lost(session.connection_id)
1199        .await
1200        .trace_err();
1201
1202    futures::select_biased! {
1203        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1204            if let Some(session) = session.for_user() {
1205                log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1206                leave_room_for_session(&session, session.connection_id).await.trace_err();
1207                leave_channel_buffers_for_session(&session)
1208                    .await
1209                    .trace_err();
1210
1211                if !session
1212                    .connection_pool()
1213                    .await
1214                    .is_user_online(session.user_id())
1215                {
1216                    let db = session.db().await;
1217                    if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1218                        room_updated(&room, &session.peer);
1219                    }
1220                }
1221
1222                update_user_contacts(session.user_id(), &session).await?;
1223            }
1224        }
1225        _ = teardown.changed().fuse() => {}
1226    }
1227
1228    Ok(())
1229}
1230
1231/// Acknowledges a ping from a client, used to keep the connection alive.
1232async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1233    response.send(proto::Ack {})?;
1234    Ok(())
1235}
1236
1237/// Creates a new room for calling (outside of channels)
1238async fn create_room(
1239    _request: proto::CreateRoom,
1240    response: Response<proto::CreateRoom>,
1241    session: UserSession,
1242) -> Result<()> {
1243    let live_kit_room = nanoid::nanoid!(30);
1244
1245    let live_kit_connection_info = util::maybe!(async {
1246        let live_kit = session.live_kit_client.as_ref();
1247        let live_kit = live_kit?;
1248        let user_id = session.user_id().to_string();
1249
1250        let token = live_kit
1251            .room_token(&live_kit_room, &user_id.to_string())
1252            .trace_err()?;
1253
1254        Some(proto::LiveKitConnectionInfo {
1255            server_url: live_kit.url().into(),
1256            token,
1257            can_publish: true,
1258        })
1259    })
1260    .await;
1261
1262    let room = session
1263        .db()
1264        .await
1265        .create_room(session.user_id(), session.connection_id, &live_kit_room)
1266        .await?;
1267
1268    response.send(proto::CreateRoomResponse {
1269        room: Some(room.clone()),
1270        live_kit_connection_info,
1271    })?;
1272
1273    update_user_contacts(session.user_id(), &session).await?;
1274    Ok(())
1275}
1276
1277/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1278async fn join_room(
1279    request: proto::JoinRoom,
1280    response: Response<proto::JoinRoom>,
1281    session: UserSession,
1282) -> Result<()> {
1283    let room_id = RoomId::from_proto(request.id);
1284
1285    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1286
1287    if let Some(channel_id) = channel_id {
1288        return join_channel_internal(channel_id, Box::new(response), session).await;
1289    }
1290
1291    let joined_room = {
1292        let room = session
1293            .db()
1294            .await
1295            .join_room(room_id, session.user_id(), session.connection_id)
1296            .await?;
1297        room_updated(&room.room, &session.peer);
1298        room.into_inner()
1299    };
1300
1301    for connection_id in session
1302        .connection_pool()
1303        .await
1304        .user_connection_ids(session.user_id())
1305    {
1306        session
1307            .peer
1308            .send(
1309                connection_id,
1310                proto::CallCanceled {
1311                    room_id: room_id.to_proto(),
1312                },
1313            )
1314            .trace_err();
1315    }
1316
1317    let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1318        if let Some(token) = live_kit
1319            .room_token(
1320                &joined_room.room.live_kit_room,
1321                &session.user_id().to_string(),
1322            )
1323            .trace_err()
1324        {
1325            Some(proto::LiveKitConnectionInfo {
1326                server_url: live_kit.url().into(),
1327                token,
1328                can_publish: true,
1329            })
1330        } else {
1331            None
1332        }
1333    } else {
1334        None
1335    };
1336
1337    response.send(proto::JoinRoomResponse {
1338        room: Some(joined_room.room),
1339        channel_id: None,
1340        live_kit_connection_info,
1341    })?;
1342
1343    update_user_contacts(session.user_id(), &session).await?;
1344    Ok(())
1345}
1346
1347/// Rejoin room is used to reconnect to a room after connection errors.
1348async fn rejoin_room(
1349    request: proto::RejoinRoom,
1350    response: Response<proto::RejoinRoom>,
1351    session: UserSession,
1352) -> Result<()> {
1353    let room;
1354    let channel;
1355    {
1356        let mut rejoined_room = session
1357            .db()
1358            .await
1359            .rejoin_room(request, session.user_id(), session.connection_id)
1360            .await?;
1361
1362        response.send(proto::RejoinRoomResponse {
1363            room: Some(rejoined_room.room.clone()),
1364            reshared_projects: rejoined_room
1365                .reshared_projects
1366                .iter()
1367                .map(|project| proto::ResharedProject {
1368                    id: project.id.to_proto(),
1369                    collaborators: project
1370                        .collaborators
1371                        .iter()
1372                        .map(|collaborator| collaborator.to_proto())
1373                        .collect(),
1374                })
1375                .collect(),
1376            rejoined_projects: rejoined_room
1377                .rejoined_projects
1378                .iter()
1379                .map(|rejoined_project| proto::RejoinedProject {
1380                    id: rejoined_project.id.to_proto(),
1381                    worktrees: rejoined_project
1382                        .worktrees
1383                        .iter()
1384                        .map(|worktree| proto::WorktreeMetadata {
1385                            id: worktree.id,
1386                            root_name: worktree.root_name.clone(),
1387                            visible: worktree.visible,
1388                            abs_path: worktree.abs_path.clone(),
1389                        })
1390                        .collect(),
1391                    collaborators: rejoined_project
1392                        .collaborators
1393                        .iter()
1394                        .map(|collaborator| collaborator.to_proto())
1395                        .collect(),
1396                    language_servers: rejoined_project.language_servers.clone(),
1397                })
1398                .collect(),
1399        })?;
1400        room_updated(&rejoined_room.room, &session.peer);
1401
1402        for project in &rejoined_room.reshared_projects {
1403            for collaborator in &project.collaborators {
1404                session
1405                    .peer
1406                    .send(
1407                        collaborator.connection_id,
1408                        proto::UpdateProjectCollaborator {
1409                            project_id: project.id.to_proto(),
1410                            old_peer_id: Some(project.old_connection_id.into()),
1411                            new_peer_id: Some(session.connection_id.into()),
1412                        },
1413                    )
1414                    .trace_err();
1415            }
1416
1417            broadcast(
1418                Some(session.connection_id),
1419                project
1420                    .collaborators
1421                    .iter()
1422                    .map(|collaborator| collaborator.connection_id),
1423                |connection_id| {
1424                    session.peer.forward_send(
1425                        session.connection_id,
1426                        connection_id,
1427                        proto::UpdateProject {
1428                            project_id: project.id.to_proto(),
1429                            worktrees: project.worktrees.clone(),
1430                        },
1431                    )
1432                },
1433            );
1434        }
1435
1436        for project in &rejoined_room.rejoined_projects {
1437            for collaborator in &project.collaborators {
1438                session
1439                    .peer
1440                    .send(
1441                        collaborator.connection_id,
1442                        proto::UpdateProjectCollaborator {
1443                            project_id: project.id.to_proto(),
1444                            old_peer_id: Some(project.old_connection_id.into()),
1445                            new_peer_id: Some(session.connection_id.into()),
1446                        },
1447                    )
1448                    .trace_err();
1449            }
1450        }
1451
1452        for project in &mut rejoined_room.rejoined_projects {
1453            for worktree in mem::take(&mut project.worktrees) {
1454                #[cfg(any(test, feature = "test-support"))]
1455                const MAX_CHUNK_SIZE: usize = 2;
1456                #[cfg(not(any(test, feature = "test-support")))]
1457                const MAX_CHUNK_SIZE: usize = 256;
1458
1459                // Stream this worktree's entries.
1460                let message = proto::UpdateWorktree {
1461                    project_id: project.id.to_proto(),
1462                    worktree_id: worktree.id,
1463                    abs_path: worktree.abs_path.clone(),
1464                    root_name: worktree.root_name,
1465                    updated_entries: worktree.updated_entries,
1466                    removed_entries: worktree.removed_entries,
1467                    scan_id: worktree.scan_id,
1468                    is_last_update: worktree.completed_scan_id == worktree.scan_id,
1469                    updated_repositories: worktree.updated_repositories,
1470                    removed_repositories: worktree.removed_repositories,
1471                };
1472                for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1473                    session.peer.send(session.connection_id, update.clone())?;
1474                }
1475
1476                // Stream this worktree's diagnostics.
1477                for summary in worktree.diagnostic_summaries {
1478                    session.peer.send(
1479                        session.connection_id,
1480                        proto::UpdateDiagnosticSummary {
1481                            project_id: project.id.to_proto(),
1482                            worktree_id: worktree.id,
1483                            summary: Some(summary),
1484                        },
1485                    )?;
1486                }
1487
1488                for settings_file in worktree.settings_files {
1489                    session.peer.send(
1490                        session.connection_id,
1491                        proto::UpdateWorktreeSettings {
1492                            project_id: project.id.to_proto(),
1493                            worktree_id: worktree.id,
1494                            path: settings_file.path,
1495                            content: Some(settings_file.content),
1496                        },
1497                    )?;
1498                }
1499            }
1500
1501            for language_server in &project.language_servers {
1502                session.peer.send(
1503                    session.connection_id,
1504                    proto::UpdateLanguageServer {
1505                        project_id: project.id.to_proto(),
1506                        language_server_id: language_server.id,
1507                        variant: Some(
1508                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1509                                proto::LspDiskBasedDiagnosticsUpdated {},
1510                            ),
1511                        ),
1512                    },
1513                )?;
1514            }
1515        }
1516
1517        let rejoined_room = rejoined_room.into_inner();
1518
1519        room = rejoined_room.room;
1520        channel = rejoined_room.channel;
1521    }
1522
1523    if let Some(channel) = channel {
1524        channel_updated(
1525            &channel,
1526            &room,
1527            &session.peer,
1528            &*session.connection_pool().await,
1529        );
1530    }
1531
1532    update_user_contacts(session.user_id(), &session).await?;
1533    Ok(())
1534}
1535
1536/// leave room disconnects from the room.
1537async fn leave_room(
1538    _: proto::LeaveRoom,
1539    response: Response<proto::LeaveRoom>,
1540    session: UserSession,
1541) -> Result<()> {
1542    leave_room_for_session(&session, session.connection_id).await?;
1543    response.send(proto::Ack {})?;
1544    Ok(())
1545}
1546
1547/// Updates the permissions of someone else in the room.
1548async fn set_room_participant_role(
1549    request: proto::SetRoomParticipantRole,
1550    response: Response<proto::SetRoomParticipantRole>,
1551    session: UserSession,
1552) -> Result<()> {
1553    let user_id = UserId::from_proto(request.user_id);
1554    let role = ChannelRole::from(request.role());
1555
1556    let (live_kit_room, can_publish) = {
1557        let room = session
1558            .db()
1559            .await
1560            .set_room_participant_role(
1561                session.user_id(),
1562                RoomId::from_proto(request.room_id),
1563                user_id,
1564                role,
1565            )
1566            .await?;
1567
1568        let live_kit_room = room.live_kit_room.clone();
1569        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1570        room_updated(&room, &session.peer);
1571        (live_kit_room, can_publish)
1572    };
1573
1574    if let Some(live_kit) = session.live_kit_client.as_ref() {
1575        live_kit
1576            .update_participant(
1577                live_kit_room.clone(),
1578                request.user_id.to_string(),
1579                live_kit_server::proto::ParticipantPermission {
1580                    can_subscribe: true,
1581                    can_publish,
1582                    can_publish_data: can_publish,
1583                    hidden: false,
1584                    recorder: false,
1585                },
1586            )
1587            .await
1588            .trace_err();
1589    }
1590
1591    response.send(proto::Ack {})?;
1592    Ok(())
1593}
1594
1595/// Call someone else into the current room
1596async fn call(
1597    request: proto::Call,
1598    response: Response<proto::Call>,
1599    session: UserSession,
1600) -> Result<()> {
1601    let room_id = RoomId::from_proto(request.room_id);
1602    let calling_user_id = session.user_id();
1603    let calling_connection_id = session.connection_id;
1604    let called_user_id = UserId::from_proto(request.called_user_id);
1605    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1606    if !session
1607        .db()
1608        .await
1609        .has_contact(calling_user_id, called_user_id)
1610        .await?
1611    {
1612        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1613    }
1614
1615    let incoming_call = {
1616        let (room, incoming_call) = &mut *session
1617            .db()
1618            .await
1619            .call(
1620                room_id,
1621                calling_user_id,
1622                calling_connection_id,
1623                called_user_id,
1624                initial_project_id,
1625            )
1626            .await?;
1627        room_updated(&room, &session.peer);
1628        mem::take(incoming_call)
1629    };
1630    update_user_contacts(called_user_id, &session).await?;
1631
1632    let mut calls = session
1633        .connection_pool()
1634        .await
1635        .user_connection_ids(called_user_id)
1636        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1637        .collect::<FuturesUnordered<_>>();
1638
1639    while let Some(call_response) = calls.next().await {
1640        match call_response.as_ref() {
1641            Ok(_) => {
1642                response.send(proto::Ack {})?;
1643                return Ok(());
1644            }
1645            Err(_) => {
1646                call_response.trace_err();
1647            }
1648        }
1649    }
1650
1651    {
1652        let room = session
1653            .db()
1654            .await
1655            .call_failed(room_id, called_user_id)
1656            .await?;
1657        room_updated(&room, &session.peer);
1658    }
1659    update_user_contacts(called_user_id, &session).await?;
1660
1661    Err(anyhow!("failed to ring user"))?
1662}
1663
1664/// Cancel an outgoing call.
1665async fn cancel_call(
1666    request: proto::CancelCall,
1667    response: Response<proto::CancelCall>,
1668    session: UserSession,
1669) -> Result<()> {
1670    let called_user_id = UserId::from_proto(request.called_user_id);
1671    let room_id = RoomId::from_proto(request.room_id);
1672    {
1673        let room = session
1674            .db()
1675            .await
1676            .cancel_call(room_id, session.connection_id, called_user_id)
1677            .await?;
1678        room_updated(&room, &session.peer);
1679    }
1680
1681    for connection_id in session
1682        .connection_pool()
1683        .await
1684        .user_connection_ids(called_user_id)
1685    {
1686        session
1687            .peer
1688            .send(
1689                connection_id,
1690                proto::CallCanceled {
1691                    room_id: room_id.to_proto(),
1692                },
1693            )
1694            .trace_err();
1695    }
1696    response.send(proto::Ack {})?;
1697
1698    update_user_contacts(called_user_id, &session).await?;
1699    Ok(())
1700}
1701
1702/// Decline an incoming call.
1703async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1704    let room_id = RoomId::from_proto(message.room_id);
1705    {
1706        let room = session
1707            .db()
1708            .await
1709            .decline_call(Some(room_id), session.user_id())
1710            .await?
1711            .ok_or_else(|| anyhow!("failed to decline call"))?;
1712        room_updated(&room, &session.peer);
1713    }
1714
1715    for connection_id in session
1716        .connection_pool()
1717        .await
1718        .user_connection_ids(session.user_id())
1719    {
1720        session
1721            .peer
1722            .send(
1723                connection_id,
1724                proto::CallCanceled {
1725                    room_id: room_id.to_proto(),
1726                },
1727            )
1728            .trace_err();
1729    }
1730    update_user_contacts(session.user_id(), &session).await?;
1731    Ok(())
1732}
1733
1734/// Updates other participants in the room with your current location.
1735async fn update_participant_location(
1736    request: proto::UpdateParticipantLocation,
1737    response: Response<proto::UpdateParticipantLocation>,
1738    session: UserSession,
1739) -> Result<()> {
1740    let room_id = RoomId::from_proto(request.room_id);
1741    let location = request
1742        .location
1743        .ok_or_else(|| anyhow!("invalid location"))?;
1744
1745    let db = session.db().await;
1746    let room = db
1747        .update_room_participant_location(room_id, session.connection_id, location)
1748        .await?;
1749
1750    room_updated(&room, &session.peer);
1751    response.send(proto::Ack {})?;
1752    Ok(())
1753}
1754
1755/// Share a project into the room.
1756async fn share_project(
1757    request: proto::ShareProject,
1758    response: Response<proto::ShareProject>,
1759    session: Session,
1760) -> Result<()> {
1761    let (project_id, room) = &*session
1762        .db()
1763        .await
1764        .share_project(
1765            RoomId::from_proto(request.room_id),
1766            session.connection_id,
1767            &request.worktrees,
1768        )
1769        .await?;
1770    response.send(proto::ShareProjectResponse {
1771        project_id: project_id.to_proto(),
1772    })?;
1773    room_updated(&room, &session.peer);
1774
1775    Ok(())
1776}
1777
1778/// Unshare a project from the room.
1779async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1780    let project_id = ProjectId::from_proto(message.project_id);
1781
1782    let (room, guest_connection_ids) = &*session
1783        .db()
1784        .await
1785        .unshare_project(project_id, session.connection_id)
1786        .await?;
1787
1788    broadcast(
1789        Some(session.connection_id),
1790        guest_connection_ids.iter().copied(),
1791        |conn_id| session.peer.send(conn_id, message.clone()),
1792    );
1793    room_updated(&room, &session.peer);
1794
1795    Ok(())
1796}
1797
1798/// Join someone elses shared project.
1799async fn join_project(
1800    request: proto::JoinProject,
1801    response: Response<proto::JoinProject>,
1802    session: UserSession,
1803) -> Result<()> {
1804    let project_id = ProjectId::from_proto(request.project_id);
1805
1806    tracing::info!(%project_id, "join project");
1807
1808    let (project, replica_id) = &mut *session
1809        .db()
1810        .await
1811        .join_project_in_room(project_id, session.connection_id)
1812        .await?;
1813
1814    join_project_internal(response, session, project, replica_id)
1815}
1816
1817trait JoinProjectInternalResponse {
1818    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1819}
1820impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1821    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1822        Response::<proto::JoinProject>::send(self, result)
1823    }
1824}
1825impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1826    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1827        Response::<proto::JoinHostedProject>::send(self, result)
1828    }
1829}
1830
1831fn join_project_internal(
1832    response: impl JoinProjectInternalResponse,
1833    session: UserSession,
1834    project: &mut Project,
1835    replica_id: &ReplicaId,
1836) -> Result<()> {
1837    let collaborators = project
1838        .collaborators
1839        .iter()
1840        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1841        .map(|collaborator| collaborator.to_proto())
1842        .collect::<Vec<_>>();
1843    let project_id = project.id;
1844    let guest_user_id = session.user_id();
1845
1846    let worktrees = project
1847        .worktrees
1848        .iter()
1849        .map(|(id, worktree)| proto::WorktreeMetadata {
1850            id: *id,
1851            root_name: worktree.root_name.clone(),
1852            visible: worktree.visible,
1853            abs_path: worktree.abs_path.clone(),
1854        })
1855        .collect::<Vec<_>>();
1856
1857    for collaborator in &collaborators {
1858        session
1859            .peer
1860            .send(
1861                collaborator.peer_id.unwrap().into(),
1862                proto::AddProjectCollaborator {
1863                    project_id: project_id.to_proto(),
1864                    collaborator: Some(proto::Collaborator {
1865                        peer_id: Some(session.connection_id.into()),
1866                        replica_id: replica_id.0 as u32,
1867                        user_id: guest_user_id.to_proto(),
1868                    }),
1869                },
1870            )
1871            .trace_err();
1872    }
1873
1874    // First, we send the metadata associated with each worktree.
1875    response.send(proto::JoinProjectResponse {
1876        project_id: project.id.0 as u64,
1877        worktrees: worktrees.clone(),
1878        replica_id: replica_id.0 as u32,
1879        collaborators: collaborators.clone(),
1880        language_servers: project.language_servers.clone(),
1881        role: project.role.into(), // todo
1882    })?;
1883
1884    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1885        #[cfg(any(test, feature = "test-support"))]
1886        const MAX_CHUNK_SIZE: usize = 2;
1887        #[cfg(not(any(test, feature = "test-support")))]
1888        const MAX_CHUNK_SIZE: usize = 256;
1889
1890        // Stream this worktree's entries.
1891        let message = proto::UpdateWorktree {
1892            project_id: project_id.to_proto(),
1893            worktree_id,
1894            abs_path: worktree.abs_path.clone(),
1895            root_name: worktree.root_name,
1896            updated_entries: worktree.entries,
1897            removed_entries: Default::default(),
1898            scan_id: worktree.scan_id,
1899            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1900            updated_repositories: worktree.repository_entries.into_values().collect(),
1901            removed_repositories: Default::default(),
1902        };
1903        for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1904            session.peer.send(session.connection_id, update.clone())?;
1905        }
1906
1907        // Stream this worktree's diagnostics.
1908        for summary in worktree.diagnostic_summaries {
1909            session.peer.send(
1910                session.connection_id,
1911                proto::UpdateDiagnosticSummary {
1912                    project_id: project_id.to_proto(),
1913                    worktree_id: worktree.id,
1914                    summary: Some(summary),
1915                },
1916            )?;
1917        }
1918
1919        for settings_file in worktree.settings_files {
1920            session.peer.send(
1921                session.connection_id,
1922                proto::UpdateWorktreeSettings {
1923                    project_id: project_id.to_proto(),
1924                    worktree_id: worktree.id,
1925                    path: settings_file.path,
1926                    content: Some(settings_file.content),
1927                },
1928            )?;
1929        }
1930    }
1931
1932    for language_server in &project.language_servers {
1933        session.peer.send(
1934            session.connection_id,
1935            proto::UpdateLanguageServer {
1936                project_id: project_id.to_proto(),
1937                language_server_id: language_server.id,
1938                variant: Some(
1939                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1940                        proto::LspDiskBasedDiagnosticsUpdated {},
1941                    ),
1942                ),
1943            },
1944        )?;
1945    }
1946
1947    Ok(())
1948}
1949
1950/// Leave someone elses shared project.
1951async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1952    let sender_id = session.connection_id;
1953    let project_id = ProjectId::from_proto(request.project_id);
1954    let db = session.db().await;
1955    if db.is_hosted_project(project_id).await? {
1956        let project = db.leave_hosted_project(project_id, sender_id).await?;
1957        project_left(&project, &session);
1958        return Ok(());
1959    }
1960
1961    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1962    tracing::info!(
1963        %project_id,
1964        host_user_id = ?project.host_user_id,
1965        host_connection_id = ?project.host_connection_id,
1966        "leave project"
1967    );
1968
1969    project_left(&project, &session);
1970    room_updated(&room, &session.peer);
1971
1972    Ok(())
1973}
1974
1975async fn join_hosted_project(
1976    request: proto::JoinHostedProject,
1977    response: Response<proto::JoinHostedProject>,
1978    session: UserSession,
1979) -> Result<()> {
1980    let (mut project, replica_id) = session
1981        .db()
1982        .await
1983        .join_hosted_project(
1984            ProjectId(request.project_id as i32),
1985            session.user_id(),
1986            session.connection_id,
1987        )
1988        .await?;
1989
1990    join_project_internal(response, session, &mut project, &replica_id)
1991}
1992
1993/// Updates other participants with changes to the project
1994async fn update_project(
1995    request: proto::UpdateProject,
1996    response: Response<proto::UpdateProject>,
1997    session: Session,
1998) -> Result<()> {
1999    let project_id = ProjectId::from_proto(request.project_id);
2000    let (room, guest_connection_ids) = &*session
2001        .db()
2002        .await
2003        .update_project(project_id, session.connection_id, &request.worktrees)
2004        .await?;
2005    broadcast(
2006        Some(session.connection_id),
2007        guest_connection_ids.iter().copied(),
2008        |connection_id| {
2009            session
2010                .peer
2011                .forward_send(session.connection_id, connection_id, request.clone())
2012        },
2013    );
2014    room_updated(&room, &session.peer);
2015    response.send(proto::Ack {})?;
2016
2017    Ok(())
2018}
2019
2020/// Updates other participants with changes to the worktree
2021async fn update_worktree(
2022    request: proto::UpdateWorktree,
2023    response: Response<proto::UpdateWorktree>,
2024    session: Session,
2025) -> Result<()> {
2026    let guest_connection_ids = session
2027        .db()
2028        .await
2029        .update_worktree(&request, session.connection_id)
2030        .await?;
2031
2032    broadcast(
2033        Some(session.connection_id),
2034        guest_connection_ids.iter().copied(),
2035        |connection_id| {
2036            session
2037                .peer
2038                .forward_send(session.connection_id, connection_id, request.clone())
2039        },
2040    );
2041    response.send(proto::Ack {})?;
2042    Ok(())
2043}
2044
2045/// Updates other participants with changes to the diagnostics
2046async fn update_diagnostic_summary(
2047    message: proto::UpdateDiagnosticSummary,
2048    session: Session,
2049) -> Result<()> {
2050    let guest_connection_ids = session
2051        .db()
2052        .await
2053        .update_diagnostic_summary(&message, session.connection_id)
2054        .await?;
2055
2056    broadcast(
2057        Some(session.connection_id),
2058        guest_connection_ids.iter().copied(),
2059        |connection_id| {
2060            session
2061                .peer
2062                .forward_send(session.connection_id, connection_id, message.clone())
2063        },
2064    );
2065
2066    Ok(())
2067}
2068
2069/// Updates other participants with changes to the worktree settings
2070async fn update_worktree_settings(
2071    message: proto::UpdateWorktreeSettings,
2072    session: Session,
2073) -> Result<()> {
2074    let guest_connection_ids = session
2075        .db()
2076        .await
2077        .update_worktree_settings(&message, session.connection_id)
2078        .await?;
2079
2080    broadcast(
2081        Some(session.connection_id),
2082        guest_connection_ids.iter().copied(),
2083        |connection_id| {
2084            session
2085                .peer
2086                .forward_send(session.connection_id, connection_id, message.clone())
2087        },
2088    );
2089
2090    Ok(())
2091}
2092
2093/// Notify other participants that a  language server has started.
2094async fn start_language_server(
2095    request: proto::StartLanguageServer,
2096    session: Session,
2097) -> Result<()> {
2098    let guest_connection_ids = session
2099        .db()
2100        .await
2101        .start_language_server(&request, session.connection_id)
2102        .await?;
2103
2104    broadcast(
2105        Some(session.connection_id),
2106        guest_connection_ids.iter().copied(),
2107        |connection_id| {
2108            session
2109                .peer
2110                .forward_send(session.connection_id, connection_id, request.clone())
2111        },
2112    );
2113    Ok(())
2114}
2115
2116/// Notify other participants that a language server has changed.
2117async fn update_language_server(
2118    request: proto::UpdateLanguageServer,
2119    session: Session,
2120) -> Result<()> {
2121    let project_id = ProjectId::from_proto(request.project_id);
2122    let project_connection_ids = session
2123        .db()
2124        .await
2125        .project_connection_ids(project_id, session.connection_id)
2126        .await?;
2127    broadcast(
2128        Some(session.connection_id),
2129        project_connection_ids.iter().copied(),
2130        |connection_id| {
2131            session
2132                .peer
2133                .forward_send(session.connection_id, connection_id, request.clone())
2134        },
2135    );
2136    Ok(())
2137}
2138
2139/// forward a project request to the host. These requests should be read only
2140/// as guests are allowed to send them.
2141async fn forward_read_only_project_request<T>(
2142    request: T,
2143    response: Response<T>,
2144    session: Session,
2145) -> Result<()>
2146where
2147    T: EntityMessage + RequestMessage,
2148{
2149    let project_id = ProjectId::from_proto(request.remote_entity_id());
2150    let host_connection_id = session
2151        .db()
2152        .await
2153        .host_for_read_only_project_request(project_id, session.connection_id)
2154        .await?;
2155    let payload = session
2156        .peer
2157        .forward_request(session.connection_id, host_connection_id, request)
2158        .await?;
2159    response.send(payload)?;
2160    Ok(())
2161}
2162
2163/// forward a project request to the host. These requests are disallowed
2164/// for guests.
2165async fn forward_mutating_project_request<T>(
2166    request: T,
2167    response: Response<T>,
2168    session: Session,
2169) -> Result<()>
2170where
2171    T: EntityMessage + RequestMessage,
2172{
2173    let project_id = ProjectId::from_proto(request.remote_entity_id());
2174    let host_connection_id = session
2175        .db()
2176        .await
2177        .host_for_mutating_project_request(project_id, session.connection_id)
2178        .await?;
2179    let payload = session
2180        .peer
2181        .forward_request(session.connection_id, host_connection_id, request)
2182        .await?;
2183    response.send(payload)?;
2184    Ok(())
2185}
2186
2187/// Notify other participants that a new buffer has been created
2188async fn create_buffer_for_peer(
2189    request: proto::CreateBufferForPeer,
2190    session: Session,
2191) -> Result<()> {
2192    session
2193        .db()
2194        .await
2195        .check_user_is_project_host(
2196            ProjectId::from_proto(request.project_id),
2197            session.connection_id,
2198        )
2199        .await?;
2200    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2201    session
2202        .peer
2203        .forward_send(session.connection_id, peer_id.into(), request)?;
2204    Ok(())
2205}
2206
2207/// Notify other participants that a buffer has been updated. This is
2208/// allowed for guests as long as the update is limited to selections.
2209async fn update_buffer(
2210    request: proto::UpdateBuffer,
2211    response: Response<proto::UpdateBuffer>,
2212    session: Session,
2213) -> Result<()> {
2214    let project_id = ProjectId::from_proto(request.project_id);
2215    let mut guest_connection_ids;
2216    let mut host_connection_id = None;
2217
2218    let mut requires_write_permission = false;
2219
2220    for op in request.operations.iter() {
2221        match op.variant {
2222            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2223            Some(_) => requires_write_permission = true,
2224        }
2225    }
2226
2227    {
2228        let collaborators = session
2229            .db()
2230            .await
2231            .project_collaborators_for_buffer_update(
2232                project_id,
2233                session.connection_id,
2234                requires_write_permission,
2235            )
2236            .await?;
2237        guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2238        for collaborator in collaborators.iter() {
2239            if collaborator.is_host {
2240                host_connection_id = Some(collaborator.connection_id);
2241            } else {
2242                guest_connection_ids.push(collaborator.connection_id);
2243            }
2244        }
2245    }
2246    let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2247
2248    broadcast(
2249        Some(session.connection_id),
2250        guest_connection_ids,
2251        |connection_id| {
2252            session
2253                .peer
2254                .forward_send(session.connection_id, connection_id, request.clone())
2255        },
2256    );
2257    if host_connection_id != session.connection_id {
2258        session
2259            .peer
2260            .forward_request(session.connection_id, host_connection_id, request.clone())
2261            .await?;
2262    }
2263
2264    response.send(proto::Ack {})?;
2265    Ok(())
2266}
2267
2268/// Notify other participants that a project has been updated.
2269async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2270    request: T,
2271    session: Session,
2272) -> Result<()> {
2273    let project_id = ProjectId::from_proto(request.remote_entity_id());
2274    let project_connection_ids = session
2275        .db()
2276        .await
2277        .project_connection_ids(project_id, session.connection_id)
2278        .await?;
2279
2280    broadcast(
2281        Some(session.connection_id),
2282        project_connection_ids.iter().copied(),
2283        |connection_id| {
2284            session
2285                .peer
2286                .forward_send(session.connection_id, connection_id, request.clone())
2287        },
2288    );
2289    Ok(())
2290}
2291
2292/// Start following another user in a call.
2293async fn follow(
2294    request: proto::Follow,
2295    response: Response<proto::Follow>,
2296    session: UserSession,
2297) -> Result<()> {
2298    let room_id = RoomId::from_proto(request.room_id);
2299    let project_id = request.project_id.map(ProjectId::from_proto);
2300    let leader_id = request
2301        .leader_id
2302        .ok_or_else(|| anyhow!("invalid leader id"))?
2303        .into();
2304    let follower_id = session.connection_id;
2305
2306    session
2307        .db()
2308        .await
2309        .check_room_participants(room_id, leader_id, session.connection_id)
2310        .await?;
2311
2312    let response_payload = session
2313        .peer
2314        .forward_request(session.connection_id, leader_id, request)
2315        .await?;
2316    response.send(response_payload)?;
2317
2318    if let Some(project_id) = project_id {
2319        let room = session
2320            .db()
2321            .await
2322            .follow(room_id, project_id, leader_id, follower_id)
2323            .await?;
2324        room_updated(&room, &session.peer);
2325    }
2326
2327    Ok(())
2328}
2329
2330/// Stop following another user in a call.
2331async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2332    let room_id = RoomId::from_proto(request.room_id);
2333    let project_id = request.project_id.map(ProjectId::from_proto);
2334    let leader_id = request
2335        .leader_id
2336        .ok_or_else(|| anyhow!("invalid leader id"))?
2337        .into();
2338    let follower_id = session.connection_id;
2339
2340    session
2341        .db()
2342        .await
2343        .check_room_participants(room_id, leader_id, session.connection_id)
2344        .await?;
2345
2346    session
2347        .peer
2348        .forward_send(session.connection_id, leader_id, request)?;
2349
2350    if let Some(project_id) = project_id {
2351        let room = session
2352            .db()
2353            .await
2354            .unfollow(room_id, project_id, leader_id, follower_id)
2355            .await?;
2356        room_updated(&room, &session.peer);
2357    }
2358
2359    Ok(())
2360}
2361
2362/// Notify everyone following you of your current location.
2363async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2364    let room_id = RoomId::from_proto(request.room_id);
2365    let database = session.db.lock().await;
2366
2367    let connection_ids = if let Some(project_id) = request.project_id {
2368        let project_id = ProjectId::from_proto(project_id);
2369        database
2370            .project_connection_ids(project_id, session.connection_id)
2371            .await?
2372    } else {
2373        database
2374            .room_connection_ids(room_id, session.connection_id)
2375            .await?
2376    };
2377
2378    // For now, don't send view update messages back to that view's current leader.
2379    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2380        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2381        _ => None,
2382    });
2383
2384    for connection_id in connection_ids.iter().cloned() {
2385        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2386            session
2387                .peer
2388                .forward_send(session.connection_id, connection_id, request.clone())?;
2389        }
2390    }
2391    Ok(())
2392}
2393
2394/// Get public data about users.
2395async fn get_users(
2396    request: proto::GetUsers,
2397    response: Response<proto::GetUsers>,
2398    session: Session,
2399) -> Result<()> {
2400    let user_ids = request
2401        .user_ids
2402        .into_iter()
2403        .map(UserId::from_proto)
2404        .collect();
2405    let users = session
2406        .db()
2407        .await
2408        .get_users_by_ids(user_ids)
2409        .await?
2410        .into_iter()
2411        .map(|user| proto::User {
2412            id: user.id.to_proto(),
2413            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2414            github_login: user.github_login,
2415        })
2416        .collect();
2417    response.send(proto::UsersResponse { users })?;
2418    Ok(())
2419}
2420
2421/// Search for users (to invite) buy Github login
2422async fn fuzzy_search_users(
2423    request: proto::FuzzySearchUsers,
2424    response: Response<proto::FuzzySearchUsers>,
2425    session: UserSession,
2426) -> Result<()> {
2427    let query = request.query;
2428    let users = match query.len() {
2429        0 => vec![],
2430        1 | 2 => session
2431            .db()
2432            .await
2433            .get_user_by_github_login(&query)
2434            .await?
2435            .into_iter()
2436            .collect(),
2437        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2438    };
2439    let users = users
2440        .into_iter()
2441        .filter(|user| user.id != session.user_id())
2442        .map(|user| proto::User {
2443            id: user.id.to_proto(),
2444            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2445            github_login: user.github_login,
2446        })
2447        .collect();
2448    response.send(proto::UsersResponse { users })?;
2449    Ok(())
2450}
2451
2452/// Send a contact request to another user.
2453async fn request_contact(
2454    request: proto::RequestContact,
2455    response: Response<proto::RequestContact>,
2456    session: UserSession,
2457) -> Result<()> {
2458    let requester_id = session.user_id();
2459    let responder_id = UserId::from_proto(request.responder_id);
2460    if requester_id == responder_id {
2461        return Err(anyhow!("cannot add yourself as a contact"))?;
2462    }
2463
2464    let notifications = session
2465        .db()
2466        .await
2467        .send_contact_request(requester_id, responder_id)
2468        .await?;
2469
2470    // Update outgoing contact requests of requester
2471    let mut update = proto::UpdateContacts::default();
2472    update.outgoing_requests.push(responder_id.to_proto());
2473    for connection_id in session
2474        .connection_pool()
2475        .await
2476        .user_connection_ids(requester_id)
2477    {
2478        session.peer.send(connection_id, update.clone())?;
2479    }
2480
2481    // Update incoming contact requests of responder
2482    let mut update = proto::UpdateContacts::default();
2483    update
2484        .incoming_requests
2485        .push(proto::IncomingContactRequest {
2486            requester_id: requester_id.to_proto(),
2487        });
2488    let connection_pool = session.connection_pool().await;
2489    for connection_id in connection_pool.user_connection_ids(responder_id) {
2490        session.peer.send(connection_id, update.clone())?;
2491    }
2492
2493    send_notifications(&connection_pool, &session.peer, notifications);
2494
2495    response.send(proto::Ack {})?;
2496    Ok(())
2497}
2498
2499/// Accept or decline a contact request
2500async fn respond_to_contact_request(
2501    request: proto::RespondToContactRequest,
2502    response: Response<proto::RespondToContactRequest>,
2503    session: UserSession,
2504) -> Result<()> {
2505    let responder_id = session.user_id();
2506    let requester_id = UserId::from_proto(request.requester_id);
2507    let db = session.db().await;
2508    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2509        db.dismiss_contact_notification(responder_id, requester_id)
2510            .await?;
2511    } else {
2512        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2513
2514        let notifications = db
2515            .respond_to_contact_request(responder_id, requester_id, accept)
2516            .await?;
2517        let requester_busy = db.is_user_busy(requester_id).await?;
2518        let responder_busy = db.is_user_busy(responder_id).await?;
2519
2520        let pool = session.connection_pool().await;
2521        // Update responder with new contact
2522        let mut update = proto::UpdateContacts::default();
2523        if accept {
2524            update
2525                .contacts
2526                .push(contact_for_user(requester_id, requester_busy, &pool));
2527        }
2528        update
2529            .remove_incoming_requests
2530            .push(requester_id.to_proto());
2531        for connection_id in pool.user_connection_ids(responder_id) {
2532            session.peer.send(connection_id, update.clone())?;
2533        }
2534
2535        // Update requester with new contact
2536        let mut update = proto::UpdateContacts::default();
2537        if accept {
2538            update
2539                .contacts
2540                .push(contact_for_user(responder_id, responder_busy, &pool));
2541        }
2542        update
2543            .remove_outgoing_requests
2544            .push(responder_id.to_proto());
2545
2546        for connection_id in pool.user_connection_ids(requester_id) {
2547            session.peer.send(connection_id, update.clone())?;
2548        }
2549
2550        send_notifications(&pool, &session.peer, notifications);
2551    }
2552
2553    response.send(proto::Ack {})?;
2554    Ok(())
2555}
2556
2557/// Remove a contact.
2558async fn remove_contact(
2559    request: proto::RemoveContact,
2560    response: Response<proto::RemoveContact>,
2561    session: UserSession,
2562) -> Result<()> {
2563    let requester_id = session.user_id();
2564    let responder_id = UserId::from_proto(request.user_id);
2565    let db = session.db().await;
2566    let (contact_accepted, deleted_notification_id) =
2567        db.remove_contact(requester_id, responder_id).await?;
2568
2569    let pool = session.connection_pool().await;
2570    // Update outgoing contact requests of requester
2571    let mut update = proto::UpdateContacts::default();
2572    if contact_accepted {
2573        update.remove_contacts.push(responder_id.to_proto());
2574    } else {
2575        update
2576            .remove_outgoing_requests
2577            .push(responder_id.to_proto());
2578    }
2579    for connection_id in pool.user_connection_ids(requester_id) {
2580        session.peer.send(connection_id, update.clone())?;
2581    }
2582
2583    // Update incoming contact requests of responder
2584    let mut update = proto::UpdateContacts::default();
2585    if contact_accepted {
2586        update.remove_contacts.push(requester_id.to_proto());
2587    } else {
2588        update
2589            .remove_incoming_requests
2590            .push(requester_id.to_proto());
2591    }
2592    for connection_id in pool.user_connection_ids(responder_id) {
2593        session.peer.send(connection_id, update.clone())?;
2594        if let Some(notification_id) = deleted_notification_id {
2595            session.peer.send(
2596                connection_id,
2597                proto::DeleteNotification {
2598                    notification_id: notification_id.to_proto(),
2599                },
2600            )?;
2601        }
2602    }
2603
2604    response.send(proto::Ack {})?;
2605    Ok(())
2606}
2607
2608/// Creates a new channel.
2609async fn create_channel(
2610    request: proto::CreateChannel,
2611    response: Response<proto::CreateChannel>,
2612    session: UserSession,
2613) -> Result<()> {
2614    let db = session.db().await;
2615
2616    let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2617    let (channel, membership) = db
2618        .create_channel(&request.name, parent_id, session.user_id())
2619        .await?;
2620
2621    let root_id = channel.root_id();
2622    let channel = Channel::from_model(channel);
2623
2624    response.send(proto::CreateChannelResponse {
2625        channel: Some(channel.to_proto()),
2626        parent_id: request.parent_id,
2627    })?;
2628
2629    let mut connection_pool = session.connection_pool().await;
2630    if let Some(membership) = membership {
2631        connection_pool.subscribe_to_channel(
2632            membership.user_id,
2633            membership.channel_id,
2634            membership.role,
2635        );
2636        let update = proto::UpdateUserChannels {
2637            channel_memberships: vec![proto::ChannelMembership {
2638                channel_id: membership.channel_id.to_proto(),
2639                role: membership.role.into(),
2640            }],
2641            ..Default::default()
2642        };
2643        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2644            session.peer.send(connection_id, update.clone())?;
2645        }
2646    }
2647
2648    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2649        if !role.can_see_channel(channel.visibility) {
2650            continue;
2651        }
2652
2653        let update = proto::UpdateChannels {
2654            channels: vec![channel.to_proto()],
2655            ..Default::default()
2656        };
2657        session.peer.send(connection_id, update.clone())?;
2658    }
2659
2660    Ok(())
2661}
2662
2663/// Delete a channel
2664async fn delete_channel(
2665    request: proto::DeleteChannel,
2666    response: Response<proto::DeleteChannel>,
2667    session: UserSession,
2668) -> Result<()> {
2669    let db = session.db().await;
2670
2671    let channel_id = request.channel_id;
2672    let (root_channel, removed_channels) = db
2673        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2674        .await?;
2675    response.send(proto::Ack {})?;
2676
2677    // Notify members of removed channels
2678    let mut update = proto::UpdateChannels::default();
2679    update
2680        .delete_channels
2681        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2682
2683    let connection_pool = session.connection_pool().await;
2684    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2685        session.peer.send(connection_id, update.clone())?;
2686    }
2687
2688    Ok(())
2689}
2690
2691/// Invite someone to join a channel.
2692async fn invite_channel_member(
2693    request: proto::InviteChannelMember,
2694    response: Response<proto::InviteChannelMember>,
2695    session: UserSession,
2696) -> Result<()> {
2697    let db = session.db().await;
2698    let channel_id = ChannelId::from_proto(request.channel_id);
2699    let invitee_id = UserId::from_proto(request.user_id);
2700    let InviteMemberResult {
2701        channel,
2702        notifications,
2703    } = db
2704        .invite_channel_member(
2705            channel_id,
2706            invitee_id,
2707            session.user_id(),
2708            request.role().into(),
2709        )
2710        .await?;
2711
2712    let update = proto::UpdateChannels {
2713        channel_invitations: vec![channel.to_proto()],
2714        ..Default::default()
2715    };
2716
2717    let connection_pool = session.connection_pool().await;
2718    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2719        session.peer.send(connection_id, update.clone())?;
2720    }
2721
2722    send_notifications(&connection_pool, &session.peer, notifications);
2723
2724    response.send(proto::Ack {})?;
2725    Ok(())
2726}
2727
2728/// remove someone from a channel
2729async fn remove_channel_member(
2730    request: proto::RemoveChannelMember,
2731    response: Response<proto::RemoveChannelMember>,
2732    session: UserSession,
2733) -> Result<()> {
2734    let db = session.db().await;
2735    let channel_id = ChannelId::from_proto(request.channel_id);
2736    let member_id = UserId::from_proto(request.user_id);
2737
2738    let RemoveChannelMemberResult {
2739        membership_update,
2740        notification_id,
2741    } = db
2742        .remove_channel_member(channel_id, member_id, session.user_id())
2743        .await?;
2744
2745    let mut connection_pool = session.connection_pool().await;
2746    notify_membership_updated(
2747        &mut connection_pool,
2748        membership_update,
2749        member_id,
2750        &session.peer,
2751    );
2752    for connection_id in connection_pool.user_connection_ids(member_id) {
2753        if let Some(notification_id) = notification_id {
2754            session
2755                .peer
2756                .send(
2757                    connection_id,
2758                    proto::DeleteNotification {
2759                        notification_id: notification_id.to_proto(),
2760                    },
2761                )
2762                .trace_err();
2763        }
2764    }
2765
2766    response.send(proto::Ack {})?;
2767    Ok(())
2768}
2769
2770/// Toggle the channel between public and private.
2771/// Care is taken to maintain the invariant that public channels only descend from public channels,
2772/// (though members-only channels can appear at any point in the hierarchy).
2773async fn set_channel_visibility(
2774    request: proto::SetChannelVisibility,
2775    response: Response<proto::SetChannelVisibility>,
2776    session: UserSession,
2777) -> Result<()> {
2778    let db = session.db().await;
2779    let channel_id = ChannelId::from_proto(request.channel_id);
2780    let visibility = request.visibility().into();
2781
2782    let channel_model = db
2783        .set_channel_visibility(channel_id, visibility, session.user_id())
2784        .await?;
2785    let root_id = channel_model.root_id();
2786    let channel = Channel::from_model(channel_model);
2787
2788    let mut connection_pool = session.connection_pool().await;
2789    for (user_id, role) in connection_pool
2790        .channel_user_ids(root_id)
2791        .collect::<Vec<_>>()
2792        .into_iter()
2793    {
2794        let update = if role.can_see_channel(channel.visibility) {
2795            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2796            proto::UpdateChannels {
2797                channels: vec![channel.to_proto()],
2798                ..Default::default()
2799            }
2800        } else {
2801            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2802            proto::UpdateChannels {
2803                delete_channels: vec![channel.id.to_proto()],
2804                ..Default::default()
2805            }
2806        };
2807
2808        for connection_id in connection_pool.user_connection_ids(user_id) {
2809            session.peer.send(connection_id, update.clone())?;
2810        }
2811    }
2812
2813    response.send(proto::Ack {})?;
2814    Ok(())
2815}
2816
2817/// Alter the role for a user in the channel.
2818async fn set_channel_member_role(
2819    request: proto::SetChannelMemberRole,
2820    response: Response<proto::SetChannelMemberRole>,
2821    session: UserSession,
2822) -> Result<()> {
2823    let db = session.db().await;
2824    let channel_id = ChannelId::from_proto(request.channel_id);
2825    let member_id = UserId::from_proto(request.user_id);
2826    let result = db
2827        .set_channel_member_role(
2828            channel_id,
2829            session.user_id(),
2830            member_id,
2831            request.role().into(),
2832        )
2833        .await?;
2834
2835    match result {
2836        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2837            let mut connection_pool = session.connection_pool().await;
2838            notify_membership_updated(
2839                &mut connection_pool,
2840                membership_update,
2841                member_id,
2842                &session.peer,
2843            )
2844        }
2845        db::SetMemberRoleResult::InviteUpdated(channel) => {
2846            let update = proto::UpdateChannels {
2847                channel_invitations: vec![channel.to_proto()],
2848                ..Default::default()
2849            };
2850
2851            for connection_id in session
2852                .connection_pool()
2853                .await
2854                .user_connection_ids(member_id)
2855            {
2856                session.peer.send(connection_id, update.clone())?;
2857            }
2858        }
2859    }
2860
2861    response.send(proto::Ack {})?;
2862    Ok(())
2863}
2864
2865/// Change the name of a channel
2866async fn rename_channel(
2867    request: proto::RenameChannel,
2868    response: Response<proto::RenameChannel>,
2869    session: UserSession,
2870) -> Result<()> {
2871    let db = session.db().await;
2872    let channel_id = ChannelId::from_proto(request.channel_id);
2873    let channel_model = db
2874        .rename_channel(channel_id, session.user_id(), &request.name)
2875        .await?;
2876    let root_id = channel_model.root_id();
2877    let channel = Channel::from_model(channel_model);
2878
2879    response.send(proto::RenameChannelResponse {
2880        channel: Some(channel.to_proto()),
2881    })?;
2882
2883    let connection_pool = session.connection_pool().await;
2884    let update = proto::UpdateChannels {
2885        channels: vec![channel.to_proto()],
2886        ..Default::default()
2887    };
2888    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2889        if role.can_see_channel(channel.visibility) {
2890            session.peer.send(connection_id, update.clone())?;
2891        }
2892    }
2893
2894    Ok(())
2895}
2896
2897/// Move a channel to a new parent.
2898async fn move_channel(
2899    request: proto::MoveChannel,
2900    response: Response<proto::MoveChannel>,
2901    session: UserSession,
2902) -> Result<()> {
2903    let channel_id = ChannelId::from_proto(request.channel_id);
2904    let to = ChannelId::from_proto(request.to);
2905
2906    let (root_id, channels) = session
2907        .db()
2908        .await
2909        .move_channel(channel_id, to, session.user_id())
2910        .await?;
2911
2912    let connection_pool = session.connection_pool().await;
2913    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2914        let channels = channels
2915            .iter()
2916            .filter_map(|channel| {
2917                if role.can_see_channel(channel.visibility) {
2918                    Some(channel.to_proto())
2919                } else {
2920                    None
2921                }
2922            })
2923            .collect::<Vec<_>>();
2924        if channels.is_empty() {
2925            continue;
2926        }
2927
2928        let update = proto::UpdateChannels {
2929            channels,
2930            ..Default::default()
2931        };
2932
2933        session.peer.send(connection_id, update.clone())?;
2934    }
2935
2936    response.send(Ack {})?;
2937    Ok(())
2938}
2939
2940/// Get the list of channel members
2941async fn get_channel_members(
2942    request: proto::GetChannelMembers,
2943    response: Response<proto::GetChannelMembers>,
2944    session: UserSession,
2945) -> Result<()> {
2946    let db = session.db().await;
2947    let channel_id = ChannelId::from_proto(request.channel_id);
2948    let members = db
2949        .get_channel_participant_details(channel_id, session.user_id())
2950        .await?;
2951    response.send(proto::GetChannelMembersResponse { members })?;
2952    Ok(())
2953}
2954
2955/// Accept or decline a channel invitation.
2956async fn respond_to_channel_invite(
2957    request: proto::RespondToChannelInvite,
2958    response: Response<proto::RespondToChannelInvite>,
2959    session: UserSession,
2960) -> Result<()> {
2961    let db = session.db().await;
2962    let channel_id = ChannelId::from_proto(request.channel_id);
2963    let RespondToChannelInvite {
2964        membership_update,
2965        notifications,
2966    } = db
2967        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2968        .await?;
2969
2970    let mut connection_pool = session.connection_pool().await;
2971    if let Some(membership_update) = membership_update {
2972        notify_membership_updated(
2973            &mut connection_pool,
2974            membership_update,
2975            session.user_id(),
2976            &session.peer,
2977        );
2978    } else {
2979        let update = proto::UpdateChannels {
2980            remove_channel_invitations: vec![channel_id.to_proto()],
2981            ..Default::default()
2982        };
2983
2984        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2985            session.peer.send(connection_id, update.clone())?;
2986        }
2987    };
2988
2989    send_notifications(&connection_pool, &session.peer, notifications);
2990
2991    response.send(proto::Ack {})?;
2992
2993    Ok(())
2994}
2995
2996/// Join the channels' room
2997async fn join_channel(
2998    request: proto::JoinChannel,
2999    response: Response<proto::JoinChannel>,
3000    session: UserSession,
3001) -> Result<()> {
3002    let channel_id = ChannelId::from_proto(request.channel_id);
3003    join_channel_internal(channel_id, Box::new(response), session).await
3004}
3005
3006trait JoinChannelInternalResponse {
3007    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3008}
3009impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3010    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3011        Response::<proto::JoinChannel>::send(self, result)
3012    }
3013}
3014impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3015    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3016        Response::<proto::JoinRoom>::send(self, result)
3017    }
3018}
3019
3020async fn join_channel_internal(
3021    channel_id: ChannelId,
3022    response: Box<impl JoinChannelInternalResponse>,
3023    session: UserSession,
3024) -> Result<()> {
3025    let joined_room = {
3026        let mut db = session.db().await;
3027        // If zed quits without leaving the room, and the user re-opens zed before the
3028        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3029        // room they were in.
3030        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3031            tracing::info!(
3032                stale_connection_id = %connection,
3033                "cleaning up stale connection",
3034            );
3035            drop(db);
3036            leave_room_for_session(&session, connection).await?;
3037            db = session.db().await;
3038        }
3039
3040        let (joined_room, membership_updated, role) = db
3041            .join_channel(channel_id, session.user_id(), session.connection_id)
3042            .await?;
3043
3044        let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3045            let (can_publish, token) = if role == ChannelRole::Guest {
3046                (
3047                    false,
3048                    live_kit
3049                        .guest_token(
3050                            &joined_room.room.live_kit_room,
3051                            &session.user_id().to_string(),
3052                        )
3053                        .trace_err()?,
3054                )
3055            } else {
3056                (
3057                    true,
3058                    live_kit
3059                        .room_token(
3060                            &joined_room.room.live_kit_room,
3061                            &session.user_id().to_string(),
3062                        )
3063                        .trace_err()?,
3064                )
3065            };
3066
3067            Some(LiveKitConnectionInfo {
3068                server_url: live_kit.url().into(),
3069                token,
3070                can_publish,
3071            })
3072        });
3073
3074        response.send(proto::JoinRoomResponse {
3075            room: Some(joined_room.room.clone()),
3076            channel_id: joined_room
3077                .channel
3078                .as_ref()
3079                .map(|channel| channel.id.to_proto()),
3080            live_kit_connection_info,
3081        })?;
3082
3083        let mut connection_pool = session.connection_pool().await;
3084        if let Some(membership_updated) = membership_updated {
3085            notify_membership_updated(
3086                &mut connection_pool,
3087                membership_updated,
3088                session.user_id(),
3089                &session.peer,
3090            );
3091        }
3092
3093        room_updated(&joined_room.room, &session.peer);
3094
3095        joined_room
3096    };
3097
3098    channel_updated(
3099        &joined_room
3100            .channel
3101            .ok_or_else(|| anyhow!("channel not returned"))?,
3102        &joined_room.room,
3103        &session.peer,
3104        &*session.connection_pool().await,
3105    );
3106
3107    update_user_contacts(session.user_id(), &session).await?;
3108    Ok(())
3109}
3110
3111/// Start editing the channel notes
3112async fn join_channel_buffer(
3113    request: proto::JoinChannelBuffer,
3114    response: Response<proto::JoinChannelBuffer>,
3115    session: UserSession,
3116) -> Result<()> {
3117    let db = session.db().await;
3118    let channel_id = ChannelId::from_proto(request.channel_id);
3119
3120    let open_response = db
3121        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3122        .await?;
3123
3124    let collaborators = open_response.collaborators.clone();
3125    response.send(open_response)?;
3126
3127    let update = UpdateChannelBufferCollaborators {
3128        channel_id: channel_id.to_proto(),
3129        collaborators: collaborators.clone(),
3130    };
3131    channel_buffer_updated(
3132        session.connection_id,
3133        collaborators
3134            .iter()
3135            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3136        &update,
3137        &session.peer,
3138    );
3139
3140    Ok(())
3141}
3142
3143/// Edit the channel notes
3144async fn update_channel_buffer(
3145    request: proto::UpdateChannelBuffer,
3146    session: UserSession,
3147) -> Result<()> {
3148    let db = session.db().await;
3149    let channel_id = ChannelId::from_proto(request.channel_id);
3150
3151    let (collaborators, non_collaborators, epoch, version) = db
3152        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3153        .await?;
3154
3155    channel_buffer_updated(
3156        session.connection_id,
3157        collaborators,
3158        &proto::UpdateChannelBuffer {
3159            channel_id: channel_id.to_proto(),
3160            operations: request.operations,
3161        },
3162        &session.peer,
3163    );
3164
3165    let pool = &*session.connection_pool().await;
3166
3167    broadcast(
3168        None,
3169        non_collaborators
3170            .iter()
3171            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3172        |peer_id| {
3173            session.peer.send(
3174                peer_id,
3175                proto::UpdateChannels {
3176                    latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3177                        channel_id: channel_id.to_proto(),
3178                        epoch: epoch as u64,
3179                        version: version.clone(),
3180                    }],
3181                    ..Default::default()
3182                },
3183            )
3184        },
3185    );
3186
3187    Ok(())
3188}
3189
3190/// Rejoin the channel notes after a connection blip
3191async fn rejoin_channel_buffers(
3192    request: proto::RejoinChannelBuffers,
3193    response: Response<proto::RejoinChannelBuffers>,
3194    session: UserSession,
3195) -> Result<()> {
3196    let db = session.db().await;
3197    let buffers = db
3198        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3199        .await?;
3200
3201    for rejoined_buffer in &buffers {
3202        let collaborators_to_notify = rejoined_buffer
3203            .buffer
3204            .collaborators
3205            .iter()
3206            .filter_map(|c| Some(c.peer_id?.into()));
3207        channel_buffer_updated(
3208            session.connection_id,
3209            collaborators_to_notify,
3210            &proto::UpdateChannelBufferCollaborators {
3211                channel_id: rejoined_buffer.buffer.channel_id,
3212                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3213            },
3214            &session.peer,
3215        );
3216    }
3217
3218    response.send(proto::RejoinChannelBuffersResponse {
3219        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3220    })?;
3221
3222    Ok(())
3223}
3224
3225/// Stop editing the channel notes
3226async fn leave_channel_buffer(
3227    request: proto::LeaveChannelBuffer,
3228    response: Response<proto::LeaveChannelBuffer>,
3229    session: UserSession,
3230) -> Result<()> {
3231    let db = session.db().await;
3232    let channel_id = ChannelId::from_proto(request.channel_id);
3233
3234    let left_buffer = db
3235        .leave_channel_buffer(channel_id, session.connection_id)
3236        .await?;
3237
3238    response.send(Ack {})?;
3239
3240    channel_buffer_updated(
3241        session.connection_id,
3242        left_buffer.connections,
3243        &proto::UpdateChannelBufferCollaborators {
3244            channel_id: channel_id.to_proto(),
3245            collaborators: left_buffer.collaborators,
3246        },
3247        &session.peer,
3248    );
3249
3250    Ok(())
3251}
3252
3253fn channel_buffer_updated<T: EnvelopedMessage>(
3254    sender_id: ConnectionId,
3255    collaborators: impl IntoIterator<Item = ConnectionId>,
3256    message: &T,
3257    peer: &Peer,
3258) {
3259    broadcast(Some(sender_id), collaborators, |peer_id| {
3260        peer.send(peer_id, message.clone())
3261    });
3262}
3263
3264fn send_notifications(
3265    connection_pool: &ConnectionPool,
3266    peer: &Peer,
3267    notifications: db::NotificationBatch,
3268) {
3269    for (user_id, notification) in notifications {
3270        for connection_id in connection_pool.user_connection_ids(user_id) {
3271            if let Err(error) = peer.send(
3272                connection_id,
3273                proto::AddNotification {
3274                    notification: Some(notification.clone()),
3275                },
3276            ) {
3277                tracing::error!(
3278                    "failed to send notification to {:?} {}",
3279                    connection_id,
3280                    error
3281                );
3282            }
3283        }
3284    }
3285}
3286
3287/// Send a message to the channel
3288async fn send_channel_message(
3289    request: proto::SendChannelMessage,
3290    response: Response<proto::SendChannelMessage>,
3291    session: UserSession,
3292) -> Result<()> {
3293    // Validate the message body.
3294    let body = request.body.trim().to_string();
3295    if body.len() > MAX_MESSAGE_LEN {
3296        return Err(anyhow!("message is too long"))?;
3297    }
3298    if body.is_empty() {
3299        return Err(anyhow!("message can't be blank"))?;
3300    }
3301
3302    // TODO: adjust mentions if body is trimmed
3303
3304    let timestamp = OffsetDateTime::now_utc();
3305    let nonce = request
3306        .nonce
3307        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3308
3309    let channel_id = ChannelId::from_proto(request.channel_id);
3310    let CreatedChannelMessage {
3311        message_id,
3312        participant_connection_ids,
3313        channel_members,
3314        notifications,
3315    } = session
3316        .db()
3317        .await
3318        .create_channel_message(
3319            channel_id,
3320            session.user_id(),
3321            &body,
3322            &request.mentions,
3323            timestamp,
3324            nonce.clone().into(),
3325            match request.reply_to_message_id {
3326                Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3327                None => None,
3328            },
3329        )
3330        .await?;
3331
3332    let message = proto::ChannelMessage {
3333        sender_id: session.user_id().to_proto(),
3334        id: message_id.to_proto(),
3335        body,
3336        mentions: request.mentions,
3337        timestamp: timestamp.unix_timestamp() as u64,
3338        nonce: Some(nonce),
3339        reply_to_message_id: request.reply_to_message_id,
3340        edited_at: None,
3341    };
3342    broadcast(
3343        Some(session.connection_id),
3344        participant_connection_ids,
3345        |connection| {
3346            session.peer.send(
3347                connection,
3348                proto::ChannelMessageSent {
3349                    channel_id: channel_id.to_proto(),
3350                    message: Some(message.clone()),
3351                },
3352            )
3353        },
3354    );
3355    response.send(proto::SendChannelMessageResponse {
3356        message: Some(message),
3357    })?;
3358
3359    let pool = &*session.connection_pool().await;
3360    broadcast(
3361        None,
3362        channel_members
3363            .iter()
3364            .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3365        |peer_id| {
3366            session.peer.send(
3367                peer_id,
3368                proto::UpdateChannels {
3369                    latest_channel_message_ids: vec![proto::ChannelMessageId {
3370                        channel_id: channel_id.to_proto(),
3371                        message_id: message_id.to_proto(),
3372                    }],
3373                    ..Default::default()
3374                },
3375            )
3376        },
3377    );
3378    send_notifications(pool, &session.peer, notifications);
3379
3380    Ok(())
3381}
3382
3383/// Delete a channel message
3384async fn remove_channel_message(
3385    request: proto::RemoveChannelMessage,
3386    response: Response<proto::RemoveChannelMessage>,
3387    session: UserSession,
3388) -> Result<()> {
3389    let channel_id = ChannelId::from_proto(request.channel_id);
3390    let message_id = MessageId::from_proto(request.message_id);
3391    let (connection_ids, existing_notification_ids) = session
3392        .db()
3393        .await
3394        .remove_channel_message(channel_id, message_id, session.user_id())
3395        .await?;
3396
3397    broadcast(
3398        Some(session.connection_id),
3399        connection_ids,
3400        move |connection| {
3401            session.peer.send(connection, request.clone())?;
3402
3403            for notification_id in &existing_notification_ids {
3404                session.peer.send(
3405                    connection,
3406                    proto::DeleteNotification {
3407                        notification_id: (*notification_id).to_proto(),
3408                    },
3409                )?;
3410            }
3411
3412            Ok(())
3413        },
3414    );
3415    response.send(proto::Ack {})?;
3416    Ok(())
3417}
3418
3419async fn update_channel_message(
3420    request: proto::UpdateChannelMessage,
3421    response: Response<proto::UpdateChannelMessage>,
3422    session: UserSession,
3423) -> Result<()> {
3424    let channel_id = ChannelId::from_proto(request.channel_id);
3425    let message_id = MessageId::from_proto(request.message_id);
3426    let updated_at = OffsetDateTime::now_utc();
3427    let UpdatedChannelMessage {
3428        message_id,
3429        participant_connection_ids,
3430        notifications,
3431        reply_to_message_id,
3432        timestamp,
3433        deleted_mention_notification_ids,
3434        updated_mention_notifications,
3435    } = session
3436        .db()
3437        .await
3438        .update_channel_message(
3439            channel_id,
3440            message_id,
3441            session.user_id(),
3442            request.body.as_str(),
3443            &request.mentions,
3444            updated_at,
3445        )
3446        .await?;
3447
3448    let nonce = request
3449        .nonce
3450        .clone()
3451        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3452
3453    let message = proto::ChannelMessage {
3454        sender_id: session.user_id().to_proto(),
3455        id: message_id.to_proto(),
3456        body: request.body.clone(),
3457        mentions: request.mentions.clone(),
3458        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3459        nonce: Some(nonce),
3460        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3461        edited_at: Some(updated_at.unix_timestamp() as u64),
3462    };
3463
3464    response.send(proto::Ack {})?;
3465
3466    let pool = &*session.connection_pool().await;
3467    broadcast(
3468        Some(session.connection_id),
3469        participant_connection_ids,
3470        |connection| {
3471            session.peer.send(
3472                connection,
3473                proto::ChannelMessageUpdate {
3474                    channel_id: channel_id.to_proto(),
3475                    message: Some(message.clone()),
3476                },
3477            )?;
3478
3479            for notification_id in &deleted_mention_notification_ids {
3480                session.peer.send(
3481                    connection,
3482                    proto::DeleteNotification {
3483                        notification_id: (*notification_id).to_proto(),
3484                    },
3485                )?;
3486            }
3487
3488            for notification in &updated_mention_notifications {
3489                session.peer.send(
3490                    connection,
3491                    proto::UpdateNotification {
3492                        notification: Some(notification.clone()),
3493                    },
3494                )?;
3495            }
3496
3497            Ok(())
3498        },
3499    );
3500
3501    send_notifications(pool, &session.peer, notifications);
3502
3503    Ok(())
3504}
3505
3506/// Mark a channel message as read
3507async fn acknowledge_channel_message(
3508    request: proto::AckChannelMessage,
3509    session: UserSession,
3510) -> Result<()> {
3511    let channel_id = ChannelId::from_proto(request.channel_id);
3512    let message_id = MessageId::from_proto(request.message_id);
3513    let notifications = session
3514        .db()
3515        .await
3516        .observe_channel_message(channel_id, session.user_id(), message_id)
3517        .await?;
3518    send_notifications(
3519        &*session.connection_pool().await,
3520        &session.peer,
3521        notifications,
3522    );
3523    Ok(())
3524}
3525
3526/// Mark a buffer version as synced
3527async fn acknowledge_buffer_version(
3528    request: proto::AckBufferOperation,
3529    session: UserSession,
3530) -> Result<()> {
3531    let buffer_id = BufferId::from_proto(request.buffer_id);
3532    session
3533        .db()
3534        .await
3535        .observe_buffer_version(
3536            buffer_id,
3537            session.user_id(),
3538            request.epoch as i32,
3539            &request.version,
3540        )
3541        .await?;
3542    Ok(())
3543}
3544
3545struct CompleteWithLanguageModelRateLimit;
3546
3547impl RateLimit for CompleteWithLanguageModelRateLimit {
3548    fn capacity() -> usize {
3549        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3550            .ok()
3551            .and_then(|v| v.parse().ok())
3552            .unwrap_or(120) // Picked arbitrarily
3553    }
3554
3555    fn refill_duration() -> chrono::Duration {
3556        chrono::Duration::hours(1)
3557    }
3558
3559    fn db_name() -> &'static str {
3560        "complete-with-language-model"
3561    }
3562}
3563
3564async fn complete_with_language_model(
3565    request: proto::CompleteWithLanguageModel,
3566    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3567    session: Session,
3568    open_ai_api_key: Option<Arc<str>>,
3569    google_ai_api_key: Option<Arc<str>>,
3570    anthropic_api_key: Option<Arc<str>>,
3571) -> Result<()> {
3572    let Some(session) = session.for_user() else {
3573        return Err(anyhow!("user not found"))?;
3574    };
3575    authorize_access_to_language_models(&session).await?;
3576    session
3577        .rate_limiter
3578        .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3579        .await?;
3580
3581    if request.model.starts_with("gpt") {
3582        let api_key =
3583            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3584        complete_with_open_ai(request, response, session, api_key).await?;
3585    } else if request.model.starts_with("gemini") {
3586        let api_key = google_ai_api_key
3587            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3588        complete_with_google_ai(request, response, session, api_key).await?;
3589    } else if request.model.starts_with("claude") {
3590        let api_key = anthropic_api_key
3591            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
3592        complete_with_anthropic(request, response, session, api_key).await?;
3593    }
3594
3595    Ok(())
3596}
3597
3598async fn complete_with_open_ai(
3599    request: proto::CompleteWithLanguageModel,
3600    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3601    session: UserSession,
3602    api_key: Arc<str>,
3603) -> Result<()> {
3604    const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3605
3606    let mut completion_stream = open_ai::stream_completion(
3607        &session.http_client,
3608        OPEN_AI_API_URL,
3609        &api_key,
3610        crate::ai::language_model_request_to_open_ai(request)?,
3611    )
3612    .await
3613    .context("open_ai::stream_completion request failed")?;
3614
3615    while let Some(event) = completion_stream.next().await {
3616        let event = event?;
3617        response.send(proto::LanguageModelResponse {
3618            choices: event
3619                .choices
3620                .into_iter()
3621                .map(|choice| proto::LanguageModelChoiceDelta {
3622                    index: choice.index,
3623                    delta: Some(proto::LanguageModelResponseMessage {
3624                        role: choice.delta.role.map(|role| match role {
3625                            open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3626                            open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3627                            open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3628                        } as i32),
3629                        content: choice.delta.content,
3630                    }),
3631                    finish_reason: choice.finish_reason,
3632                })
3633                .collect(),
3634        })?;
3635    }
3636
3637    Ok(())
3638}
3639
3640async fn complete_with_google_ai(
3641    request: proto::CompleteWithLanguageModel,
3642    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3643    session: UserSession,
3644    api_key: Arc<str>,
3645) -> Result<()> {
3646    let mut stream = google_ai::stream_generate_content(
3647        &session.http_client,
3648        google_ai::API_URL,
3649        api_key.as_ref(),
3650        crate::ai::language_model_request_to_google_ai(request)?,
3651    )
3652    .await
3653    .context("google_ai::stream_generate_content request failed")?;
3654
3655    while let Some(event) = stream.next().await {
3656        let event = event?;
3657        response.send(proto::LanguageModelResponse {
3658            choices: event
3659                .candidates
3660                .unwrap_or_default()
3661                .into_iter()
3662                .map(|candidate| proto::LanguageModelChoiceDelta {
3663                    index: candidate.index as u32,
3664                    delta: Some(proto::LanguageModelResponseMessage {
3665                        role: Some(match candidate.content.role {
3666                            google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3667                            google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3668                        } as i32),
3669                        content: Some(
3670                            candidate
3671                                .content
3672                                .parts
3673                                .into_iter()
3674                                .filter_map(|part| match part {
3675                                    google_ai::Part::TextPart(part) => Some(part.text),
3676                                    google_ai::Part::InlineDataPart(_) => None,
3677                                })
3678                                .collect(),
3679                        ),
3680                    }),
3681                    finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3682                })
3683                .collect(),
3684        })?;
3685    }
3686
3687    Ok(())
3688}
3689
3690async fn complete_with_anthropic(
3691    request: proto::CompleteWithLanguageModel,
3692    response: StreamingResponse<proto::CompleteWithLanguageModel>,
3693    session: UserSession,
3694    api_key: Arc<str>,
3695) -> Result<()> {
3696    let model = anthropic::Model::from_id(&request.model)?;
3697
3698    let mut system_message = String::new();
3699    let messages = request
3700        .messages
3701        .into_iter()
3702        .filter_map(|message| match message.role() {
3703            LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
3704                role: anthropic::Role::User,
3705                content: message.content,
3706            }),
3707            LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
3708                role: anthropic::Role::Assistant,
3709                content: message.content,
3710            }),
3711            // Anthropic's API breaks system instructions out as a separate field rather
3712            // than having a system message role.
3713            LanguageModelRole::LanguageModelSystem => {
3714                if !system_message.is_empty() {
3715                    system_message.push_str("\n\n");
3716                }
3717                system_message.push_str(&message.content);
3718
3719                None
3720            }
3721        })
3722        .collect();
3723
3724    let mut stream = anthropic::stream_completion(
3725        &session.http_client,
3726        "https://api.anthropic.com",
3727        &api_key,
3728        anthropic::Request {
3729            model,
3730            messages,
3731            stream: true,
3732            system: system_message,
3733            max_tokens: 4092,
3734        },
3735    )
3736    .await?;
3737
3738    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
3739
3740    while let Some(event) = stream.next().await {
3741        let event = event?;
3742
3743        match event {
3744            anthropic::ResponseEvent::MessageStart { message } => {
3745                if let Some(role) = message.role {
3746                    if role == "assistant" {
3747                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
3748                    } else if role == "user" {
3749                        current_role = proto::LanguageModelRole::LanguageModelUser;
3750                    }
3751                }
3752            }
3753            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
3754                match content_block {
3755                    anthropic::ContentBlock::Text { text } => {
3756                        if !text.is_empty() {
3757                            response.send(proto::LanguageModelResponse {
3758                                choices: vec![proto::LanguageModelChoiceDelta {
3759                                    index: 0,
3760                                    delta: Some(proto::LanguageModelResponseMessage {
3761                                        role: Some(current_role as i32),
3762                                        content: Some(text),
3763                                    }),
3764                                    finish_reason: None,
3765                                }],
3766                            })?;
3767                        }
3768                    }
3769                }
3770            }
3771            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
3772                anthropic::TextDelta::TextDelta { text } => {
3773                    response.send(proto::LanguageModelResponse {
3774                        choices: vec![proto::LanguageModelChoiceDelta {
3775                            index: 0,
3776                            delta: Some(proto::LanguageModelResponseMessage {
3777                                role: Some(current_role as i32),
3778                                content: Some(text),
3779                            }),
3780                            finish_reason: None,
3781                        }],
3782                    })?;
3783                }
3784            },
3785            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
3786                if let Some(stop_reason) = delta.stop_reason {
3787                    response.send(proto::LanguageModelResponse {
3788                        choices: vec![proto::LanguageModelChoiceDelta {
3789                            index: 0,
3790                            delta: None,
3791                            finish_reason: Some(stop_reason),
3792                        }],
3793                    })?;
3794                }
3795            }
3796            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
3797            anthropic::ResponseEvent::MessageStop {} => {}
3798            anthropic::ResponseEvent::Ping {} => {}
3799        }
3800    }
3801
3802    Ok(())
3803}
3804
3805struct CountTokensWithLanguageModelRateLimit;
3806
3807impl RateLimit for CountTokensWithLanguageModelRateLimit {
3808    fn capacity() -> usize {
3809        std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3810            .ok()
3811            .and_then(|v| v.parse().ok())
3812            .unwrap_or(600) // Picked arbitrarily
3813    }
3814
3815    fn refill_duration() -> chrono::Duration {
3816        chrono::Duration::hours(1)
3817    }
3818
3819    fn db_name() -> &'static str {
3820        "count-tokens-with-language-model"
3821    }
3822}
3823
3824async fn count_tokens_with_language_model(
3825    request: proto::CountTokensWithLanguageModel,
3826    response: Response<proto::CountTokensWithLanguageModel>,
3827    session: UserSession,
3828    google_ai_api_key: Option<Arc<str>>,
3829) -> Result<()> {
3830    authorize_access_to_language_models(&session).await?;
3831
3832    if !request.model.starts_with("gemini") {
3833        return Err(anyhow!(
3834            "counting tokens for model: {:?} is not supported",
3835            request.model
3836        ))?;
3837    }
3838
3839    session
3840        .rate_limiter
3841        .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3842        .await?;
3843
3844    let api_key = google_ai_api_key
3845        .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3846    let tokens_response = google_ai::count_tokens(
3847        &session.http_client,
3848        google_ai::API_URL,
3849        &api_key,
3850        crate::ai::count_tokens_request_to_google_ai(request)?,
3851    )
3852    .await?;
3853    response.send(proto::CountTokensResponse {
3854        token_count: tokens_response.total_tokens as u32,
3855    })?;
3856    Ok(())
3857}
3858
3859async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3860    let db = session.db().await;
3861    let flags = db.get_user_flags(session.user_id()).await?;
3862    if flags.iter().any(|flag| flag == "language-models") {
3863        Ok(())
3864    } else {
3865        Err(anyhow!("permission denied"))?
3866    }
3867}
3868
3869/// Start receiving chat updates for a channel
3870async fn join_channel_chat(
3871    request: proto::JoinChannelChat,
3872    response: Response<proto::JoinChannelChat>,
3873    session: UserSession,
3874) -> Result<()> {
3875    let channel_id = ChannelId::from_proto(request.channel_id);
3876
3877    let db = session.db().await;
3878    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3879        .await?;
3880    let messages = db
3881        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3882        .await?;
3883    response.send(proto::JoinChannelChatResponse {
3884        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3885        messages,
3886    })?;
3887    Ok(())
3888}
3889
3890/// Stop receiving chat updates for a channel
3891async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3892    let channel_id = ChannelId::from_proto(request.channel_id);
3893    session
3894        .db()
3895        .await
3896        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3897        .await?;
3898    Ok(())
3899}
3900
3901/// Retrieve the chat history for a channel
3902async fn get_channel_messages(
3903    request: proto::GetChannelMessages,
3904    response: Response<proto::GetChannelMessages>,
3905    session: UserSession,
3906) -> Result<()> {
3907    let channel_id = ChannelId::from_proto(request.channel_id);
3908    let messages = session
3909        .db()
3910        .await
3911        .get_channel_messages(
3912            channel_id,
3913            session.user_id(),
3914            MESSAGE_COUNT_PER_PAGE,
3915            Some(MessageId::from_proto(request.before_message_id)),
3916        )
3917        .await?;
3918    response.send(proto::GetChannelMessagesResponse {
3919        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3920        messages,
3921    })?;
3922    Ok(())
3923}
3924
3925/// Retrieve specific chat messages
3926async fn get_channel_messages_by_id(
3927    request: proto::GetChannelMessagesById,
3928    response: Response<proto::GetChannelMessagesById>,
3929    session: UserSession,
3930) -> Result<()> {
3931    let message_ids = request
3932        .message_ids
3933        .iter()
3934        .map(|id| MessageId::from_proto(*id))
3935        .collect::<Vec<_>>();
3936    let messages = session
3937        .db()
3938        .await
3939        .get_channel_messages_by_id(session.user_id(), &message_ids)
3940        .await?;
3941    response.send(proto::GetChannelMessagesResponse {
3942        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3943        messages,
3944    })?;
3945    Ok(())
3946}
3947
3948/// Retrieve the current users notifications
3949async fn get_notifications(
3950    request: proto::GetNotifications,
3951    response: Response<proto::GetNotifications>,
3952    session: UserSession,
3953) -> Result<()> {
3954    let notifications = session
3955        .db()
3956        .await
3957        .get_notifications(
3958            session.user_id(),
3959            NOTIFICATION_COUNT_PER_PAGE,
3960            request
3961                .before_id
3962                .map(|id| db::NotificationId::from_proto(id)),
3963        )
3964        .await?;
3965    response.send(proto::GetNotificationsResponse {
3966        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3967        notifications,
3968    })?;
3969    Ok(())
3970}
3971
3972/// Mark notifications as read
3973async fn mark_notification_as_read(
3974    request: proto::MarkNotificationRead,
3975    response: Response<proto::MarkNotificationRead>,
3976    session: UserSession,
3977) -> Result<()> {
3978    let database = &session.db().await;
3979    let notifications = database
3980        .mark_notification_as_read_by_id(
3981            session.user_id(),
3982            NotificationId::from_proto(request.notification_id),
3983        )
3984        .await?;
3985    send_notifications(
3986        &*session.connection_pool().await,
3987        &session.peer,
3988        notifications,
3989    );
3990    response.send(proto::Ack {})?;
3991    Ok(())
3992}
3993
3994/// Get the current users information
3995async fn get_private_user_info(
3996    _request: proto::GetPrivateUserInfo,
3997    response: Response<proto::GetPrivateUserInfo>,
3998    session: UserSession,
3999) -> Result<()> {
4000    let db = session.db().await;
4001
4002    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4003    let user = db
4004        .get_user_by_id(session.user_id())
4005        .await?
4006        .ok_or_else(|| anyhow!("user not found"))?;
4007    let flags = db.get_user_flags(session.user_id()).await?;
4008
4009    response.send(proto::GetPrivateUserInfoResponse {
4010        metrics_id,
4011        staff: user.admin,
4012        flags,
4013    })?;
4014    Ok(())
4015}
4016
4017fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4018    match message {
4019        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4020        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4021        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4022        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4023        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4024            code: frame.code.into(),
4025            reason: frame.reason,
4026        })),
4027    }
4028}
4029
4030fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4031    match message {
4032        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4033        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4034        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4035        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4036        AxumMessage::Close(frame) => {
4037            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4038                code: frame.code.into(),
4039                reason: frame.reason,
4040            }))
4041        }
4042    }
4043}
4044
4045fn notify_membership_updated(
4046    connection_pool: &mut ConnectionPool,
4047    result: MembershipUpdated,
4048    user_id: UserId,
4049    peer: &Peer,
4050) {
4051    for membership in &result.new_channels.channel_memberships {
4052        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4053    }
4054    for channel_id in &result.removed_channels {
4055        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4056    }
4057
4058    let user_channels_update = proto::UpdateUserChannels {
4059        channel_memberships: result
4060            .new_channels
4061            .channel_memberships
4062            .iter()
4063            .map(|cm| proto::ChannelMembership {
4064                channel_id: cm.channel_id.to_proto(),
4065                role: cm.role.into(),
4066            })
4067            .collect(),
4068        ..Default::default()
4069    };
4070
4071    let mut update = build_channels_update(result.new_channels, vec![]);
4072    update.delete_channels = result
4073        .removed_channels
4074        .into_iter()
4075        .map(|id| id.to_proto())
4076        .collect();
4077    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4078
4079    for connection_id in connection_pool.user_connection_ids(user_id) {
4080        peer.send(connection_id, user_channels_update.clone())
4081            .trace_err();
4082        peer.send(connection_id, update.clone()).trace_err();
4083    }
4084}
4085
4086fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4087    proto::UpdateUserChannels {
4088        channel_memberships: channels
4089            .channel_memberships
4090            .iter()
4091            .map(|m| proto::ChannelMembership {
4092                channel_id: m.channel_id.to_proto(),
4093                role: m.role.into(),
4094            })
4095            .collect(),
4096        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4097        observed_channel_message_id: channels.observed_channel_messages.clone(),
4098    }
4099}
4100
4101fn build_channels_update(
4102    channels: ChannelsForUser,
4103    channel_invites: Vec<db::Channel>,
4104) -> proto::UpdateChannels {
4105    let mut update = proto::UpdateChannels::default();
4106
4107    for channel in channels.channels {
4108        update.channels.push(channel.to_proto());
4109    }
4110
4111    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4112    update.latest_channel_message_ids = channels.latest_channel_messages;
4113
4114    for (channel_id, participants) in channels.channel_participants {
4115        update
4116            .channel_participants
4117            .push(proto::ChannelParticipants {
4118                channel_id: channel_id.to_proto(),
4119                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4120            });
4121    }
4122
4123    for channel in channel_invites {
4124        update.channel_invitations.push(channel.to_proto());
4125    }
4126    for project in channels.hosted_projects {
4127        update.hosted_projects.push(project);
4128    }
4129
4130    update
4131}
4132
4133fn build_initial_contacts_update(
4134    contacts: Vec<db::Contact>,
4135    pool: &ConnectionPool,
4136) -> proto::UpdateContacts {
4137    let mut update = proto::UpdateContacts::default();
4138
4139    for contact in contacts {
4140        match contact {
4141            db::Contact::Accepted { user_id, busy } => {
4142                update.contacts.push(contact_for_user(user_id, busy, &pool));
4143            }
4144            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4145            db::Contact::Incoming { user_id } => {
4146                update
4147                    .incoming_requests
4148                    .push(proto::IncomingContactRequest {
4149                        requester_id: user_id.to_proto(),
4150                    })
4151            }
4152        }
4153    }
4154
4155    update
4156}
4157
4158fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4159    proto::Contact {
4160        user_id: user_id.to_proto(),
4161        online: pool.is_user_online(user_id),
4162        busy,
4163    }
4164}
4165
4166fn room_updated(room: &proto::Room, peer: &Peer) {
4167    broadcast(
4168        None,
4169        room.participants
4170            .iter()
4171            .filter_map(|participant| Some(participant.peer_id?.into())),
4172        |peer_id| {
4173            peer.send(
4174                peer_id,
4175                proto::RoomUpdated {
4176                    room: Some(room.clone()),
4177                },
4178            )
4179        },
4180    );
4181}
4182
4183fn channel_updated(
4184    channel: &db::channel::Model,
4185    room: &proto::Room,
4186    peer: &Peer,
4187    pool: &ConnectionPool,
4188) {
4189    let participants = room
4190        .participants
4191        .iter()
4192        .map(|p| p.user_id)
4193        .collect::<Vec<_>>();
4194
4195    broadcast(
4196        None,
4197        pool.channel_connection_ids(channel.root_id())
4198            .filter_map(|(channel_id, role)| {
4199                role.can_see_channel(channel.visibility).then(|| channel_id)
4200            }),
4201        |peer_id| {
4202            peer.send(
4203                peer_id,
4204                proto::UpdateChannels {
4205                    channel_participants: vec![proto::ChannelParticipants {
4206                        channel_id: channel.id.to_proto(),
4207                        participant_user_ids: participants.clone(),
4208                    }],
4209                    ..Default::default()
4210                },
4211            )
4212        },
4213    );
4214}
4215
4216async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4217    let db = session.db().await;
4218
4219    let contacts = db.get_contacts(user_id).await?;
4220    let busy = db.is_user_busy(user_id).await?;
4221
4222    let pool = session.connection_pool().await;
4223    let updated_contact = contact_for_user(user_id, busy, &pool);
4224    for contact in contacts {
4225        if let db::Contact::Accepted {
4226            user_id: contact_user_id,
4227            ..
4228        } = contact
4229        {
4230            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4231                session
4232                    .peer
4233                    .send(
4234                        contact_conn_id,
4235                        proto::UpdateContacts {
4236                            contacts: vec![updated_contact.clone()],
4237                            remove_contacts: Default::default(),
4238                            incoming_requests: Default::default(),
4239                            remove_incoming_requests: Default::default(),
4240                            outgoing_requests: Default::default(),
4241                            remove_outgoing_requests: Default::default(),
4242                        },
4243                    )
4244                    .trace_err();
4245            }
4246        }
4247    }
4248    Ok(())
4249}
4250
4251async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
4252    let mut contacts_to_update = HashSet::default();
4253
4254    let room_id;
4255    let canceled_calls_to_user_ids;
4256    let live_kit_room;
4257    let delete_live_kit_room;
4258    let room;
4259    let channel;
4260
4261    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4262        contacts_to_update.insert(session.user_id());
4263
4264        for project in left_room.left_projects.values() {
4265            project_left(project, session);
4266        }
4267
4268        room_id = RoomId::from_proto(left_room.room.id);
4269        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4270        live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4271        delete_live_kit_room = left_room.deleted;
4272        room = mem::take(&mut left_room.room);
4273        channel = mem::take(&mut left_room.channel);
4274
4275        room_updated(&room, &session.peer);
4276    } else {
4277        return Ok(());
4278    }
4279
4280    if let Some(channel) = channel {
4281        channel_updated(
4282            &channel,
4283            &room,
4284            &session.peer,
4285            &*session.connection_pool().await,
4286        );
4287    }
4288
4289    {
4290        let pool = session.connection_pool().await;
4291        for canceled_user_id in canceled_calls_to_user_ids {
4292            for connection_id in pool.user_connection_ids(canceled_user_id) {
4293                session
4294                    .peer
4295                    .send(
4296                        connection_id,
4297                        proto::CallCanceled {
4298                            room_id: room_id.to_proto(),
4299                        },
4300                    )
4301                    .trace_err();
4302            }
4303            contacts_to_update.insert(canceled_user_id);
4304        }
4305    }
4306
4307    for contact_user_id in contacts_to_update {
4308        update_user_contacts(contact_user_id, &session).await?;
4309    }
4310
4311    if let Some(live_kit) = session.live_kit_client.as_ref() {
4312        live_kit
4313            .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4314            .await
4315            .trace_err();
4316
4317        if delete_live_kit_room {
4318            live_kit.delete_room(live_kit_room).await.trace_err();
4319        }
4320    }
4321
4322    Ok(())
4323}
4324
4325async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4326    let left_channel_buffers = session
4327        .db()
4328        .await
4329        .leave_channel_buffers(session.connection_id)
4330        .await?;
4331
4332    for left_buffer in left_channel_buffers {
4333        channel_buffer_updated(
4334            session.connection_id,
4335            left_buffer.connections,
4336            &proto::UpdateChannelBufferCollaborators {
4337                channel_id: left_buffer.channel_id.to_proto(),
4338                collaborators: left_buffer.collaborators,
4339            },
4340            &session.peer,
4341        );
4342    }
4343
4344    Ok(())
4345}
4346
4347fn project_left(project: &db::LeftProject, session: &UserSession) {
4348    for connection_id in &project.connection_ids {
4349        if project.host_user_id == Some(session.user_id()) {
4350            session
4351                .peer
4352                .send(
4353                    *connection_id,
4354                    proto::UnshareProject {
4355                        project_id: project.id.to_proto(),
4356                    },
4357                )
4358                .trace_err();
4359        } else {
4360            session
4361                .peer
4362                .send(
4363                    *connection_id,
4364                    proto::RemoveProjectCollaborator {
4365                        project_id: project.id.to_proto(),
4366                        peer_id: Some(session.connection_id.into()),
4367                    },
4368                )
4369                .trace_err();
4370        }
4371    }
4372}
4373
4374pub trait ResultExt {
4375    type Ok;
4376
4377    fn trace_err(self) -> Option<Self::Ok>;
4378}
4379
4380impl<T, E> ResultExt for Result<T, E>
4381where
4382    E: std::fmt::Debug,
4383{
4384    type Ok = T;
4385
4386    fn trace_err(self) -> Option<T> {
4387        match self {
4388            Ok(value) => Some(value),
4389            Err(error) => {
4390                tracing::error!("{:?}", error);
4391                None
4392            }
4393        }
4394    }
4395}