rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 327            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 332            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 333            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 338            .add_request_handler(
 339                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 340            )
 341            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 342            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 343            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 345            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 347            .add_message_handler(create_buffer_for_peer)
 348            .add_request_handler(update_buffer)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 354            .add_request_handler(get_users)
 355            .add_request_handler(fuzzy_search_users)
 356            .add_request_handler(request_contact)
 357            .add_request_handler(remove_contact)
 358            .add_request_handler(respond_to_contact_request)
 359            .add_message_handler(subscribe_to_channels)
 360            .add_request_handler(create_channel)
 361            .add_request_handler(delete_channel)
 362            .add_request_handler(invite_channel_member)
 363            .add_request_handler(remove_channel_member)
 364            .add_request_handler(set_channel_member_role)
 365            .add_request_handler(set_channel_visibility)
 366            .add_request_handler(rename_channel)
 367            .add_request_handler(join_channel_buffer)
 368            .add_request_handler(leave_channel_buffer)
 369            .add_message_handler(update_channel_buffer)
 370            .add_request_handler(rejoin_channel_buffers)
 371            .add_request_handler(get_channel_members)
 372            .add_request_handler(respond_to_channel_invite)
 373            .add_request_handler(join_channel)
 374            .add_request_handler(join_channel_chat)
 375            .add_message_handler(leave_channel_chat)
 376            .add_request_handler(send_channel_message)
 377            .add_request_handler(remove_channel_message)
 378            .add_request_handler(update_channel_message)
 379            .add_request_handler(get_channel_messages)
 380            .add_request_handler(get_channel_messages_by_id)
 381            .add_request_handler(get_notifications)
 382            .add_request_handler(mark_notification_as_read)
 383            .add_request_handler(move_channel)
 384            .add_request_handler(follow)
 385            .add_message_handler(unfollow)
 386            .add_message_handler(update_followers)
 387            .add_request_handler(get_private_user_info)
 388            .add_request_handler(get_llm_api_token)
 389            .add_request_handler(accept_terms_of_service)
 390            .add_message_handler(acknowledge_channel_message)
 391            .add_message_handler(acknowledge_buffer_version)
 392            .add_request_handler(get_supermaven_api_key)
 393            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 394            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 395            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 397            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 398            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 399            .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
 400            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 401            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 402            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 403            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 404            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 405            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 406            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 407            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 408            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 409            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 410            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 411            .add_message_handler(update_context)
 412            .add_request_handler({
 413                let app_state = app_state.clone();
 414                move |request, response, session| {
 415                    let app_state = app_state.clone();
 416                    async move {
 417                        count_language_model_tokens(request, response, session, &app_state.config)
 418                            .await
 419                    }
 420                }
 421            })
 422            .add_request_handler(get_cached_embeddings)
 423            .add_request_handler({
 424                let app_state = app_state.clone();
 425                move |request, response, session| {
 426                    compute_embeddings(
 427                        request,
 428                        response,
 429                        session,
 430                        app_state.config.openai_api_key.clone(),
 431                    )
 432                }
 433            });
 434
 435        Arc::new(server)
 436    }
 437
 438    pub async fn start(&self) -> Result<()> {
 439        let server_id = *self.id.lock();
 440        let app_state = self.app_state.clone();
 441        let peer = self.peer.clone();
 442        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 443        let pool = self.connection_pool.clone();
 444        let livekit_client = self.app_state.livekit_client.clone();
 445
 446        let span = info_span!("start server");
 447        self.app_state.executor.spawn_detached(
 448            async move {
 449                tracing::info!("waiting for cleanup timeout");
 450                timeout.await;
 451                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 452                if let Some((room_ids, channel_ids)) = app_state
 453                    .db
 454                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 455                    .await
 456                    .trace_err()
 457                {
 458                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 459                    tracing::info!(
 460                        stale_channel_buffer_count = channel_ids.len(),
 461                        "retrieved stale channel buffers"
 462                    );
 463
 464                    for channel_id in channel_ids {
 465                        if let Some(refreshed_channel_buffer) = app_state
 466                            .db
 467                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 468                            .await
 469                            .trace_err()
 470                        {
 471                            for connection_id in refreshed_channel_buffer.connection_ids {
 472                                peer.send(
 473                                    connection_id,
 474                                    proto::UpdateChannelBufferCollaborators {
 475                                        channel_id: channel_id.to_proto(),
 476                                        collaborators: refreshed_channel_buffer
 477                                            .collaborators
 478                                            .clone(),
 479                                    },
 480                                )
 481                                .trace_err();
 482                            }
 483                        }
 484                    }
 485
 486                    for room_id in room_ids {
 487                        let mut contacts_to_update = HashSet::default();
 488                        let mut canceled_calls_to_user_ids = Vec::new();
 489                        let mut livekit_room = String::new();
 490                        let mut delete_livekit_room = false;
 491
 492                        if let Some(mut refreshed_room) = app_state
 493                            .db
 494                            .clear_stale_room_participants(room_id, server_id)
 495                            .await
 496                            .trace_err()
 497                        {
 498                            tracing::info!(
 499                                room_id = room_id.0,
 500                                new_participant_count = refreshed_room.room.participants.len(),
 501                                "refreshed room"
 502                            );
 503                            room_updated(&refreshed_room.room, &peer);
 504                            if let Some(channel) = refreshed_room.channel.as_ref() {
 505                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 506                            }
 507                            contacts_to_update
 508                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 509                            contacts_to_update
 510                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 511                            canceled_calls_to_user_ids =
 512                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 513                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 514                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 515                        }
 516
 517                        {
 518                            let pool = pool.lock();
 519                            for canceled_user_id in canceled_calls_to_user_ids {
 520                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 521                                    peer.send(
 522                                        connection_id,
 523                                        proto::CallCanceled {
 524                                            room_id: room_id.to_proto(),
 525                                        },
 526                                    )
 527                                    .trace_err();
 528                                }
 529                            }
 530                        }
 531
 532                        for user_id in contacts_to_update {
 533                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 534                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 535                            if let Some((busy, contacts)) = busy.zip(contacts) {
 536                                let pool = pool.lock();
 537                                let updated_contact = contact_for_user(user_id, busy, &pool);
 538                                for contact in contacts {
 539                                    if let db::Contact::Accepted {
 540                                        user_id: contact_user_id,
 541                                        ..
 542                                    } = contact
 543                                    {
 544                                        for contact_conn_id in
 545                                            pool.user_connection_ids(contact_user_id)
 546                                        {
 547                                            peer.send(
 548                                                contact_conn_id,
 549                                                proto::UpdateContacts {
 550                                                    contacts: vec![updated_contact.clone()],
 551                                                    remove_contacts: Default::default(),
 552                                                    incoming_requests: Default::default(),
 553                                                    remove_incoming_requests: Default::default(),
 554                                                    outgoing_requests: Default::default(),
 555                                                    remove_outgoing_requests: Default::default(),
 556                                                },
 557                                            )
 558                                            .trace_err();
 559                                        }
 560                                    }
 561                                }
 562                            }
 563                        }
 564
 565                        if let Some(live_kit) = livekit_client.as_ref() {
 566                            if delete_livekit_room {
 567                                live_kit.delete_room(livekit_room).await.trace_err();
 568                            }
 569                        }
 570                    }
 571                }
 572
 573                app_state
 574                    .db
 575                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 576                    .await
 577                    .trace_err();
 578            }
 579            .instrument(span),
 580        );
 581        Ok(())
 582    }
 583
 584    pub fn teardown(&self) {
 585        self.peer.teardown();
 586        self.connection_pool.lock().reset();
 587        let _ = self.teardown.send(true);
 588    }
 589
 590    #[cfg(test)]
 591    pub fn reset(&self, id: ServerId) {
 592        self.teardown();
 593        *self.id.lock() = id;
 594        self.peer.reset(id.0 as u32);
 595        let _ = self.teardown.send(false);
 596    }
 597
 598    #[cfg(test)]
 599    pub fn id(&self) -> ServerId {
 600        *self.id.lock()
 601    }
 602
 603    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 604    where
 605        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 606        Fut: 'static + Send + Future<Output = Result<()>>,
 607        M: EnvelopedMessage,
 608    {
 609        let prev_handler = self.handlers.insert(
 610            TypeId::of::<M>(),
 611            Box::new(move |envelope, session| {
 612                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 613                let received_at = envelope.received_at;
 614                tracing::info!("message received");
 615                let start_time = Instant::now();
 616                let future = (handler)(*envelope, session);
 617                async move {
 618                    let result = future.await;
 619                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 620                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 621                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 622                    let payload_type = M::NAME;
 623
 624                    match result {
 625                        Err(error) => {
 626                            tracing::error!(
 627                                ?error,
 628                                total_duration_ms,
 629                                processing_duration_ms,
 630                                queue_duration_ms,
 631                                payload_type,
 632                                "error handling message"
 633                            )
 634                        }
 635                        Ok(()) => tracing::info!(
 636                            total_duration_ms,
 637                            processing_duration_ms,
 638                            queue_duration_ms,
 639                            "finished handling message"
 640                        ),
 641                    }
 642                }
 643                .boxed()
 644            }),
 645        );
 646        if prev_handler.is_some() {
 647            panic!("registered a handler for the same message twice");
 648        }
 649        self
 650    }
 651
 652    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 653    where
 654        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 655        Fut: 'static + Send + Future<Output = Result<()>>,
 656        M: EnvelopedMessage,
 657    {
 658        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 659        self
 660    }
 661
 662    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 663    where
 664        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 665        Fut: Send + Future<Output = Result<()>>,
 666        M: RequestMessage,
 667    {
 668        let handler = Arc::new(handler);
 669        self.add_handler(move |envelope, session| {
 670            let receipt = envelope.receipt();
 671            let handler = handler.clone();
 672            async move {
 673                let peer = session.peer.clone();
 674                let responded = Arc::new(AtomicBool::default());
 675                let response = Response {
 676                    peer: peer.clone(),
 677                    responded: responded.clone(),
 678                    receipt,
 679                };
 680                match (handler)(envelope.payload, response, session).await {
 681                    Ok(()) => {
 682                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 683                            Ok(())
 684                        } else {
 685                            Err(anyhow!("handler did not send a response"))?
 686                        }
 687                    }
 688                    Err(error) => {
 689                        let proto_err = match &error {
 690                            Error::Internal(err) => err.to_proto(),
 691                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 692                        };
 693                        peer.respond_with_error(receipt, proto_err)?;
 694                        Err(error)
 695                    }
 696                }
 697            }
 698        })
 699    }
 700
 701    pub fn handle_connection(
 702        self: &Arc<Self>,
 703        connection: Connection,
 704        address: String,
 705        principal: Principal,
 706        zed_version: ZedVersion,
 707        geoip_country_code: Option<String>,
 708        system_id: Option<String>,
 709        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 710        executor: Executor,
 711    ) -> impl Future<Output = ()> {
 712        let this = self.clone();
 713        let span = info_span!("handle connection", %address,
 714            connection_id=field::Empty,
 715            user_id=field::Empty,
 716            login=field::Empty,
 717            impersonator=field::Empty,
 718            geoip_country_code=field::Empty
 719        );
 720        principal.update_span(&span);
 721        if let Some(country_code) = geoip_country_code.as_ref() {
 722            span.record("geoip_country_code", country_code);
 723        }
 724
 725        let mut teardown = self.teardown.subscribe();
 726        async move {
 727            if *teardown.borrow() {
 728                tracing::error!("server is tearing down");
 729                return
 730            }
 731            let (connection_id, handle_io, mut incoming_rx) = this
 732                .peer
 733                .add_connection(connection, {
 734                    let executor = executor.clone();
 735                    move |duration| executor.sleep(duration)
 736                });
 737            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 738
 739            tracing::info!("connection opened");
 740
 741            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 742            let http_client = match ReqwestClient::user_agent(&user_agent) {
 743                Ok(http_client) => Arc::new(http_client),
 744                Err(error) => {
 745                    tracing::error!(?error, "failed to create HTTP client");
 746                    return;
 747                }
 748            };
 749
 750            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 751                    supermaven_admin_api_key.to_string(),
 752                    http_client.clone(),
 753                )));
 754
 755            let session = Session {
 756                principal: principal.clone(),
 757                connection_id,
 758                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 759                peer: this.peer.clone(),
 760                connection_pool: this.connection_pool.clone(),
 761                app_state: this.app_state.clone(),
 762                http_client,
 763                geoip_country_code,
 764                system_id,
 765                _executor: executor.clone(),
 766                supermaven_client,
 767            };
 768
 769            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 770                tracing::error!(?error, "failed to send initial client update");
 771                return;
 772            }
 773
 774            let handle_io = handle_io.fuse();
 775            futures::pin_mut!(handle_io);
 776
 777            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 778            // This prevents deadlocks when e.g., client A performs a request to client B and
 779            // client B performs a request to client A. If both clients stop processing further
 780            // messages until their respective request completes, they won't have a chance to
 781            // respond to the other client's request and cause a deadlock.
 782            //
 783            // This arrangement ensures we will attempt to process earlier messages first, but fall
 784            // back to processing messages arrived later in the spirit of making progress.
 785            let mut foreground_message_handlers = FuturesUnordered::new();
 786            let concurrent_handlers = Arc::new(Semaphore::new(256));
 787            loop {
 788                let next_message = async {
 789                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 790                    let message = incoming_rx.next().await;
 791                    (permit, message)
 792                }.fuse();
 793                futures::pin_mut!(next_message);
 794                futures::select_biased! {
 795                    _ = teardown.changed().fuse() => return,
 796                    result = handle_io => {
 797                        if let Err(error) = result {
 798                            tracing::error!(?error, "error handling I/O");
 799                        }
 800                        break;
 801                    }
 802                    _ = foreground_message_handlers.next() => {}
 803                    next_message = next_message => {
 804                        let (permit, message) = next_message;
 805                        if let Some(message) = message {
 806                            let type_name = message.payload_type_name();
 807                            // note: we copy all the fields from the parent span so we can query them in the logs.
 808                            // (https://github.com/tokio-rs/tracing/issues/2670).
 809                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 810                                user_id=field::Empty,
 811                                login=field::Empty,
 812                                impersonator=field::Empty,
 813                            );
 814                            principal.update_span(&span);
 815                            let span_enter = span.enter();
 816                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 817                                let is_background = message.is_background();
 818                                let handle_message = (handler)(message, session.clone());
 819                                drop(span_enter);
 820
 821                                let handle_message = async move {
 822                                    handle_message.await;
 823                                    drop(permit);
 824                                }.instrument(span);
 825                                if is_background {
 826                                    executor.spawn_detached(handle_message);
 827                                } else {
 828                                    foreground_message_handlers.push(handle_message);
 829                                }
 830                            } else {
 831                                tracing::error!("no message handler");
 832                            }
 833                        } else {
 834                            tracing::info!("connection closed");
 835                            break;
 836                        }
 837                    }
 838                }
 839            }
 840
 841            drop(foreground_message_handlers);
 842            tracing::info!("signing out");
 843            if let Err(error) = connection_lost(session, teardown, executor).await {
 844                tracing::error!(?error, "error signing out");
 845            }
 846
 847        }.instrument(span)
 848    }
 849
 850    async fn send_initial_client_update(
 851        &self,
 852        connection_id: ConnectionId,
 853        principal: &Principal,
 854        zed_version: ZedVersion,
 855        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 856        session: &Session,
 857    ) -> Result<()> {
 858        self.peer.send(
 859            connection_id,
 860            proto::Hello {
 861                peer_id: Some(connection_id.into()),
 862            },
 863        )?;
 864        tracing::info!("sent hello message");
 865        if let Some(send_connection_id) = send_connection_id.take() {
 866            let _ = send_connection_id.send(connection_id);
 867        }
 868
 869        match principal {
 870            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 871                if !user.connected_once {
 872                    self.peer.send(connection_id, proto::ShowContacts {})?;
 873                    self.app_state
 874                        .db
 875                        .set_user_connected_once(user.id, true)
 876                        .await?;
 877                }
 878
 879                update_user_plan(user.id, session).await?;
 880
 881                let contacts = self.app_state.db.get_contacts(user.id).await?;
 882
 883                {
 884                    let mut pool = self.connection_pool.lock();
 885                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 886                    self.peer.send(
 887                        connection_id,
 888                        build_initial_contacts_update(contacts, &pool),
 889                    )?;
 890                }
 891
 892                if should_auto_subscribe_to_channels(zed_version) {
 893                    subscribe_user_to_channels(user.id, session).await?;
 894                }
 895
 896                if let Some(incoming_call) =
 897                    self.app_state.db.incoming_call_for_user(user.id).await?
 898                {
 899                    self.peer.send(connection_id, incoming_call)?;
 900                }
 901
 902                update_user_contacts(user.id, session).await?;
 903            }
 904        }
 905
 906        Ok(())
 907    }
 908
 909    pub async fn invite_code_redeemed(
 910        self: &Arc<Self>,
 911        inviter_id: UserId,
 912        invitee_id: UserId,
 913    ) -> Result<()> {
 914        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 915            if let Some(code) = &user.invite_code {
 916                let pool = self.connection_pool.lock();
 917                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 918                for connection_id in pool.user_connection_ids(inviter_id) {
 919                    self.peer.send(
 920                        connection_id,
 921                        proto::UpdateContacts {
 922                            contacts: vec![invitee_contact.clone()],
 923                            ..Default::default()
 924                        },
 925                    )?;
 926                    self.peer.send(
 927                        connection_id,
 928                        proto::UpdateInviteInfo {
 929                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 930                            count: user.invite_count as u32,
 931                        },
 932                    )?;
 933                }
 934            }
 935        }
 936        Ok(())
 937    }
 938
 939    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 940        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 941            if let Some(invite_code) = &user.invite_code {
 942                let pool = self.connection_pool.lock();
 943                for connection_id in pool.user_connection_ids(user_id) {
 944                    self.peer.send(
 945                        connection_id,
 946                        proto::UpdateInviteInfo {
 947                            url: format!(
 948                                "{}{}",
 949                                self.app_state.config.invite_link_prefix, invite_code
 950                            ),
 951                            count: user.invite_count as u32,
 952                        },
 953                    )?;
 954                }
 955            }
 956        }
 957        Ok(())
 958    }
 959
 960    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 961        let pool = self.connection_pool.lock();
 962        for connection_id in pool.user_connection_ids(user_id) {
 963            self.peer
 964                .send(connection_id, proto::RefreshLlmToken {})
 965                .trace_err();
 966        }
 967    }
 968
 969    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 970        ServerSnapshot {
 971            connection_pool: ConnectionPoolGuard {
 972                guard: self.connection_pool.lock(),
 973                _not_send: PhantomData,
 974            },
 975            peer: &self.peer,
 976        }
 977    }
 978}
 979
 980impl Deref for ConnectionPoolGuard<'_> {
 981    type Target = ConnectionPool;
 982
 983    fn deref(&self) -> &Self::Target {
 984        &self.guard
 985    }
 986}
 987
 988impl DerefMut for ConnectionPoolGuard<'_> {
 989    fn deref_mut(&mut self) -> &mut Self::Target {
 990        &mut self.guard
 991    }
 992}
 993
 994impl Drop for ConnectionPoolGuard<'_> {
 995    fn drop(&mut self) {
 996        #[cfg(test)]
 997        self.check_invariants();
 998    }
 999}
1000
1001fn broadcast<F>(
1002    sender_id: Option<ConnectionId>,
1003    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1004    mut f: F,
1005) where
1006    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1007{
1008    for receiver_id in receiver_ids {
1009        if Some(receiver_id) != sender_id {
1010            if let Err(error) = f(receiver_id) {
1011                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1012            }
1013        }
1014    }
1015}
1016
1017pub struct ProtocolVersion(u32);
1018
1019impl Header for ProtocolVersion {
1020    fn name() -> &'static HeaderName {
1021        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1022        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1023    }
1024
1025    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1026    where
1027        Self: Sized,
1028        I: Iterator<Item = &'i axum::http::HeaderValue>,
1029    {
1030        let version = values
1031            .next()
1032            .ok_or_else(axum::headers::Error::invalid)?
1033            .to_str()
1034            .map_err(|_| axum::headers::Error::invalid())?
1035            .parse()
1036            .map_err(|_| axum::headers::Error::invalid())?;
1037        Ok(Self(version))
1038    }
1039
1040    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1041        values.extend([self.0.to_string().parse().unwrap()]);
1042    }
1043}
1044
1045pub struct AppVersionHeader(SemanticVersion);
1046impl Header for AppVersionHeader {
1047    fn name() -> &'static HeaderName {
1048        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1049        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1050    }
1051
1052    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1053    where
1054        Self: Sized,
1055        I: Iterator<Item = &'i axum::http::HeaderValue>,
1056    {
1057        let version = values
1058            .next()
1059            .ok_or_else(axum::headers::Error::invalid)?
1060            .to_str()
1061            .map_err(|_| axum::headers::Error::invalid())?
1062            .parse()
1063            .map_err(|_| axum::headers::Error::invalid())?;
1064        Ok(Self(version))
1065    }
1066
1067    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1068        values.extend([self.0.to_string().parse().unwrap()]);
1069    }
1070}
1071
1072pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1073    Router::new()
1074        .route("/rpc", get(handle_websocket_request))
1075        .layer(
1076            ServiceBuilder::new()
1077                .layer(Extension(server.app_state.clone()))
1078                .layer(middleware::from_fn(auth::validate_header)),
1079        )
1080        .route("/metrics", get(handle_metrics))
1081        .layer(Extension(server))
1082}
1083
1084pub async fn handle_websocket_request(
1085    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1086    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1087    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1088    Extension(server): Extension<Arc<Server>>,
1089    Extension(principal): Extension<Principal>,
1090    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1091    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1092    ws: WebSocketUpgrade,
1093) -> axum::response::Response {
1094    if protocol_version != rpc::PROTOCOL_VERSION {
1095        return (
1096            StatusCode::UPGRADE_REQUIRED,
1097            "client must be upgraded".to_string(),
1098        )
1099            .into_response();
1100    }
1101
1102    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1103        return (
1104            StatusCode::UPGRADE_REQUIRED,
1105            "no version header found".to_string(),
1106        )
1107            .into_response();
1108    };
1109
1110    if !version.can_collaborate() {
1111        return (
1112            StatusCode::UPGRADE_REQUIRED,
1113            "client must be upgraded".to_string(),
1114        )
1115            .into_response();
1116    }
1117
1118    let socket_address = socket_address.to_string();
1119    ws.on_upgrade(move |socket| {
1120        let socket = socket
1121            .map_ok(to_tungstenite_message)
1122            .err_into()
1123            .with(|message| async move { to_axum_message(message) });
1124        let connection = Connection::new(Box::pin(socket));
1125        async move {
1126            server
1127                .handle_connection(
1128                    connection,
1129                    socket_address,
1130                    principal,
1131                    version,
1132                    country_code_header.map(|header| header.to_string()),
1133                    system_id_header.map(|header| header.to_string()),
1134                    None,
1135                    Executor::Production,
1136                )
1137                .await;
1138        }
1139    })
1140}
1141
1142pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1143    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1144    let connections_metric = CONNECTIONS_METRIC
1145        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1146
1147    let connections = server
1148        .connection_pool
1149        .lock()
1150        .connections()
1151        .filter(|connection| !connection.admin)
1152        .count();
1153    connections_metric.set(connections as _);
1154
1155    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1156    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1157        register_int_gauge!(
1158            "shared_projects",
1159            "number of open projects with one or more guests"
1160        )
1161        .unwrap()
1162    });
1163
1164    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1165    shared_projects_metric.set(shared_projects as _);
1166
1167    let encoder = prometheus::TextEncoder::new();
1168    let metric_families = prometheus::gather();
1169    let encoded_metrics = encoder
1170        .encode_to_string(&metric_families)
1171        .map_err(|err| anyhow!("{}", err))?;
1172    Ok(encoded_metrics)
1173}
1174
1175#[instrument(err, skip(executor))]
1176async fn connection_lost(
1177    session: Session,
1178    mut teardown: watch::Receiver<bool>,
1179    executor: Executor,
1180) -> Result<()> {
1181    session.peer.disconnect(session.connection_id);
1182    session
1183        .connection_pool()
1184        .await
1185        .remove_connection(session.connection_id)?;
1186
1187    session
1188        .db()
1189        .await
1190        .connection_lost(session.connection_id)
1191        .await
1192        .trace_err();
1193
1194    futures::select_biased! {
1195        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1196
1197            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1198            leave_room_for_session(&session, session.connection_id).await.trace_err();
1199            leave_channel_buffers_for_session(&session)
1200                .await
1201                .trace_err();
1202
1203            if !session
1204                .connection_pool()
1205                .await
1206                .is_user_online(session.user_id())
1207            {
1208                let db = session.db().await;
1209                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1210                    room_updated(&room, &session.peer);
1211                }
1212            }
1213
1214            update_user_contacts(session.user_id(), &session).await?;
1215        },
1216        _ = teardown.changed().fuse() => {}
1217    }
1218
1219    Ok(())
1220}
1221
1222/// Acknowledges a ping from a client, used to keep the connection alive.
1223async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1224    response.send(proto::Ack {})?;
1225    Ok(())
1226}
1227
1228/// Creates a new room for calling (outside of channels)
1229async fn create_room(
1230    _request: proto::CreateRoom,
1231    response: Response<proto::CreateRoom>,
1232    session: Session,
1233) -> Result<()> {
1234    let livekit_room = nanoid::nanoid!(30);
1235
1236    let live_kit_connection_info = util::maybe!(async {
1237        let live_kit = session.app_state.livekit_client.as_ref();
1238        let live_kit = live_kit?;
1239        let user_id = session.user_id().to_string();
1240
1241        let token = live_kit
1242            .room_token(&livekit_room, &user_id.to_string())
1243            .trace_err()?;
1244
1245        Some(proto::LiveKitConnectionInfo {
1246            server_url: live_kit.url().into(),
1247            token,
1248            can_publish: true,
1249        })
1250    })
1251    .await;
1252
1253    let room = session
1254        .db()
1255        .await
1256        .create_room(session.user_id(), session.connection_id, &livekit_room)
1257        .await?;
1258
1259    response.send(proto::CreateRoomResponse {
1260        room: Some(room.clone()),
1261        live_kit_connection_info,
1262    })?;
1263
1264    update_user_contacts(session.user_id(), &session).await?;
1265    Ok(())
1266}
1267
1268/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1269async fn join_room(
1270    request: proto::JoinRoom,
1271    response: Response<proto::JoinRoom>,
1272    session: Session,
1273) -> Result<()> {
1274    let room_id = RoomId::from_proto(request.id);
1275
1276    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1277
1278    if let Some(channel_id) = channel_id {
1279        return join_channel_internal(channel_id, Box::new(response), session).await;
1280    }
1281
1282    let joined_room = {
1283        let room = session
1284            .db()
1285            .await
1286            .join_room(room_id, session.user_id(), session.connection_id)
1287            .await?;
1288        room_updated(&room.room, &session.peer);
1289        room.into_inner()
1290    };
1291
1292    for connection_id in session
1293        .connection_pool()
1294        .await
1295        .user_connection_ids(session.user_id())
1296    {
1297        session
1298            .peer
1299            .send(
1300                connection_id,
1301                proto::CallCanceled {
1302                    room_id: room_id.to_proto(),
1303                },
1304            )
1305            .trace_err();
1306    }
1307
1308    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1309    {
1310        live_kit
1311            .room_token(
1312                &joined_room.room.livekit_room,
1313                &session.user_id().to_string(),
1314            )
1315            .trace_err()
1316            .map(|token| proto::LiveKitConnectionInfo {
1317                server_url: live_kit.url().into(),
1318                token,
1319                can_publish: true,
1320            })
1321    } else {
1322        None
1323    };
1324
1325    response.send(proto::JoinRoomResponse {
1326        room: Some(joined_room.room),
1327        channel_id: None,
1328        live_kit_connection_info,
1329    })?;
1330
1331    update_user_contacts(session.user_id(), &session).await?;
1332    Ok(())
1333}
1334
1335/// Rejoin room is used to reconnect to a room after connection errors.
1336async fn rejoin_room(
1337    request: proto::RejoinRoom,
1338    response: Response<proto::RejoinRoom>,
1339    session: Session,
1340) -> Result<()> {
1341    let room;
1342    let channel;
1343    {
1344        let mut rejoined_room = session
1345            .db()
1346            .await
1347            .rejoin_room(request, session.user_id(), session.connection_id)
1348            .await?;
1349
1350        response.send(proto::RejoinRoomResponse {
1351            room: Some(rejoined_room.room.clone()),
1352            reshared_projects: rejoined_room
1353                .reshared_projects
1354                .iter()
1355                .map(|project| proto::ResharedProject {
1356                    id: project.id.to_proto(),
1357                    collaborators: project
1358                        .collaborators
1359                        .iter()
1360                        .map(|collaborator| collaborator.to_proto())
1361                        .collect(),
1362                })
1363                .collect(),
1364            rejoined_projects: rejoined_room
1365                .rejoined_projects
1366                .iter()
1367                .map(|rejoined_project| rejoined_project.to_proto())
1368                .collect(),
1369        })?;
1370        room_updated(&rejoined_room.room, &session.peer);
1371
1372        for project in &rejoined_room.reshared_projects {
1373            for collaborator in &project.collaborators {
1374                session
1375                    .peer
1376                    .send(
1377                        collaborator.connection_id,
1378                        proto::UpdateProjectCollaborator {
1379                            project_id: project.id.to_proto(),
1380                            old_peer_id: Some(project.old_connection_id.into()),
1381                            new_peer_id: Some(session.connection_id.into()),
1382                        },
1383                    )
1384                    .trace_err();
1385            }
1386
1387            broadcast(
1388                Some(session.connection_id),
1389                project
1390                    .collaborators
1391                    .iter()
1392                    .map(|collaborator| collaborator.connection_id),
1393                |connection_id| {
1394                    session.peer.forward_send(
1395                        session.connection_id,
1396                        connection_id,
1397                        proto::UpdateProject {
1398                            project_id: project.id.to_proto(),
1399                            worktrees: project.worktrees.clone(),
1400                        },
1401                    )
1402                },
1403            );
1404        }
1405
1406        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1407
1408        let rejoined_room = rejoined_room.into_inner();
1409
1410        room = rejoined_room.room;
1411        channel = rejoined_room.channel;
1412    }
1413
1414    if let Some(channel) = channel {
1415        channel_updated(
1416            &channel,
1417            &room,
1418            &session.peer,
1419            &*session.connection_pool().await,
1420        );
1421    }
1422
1423    update_user_contacts(session.user_id(), &session).await?;
1424    Ok(())
1425}
1426
1427fn notify_rejoined_projects(
1428    rejoined_projects: &mut Vec<RejoinedProject>,
1429    session: &Session,
1430) -> Result<()> {
1431    for project in rejoined_projects.iter() {
1432        for collaborator in &project.collaborators {
1433            session
1434                .peer
1435                .send(
1436                    collaborator.connection_id,
1437                    proto::UpdateProjectCollaborator {
1438                        project_id: project.id.to_proto(),
1439                        old_peer_id: Some(project.old_connection_id.into()),
1440                        new_peer_id: Some(session.connection_id.into()),
1441                    },
1442                )
1443                .trace_err();
1444        }
1445    }
1446
1447    for project in rejoined_projects {
1448        for worktree in mem::take(&mut project.worktrees) {
1449            // Stream this worktree's entries.
1450            let message = proto::UpdateWorktree {
1451                project_id: project.id.to_proto(),
1452                worktree_id: worktree.id,
1453                abs_path: worktree.abs_path.clone(),
1454                root_name: worktree.root_name,
1455                updated_entries: worktree.updated_entries,
1456                removed_entries: worktree.removed_entries,
1457                scan_id: worktree.scan_id,
1458                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1459                updated_repositories: worktree.updated_repositories,
1460                removed_repositories: worktree.removed_repositories,
1461            };
1462            for update in proto::split_worktree_update(message) {
1463                session.peer.send(session.connection_id, update.clone())?;
1464            }
1465
1466            // Stream this worktree's diagnostics.
1467            for summary in worktree.diagnostic_summaries {
1468                session.peer.send(
1469                    session.connection_id,
1470                    proto::UpdateDiagnosticSummary {
1471                        project_id: project.id.to_proto(),
1472                        worktree_id: worktree.id,
1473                        summary: Some(summary),
1474                    },
1475                )?;
1476            }
1477
1478            for settings_file in worktree.settings_files {
1479                session.peer.send(
1480                    session.connection_id,
1481                    proto::UpdateWorktreeSettings {
1482                        project_id: project.id.to_proto(),
1483                        worktree_id: worktree.id,
1484                        path: settings_file.path,
1485                        content: Some(settings_file.content),
1486                        kind: Some(settings_file.kind.to_proto().into()),
1487                    },
1488                )?;
1489            }
1490        }
1491
1492        for language_server in &project.language_servers {
1493            session.peer.send(
1494                session.connection_id,
1495                proto::UpdateLanguageServer {
1496                    project_id: project.id.to_proto(),
1497                    language_server_id: language_server.id,
1498                    variant: Some(
1499                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1500                            proto::LspDiskBasedDiagnosticsUpdated {},
1501                        ),
1502                    ),
1503                },
1504            )?;
1505        }
1506    }
1507    Ok(())
1508}
1509
1510/// leave room disconnects from the room.
1511async fn leave_room(
1512    _: proto::LeaveRoom,
1513    response: Response<proto::LeaveRoom>,
1514    session: Session,
1515) -> Result<()> {
1516    leave_room_for_session(&session, session.connection_id).await?;
1517    response.send(proto::Ack {})?;
1518    Ok(())
1519}
1520
1521/// Updates the permissions of someone else in the room.
1522async fn set_room_participant_role(
1523    request: proto::SetRoomParticipantRole,
1524    response: Response<proto::SetRoomParticipantRole>,
1525    session: Session,
1526) -> Result<()> {
1527    let user_id = UserId::from_proto(request.user_id);
1528    let role = ChannelRole::from(request.role());
1529
1530    let (livekit_room, can_publish) = {
1531        let room = session
1532            .db()
1533            .await
1534            .set_room_participant_role(
1535                session.user_id(),
1536                RoomId::from_proto(request.room_id),
1537                user_id,
1538                role,
1539            )
1540            .await?;
1541
1542        let livekit_room = room.livekit_room.clone();
1543        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1544        room_updated(&room, &session.peer);
1545        (livekit_room, can_publish)
1546    };
1547
1548    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1549        live_kit
1550            .update_participant(
1551                livekit_room.clone(),
1552                request.user_id.to_string(),
1553                livekit_api::proto::ParticipantPermission {
1554                    can_subscribe: true,
1555                    can_publish,
1556                    can_publish_data: can_publish,
1557                    hidden: false,
1558                    recorder: false,
1559                },
1560            )
1561            .await
1562            .trace_err();
1563    }
1564
1565    response.send(proto::Ack {})?;
1566    Ok(())
1567}
1568
1569/// Call someone else into the current room
1570async fn call(
1571    request: proto::Call,
1572    response: Response<proto::Call>,
1573    session: Session,
1574) -> Result<()> {
1575    let room_id = RoomId::from_proto(request.room_id);
1576    let calling_user_id = session.user_id();
1577    let calling_connection_id = session.connection_id;
1578    let called_user_id = UserId::from_proto(request.called_user_id);
1579    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1580    if !session
1581        .db()
1582        .await
1583        .has_contact(calling_user_id, called_user_id)
1584        .await?
1585    {
1586        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1587    }
1588
1589    let incoming_call = {
1590        let (room, incoming_call) = &mut *session
1591            .db()
1592            .await
1593            .call(
1594                room_id,
1595                calling_user_id,
1596                calling_connection_id,
1597                called_user_id,
1598                initial_project_id,
1599            )
1600            .await?;
1601        room_updated(room, &session.peer);
1602        mem::take(incoming_call)
1603    };
1604    update_user_contacts(called_user_id, &session).await?;
1605
1606    let mut calls = session
1607        .connection_pool()
1608        .await
1609        .user_connection_ids(called_user_id)
1610        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1611        .collect::<FuturesUnordered<_>>();
1612
1613    while let Some(call_response) = calls.next().await {
1614        match call_response.as_ref() {
1615            Ok(_) => {
1616                response.send(proto::Ack {})?;
1617                return Ok(());
1618            }
1619            Err(_) => {
1620                call_response.trace_err();
1621            }
1622        }
1623    }
1624
1625    {
1626        let room = session
1627            .db()
1628            .await
1629            .call_failed(room_id, called_user_id)
1630            .await?;
1631        room_updated(&room, &session.peer);
1632    }
1633    update_user_contacts(called_user_id, &session).await?;
1634
1635    Err(anyhow!("failed to ring user"))?
1636}
1637
1638/// Cancel an outgoing call.
1639async fn cancel_call(
1640    request: proto::CancelCall,
1641    response: Response<proto::CancelCall>,
1642    session: Session,
1643) -> Result<()> {
1644    let called_user_id = UserId::from_proto(request.called_user_id);
1645    let room_id = RoomId::from_proto(request.room_id);
1646    {
1647        let room = session
1648            .db()
1649            .await
1650            .cancel_call(room_id, session.connection_id, called_user_id)
1651            .await?;
1652        room_updated(&room, &session.peer);
1653    }
1654
1655    for connection_id in session
1656        .connection_pool()
1657        .await
1658        .user_connection_ids(called_user_id)
1659    {
1660        session
1661            .peer
1662            .send(
1663                connection_id,
1664                proto::CallCanceled {
1665                    room_id: room_id.to_proto(),
1666                },
1667            )
1668            .trace_err();
1669    }
1670    response.send(proto::Ack {})?;
1671
1672    update_user_contacts(called_user_id, &session).await?;
1673    Ok(())
1674}
1675
1676/// Decline an incoming call.
1677async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1678    let room_id = RoomId::from_proto(message.room_id);
1679    {
1680        let room = session
1681            .db()
1682            .await
1683            .decline_call(Some(room_id), session.user_id())
1684            .await?
1685            .ok_or_else(|| anyhow!("failed to decline call"))?;
1686        room_updated(&room, &session.peer);
1687    }
1688
1689    for connection_id in session
1690        .connection_pool()
1691        .await
1692        .user_connection_ids(session.user_id())
1693    {
1694        session
1695            .peer
1696            .send(
1697                connection_id,
1698                proto::CallCanceled {
1699                    room_id: room_id.to_proto(),
1700                },
1701            )
1702            .trace_err();
1703    }
1704    update_user_contacts(session.user_id(), &session).await?;
1705    Ok(())
1706}
1707
1708/// Updates other participants in the room with your current location.
1709async fn update_participant_location(
1710    request: proto::UpdateParticipantLocation,
1711    response: Response<proto::UpdateParticipantLocation>,
1712    session: Session,
1713) -> Result<()> {
1714    let room_id = RoomId::from_proto(request.room_id);
1715    let location = request
1716        .location
1717        .ok_or_else(|| anyhow!("invalid location"))?;
1718
1719    let db = session.db().await;
1720    let room = db
1721        .update_room_participant_location(room_id, session.connection_id, location)
1722        .await?;
1723
1724    room_updated(&room, &session.peer);
1725    response.send(proto::Ack {})?;
1726    Ok(())
1727}
1728
1729/// Share a project into the room.
1730async fn share_project(
1731    request: proto::ShareProject,
1732    response: Response<proto::ShareProject>,
1733    session: Session,
1734) -> Result<()> {
1735    let (project_id, room) = &*session
1736        .db()
1737        .await
1738        .share_project(
1739            RoomId::from_proto(request.room_id),
1740            session.connection_id,
1741            &request.worktrees,
1742            request.is_ssh_project,
1743        )
1744        .await?;
1745    response.send(proto::ShareProjectResponse {
1746        project_id: project_id.to_proto(),
1747    })?;
1748    room_updated(room, &session.peer);
1749
1750    Ok(())
1751}
1752
1753/// Unshare a project from the room.
1754async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1755    let project_id = ProjectId::from_proto(message.project_id);
1756    unshare_project_internal(project_id, session.connection_id, &session).await
1757}
1758
1759async fn unshare_project_internal(
1760    project_id: ProjectId,
1761    connection_id: ConnectionId,
1762    session: &Session,
1763) -> Result<()> {
1764    let delete = {
1765        let room_guard = session
1766            .db()
1767            .await
1768            .unshare_project(project_id, connection_id)
1769            .await?;
1770
1771        let (delete, room, guest_connection_ids) = &*room_guard;
1772
1773        let message = proto::UnshareProject {
1774            project_id: project_id.to_proto(),
1775        };
1776
1777        broadcast(
1778            Some(connection_id),
1779            guest_connection_ids.iter().copied(),
1780            |conn_id| session.peer.send(conn_id, message.clone()),
1781        );
1782        if let Some(room) = room {
1783            room_updated(room, &session.peer);
1784        }
1785
1786        *delete
1787    };
1788
1789    if delete {
1790        let db = session.db().await;
1791        db.delete_project(project_id).await?;
1792    }
1793
1794    Ok(())
1795}
1796
1797/// Join someone elses shared project.
1798async fn join_project(
1799    request: proto::JoinProject,
1800    response: Response<proto::JoinProject>,
1801    session: Session,
1802) -> Result<()> {
1803    let project_id = ProjectId::from_proto(request.project_id);
1804
1805    tracing::info!(%project_id, "join project");
1806
1807    let db = session.db().await;
1808    let (project, replica_id) = &mut *db
1809        .join_project(project_id, session.connection_id, session.user_id())
1810        .await?;
1811    drop(db);
1812    tracing::info!(%project_id, "join remote project");
1813    join_project_internal(response, session, project, replica_id)
1814}
1815
1816trait JoinProjectInternalResponse {
1817    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1818}
1819impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1820    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1821        Response::<proto::JoinProject>::send(self, result)
1822    }
1823}
1824
1825fn join_project_internal(
1826    response: impl JoinProjectInternalResponse,
1827    session: Session,
1828    project: &mut Project,
1829    replica_id: &ReplicaId,
1830) -> Result<()> {
1831    let collaborators = project
1832        .collaborators
1833        .iter()
1834        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1835        .map(|collaborator| collaborator.to_proto())
1836        .collect::<Vec<_>>();
1837    let project_id = project.id;
1838    let guest_user_id = session.user_id();
1839
1840    let worktrees = project
1841        .worktrees
1842        .iter()
1843        .map(|(id, worktree)| proto::WorktreeMetadata {
1844            id: *id,
1845            root_name: worktree.root_name.clone(),
1846            visible: worktree.visible,
1847            abs_path: worktree.abs_path.clone(),
1848        })
1849        .collect::<Vec<_>>();
1850
1851    let add_project_collaborator = proto::AddProjectCollaborator {
1852        project_id: project_id.to_proto(),
1853        collaborator: Some(proto::Collaborator {
1854            peer_id: Some(session.connection_id.into()),
1855            replica_id: replica_id.0 as u32,
1856            user_id: guest_user_id.to_proto(),
1857            is_host: false,
1858        }),
1859    };
1860
1861    for collaborator in &collaborators {
1862        session
1863            .peer
1864            .send(
1865                collaborator.peer_id.unwrap().into(),
1866                add_project_collaborator.clone(),
1867            )
1868            .trace_err();
1869    }
1870
1871    // First, we send the metadata associated with each worktree.
1872    response.send(proto::JoinProjectResponse {
1873        project_id: project.id.0 as u64,
1874        worktrees: worktrees.clone(),
1875        replica_id: replica_id.0 as u32,
1876        collaborators: collaborators.clone(),
1877        language_servers: project.language_servers.clone(),
1878        role: project.role.into(),
1879    })?;
1880
1881    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1882        // Stream this worktree's entries.
1883        let message = proto::UpdateWorktree {
1884            project_id: project_id.to_proto(),
1885            worktree_id,
1886            abs_path: worktree.abs_path.clone(),
1887            root_name: worktree.root_name,
1888            updated_entries: worktree.entries,
1889            removed_entries: Default::default(),
1890            scan_id: worktree.scan_id,
1891            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1892            updated_repositories: worktree.repository_entries.into_values().collect(),
1893            removed_repositories: Default::default(),
1894        };
1895        for update in proto::split_worktree_update(message) {
1896            session.peer.send(session.connection_id, update.clone())?;
1897        }
1898
1899        // Stream this worktree's diagnostics.
1900        for summary in worktree.diagnostic_summaries {
1901            session.peer.send(
1902                session.connection_id,
1903                proto::UpdateDiagnosticSummary {
1904                    project_id: project_id.to_proto(),
1905                    worktree_id: worktree.id,
1906                    summary: Some(summary),
1907                },
1908            )?;
1909        }
1910
1911        for settings_file in worktree.settings_files {
1912            session.peer.send(
1913                session.connection_id,
1914                proto::UpdateWorktreeSettings {
1915                    project_id: project_id.to_proto(),
1916                    worktree_id: worktree.id,
1917                    path: settings_file.path,
1918                    content: Some(settings_file.content),
1919                    kind: Some(settings_file.kind.to_proto() as i32),
1920                },
1921            )?;
1922        }
1923    }
1924
1925    for language_server in &project.language_servers {
1926        session.peer.send(
1927            session.connection_id,
1928            proto::UpdateLanguageServer {
1929                project_id: project_id.to_proto(),
1930                language_server_id: language_server.id,
1931                variant: Some(
1932                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1933                        proto::LspDiskBasedDiagnosticsUpdated {},
1934                    ),
1935                ),
1936            },
1937        )?;
1938    }
1939
1940    Ok(())
1941}
1942
1943/// Leave someone elses shared project.
1944async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1945    let sender_id = session.connection_id;
1946    let project_id = ProjectId::from_proto(request.project_id);
1947    let db = session.db().await;
1948
1949    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1950    tracing::info!(
1951        %project_id,
1952        "leave project"
1953    );
1954
1955    project_left(project, &session);
1956    if let Some(room) = room {
1957        room_updated(room, &session.peer);
1958    }
1959
1960    Ok(())
1961}
1962
1963/// Updates other participants with changes to the project
1964async fn update_project(
1965    request: proto::UpdateProject,
1966    response: Response<proto::UpdateProject>,
1967    session: Session,
1968) -> Result<()> {
1969    let project_id = ProjectId::from_proto(request.project_id);
1970    let (room, guest_connection_ids) = &*session
1971        .db()
1972        .await
1973        .update_project(project_id, session.connection_id, &request.worktrees)
1974        .await?;
1975    broadcast(
1976        Some(session.connection_id),
1977        guest_connection_ids.iter().copied(),
1978        |connection_id| {
1979            session
1980                .peer
1981                .forward_send(session.connection_id, connection_id, request.clone())
1982        },
1983    );
1984    if let Some(room) = room {
1985        room_updated(room, &session.peer);
1986    }
1987    response.send(proto::Ack {})?;
1988
1989    Ok(())
1990}
1991
1992/// Updates other participants with changes to the worktree
1993async fn update_worktree(
1994    request: proto::UpdateWorktree,
1995    response: Response<proto::UpdateWorktree>,
1996    session: Session,
1997) -> Result<()> {
1998    let guest_connection_ids = session
1999        .db()
2000        .await
2001        .update_worktree(&request, session.connection_id)
2002        .await?;
2003
2004    broadcast(
2005        Some(session.connection_id),
2006        guest_connection_ids.iter().copied(),
2007        |connection_id| {
2008            session
2009                .peer
2010                .forward_send(session.connection_id, connection_id, request.clone())
2011        },
2012    );
2013    response.send(proto::Ack {})?;
2014    Ok(())
2015}
2016
2017/// Updates other participants with changes to the diagnostics
2018async fn update_diagnostic_summary(
2019    message: proto::UpdateDiagnosticSummary,
2020    session: Session,
2021) -> Result<()> {
2022    let guest_connection_ids = session
2023        .db()
2024        .await
2025        .update_diagnostic_summary(&message, session.connection_id)
2026        .await?;
2027
2028    broadcast(
2029        Some(session.connection_id),
2030        guest_connection_ids.iter().copied(),
2031        |connection_id| {
2032            session
2033                .peer
2034                .forward_send(session.connection_id, connection_id, message.clone())
2035        },
2036    );
2037
2038    Ok(())
2039}
2040
2041/// Updates other participants with changes to the worktree settings
2042async fn update_worktree_settings(
2043    message: proto::UpdateWorktreeSettings,
2044    session: Session,
2045) -> Result<()> {
2046    let guest_connection_ids = session
2047        .db()
2048        .await
2049        .update_worktree_settings(&message, session.connection_id)
2050        .await?;
2051
2052    broadcast(
2053        Some(session.connection_id),
2054        guest_connection_ids.iter().copied(),
2055        |connection_id| {
2056            session
2057                .peer
2058                .forward_send(session.connection_id, connection_id, message.clone())
2059        },
2060    );
2061
2062    Ok(())
2063}
2064
2065/// Notify other participants that a  language server has started.
2066async fn start_language_server(
2067    request: proto::StartLanguageServer,
2068    session: Session,
2069) -> Result<()> {
2070    let guest_connection_ids = session
2071        .db()
2072        .await
2073        .start_language_server(&request, session.connection_id)
2074        .await?;
2075
2076    broadcast(
2077        Some(session.connection_id),
2078        guest_connection_ids.iter().copied(),
2079        |connection_id| {
2080            session
2081                .peer
2082                .forward_send(session.connection_id, connection_id, request.clone())
2083        },
2084    );
2085    Ok(())
2086}
2087
2088/// Notify other participants that a language server has changed.
2089async fn update_language_server(
2090    request: proto::UpdateLanguageServer,
2091    session: Session,
2092) -> Result<()> {
2093    let project_id = ProjectId::from_proto(request.project_id);
2094    let project_connection_ids = session
2095        .db()
2096        .await
2097        .project_connection_ids(project_id, session.connection_id, true)
2098        .await?;
2099    broadcast(
2100        Some(session.connection_id),
2101        project_connection_ids.iter().copied(),
2102        |connection_id| {
2103            session
2104                .peer
2105                .forward_send(session.connection_id, connection_id, request.clone())
2106        },
2107    );
2108    Ok(())
2109}
2110
2111/// forward a project request to the host. These requests should be read only
2112/// as guests are allowed to send them.
2113async fn forward_read_only_project_request<T>(
2114    request: T,
2115    response: Response<T>,
2116    session: Session,
2117) -> Result<()>
2118where
2119    T: EntityMessage + RequestMessage,
2120{
2121    let project_id = ProjectId::from_proto(request.remote_entity_id());
2122    let host_connection_id = session
2123        .db()
2124        .await
2125        .host_for_read_only_project_request(project_id, session.connection_id)
2126        .await?;
2127    let payload = session
2128        .peer
2129        .forward_request(session.connection_id, host_connection_id, request)
2130        .await?;
2131    response.send(payload)?;
2132    Ok(())
2133}
2134
2135async fn forward_find_search_candidates_request(
2136    request: proto::FindSearchCandidates,
2137    response: Response<proto::FindSearchCandidates>,
2138    session: Session,
2139) -> Result<()> {
2140    let project_id = ProjectId::from_proto(request.remote_entity_id());
2141    let host_connection_id = session
2142        .db()
2143        .await
2144        .host_for_read_only_project_request(project_id, session.connection_id)
2145        .await?;
2146    let payload = session
2147        .peer
2148        .forward_request(session.connection_id, host_connection_id, request)
2149        .await?;
2150    response.send(payload)?;
2151    Ok(())
2152}
2153
2154/// forward a project request to the host. These requests are disallowed
2155/// for guests.
2156async fn forward_mutating_project_request<T>(
2157    request: T,
2158    response: Response<T>,
2159    session: Session,
2160) -> Result<()>
2161where
2162    T: EntityMessage + RequestMessage,
2163{
2164    let project_id = ProjectId::from_proto(request.remote_entity_id());
2165
2166    let host_connection_id = session
2167        .db()
2168        .await
2169        .host_for_mutating_project_request(project_id, session.connection_id)
2170        .await?;
2171    let payload = session
2172        .peer
2173        .forward_request(session.connection_id, host_connection_id, request)
2174        .await?;
2175    response.send(payload)?;
2176    Ok(())
2177}
2178
2179/// Notify other participants that a new buffer has been created
2180async fn create_buffer_for_peer(
2181    request: proto::CreateBufferForPeer,
2182    session: Session,
2183) -> Result<()> {
2184    session
2185        .db()
2186        .await
2187        .check_user_is_project_host(
2188            ProjectId::from_proto(request.project_id),
2189            session.connection_id,
2190        )
2191        .await?;
2192    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2193    session
2194        .peer
2195        .forward_send(session.connection_id, peer_id.into(), request)?;
2196    Ok(())
2197}
2198
2199/// Notify other participants that a buffer has been updated. This is
2200/// allowed for guests as long as the update is limited to selections.
2201async fn update_buffer(
2202    request: proto::UpdateBuffer,
2203    response: Response<proto::UpdateBuffer>,
2204    session: Session,
2205) -> Result<()> {
2206    let project_id = ProjectId::from_proto(request.project_id);
2207    let mut capability = Capability::ReadOnly;
2208
2209    for op in request.operations.iter() {
2210        match op.variant {
2211            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2212            Some(_) => capability = Capability::ReadWrite,
2213        }
2214    }
2215
2216    let host = {
2217        let guard = session
2218            .db()
2219            .await
2220            .connections_for_buffer_update(project_id, session.connection_id, capability)
2221            .await?;
2222
2223        let (host, guests) = &*guard;
2224
2225        broadcast(
2226            Some(session.connection_id),
2227            guests.clone(),
2228            |connection_id| {
2229                session
2230                    .peer
2231                    .forward_send(session.connection_id, connection_id, request.clone())
2232            },
2233        );
2234
2235        *host
2236    };
2237
2238    if host != session.connection_id {
2239        session
2240            .peer
2241            .forward_request(session.connection_id, host, request.clone())
2242            .await?;
2243    }
2244
2245    response.send(proto::Ack {})?;
2246    Ok(())
2247}
2248
2249async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2250    let project_id = ProjectId::from_proto(message.project_id);
2251
2252    let operation = message.operation.as_ref().context("invalid operation")?;
2253    let capability = match operation.variant.as_ref() {
2254        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2255            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2256                match buffer_op.variant {
2257                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2258                        Capability::ReadOnly
2259                    }
2260                    _ => Capability::ReadWrite,
2261                }
2262            } else {
2263                Capability::ReadWrite
2264            }
2265        }
2266        Some(_) => Capability::ReadWrite,
2267        None => Capability::ReadOnly,
2268    };
2269
2270    let guard = session
2271        .db()
2272        .await
2273        .connections_for_buffer_update(project_id, session.connection_id, capability)
2274        .await?;
2275
2276    let (host, guests) = &*guard;
2277
2278    broadcast(
2279        Some(session.connection_id),
2280        guests.iter().chain([host]).copied(),
2281        |connection_id| {
2282            session
2283                .peer
2284                .forward_send(session.connection_id, connection_id, message.clone())
2285        },
2286    );
2287
2288    Ok(())
2289}
2290
2291/// Notify other participants that a project has been updated.
2292async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2293    request: T,
2294    session: Session,
2295) -> Result<()> {
2296    let project_id = ProjectId::from_proto(request.remote_entity_id());
2297    let project_connection_ids = session
2298        .db()
2299        .await
2300        .project_connection_ids(project_id, session.connection_id, false)
2301        .await?;
2302
2303    broadcast(
2304        Some(session.connection_id),
2305        project_connection_ids.iter().copied(),
2306        |connection_id| {
2307            session
2308                .peer
2309                .forward_send(session.connection_id, connection_id, request.clone())
2310        },
2311    );
2312    Ok(())
2313}
2314
2315/// Start following another user in a call.
2316async fn follow(
2317    request: proto::Follow,
2318    response: Response<proto::Follow>,
2319    session: Session,
2320) -> Result<()> {
2321    let room_id = RoomId::from_proto(request.room_id);
2322    let project_id = request.project_id.map(ProjectId::from_proto);
2323    let leader_id = request
2324        .leader_id
2325        .ok_or_else(|| anyhow!("invalid leader id"))?
2326        .into();
2327    let follower_id = session.connection_id;
2328
2329    session
2330        .db()
2331        .await
2332        .check_room_participants(room_id, leader_id, session.connection_id)
2333        .await?;
2334
2335    let response_payload = session
2336        .peer
2337        .forward_request(session.connection_id, leader_id, request)
2338        .await?;
2339    response.send(response_payload)?;
2340
2341    if let Some(project_id) = project_id {
2342        let room = session
2343            .db()
2344            .await
2345            .follow(room_id, project_id, leader_id, follower_id)
2346            .await?;
2347        room_updated(&room, &session.peer);
2348    }
2349
2350    Ok(())
2351}
2352
2353/// Stop following another user in a call.
2354async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2355    let room_id = RoomId::from_proto(request.room_id);
2356    let project_id = request.project_id.map(ProjectId::from_proto);
2357    let leader_id = request
2358        .leader_id
2359        .ok_or_else(|| anyhow!("invalid leader id"))?
2360        .into();
2361    let follower_id = session.connection_id;
2362
2363    session
2364        .db()
2365        .await
2366        .check_room_participants(room_id, leader_id, session.connection_id)
2367        .await?;
2368
2369    session
2370        .peer
2371        .forward_send(session.connection_id, leader_id, request)?;
2372
2373    if let Some(project_id) = project_id {
2374        let room = session
2375            .db()
2376            .await
2377            .unfollow(room_id, project_id, leader_id, follower_id)
2378            .await?;
2379        room_updated(&room, &session.peer);
2380    }
2381
2382    Ok(())
2383}
2384
2385/// Notify everyone following you of your current location.
2386async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2387    let room_id = RoomId::from_proto(request.room_id);
2388    let database = session.db.lock().await;
2389
2390    let connection_ids = if let Some(project_id) = request.project_id {
2391        let project_id = ProjectId::from_proto(project_id);
2392        database
2393            .project_connection_ids(project_id, session.connection_id, true)
2394            .await?
2395    } else {
2396        database
2397            .room_connection_ids(room_id, session.connection_id)
2398            .await?
2399    };
2400
2401    // For now, don't send view update messages back to that view's current leader.
2402    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2403        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2404        _ => None,
2405    });
2406
2407    for connection_id in connection_ids.iter().cloned() {
2408        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2409            session
2410                .peer
2411                .forward_send(session.connection_id, connection_id, request.clone())?;
2412        }
2413    }
2414    Ok(())
2415}
2416
2417/// Get public data about users.
2418async fn get_users(
2419    request: proto::GetUsers,
2420    response: Response<proto::GetUsers>,
2421    session: Session,
2422) -> Result<()> {
2423    let user_ids = request
2424        .user_ids
2425        .into_iter()
2426        .map(UserId::from_proto)
2427        .collect();
2428    let users = session
2429        .db()
2430        .await
2431        .get_users_by_ids(user_ids)
2432        .await?
2433        .into_iter()
2434        .map(|user| proto::User {
2435            id: user.id.to_proto(),
2436            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2437            github_login: user.github_login,
2438            email: user.email_address,
2439            name: user.name,
2440        })
2441        .collect();
2442    response.send(proto::UsersResponse { users })?;
2443    Ok(())
2444}
2445
2446/// Search for users (to invite) buy Github login
2447async fn fuzzy_search_users(
2448    request: proto::FuzzySearchUsers,
2449    response: Response<proto::FuzzySearchUsers>,
2450    session: Session,
2451) -> Result<()> {
2452    let query = request.query;
2453    let users = match query.len() {
2454        0 => vec![],
2455        1 | 2 => session
2456            .db()
2457            .await
2458            .get_user_by_github_login(&query)
2459            .await?
2460            .into_iter()
2461            .collect(),
2462        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2463    };
2464    let users = users
2465        .into_iter()
2466        .filter(|user| user.id != session.user_id())
2467        .map(|user| proto::User {
2468            id: user.id.to_proto(),
2469            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2470            github_login: user.github_login,
2471            name: user.name,
2472            email: user.email_address,
2473        })
2474        .collect();
2475    response.send(proto::UsersResponse { users })?;
2476    Ok(())
2477}
2478
2479/// Send a contact request to another user.
2480async fn request_contact(
2481    request: proto::RequestContact,
2482    response: Response<proto::RequestContact>,
2483    session: Session,
2484) -> Result<()> {
2485    let requester_id = session.user_id();
2486    let responder_id = UserId::from_proto(request.responder_id);
2487    if requester_id == responder_id {
2488        return Err(anyhow!("cannot add yourself as a contact"))?;
2489    }
2490
2491    let notifications = session
2492        .db()
2493        .await
2494        .send_contact_request(requester_id, responder_id)
2495        .await?;
2496
2497    // Update outgoing contact requests of requester
2498    let mut update = proto::UpdateContacts::default();
2499    update.outgoing_requests.push(responder_id.to_proto());
2500    for connection_id in session
2501        .connection_pool()
2502        .await
2503        .user_connection_ids(requester_id)
2504    {
2505        session.peer.send(connection_id, update.clone())?;
2506    }
2507
2508    // Update incoming contact requests of responder
2509    let mut update = proto::UpdateContacts::default();
2510    update
2511        .incoming_requests
2512        .push(proto::IncomingContactRequest {
2513            requester_id: requester_id.to_proto(),
2514        });
2515    let connection_pool = session.connection_pool().await;
2516    for connection_id in connection_pool.user_connection_ids(responder_id) {
2517        session.peer.send(connection_id, update.clone())?;
2518    }
2519
2520    send_notifications(&connection_pool, &session.peer, notifications);
2521
2522    response.send(proto::Ack {})?;
2523    Ok(())
2524}
2525
2526/// Accept or decline a contact request
2527async fn respond_to_contact_request(
2528    request: proto::RespondToContactRequest,
2529    response: Response<proto::RespondToContactRequest>,
2530    session: Session,
2531) -> Result<()> {
2532    let responder_id = session.user_id();
2533    let requester_id = UserId::from_proto(request.requester_id);
2534    let db = session.db().await;
2535    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2536        db.dismiss_contact_notification(responder_id, requester_id)
2537            .await?;
2538    } else {
2539        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2540
2541        let notifications = db
2542            .respond_to_contact_request(responder_id, requester_id, accept)
2543            .await?;
2544        let requester_busy = db.is_user_busy(requester_id).await?;
2545        let responder_busy = db.is_user_busy(responder_id).await?;
2546
2547        let pool = session.connection_pool().await;
2548        // Update responder with new contact
2549        let mut update = proto::UpdateContacts::default();
2550        if accept {
2551            update
2552                .contacts
2553                .push(contact_for_user(requester_id, requester_busy, &pool));
2554        }
2555        update
2556            .remove_incoming_requests
2557            .push(requester_id.to_proto());
2558        for connection_id in pool.user_connection_ids(responder_id) {
2559            session.peer.send(connection_id, update.clone())?;
2560        }
2561
2562        // Update requester with new contact
2563        let mut update = proto::UpdateContacts::default();
2564        if accept {
2565            update
2566                .contacts
2567                .push(contact_for_user(responder_id, responder_busy, &pool));
2568        }
2569        update
2570            .remove_outgoing_requests
2571            .push(responder_id.to_proto());
2572
2573        for connection_id in pool.user_connection_ids(requester_id) {
2574            session.peer.send(connection_id, update.clone())?;
2575        }
2576
2577        send_notifications(&pool, &session.peer, notifications);
2578    }
2579
2580    response.send(proto::Ack {})?;
2581    Ok(())
2582}
2583
2584/// Remove a contact.
2585async fn remove_contact(
2586    request: proto::RemoveContact,
2587    response: Response<proto::RemoveContact>,
2588    session: Session,
2589) -> Result<()> {
2590    let requester_id = session.user_id();
2591    let responder_id = UserId::from_proto(request.user_id);
2592    let db = session.db().await;
2593    let (contact_accepted, deleted_notification_id) =
2594        db.remove_contact(requester_id, responder_id).await?;
2595
2596    let pool = session.connection_pool().await;
2597    // Update outgoing contact requests of requester
2598    let mut update = proto::UpdateContacts::default();
2599    if contact_accepted {
2600        update.remove_contacts.push(responder_id.to_proto());
2601    } else {
2602        update
2603            .remove_outgoing_requests
2604            .push(responder_id.to_proto());
2605    }
2606    for connection_id in pool.user_connection_ids(requester_id) {
2607        session.peer.send(connection_id, update.clone())?;
2608    }
2609
2610    // Update incoming contact requests of responder
2611    let mut update = proto::UpdateContacts::default();
2612    if contact_accepted {
2613        update.remove_contacts.push(requester_id.to_proto());
2614    } else {
2615        update
2616            .remove_incoming_requests
2617            .push(requester_id.to_proto());
2618    }
2619    for connection_id in pool.user_connection_ids(responder_id) {
2620        session.peer.send(connection_id, update.clone())?;
2621        if let Some(notification_id) = deleted_notification_id {
2622            session.peer.send(
2623                connection_id,
2624                proto::DeleteNotification {
2625                    notification_id: notification_id.to_proto(),
2626                },
2627            )?;
2628        }
2629    }
2630
2631    response.send(proto::Ack {})?;
2632    Ok(())
2633}
2634
2635fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2636    version.0.minor() < 139
2637}
2638
2639async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2640    let plan = session.current_plan(&session.db().await).await?;
2641
2642    session
2643        .peer
2644        .send(
2645            session.connection_id,
2646            proto::UpdateUserPlan { plan: plan.into() },
2647        )
2648        .trace_err();
2649
2650    Ok(())
2651}
2652
2653async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2654    subscribe_user_to_channels(session.user_id(), &session).await?;
2655    Ok(())
2656}
2657
2658async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2659    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2660    let mut pool = session.connection_pool().await;
2661    for membership in &channels_for_user.channel_memberships {
2662        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2663    }
2664    session.peer.send(
2665        session.connection_id,
2666        build_update_user_channels(&channels_for_user),
2667    )?;
2668    session.peer.send(
2669        session.connection_id,
2670        build_channels_update(channels_for_user),
2671    )?;
2672    Ok(())
2673}
2674
2675/// Creates a new channel.
2676async fn create_channel(
2677    request: proto::CreateChannel,
2678    response: Response<proto::CreateChannel>,
2679    session: Session,
2680) -> Result<()> {
2681    let db = session.db().await;
2682
2683    let parent_id = request.parent_id.map(ChannelId::from_proto);
2684    let (channel, membership) = db
2685        .create_channel(&request.name, parent_id, session.user_id())
2686        .await?;
2687
2688    let root_id = channel.root_id();
2689    let channel = Channel::from_model(channel);
2690
2691    response.send(proto::CreateChannelResponse {
2692        channel: Some(channel.to_proto()),
2693        parent_id: request.parent_id,
2694    })?;
2695
2696    let mut connection_pool = session.connection_pool().await;
2697    if let Some(membership) = membership {
2698        connection_pool.subscribe_to_channel(
2699            membership.user_id,
2700            membership.channel_id,
2701            membership.role,
2702        );
2703        let update = proto::UpdateUserChannels {
2704            channel_memberships: vec![proto::ChannelMembership {
2705                channel_id: membership.channel_id.to_proto(),
2706                role: membership.role.into(),
2707            }],
2708            ..Default::default()
2709        };
2710        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2711            session.peer.send(connection_id, update.clone())?;
2712        }
2713    }
2714
2715    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2716        if !role.can_see_channel(channel.visibility) {
2717            continue;
2718        }
2719
2720        let update = proto::UpdateChannels {
2721            channels: vec![channel.to_proto()],
2722            ..Default::default()
2723        };
2724        session.peer.send(connection_id, update.clone())?;
2725    }
2726
2727    Ok(())
2728}
2729
2730/// Delete a channel
2731async fn delete_channel(
2732    request: proto::DeleteChannel,
2733    response: Response<proto::DeleteChannel>,
2734    session: Session,
2735) -> Result<()> {
2736    let db = session.db().await;
2737
2738    let channel_id = request.channel_id;
2739    let (root_channel, removed_channels) = db
2740        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2741        .await?;
2742    response.send(proto::Ack {})?;
2743
2744    // Notify members of removed channels
2745    let mut update = proto::UpdateChannels::default();
2746    update
2747        .delete_channels
2748        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2749
2750    let connection_pool = session.connection_pool().await;
2751    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2752        session.peer.send(connection_id, update.clone())?;
2753    }
2754
2755    Ok(())
2756}
2757
2758/// Invite someone to join a channel.
2759async fn invite_channel_member(
2760    request: proto::InviteChannelMember,
2761    response: Response<proto::InviteChannelMember>,
2762    session: Session,
2763) -> Result<()> {
2764    let db = session.db().await;
2765    let channel_id = ChannelId::from_proto(request.channel_id);
2766    let invitee_id = UserId::from_proto(request.user_id);
2767    let InviteMemberResult {
2768        channel,
2769        notifications,
2770    } = db
2771        .invite_channel_member(
2772            channel_id,
2773            invitee_id,
2774            session.user_id(),
2775            request.role().into(),
2776        )
2777        .await?;
2778
2779    let update = proto::UpdateChannels {
2780        channel_invitations: vec![channel.to_proto()],
2781        ..Default::default()
2782    };
2783
2784    let connection_pool = session.connection_pool().await;
2785    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2786        session.peer.send(connection_id, update.clone())?;
2787    }
2788
2789    send_notifications(&connection_pool, &session.peer, notifications);
2790
2791    response.send(proto::Ack {})?;
2792    Ok(())
2793}
2794
2795/// remove someone from a channel
2796async fn remove_channel_member(
2797    request: proto::RemoveChannelMember,
2798    response: Response<proto::RemoveChannelMember>,
2799    session: Session,
2800) -> Result<()> {
2801    let db = session.db().await;
2802    let channel_id = ChannelId::from_proto(request.channel_id);
2803    let member_id = UserId::from_proto(request.user_id);
2804
2805    let RemoveChannelMemberResult {
2806        membership_update,
2807        notification_id,
2808    } = db
2809        .remove_channel_member(channel_id, member_id, session.user_id())
2810        .await?;
2811
2812    let mut connection_pool = session.connection_pool().await;
2813    notify_membership_updated(
2814        &mut connection_pool,
2815        membership_update,
2816        member_id,
2817        &session.peer,
2818    );
2819    for connection_id in connection_pool.user_connection_ids(member_id) {
2820        if let Some(notification_id) = notification_id {
2821            session
2822                .peer
2823                .send(
2824                    connection_id,
2825                    proto::DeleteNotification {
2826                        notification_id: notification_id.to_proto(),
2827                    },
2828                )
2829                .trace_err();
2830        }
2831    }
2832
2833    response.send(proto::Ack {})?;
2834    Ok(())
2835}
2836
2837/// Toggle the channel between public and private.
2838/// Care is taken to maintain the invariant that public channels only descend from public channels,
2839/// (though members-only channels can appear at any point in the hierarchy).
2840async fn set_channel_visibility(
2841    request: proto::SetChannelVisibility,
2842    response: Response<proto::SetChannelVisibility>,
2843    session: Session,
2844) -> Result<()> {
2845    let db = session.db().await;
2846    let channel_id = ChannelId::from_proto(request.channel_id);
2847    let visibility = request.visibility().into();
2848
2849    let channel_model = db
2850        .set_channel_visibility(channel_id, visibility, session.user_id())
2851        .await?;
2852    let root_id = channel_model.root_id();
2853    let channel = Channel::from_model(channel_model);
2854
2855    let mut connection_pool = session.connection_pool().await;
2856    for (user_id, role) in connection_pool
2857        .channel_user_ids(root_id)
2858        .collect::<Vec<_>>()
2859        .into_iter()
2860    {
2861        let update = if role.can_see_channel(channel.visibility) {
2862            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2863            proto::UpdateChannels {
2864                channels: vec![channel.to_proto()],
2865                ..Default::default()
2866            }
2867        } else {
2868            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2869            proto::UpdateChannels {
2870                delete_channels: vec![channel.id.to_proto()],
2871                ..Default::default()
2872            }
2873        };
2874
2875        for connection_id in connection_pool.user_connection_ids(user_id) {
2876            session.peer.send(connection_id, update.clone())?;
2877        }
2878    }
2879
2880    response.send(proto::Ack {})?;
2881    Ok(())
2882}
2883
2884/// Alter the role for a user in the channel.
2885async fn set_channel_member_role(
2886    request: proto::SetChannelMemberRole,
2887    response: Response<proto::SetChannelMemberRole>,
2888    session: Session,
2889) -> Result<()> {
2890    let db = session.db().await;
2891    let channel_id = ChannelId::from_proto(request.channel_id);
2892    let member_id = UserId::from_proto(request.user_id);
2893    let result = db
2894        .set_channel_member_role(
2895            channel_id,
2896            session.user_id(),
2897            member_id,
2898            request.role().into(),
2899        )
2900        .await?;
2901
2902    match result {
2903        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2904            let mut connection_pool = session.connection_pool().await;
2905            notify_membership_updated(
2906                &mut connection_pool,
2907                membership_update,
2908                member_id,
2909                &session.peer,
2910            )
2911        }
2912        db::SetMemberRoleResult::InviteUpdated(channel) => {
2913            let update = proto::UpdateChannels {
2914                channel_invitations: vec![channel.to_proto()],
2915                ..Default::default()
2916            };
2917
2918            for connection_id in session
2919                .connection_pool()
2920                .await
2921                .user_connection_ids(member_id)
2922            {
2923                session.peer.send(connection_id, update.clone())?;
2924            }
2925        }
2926    }
2927
2928    response.send(proto::Ack {})?;
2929    Ok(())
2930}
2931
2932/// Change the name of a channel
2933async fn rename_channel(
2934    request: proto::RenameChannel,
2935    response: Response<proto::RenameChannel>,
2936    session: Session,
2937) -> Result<()> {
2938    let db = session.db().await;
2939    let channel_id = ChannelId::from_proto(request.channel_id);
2940    let channel_model = db
2941        .rename_channel(channel_id, session.user_id(), &request.name)
2942        .await?;
2943    let root_id = channel_model.root_id();
2944    let channel = Channel::from_model(channel_model);
2945
2946    response.send(proto::RenameChannelResponse {
2947        channel: Some(channel.to_proto()),
2948    })?;
2949
2950    let connection_pool = session.connection_pool().await;
2951    let update = proto::UpdateChannels {
2952        channels: vec![channel.to_proto()],
2953        ..Default::default()
2954    };
2955    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2956        if role.can_see_channel(channel.visibility) {
2957            session.peer.send(connection_id, update.clone())?;
2958        }
2959    }
2960
2961    Ok(())
2962}
2963
2964/// Move a channel to a new parent.
2965async fn move_channel(
2966    request: proto::MoveChannel,
2967    response: Response<proto::MoveChannel>,
2968    session: Session,
2969) -> Result<()> {
2970    let channel_id = ChannelId::from_proto(request.channel_id);
2971    let to = ChannelId::from_proto(request.to);
2972
2973    let (root_id, channels) = session
2974        .db()
2975        .await
2976        .move_channel(channel_id, to, session.user_id())
2977        .await?;
2978
2979    let connection_pool = session.connection_pool().await;
2980    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2981        let channels = channels
2982            .iter()
2983            .filter_map(|channel| {
2984                if role.can_see_channel(channel.visibility) {
2985                    Some(channel.to_proto())
2986                } else {
2987                    None
2988                }
2989            })
2990            .collect::<Vec<_>>();
2991        if channels.is_empty() {
2992            continue;
2993        }
2994
2995        let update = proto::UpdateChannels {
2996            channels,
2997            ..Default::default()
2998        };
2999
3000        session.peer.send(connection_id, update.clone())?;
3001    }
3002
3003    response.send(Ack {})?;
3004    Ok(())
3005}
3006
3007/// Get the list of channel members
3008async fn get_channel_members(
3009    request: proto::GetChannelMembers,
3010    response: Response<proto::GetChannelMembers>,
3011    session: Session,
3012) -> Result<()> {
3013    let db = session.db().await;
3014    let channel_id = ChannelId::from_proto(request.channel_id);
3015    let limit = if request.limit == 0 {
3016        u16::MAX as u64
3017    } else {
3018        request.limit
3019    };
3020    let (members, users) = db
3021        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3022        .await?;
3023    response.send(proto::GetChannelMembersResponse { members, users })?;
3024    Ok(())
3025}
3026
3027/// Accept or decline a channel invitation.
3028async fn respond_to_channel_invite(
3029    request: proto::RespondToChannelInvite,
3030    response: Response<proto::RespondToChannelInvite>,
3031    session: Session,
3032) -> Result<()> {
3033    let db = session.db().await;
3034    let channel_id = ChannelId::from_proto(request.channel_id);
3035    let RespondToChannelInvite {
3036        membership_update,
3037        notifications,
3038    } = db
3039        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3040        .await?;
3041
3042    let mut connection_pool = session.connection_pool().await;
3043    if let Some(membership_update) = membership_update {
3044        notify_membership_updated(
3045            &mut connection_pool,
3046            membership_update,
3047            session.user_id(),
3048            &session.peer,
3049        );
3050    } else {
3051        let update = proto::UpdateChannels {
3052            remove_channel_invitations: vec![channel_id.to_proto()],
3053            ..Default::default()
3054        };
3055
3056        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3057            session.peer.send(connection_id, update.clone())?;
3058        }
3059    };
3060
3061    send_notifications(&connection_pool, &session.peer, notifications);
3062
3063    response.send(proto::Ack {})?;
3064
3065    Ok(())
3066}
3067
3068/// Join the channels' room
3069async fn join_channel(
3070    request: proto::JoinChannel,
3071    response: Response<proto::JoinChannel>,
3072    session: Session,
3073) -> Result<()> {
3074    let channel_id = ChannelId::from_proto(request.channel_id);
3075    join_channel_internal(channel_id, Box::new(response), session).await
3076}
3077
3078trait JoinChannelInternalResponse {
3079    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3080}
3081impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3082    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3083        Response::<proto::JoinChannel>::send(self, result)
3084    }
3085}
3086impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3087    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3088        Response::<proto::JoinRoom>::send(self, result)
3089    }
3090}
3091
3092async fn join_channel_internal(
3093    channel_id: ChannelId,
3094    response: Box<impl JoinChannelInternalResponse>,
3095    session: Session,
3096) -> Result<()> {
3097    let joined_room = {
3098        let mut db = session.db().await;
3099        // If zed quits without leaving the room, and the user re-opens zed before the
3100        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3101        // room they were in.
3102        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3103            tracing::info!(
3104                stale_connection_id = %connection,
3105                "cleaning up stale connection",
3106            );
3107            drop(db);
3108            leave_room_for_session(&session, connection).await?;
3109            db = session.db().await;
3110        }
3111
3112        let (joined_room, membership_updated, role) = db
3113            .join_channel(channel_id, session.user_id(), session.connection_id)
3114            .await?;
3115
3116        let live_kit_connection_info =
3117            session
3118                .app_state
3119                .livekit_client
3120                .as_ref()
3121                .and_then(|live_kit| {
3122                    let (can_publish, token) = if role == ChannelRole::Guest {
3123                        (
3124                            false,
3125                            live_kit
3126                                .guest_token(
3127                                    &joined_room.room.livekit_room,
3128                                    &session.user_id().to_string(),
3129                                )
3130                                .trace_err()?,
3131                        )
3132                    } else {
3133                        (
3134                            true,
3135                            live_kit
3136                                .room_token(
3137                                    &joined_room.room.livekit_room,
3138                                    &session.user_id().to_string(),
3139                                )
3140                                .trace_err()?,
3141                        )
3142                    };
3143
3144                    Some(LiveKitConnectionInfo {
3145                        server_url: live_kit.url().into(),
3146                        token,
3147                        can_publish,
3148                    })
3149                });
3150
3151        response.send(proto::JoinRoomResponse {
3152            room: Some(joined_room.room.clone()),
3153            channel_id: joined_room
3154                .channel
3155                .as_ref()
3156                .map(|channel| channel.id.to_proto()),
3157            live_kit_connection_info,
3158        })?;
3159
3160        let mut connection_pool = session.connection_pool().await;
3161        if let Some(membership_updated) = membership_updated {
3162            notify_membership_updated(
3163                &mut connection_pool,
3164                membership_updated,
3165                session.user_id(),
3166                &session.peer,
3167            );
3168        }
3169
3170        room_updated(&joined_room.room, &session.peer);
3171
3172        joined_room
3173    };
3174
3175    channel_updated(
3176        &joined_room
3177            .channel
3178            .ok_or_else(|| anyhow!("channel not returned"))?,
3179        &joined_room.room,
3180        &session.peer,
3181        &*session.connection_pool().await,
3182    );
3183
3184    update_user_contacts(session.user_id(), &session).await?;
3185    Ok(())
3186}
3187
3188/// Start editing the channel notes
3189async fn join_channel_buffer(
3190    request: proto::JoinChannelBuffer,
3191    response: Response<proto::JoinChannelBuffer>,
3192    session: Session,
3193) -> Result<()> {
3194    let db = session.db().await;
3195    let channel_id = ChannelId::from_proto(request.channel_id);
3196
3197    let open_response = db
3198        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3199        .await?;
3200
3201    let collaborators = open_response.collaborators.clone();
3202    response.send(open_response)?;
3203
3204    let update = UpdateChannelBufferCollaborators {
3205        channel_id: channel_id.to_proto(),
3206        collaborators: collaborators.clone(),
3207    };
3208    channel_buffer_updated(
3209        session.connection_id,
3210        collaborators
3211            .iter()
3212            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3213        &update,
3214        &session.peer,
3215    );
3216
3217    Ok(())
3218}
3219
3220/// Edit the channel notes
3221async fn update_channel_buffer(
3222    request: proto::UpdateChannelBuffer,
3223    session: Session,
3224) -> Result<()> {
3225    let db = session.db().await;
3226    let channel_id = ChannelId::from_proto(request.channel_id);
3227
3228    let (collaborators, epoch, version) = db
3229        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3230        .await?;
3231
3232    channel_buffer_updated(
3233        session.connection_id,
3234        collaborators.clone(),
3235        &proto::UpdateChannelBuffer {
3236            channel_id: channel_id.to_proto(),
3237            operations: request.operations,
3238        },
3239        &session.peer,
3240    );
3241
3242    let pool = &*session.connection_pool().await;
3243
3244    let non_collaborators =
3245        pool.channel_connection_ids(channel_id)
3246            .filter_map(|(connection_id, _)| {
3247                if collaborators.contains(&connection_id) {
3248                    None
3249                } else {
3250                    Some(connection_id)
3251                }
3252            });
3253
3254    broadcast(None, non_collaborators, |peer_id| {
3255        session.peer.send(
3256            peer_id,
3257            proto::UpdateChannels {
3258                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3259                    channel_id: channel_id.to_proto(),
3260                    epoch: epoch as u64,
3261                    version: version.clone(),
3262                }],
3263                ..Default::default()
3264            },
3265        )
3266    });
3267
3268    Ok(())
3269}
3270
3271/// Rejoin the channel notes after a connection blip
3272async fn rejoin_channel_buffers(
3273    request: proto::RejoinChannelBuffers,
3274    response: Response<proto::RejoinChannelBuffers>,
3275    session: Session,
3276) -> Result<()> {
3277    let db = session.db().await;
3278    let buffers = db
3279        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3280        .await?;
3281
3282    for rejoined_buffer in &buffers {
3283        let collaborators_to_notify = rejoined_buffer
3284            .buffer
3285            .collaborators
3286            .iter()
3287            .filter_map(|c| Some(c.peer_id?.into()));
3288        channel_buffer_updated(
3289            session.connection_id,
3290            collaborators_to_notify,
3291            &proto::UpdateChannelBufferCollaborators {
3292                channel_id: rejoined_buffer.buffer.channel_id,
3293                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3294            },
3295            &session.peer,
3296        );
3297    }
3298
3299    response.send(proto::RejoinChannelBuffersResponse {
3300        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3301    })?;
3302
3303    Ok(())
3304}
3305
3306/// Stop editing the channel notes
3307async fn leave_channel_buffer(
3308    request: proto::LeaveChannelBuffer,
3309    response: Response<proto::LeaveChannelBuffer>,
3310    session: Session,
3311) -> Result<()> {
3312    let db = session.db().await;
3313    let channel_id = ChannelId::from_proto(request.channel_id);
3314
3315    let left_buffer = db
3316        .leave_channel_buffer(channel_id, session.connection_id)
3317        .await?;
3318
3319    response.send(Ack {})?;
3320
3321    channel_buffer_updated(
3322        session.connection_id,
3323        left_buffer.connections,
3324        &proto::UpdateChannelBufferCollaborators {
3325            channel_id: channel_id.to_proto(),
3326            collaborators: left_buffer.collaborators,
3327        },
3328        &session.peer,
3329    );
3330
3331    Ok(())
3332}
3333
3334fn channel_buffer_updated<T: EnvelopedMessage>(
3335    sender_id: ConnectionId,
3336    collaborators: impl IntoIterator<Item = ConnectionId>,
3337    message: &T,
3338    peer: &Peer,
3339) {
3340    broadcast(Some(sender_id), collaborators, |peer_id| {
3341        peer.send(peer_id, message.clone())
3342    });
3343}
3344
3345fn send_notifications(
3346    connection_pool: &ConnectionPool,
3347    peer: &Peer,
3348    notifications: db::NotificationBatch,
3349) {
3350    for (user_id, notification) in notifications {
3351        for connection_id in connection_pool.user_connection_ids(user_id) {
3352            if let Err(error) = peer.send(
3353                connection_id,
3354                proto::AddNotification {
3355                    notification: Some(notification.clone()),
3356                },
3357            ) {
3358                tracing::error!(
3359                    "failed to send notification to {:?} {}",
3360                    connection_id,
3361                    error
3362                );
3363            }
3364        }
3365    }
3366}
3367
3368/// Send a message to the channel
3369async fn send_channel_message(
3370    request: proto::SendChannelMessage,
3371    response: Response<proto::SendChannelMessage>,
3372    session: Session,
3373) -> Result<()> {
3374    // Validate the message body.
3375    let body = request.body.trim().to_string();
3376    if body.len() > MAX_MESSAGE_LEN {
3377        return Err(anyhow!("message is too long"))?;
3378    }
3379    if body.is_empty() {
3380        return Err(anyhow!("message can't be blank"))?;
3381    }
3382
3383    // TODO: adjust mentions if body is trimmed
3384
3385    let timestamp = OffsetDateTime::now_utc();
3386    let nonce = request
3387        .nonce
3388        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3389
3390    let channel_id = ChannelId::from_proto(request.channel_id);
3391    let CreatedChannelMessage {
3392        message_id,
3393        participant_connection_ids,
3394        notifications,
3395    } = session
3396        .db()
3397        .await
3398        .create_channel_message(
3399            channel_id,
3400            session.user_id(),
3401            &body,
3402            &request.mentions,
3403            timestamp,
3404            nonce.clone().into(),
3405            request.reply_to_message_id.map(MessageId::from_proto),
3406        )
3407        .await?;
3408
3409    let message = proto::ChannelMessage {
3410        sender_id: session.user_id().to_proto(),
3411        id: message_id.to_proto(),
3412        body,
3413        mentions: request.mentions,
3414        timestamp: timestamp.unix_timestamp() as u64,
3415        nonce: Some(nonce),
3416        reply_to_message_id: request.reply_to_message_id,
3417        edited_at: None,
3418    };
3419    broadcast(
3420        Some(session.connection_id),
3421        participant_connection_ids.clone(),
3422        |connection| {
3423            session.peer.send(
3424                connection,
3425                proto::ChannelMessageSent {
3426                    channel_id: channel_id.to_proto(),
3427                    message: Some(message.clone()),
3428                },
3429            )
3430        },
3431    );
3432    response.send(proto::SendChannelMessageResponse {
3433        message: Some(message),
3434    })?;
3435
3436    let pool = &*session.connection_pool().await;
3437    let non_participants =
3438        pool.channel_connection_ids(channel_id)
3439            .filter_map(|(connection_id, _)| {
3440                if participant_connection_ids.contains(&connection_id) {
3441                    None
3442                } else {
3443                    Some(connection_id)
3444                }
3445            });
3446    broadcast(None, non_participants, |peer_id| {
3447        session.peer.send(
3448            peer_id,
3449            proto::UpdateChannels {
3450                latest_channel_message_ids: vec![proto::ChannelMessageId {
3451                    channel_id: channel_id.to_proto(),
3452                    message_id: message_id.to_proto(),
3453                }],
3454                ..Default::default()
3455            },
3456        )
3457    });
3458    send_notifications(pool, &session.peer, notifications);
3459
3460    Ok(())
3461}
3462
3463/// Delete a channel message
3464async fn remove_channel_message(
3465    request: proto::RemoveChannelMessage,
3466    response: Response<proto::RemoveChannelMessage>,
3467    session: Session,
3468) -> Result<()> {
3469    let channel_id = ChannelId::from_proto(request.channel_id);
3470    let message_id = MessageId::from_proto(request.message_id);
3471    let (connection_ids, existing_notification_ids) = session
3472        .db()
3473        .await
3474        .remove_channel_message(channel_id, message_id, session.user_id())
3475        .await?;
3476
3477    broadcast(
3478        Some(session.connection_id),
3479        connection_ids,
3480        move |connection| {
3481            session.peer.send(connection, request.clone())?;
3482
3483            for notification_id in &existing_notification_ids {
3484                session.peer.send(
3485                    connection,
3486                    proto::DeleteNotification {
3487                        notification_id: (*notification_id).to_proto(),
3488                    },
3489                )?;
3490            }
3491
3492            Ok(())
3493        },
3494    );
3495    response.send(proto::Ack {})?;
3496    Ok(())
3497}
3498
3499async fn update_channel_message(
3500    request: proto::UpdateChannelMessage,
3501    response: Response<proto::UpdateChannelMessage>,
3502    session: Session,
3503) -> Result<()> {
3504    let channel_id = ChannelId::from_proto(request.channel_id);
3505    let message_id = MessageId::from_proto(request.message_id);
3506    let updated_at = OffsetDateTime::now_utc();
3507    let UpdatedChannelMessage {
3508        message_id,
3509        participant_connection_ids,
3510        notifications,
3511        reply_to_message_id,
3512        timestamp,
3513        deleted_mention_notification_ids,
3514        updated_mention_notifications,
3515    } = session
3516        .db()
3517        .await
3518        .update_channel_message(
3519            channel_id,
3520            message_id,
3521            session.user_id(),
3522            request.body.as_str(),
3523            &request.mentions,
3524            updated_at,
3525        )
3526        .await?;
3527
3528    let nonce = request
3529        .nonce
3530        .clone()
3531        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3532
3533    let message = proto::ChannelMessage {
3534        sender_id: session.user_id().to_proto(),
3535        id: message_id.to_proto(),
3536        body: request.body.clone(),
3537        mentions: request.mentions.clone(),
3538        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3539        nonce: Some(nonce),
3540        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3541        edited_at: Some(updated_at.unix_timestamp() as u64),
3542    };
3543
3544    response.send(proto::Ack {})?;
3545
3546    let pool = &*session.connection_pool().await;
3547    broadcast(
3548        Some(session.connection_id),
3549        participant_connection_ids,
3550        |connection| {
3551            session.peer.send(
3552                connection,
3553                proto::ChannelMessageUpdate {
3554                    channel_id: channel_id.to_proto(),
3555                    message: Some(message.clone()),
3556                },
3557            )?;
3558
3559            for notification_id in &deleted_mention_notification_ids {
3560                session.peer.send(
3561                    connection,
3562                    proto::DeleteNotification {
3563                        notification_id: (*notification_id).to_proto(),
3564                    },
3565                )?;
3566            }
3567
3568            for notification in &updated_mention_notifications {
3569                session.peer.send(
3570                    connection,
3571                    proto::UpdateNotification {
3572                        notification: Some(notification.clone()),
3573                    },
3574                )?;
3575            }
3576
3577            Ok(())
3578        },
3579    );
3580
3581    send_notifications(pool, &session.peer, notifications);
3582
3583    Ok(())
3584}
3585
3586/// Mark a channel message as read
3587async fn acknowledge_channel_message(
3588    request: proto::AckChannelMessage,
3589    session: Session,
3590) -> Result<()> {
3591    let channel_id = ChannelId::from_proto(request.channel_id);
3592    let message_id = MessageId::from_proto(request.message_id);
3593    let notifications = session
3594        .db()
3595        .await
3596        .observe_channel_message(channel_id, session.user_id(), message_id)
3597        .await?;
3598    send_notifications(
3599        &*session.connection_pool().await,
3600        &session.peer,
3601        notifications,
3602    );
3603    Ok(())
3604}
3605
3606/// Mark a buffer version as synced
3607async fn acknowledge_buffer_version(
3608    request: proto::AckBufferOperation,
3609    session: Session,
3610) -> Result<()> {
3611    let buffer_id = BufferId::from_proto(request.buffer_id);
3612    session
3613        .db()
3614        .await
3615        .observe_buffer_version(
3616            buffer_id,
3617            session.user_id(),
3618            request.epoch as i32,
3619            &request.version,
3620        )
3621        .await?;
3622    Ok(())
3623}
3624
3625async fn count_language_model_tokens(
3626    request: proto::CountLanguageModelTokens,
3627    response: Response<proto::CountLanguageModelTokens>,
3628    session: Session,
3629    config: &Config,
3630) -> Result<()> {
3631    authorize_access_to_legacy_llm_endpoints(&session).await?;
3632
3633    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3634        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3635        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3636    };
3637
3638    session
3639        .app_state
3640        .rate_limiter
3641        .check(&*rate_limit, session.user_id())
3642        .await?;
3643
3644    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3645        Some(proto::LanguageModelProvider::Google) => {
3646            let api_key = config
3647                .google_ai_api_key
3648                .as_ref()
3649                .context("no Google AI API key configured on the server")?;
3650            google_ai::count_tokens(
3651                session.http_client.as_ref(),
3652                google_ai::API_URL,
3653                api_key,
3654                serde_json::from_str(&request.request)?,
3655            )
3656            .await?
3657        }
3658        _ => return Err(anyhow!("unsupported provider"))?,
3659    };
3660
3661    response.send(proto::CountLanguageModelTokensResponse {
3662        token_count: result.total_tokens as u32,
3663    })?;
3664
3665    Ok(())
3666}
3667
3668struct ZedProCountLanguageModelTokensRateLimit;
3669
3670impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3671    fn capacity(&self) -> usize {
3672        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3673            .ok()
3674            .and_then(|v| v.parse().ok())
3675            .unwrap_or(600) // Picked arbitrarily
3676    }
3677
3678    fn refill_duration(&self) -> chrono::Duration {
3679        chrono::Duration::hours(1)
3680    }
3681
3682    fn db_name(&self) -> &'static str {
3683        "zed-pro:count-language-model-tokens"
3684    }
3685}
3686
3687struct FreeCountLanguageModelTokensRateLimit;
3688
3689impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3690    fn capacity(&self) -> usize {
3691        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3692            .ok()
3693            .and_then(|v| v.parse().ok())
3694            .unwrap_or(600 / 10) // Picked arbitrarily
3695    }
3696
3697    fn refill_duration(&self) -> chrono::Duration {
3698        chrono::Duration::hours(1)
3699    }
3700
3701    fn db_name(&self) -> &'static str {
3702        "free:count-language-model-tokens"
3703    }
3704}
3705
3706struct ZedProComputeEmbeddingsRateLimit;
3707
3708impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3709    fn capacity(&self) -> usize {
3710        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3711            .ok()
3712            .and_then(|v| v.parse().ok())
3713            .unwrap_or(5000) // Picked arbitrarily
3714    }
3715
3716    fn refill_duration(&self) -> chrono::Duration {
3717        chrono::Duration::hours(1)
3718    }
3719
3720    fn db_name(&self) -> &'static str {
3721        "zed-pro:compute-embeddings"
3722    }
3723}
3724
3725struct FreeComputeEmbeddingsRateLimit;
3726
3727impl RateLimit for FreeComputeEmbeddingsRateLimit {
3728    fn capacity(&self) -> usize {
3729        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3730            .ok()
3731            .and_then(|v| v.parse().ok())
3732            .unwrap_or(5000 / 10) // Picked arbitrarily
3733    }
3734
3735    fn refill_duration(&self) -> chrono::Duration {
3736        chrono::Duration::hours(1)
3737    }
3738
3739    fn db_name(&self) -> &'static str {
3740        "free:compute-embeddings"
3741    }
3742}
3743
3744async fn compute_embeddings(
3745    request: proto::ComputeEmbeddings,
3746    response: Response<proto::ComputeEmbeddings>,
3747    session: Session,
3748    api_key: Option<Arc<str>>,
3749) -> Result<()> {
3750    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3751    authorize_access_to_legacy_llm_endpoints(&session).await?;
3752
3753    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3754        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3755        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3756    };
3757
3758    session
3759        .app_state
3760        .rate_limiter
3761        .check(&*rate_limit, session.user_id())
3762        .await?;
3763
3764    let embeddings = match request.model.as_str() {
3765        "openai/text-embedding-3-small" => {
3766            open_ai::embed(
3767                session.http_client.as_ref(),
3768                OPEN_AI_API_URL,
3769                &api_key,
3770                OpenAiEmbeddingModel::TextEmbedding3Small,
3771                request.texts.iter().map(|text| text.as_str()),
3772            )
3773            .await?
3774        }
3775        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3776    };
3777
3778    let embeddings = request
3779        .texts
3780        .iter()
3781        .map(|text| {
3782            let mut hasher = sha2::Sha256::new();
3783            hasher.update(text.as_bytes());
3784            let result = hasher.finalize();
3785            result.to_vec()
3786        })
3787        .zip(
3788            embeddings
3789                .data
3790                .into_iter()
3791                .map(|embedding| embedding.embedding),
3792        )
3793        .collect::<HashMap<_, _>>();
3794
3795    let db = session.db().await;
3796    db.save_embeddings(&request.model, &embeddings)
3797        .await
3798        .context("failed to save embeddings")
3799        .trace_err();
3800
3801    response.send(proto::ComputeEmbeddingsResponse {
3802        embeddings: embeddings
3803            .into_iter()
3804            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3805            .collect(),
3806    })?;
3807    Ok(())
3808}
3809
3810async fn get_cached_embeddings(
3811    request: proto::GetCachedEmbeddings,
3812    response: Response<proto::GetCachedEmbeddings>,
3813    session: Session,
3814) -> Result<()> {
3815    authorize_access_to_legacy_llm_endpoints(&session).await?;
3816
3817    let db = session.db().await;
3818    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3819
3820    response.send(proto::GetCachedEmbeddingsResponse {
3821        embeddings: embeddings
3822            .into_iter()
3823            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3824            .collect(),
3825    })?;
3826    Ok(())
3827}
3828
3829/// This is leftover from before the LLM service.
3830///
3831/// The endpoints protected by this check will be moved there eventually.
3832async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3833    if session.is_staff() {
3834        Ok(())
3835    } else {
3836        Err(anyhow!("permission denied"))?
3837    }
3838}
3839
3840/// Get a Supermaven API key for the user
3841async fn get_supermaven_api_key(
3842    _request: proto::GetSupermavenApiKey,
3843    response: Response<proto::GetSupermavenApiKey>,
3844    session: Session,
3845) -> Result<()> {
3846    let user_id: String = session.user_id().to_string();
3847    if !session.is_staff() {
3848        return Err(anyhow!("supermaven not enabled for this account"))?;
3849    }
3850
3851    let email = session
3852        .email()
3853        .ok_or_else(|| anyhow!("user must have an email"))?;
3854
3855    let supermaven_admin_api = session
3856        .supermaven_client
3857        .as_ref()
3858        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3859
3860    let result = supermaven_admin_api
3861        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3862        .await?;
3863
3864    response.send(proto::GetSupermavenApiKeyResponse {
3865        api_key: result.api_key,
3866    })?;
3867
3868    Ok(())
3869}
3870
3871/// Start receiving chat updates for a channel
3872async fn join_channel_chat(
3873    request: proto::JoinChannelChat,
3874    response: Response<proto::JoinChannelChat>,
3875    session: Session,
3876) -> Result<()> {
3877    let channel_id = ChannelId::from_proto(request.channel_id);
3878
3879    let db = session.db().await;
3880    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3881        .await?;
3882    let messages = db
3883        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3884        .await?;
3885    response.send(proto::JoinChannelChatResponse {
3886        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3887        messages,
3888    })?;
3889    Ok(())
3890}
3891
3892/// Stop receiving chat updates for a channel
3893async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3894    let channel_id = ChannelId::from_proto(request.channel_id);
3895    session
3896        .db()
3897        .await
3898        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3899        .await?;
3900    Ok(())
3901}
3902
3903/// Retrieve the chat history for a channel
3904async fn get_channel_messages(
3905    request: proto::GetChannelMessages,
3906    response: Response<proto::GetChannelMessages>,
3907    session: Session,
3908) -> Result<()> {
3909    let channel_id = ChannelId::from_proto(request.channel_id);
3910    let messages = session
3911        .db()
3912        .await
3913        .get_channel_messages(
3914            channel_id,
3915            session.user_id(),
3916            MESSAGE_COUNT_PER_PAGE,
3917            Some(MessageId::from_proto(request.before_message_id)),
3918        )
3919        .await?;
3920    response.send(proto::GetChannelMessagesResponse {
3921        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3922        messages,
3923    })?;
3924    Ok(())
3925}
3926
3927/// Retrieve specific chat messages
3928async fn get_channel_messages_by_id(
3929    request: proto::GetChannelMessagesById,
3930    response: Response<proto::GetChannelMessagesById>,
3931    session: Session,
3932) -> Result<()> {
3933    let message_ids = request
3934        .message_ids
3935        .iter()
3936        .map(|id| MessageId::from_proto(*id))
3937        .collect::<Vec<_>>();
3938    let messages = session
3939        .db()
3940        .await
3941        .get_channel_messages_by_id(session.user_id(), &message_ids)
3942        .await?;
3943    response.send(proto::GetChannelMessagesResponse {
3944        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3945        messages,
3946    })?;
3947    Ok(())
3948}
3949
3950/// Retrieve the current users notifications
3951async fn get_notifications(
3952    request: proto::GetNotifications,
3953    response: Response<proto::GetNotifications>,
3954    session: Session,
3955) -> Result<()> {
3956    let notifications = session
3957        .db()
3958        .await
3959        .get_notifications(
3960            session.user_id(),
3961            NOTIFICATION_COUNT_PER_PAGE,
3962            request.before_id.map(db::NotificationId::from_proto),
3963        )
3964        .await?;
3965    response.send(proto::GetNotificationsResponse {
3966        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3967        notifications,
3968    })?;
3969    Ok(())
3970}
3971
3972/// Mark notifications as read
3973async fn mark_notification_as_read(
3974    request: proto::MarkNotificationRead,
3975    response: Response<proto::MarkNotificationRead>,
3976    session: Session,
3977) -> Result<()> {
3978    let database = &session.db().await;
3979    let notifications = database
3980        .mark_notification_as_read_by_id(
3981            session.user_id(),
3982            NotificationId::from_proto(request.notification_id),
3983        )
3984        .await?;
3985    send_notifications(
3986        &*session.connection_pool().await,
3987        &session.peer,
3988        notifications,
3989    );
3990    response.send(proto::Ack {})?;
3991    Ok(())
3992}
3993
3994/// Get the current users information
3995async fn get_private_user_info(
3996    _request: proto::GetPrivateUserInfo,
3997    response: Response<proto::GetPrivateUserInfo>,
3998    session: Session,
3999) -> Result<()> {
4000    let db = session.db().await;
4001
4002    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4003    let user = db
4004        .get_user_by_id(session.user_id())
4005        .await?
4006        .ok_or_else(|| anyhow!("user not found"))?;
4007    let flags = db.get_user_flags(session.user_id()).await?;
4008
4009    response.send(proto::GetPrivateUserInfoResponse {
4010        metrics_id,
4011        staff: user.admin,
4012        flags,
4013        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4014    })?;
4015    Ok(())
4016}
4017
4018/// Accept the terms of service (tos) on behalf of the current user
4019async fn accept_terms_of_service(
4020    _request: proto::AcceptTermsOfService,
4021    response: Response<proto::AcceptTermsOfService>,
4022    session: Session,
4023) -> Result<()> {
4024    let db = session.db().await;
4025
4026    let accepted_tos_at = Utc::now();
4027    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4028        .await?;
4029
4030    response.send(proto::AcceptTermsOfServiceResponse {
4031        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4032    })?;
4033    Ok(())
4034}
4035
4036/// The minimum account age an account must have in order to use the LLM service.
4037const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4038
4039async fn get_llm_api_token(
4040    _request: proto::GetLlmToken,
4041    response: Response<proto::GetLlmToken>,
4042    session: Session,
4043) -> Result<()> {
4044    let db = session.db().await;
4045
4046    let flags = db.get_user_flags(session.user_id()).await?;
4047    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4048    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4049    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4050
4051    if !session.is_staff() && !has_language_models_feature_flag {
4052        Err(anyhow!("permission denied"))?
4053    }
4054
4055    let user_id = session.user_id();
4056    let user = db
4057        .get_user_by_id(user_id)
4058        .await?
4059        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4060
4061    if user.accepted_tos_at.is_none() {
4062        Err(anyhow!("terms of service not accepted"))?
4063    }
4064
4065    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4066
4067    let bypass_account_age_check =
4068        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4069    if !bypass_account_age_check {
4070        let mut account_created_at = user.created_at;
4071        if let Some(github_created_at) = user.github_user_created_at {
4072            account_created_at = account_created_at.min(github_created_at);
4073        }
4074        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4075            Err(anyhow!("account too young"))?
4076        }
4077    }
4078
4079    let billing_preferences = db.get_billing_preferences(user.id).await?;
4080
4081    let token = LlmTokenClaims::create(
4082        &user,
4083        session.is_staff(),
4084        billing_preferences,
4085        has_llm_closed_beta_feature_flag,
4086        has_predict_edits_feature_flag,
4087        has_llm_subscription,
4088        session.current_plan(&db).await?,
4089        session.system_id.clone(),
4090        &session.app_state.config,
4091    )?;
4092    response.send(proto::GetLlmTokenResponse { token })?;
4093    Ok(())
4094}
4095
4096fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4097    let message = match message {
4098        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4099        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4100        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4101        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4102        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4103            code: frame.code.into(),
4104            reason: frame.reason,
4105        })),
4106        // We should never receive a frame while reading the message, according
4107        // to the `tungstenite` maintainers:
4108        //
4109        // > It cannot occur when you read messages from the WebSocket, but it
4110        // > can be used when you want to send the raw frames (e.g. you want to
4111        // > send the frames to the WebSocket without composing the full message first).
4112        // >
4113        // > — https://github.com/snapview/tungstenite-rs/issues/268
4114        TungsteniteMessage::Frame(_) => {
4115            bail!("received an unexpected frame while reading the message")
4116        }
4117    };
4118
4119    Ok(message)
4120}
4121
4122fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4123    match message {
4124        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4125        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4126        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4127        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4128        AxumMessage::Close(frame) => {
4129            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4130                code: frame.code.into(),
4131                reason: frame.reason,
4132            }))
4133        }
4134    }
4135}
4136
4137fn notify_membership_updated(
4138    connection_pool: &mut ConnectionPool,
4139    result: MembershipUpdated,
4140    user_id: UserId,
4141    peer: &Peer,
4142) {
4143    for membership in &result.new_channels.channel_memberships {
4144        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4145    }
4146    for channel_id in &result.removed_channels {
4147        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4148    }
4149
4150    let user_channels_update = proto::UpdateUserChannels {
4151        channel_memberships: result
4152            .new_channels
4153            .channel_memberships
4154            .iter()
4155            .map(|cm| proto::ChannelMembership {
4156                channel_id: cm.channel_id.to_proto(),
4157                role: cm.role.into(),
4158            })
4159            .collect(),
4160        ..Default::default()
4161    };
4162
4163    let mut update = build_channels_update(result.new_channels);
4164    update.delete_channels = result
4165        .removed_channels
4166        .into_iter()
4167        .map(|id| id.to_proto())
4168        .collect();
4169    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4170
4171    for connection_id in connection_pool.user_connection_ids(user_id) {
4172        peer.send(connection_id, user_channels_update.clone())
4173            .trace_err();
4174        peer.send(connection_id, update.clone()).trace_err();
4175    }
4176}
4177
4178fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4179    proto::UpdateUserChannels {
4180        channel_memberships: channels
4181            .channel_memberships
4182            .iter()
4183            .map(|m| proto::ChannelMembership {
4184                channel_id: m.channel_id.to_proto(),
4185                role: m.role.into(),
4186            })
4187            .collect(),
4188        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4189        observed_channel_message_id: channels.observed_channel_messages.clone(),
4190    }
4191}
4192
4193fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4194    let mut update = proto::UpdateChannels::default();
4195
4196    for channel in channels.channels {
4197        update.channels.push(channel.to_proto());
4198    }
4199
4200    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4201    update.latest_channel_message_ids = channels.latest_channel_messages;
4202
4203    for (channel_id, participants) in channels.channel_participants {
4204        update
4205            .channel_participants
4206            .push(proto::ChannelParticipants {
4207                channel_id: channel_id.to_proto(),
4208                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4209            });
4210    }
4211
4212    for channel in channels.invited_channels {
4213        update.channel_invitations.push(channel.to_proto());
4214    }
4215
4216    update
4217}
4218
4219fn build_initial_contacts_update(
4220    contacts: Vec<db::Contact>,
4221    pool: &ConnectionPool,
4222) -> proto::UpdateContacts {
4223    let mut update = proto::UpdateContacts::default();
4224
4225    for contact in contacts {
4226        match contact {
4227            db::Contact::Accepted { user_id, busy } => {
4228                update.contacts.push(contact_for_user(user_id, busy, pool));
4229            }
4230            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4231            db::Contact::Incoming { user_id } => {
4232                update
4233                    .incoming_requests
4234                    .push(proto::IncomingContactRequest {
4235                        requester_id: user_id.to_proto(),
4236                    })
4237            }
4238        }
4239    }
4240
4241    update
4242}
4243
4244fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4245    proto::Contact {
4246        user_id: user_id.to_proto(),
4247        online: pool.is_user_online(user_id),
4248        busy,
4249    }
4250}
4251
4252fn room_updated(room: &proto::Room, peer: &Peer) {
4253    broadcast(
4254        None,
4255        room.participants
4256            .iter()
4257            .filter_map(|participant| Some(participant.peer_id?.into())),
4258        |peer_id| {
4259            peer.send(
4260                peer_id,
4261                proto::RoomUpdated {
4262                    room: Some(room.clone()),
4263                },
4264            )
4265        },
4266    );
4267}
4268
4269fn channel_updated(
4270    channel: &db::channel::Model,
4271    room: &proto::Room,
4272    peer: &Peer,
4273    pool: &ConnectionPool,
4274) {
4275    let participants = room
4276        .participants
4277        .iter()
4278        .map(|p| p.user_id)
4279        .collect::<Vec<_>>();
4280
4281    broadcast(
4282        None,
4283        pool.channel_connection_ids(channel.root_id())
4284            .filter_map(|(channel_id, role)| {
4285                role.can_see_channel(channel.visibility)
4286                    .then_some(channel_id)
4287            }),
4288        |peer_id| {
4289            peer.send(
4290                peer_id,
4291                proto::UpdateChannels {
4292                    channel_participants: vec![proto::ChannelParticipants {
4293                        channel_id: channel.id.to_proto(),
4294                        participant_user_ids: participants.clone(),
4295                    }],
4296                    ..Default::default()
4297                },
4298            )
4299        },
4300    );
4301}
4302
4303async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4304    let db = session.db().await;
4305
4306    let contacts = db.get_contacts(user_id).await?;
4307    let busy = db.is_user_busy(user_id).await?;
4308
4309    let pool = session.connection_pool().await;
4310    let updated_contact = contact_for_user(user_id, busy, &pool);
4311    for contact in contacts {
4312        if let db::Contact::Accepted {
4313            user_id: contact_user_id,
4314            ..
4315        } = contact
4316        {
4317            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4318                session
4319                    .peer
4320                    .send(
4321                        contact_conn_id,
4322                        proto::UpdateContacts {
4323                            contacts: vec![updated_contact.clone()],
4324                            remove_contacts: Default::default(),
4325                            incoming_requests: Default::default(),
4326                            remove_incoming_requests: Default::default(),
4327                            outgoing_requests: Default::default(),
4328                            remove_outgoing_requests: Default::default(),
4329                        },
4330                    )
4331                    .trace_err();
4332            }
4333        }
4334    }
4335    Ok(())
4336}
4337
4338async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4339    let mut contacts_to_update = HashSet::default();
4340
4341    let room_id;
4342    let canceled_calls_to_user_ids;
4343    let livekit_room;
4344    let delete_livekit_room;
4345    let room;
4346    let channel;
4347
4348    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4349        contacts_to_update.insert(session.user_id());
4350
4351        for project in left_room.left_projects.values() {
4352            project_left(project, session);
4353        }
4354
4355        room_id = RoomId::from_proto(left_room.room.id);
4356        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4357        livekit_room = mem::take(&mut left_room.room.livekit_room);
4358        delete_livekit_room = left_room.deleted;
4359        room = mem::take(&mut left_room.room);
4360        channel = mem::take(&mut left_room.channel);
4361
4362        room_updated(&room, &session.peer);
4363    } else {
4364        return Ok(());
4365    }
4366
4367    if let Some(channel) = channel {
4368        channel_updated(
4369            &channel,
4370            &room,
4371            &session.peer,
4372            &*session.connection_pool().await,
4373        );
4374    }
4375
4376    {
4377        let pool = session.connection_pool().await;
4378        for canceled_user_id in canceled_calls_to_user_ids {
4379            for connection_id in pool.user_connection_ids(canceled_user_id) {
4380                session
4381                    .peer
4382                    .send(
4383                        connection_id,
4384                        proto::CallCanceled {
4385                            room_id: room_id.to_proto(),
4386                        },
4387                    )
4388                    .trace_err();
4389            }
4390            contacts_to_update.insert(canceled_user_id);
4391        }
4392    }
4393
4394    for contact_user_id in contacts_to_update {
4395        update_user_contacts(contact_user_id, session).await?;
4396    }
4397
4398    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4399        live_kit
4400            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4401            .await
4402            .trace_err();
4403
4404        if delete_livekit_room {
4405            live_kit.delete_room(livekit_room).await.trace_err();
4406        }
4407    }
4408
4409    Ok(())
4410}
4411
4412async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4413    let left_channel_buffers = session
4414        .db()
4415        .await
4416        .leave_channel_buffers(session.connection_id)
4417        .await?;
4418
4419    for left_buffer in left_channel_buffers {
4420        channel_buffer_updated(
4421            session.connection_id,
4422            left_buffer.connections,
4423            &proto::UpdateChannelBufferCollaborators {
4424                channel_id: left_buffer.channel_id.to_proto(),
4425                collaborators: left_buffer.collaborators,
4426            },
4427            &session.peer,
4428        );
4429    }
4430
4431    Ok(())
4432}
4433
4434fn project_left(project: &db::LeftProject, session: &Session) {
4435    for connection_id in &project.connection_ids {
4436        if project.should_unshare {
4437            session
4438                .peer
4439                .send(
4440                    *connection_id,
4441                    proto::UnshareProject {
4442                        project_id: project.id.to_proto(),
4443                    },
4444                )
4445                .trace_err();
4446        } else {
4447            session
4448                .peer
4449                .send(
4450                    *connection_id,
4451                    proto::RemoveProjectCollaborator {
4452                        project_id: project.id.to_proto(),
4453                        peer_id: Some(session.connection_id.into()),
4454                    },
4455                )
4456                .trace_err();
4457        }
4458    }
4459}
4460
4461pub trait ResultExt {
4462    type Ok;
4463
4464    fn trace_err(self) -> Option<Self::Ok>;
4465}
4466
4467impl<T, E> ResultExt for Result<T, E>
4468where
4469    E: std::fmt::Debug,
4470{
4471    type Ok = T;
4472
4473    #[track_caller]
4474    fn trace_err(self) -> Option<T> {
4475        match self {
4476            Ok(value) => Some(value),
4477            Err(error) => {
4478                tracing::error!("{:?}", error);
4479                None
4480            }
4481        }
4482    }
4483}