rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 327            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
 332            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 333            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 337            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 338            .add_request_handler(
 339                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 340            )
 341            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 342            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 343            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 344            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 345            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 346            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 347            .add_message_handler(create_buffer_for_peer)
 348            .add_request_handler(update_buffer)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 353            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 354            .add_request_handler(get_users)
 355            .add_request_handler(fuzzy_search_users)
 356            .add_request_handler(request_contact)
 357            .add_request_handler(remove_contact)
 358            .add_request_handler(respond_to_contact_request)
 359            .add_message_handler(subscribe_to_channels)
 360            .add_request_handler(create_channel)
 361            .add_request_handler(delete_channel)
 362            .add_request_handler(invite_channel_member)
 363            .add_request_handler(remove_channel_member)
 364            .add_request_handler(set_channel_member_role)
 365            .add_request_handler(set_channel_visibility)
 366            .add_request_handler(rename_channel)
 367            .add_request_handler(join_channel_buffer)
 368            .add_request_handler(leave_channel_buffer)
 369            .add_message_handler(update_channel_buffer)
 370            .add_request_handler(rejoin_channel_buffers)
 371            .add_request_handler(get_channel_members)
 372            .add_request_handler(respond_to_channel_invite)
 373            .add_request_handler(join_channel)
 374            .add_request_handler(join_channel_chat)
 375            .add_message_handler(leave_channel_chat)
 376            .add_request_handler(send_channel_message)
 377            .add_request_handler(remove_channel_message)
 378            .add_request_handler(update_channel_message)
 379            .add_request_handler(get_channel_messages)
 380            .add_request_handler(get_channel_messages_by_id)
 381            .add_request_handler(get_notifications)
 382            .add_request_handler(mark_notification_as_read)
 383            .add_request_handler(move_channel)
 384            .add_request_handler(follow)
 385            .add_message_handler(unfollow)
 386            .add_message_handler(update_followers)
 387            .add_request_handler(get_private_user_info)
 388            .add_request_handler(get_llm_api_token)
 389            .add_request_handler(accept_terms_of_service)
 390            .add_message_handler(acknowledge_channel_message)
 391            .add_message_handler(acknowledge_buffer_version)
 392            .add_request_handler(get_supermaven_api_key)
 393            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 394            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 395            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 397            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 398            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 399            .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
 400            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 401            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 402            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 403            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 404            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 405            .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
 406            .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
 407            .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
 408            .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
 409            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 410            .add_message_handler(update_context)
 411            .add_request_handler({
 412                let app_state = app_state.clone();
 413                move |request, response, session| {
 414                    let app_state = app_state.clone();
 415                    async move {
 416                        count_language_model_tokens(request, response, session, &app_state.config)
 417                            .await
 418                    }
 419                }
 420            })
 421            .add_request_handler(get_cached_embeddings)
 422            .add_request_handler({
 423                let app_state = app_state.clone();
 424                move |request, response, session| {
 425                    compute_embeddings(
 426                        request,
 427                        response,
 428                        session,
 429                        app_state.config.openai_api_key.clone(),
 430                    )
 431                }
 432            });
 433
 434        Arc::new(server)
 435    }
 436
 437    pub async fn start(&self) -> Result<()> {
 438        let server_id = *self.id.lock();
 439        let app_state = self.app_state.clone();
 440        let peer = self.peer.clone();
 441        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 442        let pool = self.connection_pool.clone();
 443        let livekit_client = self.app_state.livekit_client.clone();
 444
 445        let span = info_span!("start server");
 446        self.app_state.executor.spawn_detached(
 447            async move {
 448                tracing::info!("waiting for cleanup timeout");
 449                timeout.await;
 450                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 451                if let Some((room_ids, channel_ids)) = app_state
 452                    .db
 453                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 454                    .await
 455                    .trace_err()
 456                {
 457                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 458                    tracing::info!(
 459                        stale_channel_buffer_count = channel_ids.len(),
 460                        "retrieved stale channel buffers"
 461                    );
 462
 463                    for channel_id in channel_ids {
 464                        if let Some(refreshed_channel_buffer) = app_state
 465                            .db
 466                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 467                            .await
 468                            .trace_err()
 469                        {
 470                            for connection_id in refreshed_channel_buffer.connection_ids {
 471                                peer.send(
 472                                    connection_id,
 473                                    proto::UpdateChannelBufferCollaborators {
 474                                        channel_id: channel_id.to_proto(),
 475                                        collaborators: refreshed_channel_buffer
 476                                            .collaborators
 477                                            .clone(),
 478                                    },
 479                                )
 480                                .trace_err();
 481                            }
 482                        }
 483                    }
 484
 485                    for room_id in room_ids {
 486                        let mut contacts_to_update = HashSet::default();
 487                        let mut canceled_calls_to_user_ids = Vec::new();
 488                        let mut livekit_room = String::new();
 489                        let mut delete_livekit_room = false;
 490
 491                        if let Some(mut refreshed_room) = app_state
 492                            .db
 493                            .clear_stale_room_participants(room_id, server_id)
 494                            .await
 495                            .trace_err()
 496                        {
 497                            tracing::info!(
 498                                room_id = room_id.0,
 499                                new_participant_count = refreshed_room.room.participants.len(),
 500                                "refreshed room"
 501                            );
 502                            room_updated(&refreshed_room.room, &peer);
 503                            if let Some(channel) = refreshed_room.channel.as_ref() {
 504                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 505                            }
 506                            contacts_to_update
 507                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 508                            contacts_to_update
 509                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 510                            canceled_calls_to_user_ids =
 511                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 512                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 513                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 514                        }
 515
 516                        {
 517                            let pool = pool.lock();
 518                            for canceled_user_id in canceled_calls_to_user_ids {
 519                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 520                                    peer.send(
 521                                        connection_id,
 522                                        proto::CallCanceled {
 523                                            room_id: room_id.to_proto(),
 524                                        },
 525                                    )
 526                                    .trace_err();
 527                                }
 528                            }
 529                        }
 530
 531                        for user_id in contacts_to_update {
 532                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 533                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 534                            if let Some((busy, contacts)) = busy.zip(contacts) {
 535                                let pool = pool.lock();
 536                                let updated_contact = contact_for_user(user_id, busy, &pool);
 537                                for contact in contacts {
 538                                    if let db::Contact::Accepted {
 539                                        user_id: contact_user_id,
 540                                        ..
 541                                    } = contact
 542                                    {
 543                                        for contact_conn_id in
 544                                            pool.user_connection_ids(contact_user_id)
 545                                        {
 546                                            peer.send(
 547                                                contact_conn_id,
 548                                                proto::UpdateContacts {
 549                                                    contacts: vec![updated_contact.clone()],
 550                                                    remove_contacts: Default::default(),
 551                                                    incoming_requests: Default::default(),
 552                                                    remove_incoming_requests: Default::default(),
 553                                                    outgoing_requests: Default::default(),
 554                                                    remove_outgoing_requests: Default::default(),
 555                                                },
 556                                            )
 557                                            .trace_err();
 558                                        }
 559                                    }
 560                                }
 561                            }
 562                        }
 563
 564                        if let Some(live_kit) = livekit_client.as_ref() {
 565                            if delete_livekit_room {
 566                                live_kit.delete_room(livekit_room).await.trace_err();
 567                            }
 568                        }
 569                    }
 570                }
 571
 572                app_state
 573                    .db
 574                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 575                    .await
 576                    .trace_err();
 577            }
 578            .instrument(span),
 579        );
 580        Ok(())
 581    }
 582
 583    pub fn teardown(&self) {
 584        self.peer.teardown();
 585        self.connection_pool.lock().reset();
 586        let _ = self.teardown.send(true);
 587    }
 588
 589    #[cfg(test)]
 590    pub fn reset(&self, id: ServerId) {
 591        self.teardown();
 592        *self.id.lock() = id;
 593        self.peer.reset(id.0 as u32);
 594        let _ = self.teardown.send(false);
 595    }
 596
 597    #[cfg(test)]
 598    pub fn id(&self) -> ServerId {
 599        *self.id.lock()
 600    }
 601
 602    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 603    where
 604        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 605        Fut: 'static + Send + Future<Output = Result<()>>,
 606        M: EnvelopedMessage,
 607    {
 608        let prev_handler = self.handlers.insert(
 609            TypeId::of::<M>(),
 610            Box::new(move |envelope, session| {
 611                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 612                let received_at = envelope.received_at;
 613                tracing::info!("message received");
 614                let start_time = Instant::now();
 615                let future = (handler)(*envelope, session);
 616                async move {
 617                    let result = future.await;
 618                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 619                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 620                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 621                    let payload_type = M::NAME;
 622
 623                    match result {
 624                        Err(error) => {
 625                            tracing::error!(
 626                                ?error,
 627                                total_duration_ms,
 628                                processing_duration_ms,
 629                                queue_duration_ms,
 630                                payload_type,
 631                                "error handling message"
 632                            )
 633                        }
 634                        Ok(()) => tracing::info!(
 635                            total_duration_ms,
 636                            processing_duration_ms,
 637                            queue_duration_ms,
 638                            "finished handling message"
 639                        ),
 640                    }
 641                }
 642                .boxed()
 643            }),
 644        );
 645        if prev_handler.is_some() {
 646            panic!("registered a handler for the same message twice");
 647        }
 648        self
 649    }
 650
 651    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 652    where
 653        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 654        Fut: 'static + Send + Future<Output = Result<()>>,
 655        M: EnvelopedMessage,
 656    {
 657        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 658        self
 659    }
 660
 661    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 662    where
 663        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 664        Fut: Send + Future<Output = Result<()>>,
 665        M: RequestMessage,
 666    {
 667        let handler = Arc::new(handler);
 668        self.add_handler(move |envelope, session| {
 669            let receipt = envelope.receipt();
 670            let handler = handler.clone();
 671            async move {
 672                let peer = session.peer.clone();
 673                let responded = Arc::new(AtomicBool::default());
 674                let response = Response {
 675                    peer: peer.clone(),
 676                    responded: responded.clone(),
 677                    receipt,
 678                };
 679                match (handler)(envelope.payload, response, session).await {
 680                    Ok(()) => {
 681                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 682                            Ok(())
 683                        } else {
 684                            Err(anyhow!("handler did not send a response"))?
 685                        }
 686                    }
 687                    Err(error) => {
 688                        let proto_err = match &error {
 689                            Error::Internal(err) => err.to_proto(),
 690                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 691                        };
 692                        peer.respond_with_error(receipt, proto_err)?;
 693                        Err(error)
 694                    }
 695                }
 696            }
 697        })
 698    }
 699
 700    #[allow(clippy::too_many_arguments)]
 701    pub fn handle_connection(
 702        self: &Arc<Self>,
 703        connection: Connection,
 704        address: String,
 705        principal: Principal,
 706        zed_version: ZedVersion,
 707        geoip_country_code: Option<String>,
 708        system_id: Option<String>,
 709        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 710        executor: Executor,
 711    ) -> impl Future<Output = ()> {
 712        let this = self.clone();
 713        let span = info_span!("handle connection", %address,
 714            connection_id=field::Empty,
 715            user_id=field::Empty,
 716            login=field::Empty,
 717            impersonator=field::Empty,
 718            geoip_country_code=field::Empty
 719        );
 720        principal.update_span(&span);
 721        if let Some(country_code) = geoip_country_code.as_ref() {
 722            span.record("geoip_country_code", country_code);
 723        }
 724
 725        let mut teardown = self.teardown.subscribe();
 726        async move {
 727            if *teardown.borrow() {
 728                tracing::error!("server is tearing down");
 729                return
 730            }
 731            let (connection_id, handle_io, mut incoming_rx) = this
 732                .peer
 733                .add_connection(connection, {
 734                    let executor = executor.clone();
 735                    move |duration| executor.sleep(duration)
 736                });
 737            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 738
 739            tracing::info!("connection opened");
 740
 741            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 742            let http_client = match ReqwestClient::user_agent(&user_agent) {
 743                Ok(http_client) => Arc::new(http_client),
 744                Err(error) => {
 745                    tracing::error!(?error, "failed to create HTTP client");
 746                    return;
 747                }
 748            };
 749
 750            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 751                    supermaven_admin_api_key.to_string(),
 752                    http_client.clone(),
 753                )));
 754
 755            let session = Session {
 756                principal: principal.clone(),
 757                connection_id,
 758                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 759                peer: this.peer.clone(),
 760                connection_pool: this.connection_pool.clone(),
 761                app_state: this.app_state.clone(),
 762                http_client,
 763                geoip_country_code,
 764                system_id,
 765                _executor: executor.clone(),
 766                supermaven_client,
 767            };
 768
 769            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 770                tracing::error!(?error, "failed to send initial client update");
 771                return;
 772            }
 773
 774            let handle_io = handle_io.fuse();
 775            futures::pin_mut!(handle_io);
 776
 777            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 778            // This prevents deadlocks when e.g., client A performs a request to client B and
 779            // client B performs a request to client A. If both clients stop processing further
 780            // messages until their respective request completes, they won't have a chance to
 781            // respond to the other client's request and cause a deadlock.
 782            //
 783            // This arrangement ensures we will attempt to process earlier messages first, but fall
 784            // back to processing messages arrived later in the spirit of making progress.
 785            let mut foreground_message_handlers = FuturesUnordered::new();
 786            let concurrent_handlers = Arc::new(Semaphore::new(256));
 787            loop {
 788                let next_message = async {
 789                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 790                    let message = incoming_rx.next().await;
 791                    (permit, message)
 792                }.fuse();
 793                futures::pin_mut!(next_message);
 794                futures::select_biased! {
 795                    _ = teardown.changed().fuse() => return,
 796                    result = handle_io => {
 797                        if let Err(error) = result {
 798                            tracing::error!(?error, "error handling I/O");
 799                        }
 800                        break;
 801                    }
 802                    _ = foreground_message_handlers.next() => {}
 803                    next_message = next_message => {
 804                        let (permit, message) = next_message;
 805                        if let Some(message) = message {
 806                            let type_name = message.payload_type_name();
 807                            // note: we copy all the fields from the parent span so we can query them in the logs.
 808                            // (https://github.com/tokio-rs/tracing/issues/2670).
 809                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 810                                user_id=field::Empty,
 811                                login=field::Empty,
 812                                impersonator=field::Empty,
 813                            );
 814                            principal.update_span(&span);
 815                            let span_enter = span.enter();
 816                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 817                                let is_background = message.is_background();
 818                                let handle_message = (handler)(message, session.clone());
 819                                drop(span_enter);
 820
 821                                let handle_message = async move {
 822                                    handle_message.await;
 823                                    drop(permit);
 824                                }.instrument(span);
 825                                if is_background {
 826                                    executor.spawn_detached(handle_message);
 827                                } else {
 828                                    foreground_message_handlers.push(handle_message);
 829                                }
 830                            } else {
 831                                tracing::error!("no message handler");
 832                            }
 833                        } else {
 834                            tracing::info!("connection closed");
 835                            break;
 836                        }
 837                    }
 838                }
 839            }
 840
 841            drop(foreground_message_handlers);
 842            tracing::info!("signing out");
 843            if let Err(error) = connection_lost(session, teardown, executor).await {
 844                tracing::error!(?error, "error signing out");
 845            }
 846
 847        }.instrument(span)
 848    }
 849
 850    async fn send_initial_client_update(
 851        &self,
 852        connection_id: ConnectionId,
 853        principal: &Principal,
 854        zed_version: ZedVersion,
 855        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 856        session: &Session,
 857    ) -> Result<()> {
 858        self.peer.send(
 859            connection_id,
 860            proto::Hello {
 861                peer_id: Some(connection_id.into()),
 862            },
 863        )?;
 864        tracing::info!("sent hello message");
 865        if let Some(send_connection_id) = send_connection_id.take() {
 866            let _ = send_connection_id.send(connection_id);
 867        }
 868
 869        match principal {
 870            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 871                if !user.connected_once {
 872                    self.peer.send(connection_id, proto::ShowContacts {})?;
 873                    self.app_state
 874                        .db
 875                        .set_user_connected_once(user.id, true)
 876                        .await?;
 877                }
 878
 879                update_user_plan(user.id, session).await?;
 880
 881                let contacts = self.app_state.db.get_contacts(user.id).await?;
 882
 883                {
 884                    let mut pool = self.connection_pool.lock();
 885                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 886                    self.peer.send(
 887                        connection_id,
 888                        build_initial_contacts_update(contacts, &pool),
 889                    )?;
 890                }
 891
 892                if should_auto_subscribe_to_channels(zed_version) {
 893                    subscribe_user_to_channels(user.id, session).await?;
 894                }
 895
 896                if let Some(incoming_call) =
 897                    self.app_state.db.incoming_call_for_user(user.id).await?
 898                {
 899                    self.peer.send(connection_id, incoming_call)?;
 900                }
 901
 902                update_user_contacts(user.id, session).await?;
 903            }
 904        }
 905
 906        Ok(())
 907    }
 908
 909    pub async fn invite_code_redeemed(
 910        self: &Arc<Self>,
 911        inviter_id: UserId,
 912        invitee_id: UserId,
 913    ) -> Result<()> {
 914        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 915            if let Some(code) = &user.invite_code {
 916                let pool = self.connection_pool.lock();
 917                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 918                for connection_id in pool.user_connection_ids(inviter_id) {
 919                    self.peer.send(
 920                        connection_id,
 921                        proto::UpdateContacts {
 922                            contacts: vec![invitee_contact.clone()],
 923                            ..Default::default()
 924                        },
 925                    )?;
 926                    self.peer.send(
 927                        connection_id,
 928                        proto::UpdateInviteInfo {
 929                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 930                            count: user.invite_count as u32,
 931                        },
 932                    )?;
 933                }
 934            }
 935        }
 936        Ok(())
 937    }
 938
 939    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 940        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 941            if let Some(invite_code) = &user.invite_code {
 942                let pool = self.connection_pool.lock();
 943                for connection_id in pool.user_connection_ids(user_id) {
 944                    self.peer.send(
 945                        connection_id,
 946                        proto::UpdateInviteInfo {
 947                            url: format!(
 948                                "{}{}",
 949                                self.app_state.config.invite_link_prefix, invite_code
 950                            ),
 951                            count: user.invite_count as u32,
 952                        },
 953                    )?;
 954                }
 955            }
 956        }
 957        Ok(())
 958    }
 959
 960    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 961        let pool = self.connection_pool.lock();
 962        for connection_id in pool.user_connection_ids(user_id) {
 963            self.peer
 964                .send(connection_id, proto::RefreshLlmToken {})
 965                .trace_err();
 966        }
 967    }
 968
 969    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 970        ServerSnapshot {
 971            connection_pool: ConnectionPoolGuard {
 972                guard: self.connection_pool.lock(),
 973                _not_send: PhantomData,
 974            },
 975            peer: &self.peer,
 976        }
 977    }
 978}
 979
 980impl Deref for ConnectionPoolGuard<'_> {
 981    type Target = ConnectionPool;
 982
 983    fn deref(&self) -> &Self::Target {
 984        &self.guard
 985    }
 986}
 987
 988impl DerefMut for ConnectionPoolGuard<'_> {
 989    fn deref_mut(&mut self) -> &mut Self::Target {
 990        &mut self.guard
 991    }
 992}
 993
 994impl Drop for ConnectionPoolGuard<'_> {
 995    fn drop(&mut self) {
 996        #[cfg(test)]
 997        self.check_invariants();
 998    }
 999}
1000
1001fn broadcast<F>(
1002    sender_id: Option<ConnectionId>,
1003    receiver_ids: impl IntoIterator<Item = ConnectionId>,
1004    mut f: F,
1005) where
1006    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1007{
1008    for receiver_id in receiver_ids {
1009        if Some(receiver_id) != sender_id {
1010            if let Err(error) = f(receiver_id) {
1011                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1012            }
1013        }
1014    }
1015}
1016
1017pub struct ProtocolVersion(u32);
1018
1019impl Header for ProtocolVersion {
1020    fn name() -> &'static HeaderName {
1021        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1022        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1023    }
1024
1025    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1026    where
1027        Self: Sized,
1028        I: Iterator<Item = &'i axum::http::HeaderValue>,
1029    {
1030        let version = values
1031            .next()
1032            .ok_or_else(axum::headers::Error::invalid)?
1033            .to_str()
1034            .map_err(|_| axum::headers::Error::invalid())?
1035            .parse()
1036            .map_err(|_| axum::headers::Error::invalid())?;
1037        Ok(Self(version))
1038    }
1039
1040    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1041        values.extend([self.0.to_string().parse().unwrap()]);
1042    }
1043}
1044
1045pub struct AppVersionHeader(SemanticVersion);
1046impl Header for AppVersionHeader {
1047    fn name() -> &'static HeaderName {
1048        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1049        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1050    }
1051
1052    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1053    where
1054        Self: Sized,
1055        I: Iterator<Item = &'i axum::http::HeaderValue>,
1056    {
1057        let version = values
1058            .next()
1059            .ok_or_else(axum::headers::Error::invalid)?
1060            .to_str()
1061            .map_err(|_| axum::headers::Error::invalid())?
1062            .parse()
1063            .map_err(|_| axum::headers::Error::invalid())?;
1064        Ok(Self(version))
1065    }
1066
1067    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1068        values.extend([self.0.to_string().parse().unwrap()]);
1069    }
1070}
1071
1072pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1073    Router::new()
1074        .route("/rpc", get(handle_websocket_request))
1075        .layer(
1076            ServiceBuilder::new()
1077                .layer(Extension(server.app_state.clone()))
1078                .layer(middleware::from_fn(auth::validate_header)),
1079        )
1080        .route("/metrics", get(handle_metrics))
1081        .layer(Extension(server))
1082}
1083
1084#[allow(clippy::too_many_arguments)]
1085pub async fn handle_websocket_request(
1086    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1087    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1088    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1089    Extension(server): Extension<Arc<Server>>,
1090    Extension(principal): Extension<Principal>,
1091    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1092    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1093    ws: WebSocketUpgrade,
1094) -> axum::response::Response {
1095    if protocol_version != rpc::PROTOCOL_VERSION {
1096        return (
1097            StatusCode::UPGRADE_REQUIRED,
1098            "client must be upgraded".to_string(),
1099        )
1100            .into_response();
1101    }
1102
1103    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1104        return (
1105            StatusCode::UPGRADE_REQUIRED,
1106            "no version header found".to_string(),
1107        )
1108            .into_response();
1109    };
1110
1111    if !version.can_collaborate() {
1112        return (
1113            StatusCode::UPGRADE_REQUIRED,
1114            "client must be upgraded".to_string(),
1115        )
1116            .into_response();
1117    }
1118
1119    let socket_address = socket_address.to_string();
1120    ws.on_upgrade(move |socket| {
1121        let socket = socket
1122            .map_ok(to_tungstenite_message)
1123            .err_into()
1124            .with(|message| async move { to_axum_message(message) });
1125        let connection = Connection::new(Box::pin(socket));
1126        async move {
1127            server
1128                .handle_connection(
1129                    connection,
1130                    socket_address,
1131                    principal,
1132                    version,
1133                    country_code_header.map(|header| header.to_string()),
1134                    system_id_header.map(|header| header.to_string()),
1135                    None,
1136                    Executor::Production,
1137                )
1138                .await;
1139        }
1140    })
1141}
1142
1143pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1144    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1145    let connections_metric = CONNECTIONS_METRIC
1146        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1147
1148    let connections = server
1149        .connection_pool
1150        .lock()
1151        .connections()
1152        .filter(|connection| !connection.admin)
1153        .count();
1154    connections_metric.set(connections as _);
1155
1156    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1157    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1158        register_int_gauge!(
1159            "shared_projects",
1160            "number of open projects with one or more guests"
1161        )
1162        .unwrap()
1163    });
1164
1165    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1166    shared_projects_metric.set(shared_projects as _);
1167
1168    let encoder = prometheus::TextEncoder::new();
1169    let metric_families = prometheus::gather();
1170    let encoded_metrics = encoder
1171        .encode_to_string(&metric_families)
1172        .map_err(|err| anyhow!("{}", err))?;
1173    Ok(encoded_metrics)
1174}
1175
1176#[instrument(err, skip(executor))]
1177async fn connection_lost(
1178    session: Session,
1179    mut teardown: watch::Receiver<bool>,
1180    executor: Executor,
1181) -> Result<()> {
1182    session.peer.disconnect(session.connection_id);
1183    session
1184        .connection_pool()
1185        .await
1186        .remove_connection(session.connection_id)?;
1187
1188    session
1189        .db()
1190        .await
1191        .connection_lost(session.connection_id)
1192        .await
1193        .trace_err();
1194
1195    futures::select_biased! {
1196        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1197
1198            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1199            leave_room_for_session(&session, session.connection_id).await.trace_err();
1200            leave_channel_buffers_for_session(&session)
1201                .await
1202                .trace_err();
1203
1204            if !session
1205                .connection_pool()
1206                .await
1207                .is_user_online(session.user_id())
1208            {
1209                let db = session.db().await;
1210                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1211                    room_updated(&room, &session.peer);
1212                }
1213            }
1214
1215            update_user_contacts(session.user_id(), &session).await?;
1216        },
1217        _ = teardown.changed().fuse() => {}
1218    }
1219
1220    Ok(())
1221}
1222
1223/// Acknowledges a ping from a client, used to keep the connection alive.
1224async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1225    response.send(proto::Ack {})?;
1226    Ok(())
1227}
1228
1229/// Creates a new room for calling (outside of channels)
1230async fn create_room(
1231    _request: proto::CreateRoom,
1232    response: Response<proto::CreateRoom>,
1233    session: Session,
1234) -> Result<()> {
1235    let livekit_room = nanoid::nanoid!(30);
1236
1237    let live_kit_connection_info = util::maybe!(async {
1238        let live_kit = session.app_state.livekit_client.as_ref();
1239        let live_kit = live_kit?;
1240        let user_id = session.user_id().to_string();
1241
1242        let token = live_kit
1243            .room_token(&livekit_room, &user_id.to_string())
1244            .trace_err()?;
1245
1246        Some(proto::LiveKitConnectionInfo {
1247            server_url: live_kit.url().into(),
1248            token,
1249            can_publish: true,
1250        })
1251    })
1252    .await;
1253
1254    let room = session
1255        .db()
1256        .await
1257        .create_room(session.user_id(), session.connection_id, &livekit_room)
1258        .await?;
1259
1260    response.send(proto::CreateRoomResponse {
1261        room: Some(room.clone()),
1262        live_kit_connection_info,
1263    })?;
1264
1265    update_user_contacts(session.user_id(), &session).await?;
1266    Ok(())
1267}
1268
1269/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1270async fn join_room(
1271    request: proto::JoinRoom,
1272    response: Response<proto::JoinRoom>,
1273    session: Session,
1274) -> Result<()> {
1275    let room_id = RoomId::from_proto(request.id);
1276
1277    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1278
1279    if let Some(channel_id) = channel_id {
1280        return join_channel_internal(channel_id, Box::new(response), session).await;
1281    }
1282
1283    let joined_room = {
1284        let room = session
1285            .db()
1286            .await
1287            .join_room(room_id, session.user_id(), session.connection_id)
1288            .await?;
1289        room_updated(&room.room, &session.peer);
1290        room.into_inner()
1291    };
1292
1293    for connection_id in session
1294        .connection_pool()
1295        .await
1296        .user_connection_ids(session.user_id())
1297    {
1298        session
1299            .peer
1300            .send(
1301                connection_id,
1302                proto::CallCanceled {
1303                    room_id: room_id.to_proto(),
1304                },
1305            )
1306            .trace_err();
1307    }
1308
1309    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1310    {
1311        live_kit
1312            .room_token(
1313                &joined_room.room.livekit_room,
1314                &session.user_id().to_string(),
1315            )
1316            .trace_err()
1317            .map(|token| proto::LiveKitConnectionInfo {
1318                server_url: live_kit.url().into(),
1319                token,
1320                can_publish: true,
1321            })
1322    } else {
1323        None
1324    };
1325
1326    response.send(proto::JoinRoomResponse {
1327        room: Some(joined_room.room),
1328        channel_id: None,
1329        live_kit_connection_info,
1330    })?;
1331
1332    update_user_contacts(session.user_id(), &session).await?;
1333    Ok(())
1334}
1335
1336/// Rejoin room is used to reconnect to a room after connection errors.
1337async fn rejoin_room(
1338    request: proto::RejoinRoom,
1339    response: Response<proto::RejoinRoom>,
1340    session: Session,
1341) -> Result<()> {
1342    let room;
1343    let channel;
1344    {
1345        let mut rejoined_room = session
1346            .db()
1347            .await
1348            .rejoin_room(request, session.user_id(), session.connection_id)
1349            .await?;
1350
1351        response.send(proto::RejoinRoomResponse {
1352            room: Some(rejoined_room.room.clone()),
1353            reshared_projects: rejoined_room
1354                .reshared_projects
1355                .iter()
1356                .map(|project| proto::ResharedProject {
1357                    id: project.id.to_proto(),
1358                    collaborators: project
1359                        .collaborators
1360                        .iter()
1361                        .map(|collaborator| collaborator.to_proto())
1362                        .collect(),
1363                })
1364                .collect(),
1365            rejoined_projects: rejoined_room
1366                .rejoined_projects
1367                .iter()
1368                .map(|rejoined_project| rejoined_project.to_proto())
1369                .collect(),
1370        })?;
1371        room_updated(&rejoined_room.room, &session.peer);
1372
1373        for project in &rejoined_room.reshared_projects {
1374            for collaborator in &project.collaborators {
1375                session
1376                    .peer
1377                    .send(
1378                        collaborator.connection_id,
1379                        proto::UpdateProjectCollaborator {
1380                            project_id: project.id.to_proto(),
1381                            old_peer_id: Some(project.old_connection_id.into()),
1382                            new_peer_id: Some(session.connection_id.into()),
1383                        },
1384                    )
1385                    .trace_err();
1386            }
1387
1388            broadcast(
1389                Some(session.connection_id),
1390                project
1391                    .collaborators
1392                    .iter()
1393                    .map(|collaborator| collaborator.connection_id),
1394                |connection_id| {
1395                    session.peer.forward_send(
1396                        session.connection_id,
1397                        connection_id,
1398                        proto::UpdateProject {
1399                            project_id: project.id.to_proto(),
1400                            worktrees: project.worktrees.clone(),
1401                        },
1402                    )
1403                },
1404            );
1405        }
1406
1407        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1408
1409        let rejoined_room = rejoined_room.into_inner();
1410
1411        room = rejoined_room.room;
1412        channel = rejoined_room.channel;
1413    }
1414
1415    if let Some(channel) = channel {
1416        channel_updated(
1417            &channel,
1418            &room,
1419            &session.peer,
1420            &*session.connection_pool().await,
1421        );
1422    }
1423
1424    update_user_contacts(session.user_id(), &session).await?;
1425    Ok(())
1426}
1427
1428fn notify_rejoined_projects(
1429    rejoined_projects: &mut Vec<RejoinedProject>,
1430    session: &Session,
1431) -> Result<()> {
1432    for project in rejoined_projects.iter() {
1433        for collaborator in &project.collaborators {
1434            session
1435                .peer
1436                .send(
1437                    collaborator.connection_id,
1438                    proto::UpdateProjectCollaborator {
1439                        project_id: project.id.to_proto(),
1440                        old_peer_id: Some(project.old_connection_id.into()),
1441                        new_peer_id: Some(session.connection_id.into()),
1442                    },
1443                )
1444                .trace_err();
1445        }
1446    }
1447
1448    for project in rejoined_projects {
1449        for worktree in mem::take(&mut project.worktrees) {
1450            // Stream this worktree's entries.
1451            let message = proto::UpdateWorktree {
1452                project_id: project.id.to_proto(),
1453                worktree_id: worktree.id,
1454                abs_path: worktree.abs_path.clone(),
1455                root_name: worktree.root_name,
1456                updated_entries: worktree.updated_entries,
1457                removed_entries: worktree.removed_entries,
1458                scan_id: worktree.scan_id,
1459                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1460                updated_repositories: worktree.updated_repositories,
1461                removed_repositories: worktree.removed_repositories,
1462            };
1463            for update in proto::split_worktree_update(message) {
1464                session.peer.send(session.connection_id, update.clone())?;
1465            }
1466
1467            // Stream this worktree's diagnostics.
1468            for summary in worktree.diagnostic_summaries {
1469                session.peer.send(
1470                    session.connection_id,
1471                    proto::UpdateDiagnosticSummary {
1472                        project_id: project.id.to_proto(),
1473                        worktree_id: worktree.id,
1474                        summary: Some(summary),
1475                    },
1476                )?;
1477            }
1478
1479            for settings_file in worktree.settings_files {
1480                session.peer.send(
1481                    session.connection_id,
1482                    proto::UpdateWorktreeSettings {
1483                        project_id: project.id.to_proto(),
1484                        worktree_id: worktree.id,
1485                        path: settings_file.path,
1486                        content: Some(settings_file.content),
1487                        kind: Some(settings_file.kind.to_proto().into()),
1488                    },
1489                )?;
1490            }
1491        }
1492
1493        for language_server in &project.language_servers {
1494            session.peer.send(
1495                session.connection_id,
1496                proto::UpdateLanguageServer {
1497                    project_id: project.id.to_proto(),
1498                    language_server_id: language_server.id,
1499                    variant: Some(
1500                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1501                            proto::LspDiskBasedDiagnosticsUpdated {},
1502                        ),
1503                    ),
1504                },
1505            )?;
1506        }
1507    }
1508    Ok(())
1509}
1510
1511/// leave room disconnects from the room.
1512async fn leave_room(
1513    _: proto::LeaveRoom,
1514    response: Response<proto::LeaveRoom>,
1515    session: Session,
1516) -> Result<()> {
1517    leave_room_for_session(&session, session.connection_id).await?;
1518    response.send(proto::Ack {})?;
1519    Ok(())
1520}
1521
1522/// Updates the permissions of someone else in the room.
1523async fn set_room_participant_role(
1524    request: proto::SetRoomParticipantRole,
1525    response: Response<proto::SetRoomParticipantRole>,
1526    session: Session,
1527) -> Result<()> {
1528    let user_id = UserId::from_proto(request.user_id);
1529    let role = ChannelRole::from(request.role());
1530
1531    let (livekit_room, can_publish) = {
1532        let room = session
1533            .db()
1534            .await
1535            .set_room_participant_role(
1536                session.user_id(),
1537                RoomId::from_proto(request.room_id),
1538                user_id,
1539                role,
1540            )
1541            .await?;
1542
1543        let livekit_room = room.livekit_room.clone();
1544        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1545        room_updated(&room, &session.peer);
1546        (livekit_room, can_publish)
1547    };
1548
1549    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1550        live_kit
1551            .update_participant(
1552                livekit_room.clone(),
1553                request.user_id.to_string(),
1554                livekit_api::proto::ParticipantPermission {
1555                    can_subscribe: true,
1556                    can_publish,
1557                    can_publish_data: can_publish,
1558                    hidden: false,
1559                    recorder: false,
1560                },
1561            )
1562            .await
1563            .trace_err();
1564    }
1565
1566    response.send(proto::Ack {})?;
1567    Ok(())
1568}
1569
1570/// Call someone else into the current room
1571async fn call(
1572    request: proto::Call,
1573    response: Response<proto::Call>,
1574    session: Session,
1575) -> Result<()> {
1576    let room_id = RoomId::from_proto(request.room_id);
1577    let calling_user_id = session.user_id();
1578    let calling_connection_id = session.connection_id;
1579    let called_user_id = UserId::from_proto(request.called_user_id);
1580    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1581    if !session
1582        .db()
1583        .await
1584        .has_contact(calling_user_id, called_user_id)
1585        .await?
1586    {
1587        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1588    }
1589
1590    let incoming_call = {
1591        let (room, incoming_call) = &mut *session
1592            .db()
1593            .await
1594            .call(
1595                room_id,
1596                calling_user_id,
1597                calling_connection_id,
1598                called_user_id,
1599                initial_project_id,
1600            )
1601            .await?;
1602        room_updated(room, &session.peer);
1603        mem::take(incoming_call)
1604    };
1605    update_user_contacts(called_user_id, &session).await?;
1606
1607    let mut calls = session
1608        .connection_pool()
1609        .await
1610        .user_connection_ids(called_user_id)
1611        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1612        .collect::<FuturesUnordered<_>>();
1613
1614    while let Some(call_response) = calls.next().await {
1615        match call_response.as_ref() {
1616            Ok(_) => {
1617                response.send(proto::Ack {})?;
1618                return Ok(());
1619            }
1620            Err(_) => {
1621                call_response.trace_err();
1622            }
1623        }
1624    }
1625
1626    {
1627        let room = session
1628            .db()
1629            .await
1630            .call_failed(room_id, called_user_id)
1631            .await?;
1632        room_updated(&room, &session.peer);
1633    }
1634    update_user_contacts(called_user_id, &session).await?;
1635
1636    Err(anyhow!("failed to ring user"))?
1637}
1638
1639/// Cancel an outgoing call.
1640async fn cancel_call(
1641    request: proto::CancelCall,
1642    response: Response<proto::CancelCall>,
1643    session: Session,
1644) -> Result<()> {
1645    let called_user_id = UserId::from_proto(request.called_user_id);
1646    let room_id = RoomId::from_proto(request.room_id);
1647    {
1648        let room = session
1649            .db()
1650            .await
1651            .cancel_call(room_id, session.connection_id, called_user_id)
1652            .await?;
1653        room_updated(&room, &session.peer);
1654    }
1655
1656    for connection_id in session
1657        .connection_pool()
1658        .await
1659        .user_connection_ids(called_user_id)
1660    {
1661        session
1662            .peer
1663            .send(
1664                connection_id,
1665                proto::CallCanceled {
1666                    room_id: room_id.to_proto(),
1667                },
1668            )
1669            .trace_err();
1670    }
1671    response.send(proto::Ack {})?;
1672
1673    update_user_contacts(called_user_id, &session).await?;
1674    Ok(())
1675}
1676
1677/// Decline an incoming call.
1678async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1679    let room_id = RoomId::from_proto(message.room_id);
1680    {
1681        let room = session
1682            .db()
1683            .await
1684            .decline_call(Some(room_id), session.user_id())
1685            .await?
1686            .ok_or_else(|| anyhow!("failed to decline call"))?;
1687        room_updated(&room, &session.peer);
1688    }
1689
1690    for connection_id in session
1691        .connection_pool()
1692        .await
1693        .user_connection_ids(session.user_id())
1694    {
1695        session
1696            .peer
1697            .send(
1698                connection_id,
1699                proto::CallCanceled {
1700                    room_id: room_id.to_proto(),
1701                },
1702            )
1703            .trace_err();
1704    }
1705    update_user_contacts(session.user_id(), &session).await?;
1706    Ok(())
1707}
1708
1709/// Updates other participants in the room with your current location.
1710async fn update_participant_location(
1711    request: proto::UpdateParticipantLocation,
1712    response: Response<proto::UpdateParticipantLocation>,
1713    session: Session,
1714) -> Result<()> {
1715    let room_id = RoomId::from_proto(request.room_id);
1716    let location = request
1717        .location
1718        .ok_or_else(|| anyhow!("invalid location"))?;
1719
1720    let db = session.db().await;
1721    let room = db
1722        .update_room_participant_location(room_id, session.connection_id, location)
1723        .await?;
1724
1725    room_updated(&room, &session.peer);
1726    response.send(proto::Ack {})?;
1727    Ok(())
1728}
1729
1730/// Share a project into the room.
1731async fn share_project(
1732    request: proto::ShareProject,
1733    response: Response<proto::ShareProject>,
1734    session: Session,
1735) -> Result<()> {
1736    let (project_id, room) = &*session
1737        .db()
1738        .await
1739        .share_project(
1740            RoomId::from_proto(request.room_id),
1741            session.connection_id,
1742            &request.worktrees,
1743            request.is_ssh_project,
1744        )
1745        .await?;
1746    response.send(proto::ShareProjectResponse {
1747        project_id: project_id.to_proto(),
1748    })?;
1749    room_updated(room, &session.peer);
1750
1751    Ok(())
1752}
1753
1754/// Unshare a project from the room.
1755async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1756    let project_id = ProjectId::from_proto(message.project_id);
1757    unshare_project_internal(project_id, session.connection_id, &session).await
1758}
1759
1760async fn unshare_project_internal(
1761    project_id: ProjectId,
1762    connection_id: ConnectionId,
1763    session: &Session,
1764) -> Result<()> {
1765    let delete = {
1766        let room_guard = session
1767            .db()
1768            .await
1769            .unshare_project(project_id, connection_id)
1770            .await?;
1771
1772        let (delete, room, guest_connection_ids) = &*room_guard;
1773
1774        let message = proto::UnshareProject {
1775            project_id: project_id.to_proto(),
1776        };
1777
1778        broadcast(
1779            Some(connection_id),
1780            guest_connection_ids.iter().copied(),
1781            |conn_id| session.peer.send(conn_id, message.clone()),
1782        );
1783        if let Some(room) = room {
1784            room_updated(room, &session.peer);
1785        }
1786
1787        *delete
1788    };
1789
1790    if delete {
1791        let db = session.db().await;
1792        db.delete_project(project_id).await?;
1793    }
1794
1795    Ok(())
1796}
1797
1798/// Join someone elses shared project.
1799async fn join_project(
1800    request: proto::JoinProject,
1801    response: Response<proto::JoinProject>,
1802    session: Session,
1803) -> Result<()> {
1804    let project_id = ProjectId::from_proto(request.project_id);
1805
1806    tracing::info!(%project_id, "join project");
1807
1808    let db = session.db().await;
1809    let (project, replica_id) = &mut *db
1810        .join_project(project_id, session.connection_id, session.user_id())
1811        .await?;
1812    drop(db);
1813    tracing::info!(%project_id, "join remote project");
1814    join_project_internal(response, session, project, replica_id)
1815}
1816
1817trait JoinProjectInternalResponse {
1818    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1819}
1820impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1821    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1822        Response::<proto::JoinProject>::send(self, result)
1823    }
1824}
1825
1826fn join_project_internal(
1827    response: impl JoinProjectInternalResponse,
1828    session: Session,
1829    project: &mut Project,
1830    replica_id: &ReplicaId,
1831) -> Result<()> {
1832    let collaborators = project
1833        .collaborators
1834        .iter()
1835        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1836        .map(|collaborator| collaborator.to_proto())
1837        .collect::<Vec<_>>();
1838    let project_id = project.id;
1839    let guest_user_id = session.user_id();
1840
1841    let worktrees = project
1842        .worktrees
1843        .iter()
1844        .map(|(id, worktree)| proto::WorktreeMetadata {
1845            id: *id,
1846            root_name: worktree.root_name.clone(),
1847            visible: worktree.visible,
1848            abs_path: worktree.abs_path.clone(),
1849        })
1850        .collect::<Vec<_>>();
1851
1852    let add_project_collaborator = proto::AddProjectCollaborator {
1853        project_id: project_id.to_proto(),
1854        collaborator: Some(proto::Collaborator {
1855            peer_id: Some(session.connection_id.into()),
1856            replica_id: replica_id.0 as u32,
1857            user_id: guest_user_id.to_proto(),
1858            is_host: false,
1859        }),
1860    };
1861
1862    for collaborator in &collaborators {
1863        session
1864            .peer
1865            .send(
1866                collaborator.peer_id.unwrap().into(),
1867                add_project_collaborator.clone(),
1868            )
1869            .trace_err();
1870    }
1871
1872    // First, we send the metadata associated with each worktree.
1873    response.send(proto::JoinProjectResponse {
1874        project_id: project.id.0 as u64,
1875        worktrees: worktrees.clone(),
1876        replica_id: replica_id.0 as u32,
1877        collaborators: collaborators.clone(),
1878        language_servers: project.language_servers.clone(),
1879        role: project.role.into(),
1880    })?;
1881
1882    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1883        // Stream this worktree's entries.
1884        let message = proto::UpdateWorktree {
1885            project_id: project_id.to_proto(),
1886            worktree_id,
1887            abs_path: worktree.abs_path.clone(),
1888            root_name: worktree.root_name,
1889            updated_entries: worktree.entries,
1890            removed_entries: Default::default(),
1891            scan_id: worktree.scan_id,
1892            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1893            updated_repositories: worktree.repository_entries.into_values().collect(),
1894            removed_repositories: Default::default(),
1895        };
1896        for update in proto::split_worktree_update(message) {
1897            session.peer.send(session.connection_id, update.clone())?;
1898        }
1899
1900        // Stream this worktree's diagnostics.
1901        for summary in worktree.diagnostic_summaries {
1902            session.peer.send(
1903                session.connection_id,
1904                proto::UpdateDiagnosticSummary {
1905                    project_id: project_id.to_proto(),
1906                    worktree_id: worktree.id,
1907                    summary: Some(summary),
1908                },
1909            )?;
1910        }
1911
1912        for settings_file in worktree.settings_files {
1913            session.peer.send(
1914                session.connection_id,
1915                proto::UpdateWorktreeSettings {
1916                    project_id: project_id.to_proto(),
1917                    worktree_id: worktree.id,
1918                    path: settings_file.path,
1919                    content: Some(settings_file.content),
1920                    kind: Some(settings_file.kind.to_proto() as i32),
1921                },
1922            )?;
1923        }
1924    }
1925
1926    for language_server in &project.language_servers {
1927        session.peer.send(
1928            session.connection_id,
1929            proto::UpdateLanguageServer {
1930                project_id: project_id.to_proto(),
1931                language_server_id: language_server.id,
1932                variant: Some(
1933                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1934                        proto::LspDiskBasedDiagnosticsUpdated {},
1935                    ),
1936                ),
1937            },
1938        )?;
1939    }
1940
1941    Ok(())
1942}
1943
1944/// Leave someone elses shared project.
1945async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1946    let sender_id = session.connection_id;
1947    let project_id = ProjectId::from_proto(request.project_id);
1948    let db = session.db().await;
1949
1950    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1951    tracing::info!(
1952        %project_id,
1953        "leave project"
1954    );
1955
1956    project_left(project, &session);
1957    if let Some(room) = room {
1958        room_updated(room, &session.peer);
1959    }
1960
1961    Ok(())
1962}
1963
1964/// Updates other participants with changes to the project
1965async fn update_project(
1966    request: proto::UpdateProject,
1967    response: Response<proto::UpdateProject>,
1968    session: Session,
1969) -> Result<()> {
1970    let project_id = ProjectId::from_proto(request.project_id);
1971    let (room, guest_connection_ids) = &*session
1972        .db()
1973        .await
1974        .update_project(project_id, session.connection_id, &request.worktrees)
1975        .await?;
1976    broadcast(
1977        Some(session.connection_id),
1978        guest_connection_ids.iter().copied(),
1979        |connection_id| {
1980            session
1981                .peer
1982                .forward_send(session.connection_id, connection_id, request.clone())
1983        },
1984    );
1985    if let Some(room) = room {
1986        room_updated(room, &session.peer);
1987    }
1988    response.send(proto::Ack {})?;
1989
1990    Ok(())
1991}
1992
1993/// Updates other participants with changes to the worktree
1994async fn update_worktree(
1995    request: proto::UpdateWorktree,
1996    response: Response<proto::UpdateWorktree>,
1997    session: Session,
1998) -> Result<()> {
1999    let guest_connection_ids = session
2000        .db()
2001        .await
2002        .update_worktree(&request, session.connection_id)
2003        .await?;
2004
2005    broadcast(
2006        Some(session.connection_id),
2007        guest_connection_ids.iter().copied(),
2008        |connection_id| {
2009            session
2010                .peer
2011                .forward_send(session.connection_id, connection_id, request.clone())
2012        },
2013    );
2014    response.send(proto::Ack {})?;
2015    Ok(())
2016}
2017
2018/// Updates other participants with changes to the diagnostics
2019async fn update_diagnostic_summary(
2020    message: proto::UpdateDiagnosticSummary,
2021    session: Session,
2022) -> Result<()> {
2023    let guest_connection_ids = session
2024        .db()
2025        .await
2026        .update_diagnostic_summary(&message, session.connection_id)
2027        .await?;
2028
2029    broadcast(
2030        Some(session.connection_id),
2031        guest_connection_ids.iter().copied(),
2032        |connection_id| {
2033            session
2034                .peer
2035                .forward_send(session.connection_id, connection_id, message.clone())
2036        },
2037    );
2038
2039    Ok(())
2040}
2041
2042/// Updates other participants with changes to the worktree settings
2043async fn update_worktree_settings(
2044    message: proto::UpdateWorktreeSettings,
2045    session: Session,
2046) -> Result<()> {
2047    let guest_connection_ids = session
2048        .db()
2049        .await
2050        .update_worktree_settings(&message, session.connection_id)
2051        .await?;
2052
2053    broadcast(
2054        Some(session.connection_id),
2055        guest_connection_ids.iter().copied(),
2056        |connection_id| {
2057            session
2058                .peer
2059                .forward_send(session.connection_id, connection_id, message.clone())
2060        },
2061    );
2062
2063    Ok(())
2064}
2065
2066/// Notify other participants that a  language server has started.
2067async fn start_language_server(
2068    request: proto::StartLanguageServer,
2069    session: Session,
2070) -> Result<()> {
2071    let guest_connection_ids = session
2072        .db()
2073        .await
2074        .start_language_server(&request, session.connection_id)
2075        .await?;
2076
2077    broadcast(
2078        Some(session.connection_id),
2079        guest_connection_ids.iter().copied(),
2080        |connection_id| {
2081            session
2082                .peer
2083                .forward_send(session.connection_id, connection_id, request.clone())
2084        },
2085    );
2086    Ok(())
2087}
2088
2089/// Notify other participants that a language server has changed.
2090async fn update_language_server(
2091    request: proto::UpdateLanguageServer,
2092    session: Session,
2093) -> Result<()> {
2094    let project_id = ProjectId::from_proto(request.project_id);
2095    let project_connection_ids = session
2096        .db()
2097        .await
2098        .project_connection_ids(project_id, session.connection_id, true)
2099        .await?;
2100    broadcast(
2101        Some(session.connection_id),
2102        project_connection_ids.iter().copied(),
2103        |connection_id| {
2104            session
2105                .peer
2106                .forward_send(session.connection_id, connection_id, request.clone())
2107        },
2108    );
2109    Ok(())
2110}
2111
2112/// forward a project request to the host. These requests should be read only
2113/// as guests are allowed to send them.
2114async fn forward_read_only_project_request<T>(
2115    request: T,
2116    response: Response<T>,
2117    session: Session,
2118) -> Result<()>
2119where
2120    T: EntityMessage + RequestMessage,
2121{
2122    let project_id = ProjectId::from_proto(request.remote_entity_id());
2123    let host_connection_id = session
2124        .db()
2125        .await
2126        .host_for_read_only_project_request(project_id, session.connection_id)
2127        .await?;
2128    let payload = session
2129        .peer
2130        .forward_request(session.connection_id, host_connection_id, request)
2131        .await?;
2132    response.send(payload)?;
2133    Ok(())
2134}
2135
2136async fn forward_find_search_candidates_request(
2137    request: proto::FindSearchCandidates,
2138    response: Response<proto::FindSearchCandidates>,
2139    session: Session,
2140) -> Result<()> {
2141    let project_id = ProjectId::from_proto(request.remote_entity_id());
2142    let host_connection_id = session
2143        .db()
2144        .await
2145        .host_for_read_only_project_request(project_id, session.connection_id)
2146        .await?;
2147    let payload = session
2148        .peer
2149        .forward_request(session.connection_id, host_connection_id, request)
2150        .await?;
2151    response.send(payload)?;
2152    Ok(())
2153}
2154
2155/// forward a project request to the host. These requests are disallowed
2156/// for guests.
2157async fn forward_mutating_project_request<T>(
2158    request: T,
2159    response: Response<T>,
2160    session: Session,
2161) -> Result<()>
2162where
2163    T: EntityMessage + RequestMessage,
2164{
2165    let project_id = ProjectId::from_proto(request.remote_entity_id());
2166
2167    let host_connection_id = session
2168        .db()
2169        .await
2170        .host_for_mutating_project_request(project_id, session.connection_id)
2171        .await?;
2172    let payload = session
2173        .peer
2174        .forward_request(session.connection_id, host_connection_id, request)
2175        .await?;
2176    response.send(payload)?;
2177    Ok(())
2178}
2179
2180/// Notify other participants that a new buffer has been created
2181async fn create_buffer_for_peer(
2182    request: proto::CreateBufferForPeer,
2183    session: Session,
2184) -> Result<()> {
2185    session
2186        .db()
2187        .await
2188        .check_user_is_project_host(
2189            ProjectId::from_proto(request.project_id),
2190            session.connection_id,
2191        )
2192        .await?;
2193    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2194    session
2195        .peer
2196        .forward_send(session.connection_id, peer_id.into(), request)?;
2197    Ok(())
2198}
2199
2200/// Notify other participants that a buffer has been updated. This is
2201/// allowed for guests as long as the update is limited to selections.
2202async fn update_buffer(
2203    request: proto::UpdateBuffer,
2204    response: Response<proto::UpdateBuffer>,
2205    session: Session,
2206) -> Result<()> {
2207    let project_id = ProjectId::from_proto(request.project_id);
2208    let mut capability = Capability::ReadOnly;
2209
2210    for op in request.operations.iter() {
2211        match op.variant {
2212            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2213            Some(_) => capability = Capability::ReadWrite,
2214        }
2215    }
2216
2217    let host = {
2218        let guard = session
2219            .db()
2220            .await
2221            .connections_for_buffer_update(project_id, session.connection_id, capability)
2222            .await?;
2223
2224        let (host, guests) = &*guard;
2225
2226        broadcast(
2227            Some(session.connection_id),
2228            guests.clone(),
2229            |connection_id| {
2230                session
2231                    .peer
2232                    .forward_send(session.connection_id, connection_id, request.clone())
2233            },
2234        );
2235
2236        *host
2237    };
2238
2239    if host != session.connection_id {
2240        session
2241            .peer
2242            .forward_request(session.connection_id, host, request.clone())
2243            .await?;
2244    }
2245
2246    response.send(proto::Ack {})?;
2247    Ok(())
2248}
2249
2250async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2251    let project_id = ProjectId::from_proto(message.project_id);
2252
2253    let operation = message.operation.as_ref().context("invalid operation")?;
2254    let capability = match operation.variant.as_ref() {
2255        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2256            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2257                match buffer_op.variant {
2258                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2259                        Capability::ReadOnly
2260                    }
2261                    _ => Capability::ReadWrite,
2262                }
2263            } else {
2264                Capability::ReadWrite
2265            }
2266        }
2267        Some(_) => Capability::ReadWrite,
2268        None => Capability::ReadOnly,
2269    };
2270
2271    let guard = session
2272        .db()
2273        .await
2274        .connections_for_buffer_update(project_id, session.connection_id, capability)
2275        .await?;
2276
2277    let (host, guests) = &*guard;
2278
2279    broadcast(
2280        Some(session.connection_id),
2281        guests.iter().chain([host]).copied(),
2282        |connection_id| {
2283            session
2284                .peer
2285                .forward_send(session.connection_id, connection_id, message.clone())
2286        },
2287    );
2288
2289    Ok(())
2290}
2291
2292/// Notify other participants that a project has been updated.
2293async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2294    request: T,
2295    session: Session,
2296) -> Result<()> {
2297    let project_id = ProjectId::from_proto(request.remote_entity_id());
2298    let project_connection_ids = session
2299        .db()
2300        .await
2301        .project_connection_ids(project_id, session.connection_id, false)
2302        .await?;
2303
2304    broadcast(
2305        Some(session.connection_id),
2306        project_connection_ids.iter().copied(),
2307        |connection_id| {
2308            session
2309                .peer
2310                .forward_send(session.connection_id, connection_id, request.clone())
2311        },
2312    );
2313    Ok(())
2314}
2315
2316/// Start following another user in a call.
2317async fn follow(
2318    request: proto::Follow,
2319    response: Response<proto::Follow>,
2320    session: Session,
2321) -> Result<()> {
2322    let room_id = RoomId::from_proto(request.room_id);
2323    let project_id = request.project_id.map(ProjectId::from_proto);
2324    let leader_id = request
2325        .leader_id
2326        .ok_or_else(|| anyhow!("invalid leader id"))?
2327        .into();
2328    let follower_id = session.connection_id;
2329
2330    session
2331        .db()
2332        .await
2333        .check_room_participants(room_id, leader_id, session.connection_id)
2334        .await?;
2335
2336    let response_payload = session
2337        .peer
2338        .forward_request(session.connection_id, leader_id, request)
2339        .await?;
2340    response.send(response_payload)?;
2341
2342    if let Some(project_id) = project_id {
2343        let room = session
2344            .db()
2345            .await
2346            .follow(room_id, project_id, leader_id, follower_id)
2347            .await?;
2348        room_updated(&room, &session.peer);
2349    }
2350
2351    Ok(())
2352}
2353
2354/// Stop following another user in a call.
2355async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2356    let room_id = RoomId::from_proto(request.room_id);
2357    let project_id = request.project_id.map(ProjectId::from_proto);
2358    let leader_id = request
2359        .leader_id
2360        .ok_or_else(|| anyhow!("invalid leader id"))?
2361        .into();
2362    let follower_id = session.connection_id;
2363
2364    session
2365        .db()
2366        .await
2367        .check_room_participants(room_id, leader_id, session.connection_id)
2368        .await?;
2369
2370    session
2371        .peer
2372        .forward_send(session.connection_id, leader_id, request)?;
2373
2374    if let Some(project_id) = project_id {
2375        let room = session
2376            .db()
2377            .await
2378            .unfollow(room_id, project_id, leader_id, follower_id)
2379            .await?;
2380        room_updated(&room, &session.peer);
2381    }
2382
2383    Ok(())
2384}
2385
2386/// Notify everyone following you of your current location.
2387async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2388    let room_id = RoomId::from_proto(request.room_id);
2389    let database = session.db.lock().await;
2390
2391    let connection_ids = if let Some(project_id) = request.project_id {
2392        let project_id = ProjectId::from_proto(project_id);
2393        database
2394            .project_connection_ids(project_id, session.connection_id, true)
2395            .await?
2396    } else {
2397        database
2398            .room_connection_ids(room_id, session.connection_id)
2399            .await?
2400    };
2401
2402    // For now, don't send view update messages back to that view's current leader.
2403    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2404        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2405        _ => None,
2406    });
2407
2408    for connection_id in connection_ids.iter().cloned() {
2409        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2410            session
2411                .peer
2412                .forward_send(session.connection_id, connection_id, request.clone())?;
2413        }
2414    }
2415    Ok(())
2416}
2417
2418/// Get public data about users.
2419async fn get_users(
2420    request: proto::GetUsers,
2421    response: Response<proto::GetUsers>,
2422    session: Session,
2423) -> Result<()> {
2424    let user_ids = request
2425        .user_ids
2426        .into_iter()
2427        .map(UserId::from_proto)
2428        .collect();
2429    let users = session
2430        .db()
2431        .await
2432        .get_users_by_ids(user_ids)
2433        .await?
2434        .into_iter()
2435        .map(|user| proto::User {
2436            id: user.id.to_proto(),
2437            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2438            github_login: user.github_login,
2439            email: user.email_address,
2440            name: user.name,
2441        })
2442        .collect();
2443    response.send(proto::UsersResponse { users })?;
2444    Ok(())
2445}
2446
2447/// Search for users (to invite) buy Github login
2448async fn fuzzy_search_users(
2449    request: proto::FuzzySearchUsers,
2450    response: Response<proto::FuzzySearchUsers>,
2451    session: Session,
2452) -> Result<()> {
2453    let query = request.query;
2454    let users = match query.len() {
2455        0 => vec![],
2456        1 | 2 => session
2457            .db()
2458            .await
2459            .get_user_by_github_login(&query)
2460            .await?
2461            .into_iter()
2462            .collect(),
2463        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2464    };
2465    let users = users
2466        .into_iter()
2467        .filter(|user| user.id != session.user_id())
2468        .map(|user| proto::User {
2469            id: user.id.to_proto(),
2470            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2471            github_login: user.github_login,
2472            name: user.name,
2473            email: user.email_address,
2474        })
2475        .collect();
2476    response.send(proto::UsersResponse { users })?;
2477    Ok(())
2478}
2479
2480/// Send a contact request to another user.
2481async fn request_contact(
2482    request: proto::RequestContact,
2483    response: Response<proto::RequestContact>,
2484    session: Session,
2485) -> Result<()> {
2486    let requester_id = session.user_id();
2487    let responder_id = UserId::from_proto(request.responder_id);
2488    if requester_id == responder_id {
2489        return Err(anyhow!("cannot add yourself as a contact"))?;
2490    }
2491
2492    let notifications = session
2493        .db()
2494        .await
2495        .send_contact_request(requester_id, responder_id)
2496        .await?;
2497
2498    // Update outgoing contact requests of requester
2499    let mut update = proto::UpdateContacts::default();
2500    update.outgoing_requests.push(responder_id.to_proto());
2501    for connection_id in session
2502        .connection_pool()
2503        .await
2504        .user_connection_ids(requester_id)
2505    {
2506        session.peer.send(connection_id, update.clone())?;
2507    }
2508
2509    // Update incoming contact requests of responder
2510    let mut update = proto::UpdateContacts::default();
2511    update
2512        .incoming_requests
2513        .push(proto::IncomingContactRequest {
2514            requester_id: requester_id.to_proto(),
2515        });
2516    let connection_pool = session.connection_pool().await;
2517    for connection_id in connection_pool.user_connection_ids(responder_id) {
2518        session.peer.send(connection_id, update.clone())?;
2519    }
2520
2521    send_notifications(&connection_pool, &session.peer, notifications);
2522
2523    response.send(proto::Ack {})?;
2524    Ok(())
2525}
2526
2527/// Accept or decline a contact request
2528async fn respond_to_contact_request(
2529    request: proto::RespondToContactRequest,
2530    response: Response<proto::RespondToContactRequest>,
2531    session: Session,
2532) -> Result<()> {
2533    let responder_id = session.user_id();
2534    let requester_id = UserId::from_proto(request.requester_id);
2535    let db = session.db().await;
2536    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2537        db.dismiss_contact_notification(responder_id, requester_id)
2538            .await?;
2539    } else {
2540        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2541
2542        let notifications = db
2543            .respond_to_contact_request(responder_id, requester_id, accept)
2544            .await?;
2545        let requester_busy = db.is_user_busy(requester_id).await?;
2546        let responder_busy = db.is_user_busy(responder_id).await?;
2547
2548        let pool = session.connection_pool().await;
2549        // Update responder with new contact
2550        let mut update = proto::UpdateContacts::default();
2551        if accept {
2552            update
2553                .contacts
2554                .push(contact_for_user(requester_id, requester_busy, &pool));
2555        }
2556        update
2557            .remove_incoming_requests
2558            .push(requester_id.to_proto());
2559        for connection_id in pool.user_connection_ids(responder_id) {
2560            session.peer.send(connection_id, update.clone())?;
2561        }
2562
2563        // Update requester with new contact
2564        let mut update = proto::UpdateContacts::default();
2565        if accept {
2566            update
2567                .contacts
2568                .push(contact_for_user(responder_id, responder_busy, &pool));
2569        }
2570        update
2571            .remove_outgoing_requests
2572            .push(responder_id.to_proto());
2573
2574        for connection_id in pool.user_connection_ids(requester_id) {
2575            session.peer.send(connection_id, update.clone())?;
2576        }
2577
2578        send_notifications(&pool, &session.peer, notifications);
2579    }
2580
2581    response.send(proto::Ack {})?;
2582    Ok(())
2583}
2584
2585/// Remove a contact.
2586async fn remove_contact(
2587    request: proto::RemoveContact,
2588    response: Response<proto::RemoveContact>,
2589    session: Session,
2590) -> Result<()> {
2591    let requester_id = session.user_id();
2592    let responder_id = UserId::from_proto(request.user_id);
2593    let db = session.db().await;
2594    let (contact_accepted, deleted_notification_id) =
2595        db.remove_contact(requester_id, responder_id).await?;
2596
2597    let pool = session.connection_pool().await;
2598    // Update outgoing contact requests of requester
2599    let mut update = proto::UpdateContacts::default();
2600    if contact_accepted {
2601        update.remove_contacts.push(responder_id.to_proto());
2602    } else {
2603        update
2604            .remove_outgoing_requests
2605            .push(responder_id.to_proto());
2606    }
2607    for connection_id in pool.user_connection_ids(requester_id) {
2608        session.peer.send(connection_id, update.clone())?;
2609    }
2610
2611    // Update incoming contact requests of responder
2612    let mut update = proto::UpdateContacts::default();
2613    if contact_accepted {
2614        update.remove_contacts.push(requester_id.to_proto());
2615    } else {
2616        update
2617            .remove_incoming_requests
2618            .push(requester_id.to_proto());
2619    }
2620    for connection_id in pool.user_connection_ids(responder_id) {
2621        session.peer.send(connection_id, update.clone())?;
2622        if let Some(notification_id) = deleted_notification_id {
2623            session.peer.send(
2624                connection_id,
2625                proto::DeleteNotification {
2626                    notification_id: notification_id.to_proto(),
2627                },
2628            )?;
2629        }
2630    }
2631
2632    response.send(proto::Ack {})?;
2633    Ok(())
2634}
2635
2636fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2637    version.0.minor() < 139
2638}
2639
2640async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2641    let plan = session.current_plan(&session.db().await).await?;
2642
2643    session
2644        .peer
2645        .send(
2646            session.connection_id,
2647            proto::UpdateUserPlan { plan: plan.into() },
2648        )
2649        .trace_err();
2650
2651    Ok(())
2652}
2653
2654async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2655    subscribe_user_to_channels(session.user_id(), &session).await?;
2656    Ok(())
2657}
2658
2659async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2660    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2661    let mut pool = session.connection_pool().await;
2662    for membership in &channels_for_user.channel_memberships {
2663        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2664    }
2665    session.peer.send(
2666        session.connection_id,
2667        build_update_user_channels(&channels_for_user),
2668    )?;
2669    session.peer.send(
2670        session.connection_id,
2671        build_channels_update(channels_for_user),
2672    )?;
2673    Ok(())
2674}
2675
2676/// Creates a new channel.
2677async fn create_channel(
2678    request: proto::CreateChannel,
2679    response: Response<proto::CreateChannel>,
2680    session: Session,
2681) -> Result<()> {
2682    let db = session.db().await;
2683
2684    let parent_id = request.parent_id.map(ChannelId::from_proto);
2685    let (channel, membership) = db
2686        .create_channel(&request.name, parent_id, session.user_id())
2687        .await?;
2688
2689    let root_id = channel.root_id();
2690    let channel = Channel::from_model(channel);
2691
2692    response.send(proto::CreateChannelResponse {
2693        channel: Some(channel.to_proto()),
2694        parent_id: request.parent_id,
2695    })?;
2696
2697    let mut connection_pool = session.connection_pool().await;
2698    if let Some(membership) = membership {
2699        connection_pool.subscribe_to_channel(
2700            membership.user_id,
2701            membership.channel_id,
2702            membership.role,
2703        );
2704        let update = proto::UpdateUserChannels {
2705            channel_memberships: vec![proto::ChannelMembership {
2706                channel_id: membership.channel_id.to_proto(),
2707                role: membership.role.into(),
2708            }],
2709            ..Default::default()
2710        };
2711        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2712            session.peer.send(connection_id, update.clone())?;
2713        }
2714    }
2715
2716    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2717        if !role.can_see_channel(channel.visibility) {
2718            continue;
2719        }
2720
2721        let update = proto::UpdateChannels {
2722            channels: vec![channel.to_proto()],
2723            ..Default::default()
2724        };
2725        session.peer.send(connection_id, update.clone())?;
2726    }
2727
2728    Ok(())
2729}
2730
2731/// Delete a channel
2732async fn delete_channel(
2733    request: proto::DeleteChannel,
2734    response: Response<proto::DeleteChannel>,
2735    session: Session,
2736) -> Result<()> {
2737    let db = session.db().await;
2738
2739    let channel_id = request.channel_id;
2740    let (root_channel, removed_channels) = db
2741        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2742        .await?;
2743    response.send(proto::Ack {})?;
2744
2745    // Notify members of removed channels
2746    let mut update = proto::UpdateChannels::default();
2747    update
2748        .delete_channels
2749        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2750
2751    let connection_pool = session.connection_pool().await;
2752    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2753        session.peer.send(connection_id, update.clone())?;
2754    }
2755
2756    Ok(())
2757}
2758
2759/// Invite someone to join a channel.
2760async fn invite_channel_member(
2761    request: proto::InviteChannelMember,
2762    response: Response<proto::InviteChannelMember>,
2763    session: Session,
2764) -> Result<()> {
2765    let db = session.db().await;
2766    let channel_id = ChannelId::from_proto(request.channel_id);
2767    let invitee_id = UserId::from_proto(request.user_id);
2768    let InviteMemberResult {
2769        channel,
2770        notifications,
2771    } = db
2772        .invite_channel_member(
2773            channel_id,
2774            invitee_id,
2775            session.user_id(),
2776            request.role().into(),
2777        )
2778        .await?;
2779
2780    let update = proto::UpdateChannels {
2781        channel_invitations: vec![channel.to_proto()],
2782        ..Default::default()
2783    };
2784
2785    let connection_pool = session.connection_pool().await;
2786    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2787        session.peer.send(connection_id, update.clone())?;
2788    }
2789
2790    send_notifications(&connection_pool, &session.peer, notifications);
2791
2792    response.send(proto::Ack {})?;
2793    Ok(())
2794}
2795
2796/// remove someone from a channel
2797async fn remove_channel_member(
2798    request: proto::RemoveChannelMember,
2799    response: Response<proto::RemoveChannelMember>,
2800    session: Session,
2801) -> Result<()> {
2802    let db = session.db().await;
2803    let channel_id = ChannelId::from_proto(request.channel_id);
2804    let member_id = UserId::from_proto(request.user_id);
2805
2806    let RemoveChannelMemberResult {
2807        membership_update,
2808        notification_id,
2809    } = db
2810        .remove_channel_member(channel_id, member_id, session.user_id())
2811        .await?;
2812
2813    let mut connection_pool = session.connection_pool().await;
2814    notify_membership_updated(
2815        &mut connection_pool,
2816        membership_update,
2817        member_id,
2818        &session.peer,
2819    );
2820    for connection_id in connection_pool.user_connection_ids(member_id) {
2821        if let Some(notification_id) = notification_id {
2822            session
2823                .peer
2824                .send(
2825                    connection_id,
2826                    proto::DeleteNotification {
2827                        notification_id: notification_id.to_proto(),
2828                    },
2829                )
2830                .trace_err();
2831        }
2832    }
2833
2834    response.send(proto::Ack {})?;
2835    Ok(())
2836}
2837
2838/// Toggle the channel between public and private.
2839/// Care is taken to maintain the invariant that public channels only descend from public channels,
2840/// (though members-only channels can appear at any point in the hierarchy).
2841async fn set_channel_visibility(
2842    request: proto::SetChannelVisibility,
2843    response: Response<proto::SetChannelVisibility>,
2844    session: Session,
2845) -> Result<()> {
2846    let db = session.db().await;
2847    let channel_id = ChannelId::from_proto(request.channel_id);
2848    let visibility = request.visibility().into();
2849
2850    let channel_model = db
2851        .set_channel_visibility(channel_id, visibility, session.user_id())
2852        .await?;
2853    let root_id = channel_model.root_id();
2854    let channel = Channel::from_model(channel_model);
2855
2856    let mut connection_pool = session.connection_pool().await;
2857    for (user_id, role) in connection_pool
2858        .channel_user_ids(root_id)
2859        .collect::<Vec<_>>()
2860        .into_iter()
2861    {
2862        let update = if role.can_see_channel(channel.visibility) {
2863            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2864            proto::UpdateChannels {
2865                channels: vec![channel.to_proto()],
2866                ..Default::default()
2867            }
2868        } else {
2869            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2870            proto::UpdateChannels {
2871                delete_channels: vec![channel.id.to_proto()],
2872                ..Default::default()
2873            }
2874        };
2875
2876        for connection_id in connection_pool.user_connection_ids(user_id) {
2877            session.peer.send(connection_id, update.clone())?;
2878        }
2879    }
2880
2881    response.send(proto::Ack {})?;
2882    Ok(())
2883}
2884
2885/// Alter the role for a user in the channel.
2886async fn set_channel_member_role(
2887    request: proto::SetChannelMemberRole,
2888    response: Response<proto::SetChannelMemberRole>,
2889    session: Session,
2890) -> Result<()> {
2891    let db = session.db().await;
2892    let channel_id = ChannelId::from_proto(request.channel_id);
2893    let member_id = UserId::from_proto(request.user_id);
2894    let result = db
2895        .set_channel_member_role(
2896            channel_id,
2897            session.user_id(),
2898            member_id,
2899            request.role().into(),
2900        )
2901        .await?;
2902
2903    match result {
2904        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2905            let mut connection_pool = session.connection_pool().await;
2906            notify_membership_updated(
2907                &mut connection_pool,
2908                membership_update,
2909                member_id,
2910                &session.peer,
2911            )
2912        }
2913        db::SetMemberRoleResult::InviteUpdated(channel) => {
2914            let update = proto::UpdateChannels {
2915                channel_invitations: vec![channel.to_proto()],
2916                ..Default::default()
2917            };
2918
2919            for connection_id in session
2920                .connection_pool()
2921                .await
2922                .user_connection_ids(member_id)
2923            {
2924                session.peer.send(connection_id, update.clone())?;
2925            }
2926        }
2927    }
2928
2929    response.send(proto::Ack {})?;
2930    Ok(())
2931}
2932
2933/// Change the name of a channel
2934async fn rename_channel(
2935    request: proto::RenameChannel,
2936    response: Response<proto::RenameChannel>,
2937    session: Session,
2938) -> Result<()> {
2939    let db = session.db().await;
2940    let channel_id = ChannelId::from_proto(request.channel_id);
2941    let channel_model = db
2942        .rename_channel(channel_id, session.user_id(), &request.name)
2943        .await?;
2944    let root_id = channel_model.root_id();
2945    let channel = Channel::from_model(channel_model);
2946
2947    response.send(proto::RenameChannelResponse {
2948        channel: Some(channel.to_proto()),
2949    })?;
2950
2951    let connection_pool = session.connection_pool().await;
2952    let update = proto::UpdateChannels {
2953        channels: vec![channel.to_proto()],
2954        ..Default::default()
2955    };
2956    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2957        if role.can_see_channel(channel.visibility) {
2958            session.peer.send(connection_id, update.clone())?;
2959        }
2960    }
2961
2962    Ok(())
2963}
2964
2965/// Move a channel to a new parent.
2966async fn move_channel(
2967    request: proto::MoveChannel,
2968    response: Response<proto::MoveChannel>,
2969    session: Session,
2970) -> Result<()> {
2971    let channel_id = ChannelId::from_proto(request.channel_id);
2972    let to = ChannelId::from_proto(request.to);
2973
2974    let (root_id, channels) = session
2975        .db()
2976        .await
2977        .move_channel(channel_id, to, session.user_id())
2978        .await?;
2979
2980    let connection_pool = session.connection_pool().await;
2981    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2982        let channels = channels
2983            .iter()
2984            .filter_map(|channel| {
2985                if role.can_see_channel(channel.visibility) {
2986                    Some(channel.to_proto())
2987                } else {
2988                    None
2989                }
2990            })
2991            .collect::<Vec<_>>();
2992        if channels.is_empty() {
2993            continue;
2994        }
2995
2996        let update = proto::UpdateChannels {
2997            channels,
2998            ..Default::default()
2999        };
3000
3001        session.peer.send(connection_id, update.clone())?;
3002    }
3003
3004    response.send(Ack {})?;
3005    Ok(())
3006}
3007
3008/// Get the list of channel members
3009async fn get_channel_members(
3010    request: proto::GetChannelMembers,
3011    response: Response<proto::GetChannelMembers>,
3012    session: Session,
3013) -> Result<()> {
3014    let db = session.db().await;
3015    let channel_id = ChannelId::from_proto(request.channel_id);
3016    let limit = if request.limit == 0 {
3017        u16::MAX as u64
3018    } else {
3019        request.limit
3020    };
3021    let (members, users) = db
3022        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3023        .await?;
3024    response.send(proto::GetChannelMembersResponse { members, users })?;
3025    Ok(())
3026}
3027
3028/// Accept or decline a channel invitation.
3029async fn respond_to_channel_invite(
3030    request: proto::RespondToChannelInvite,
3031    response: Response<proto::RespondToChannelInvite>,
3032    session: Session,
3033) -> Result<()> {
3034    let db = session.db().await;
3035    let channel_id = ChannelId::from_proto(request.channel_id);
3036    let RespondToChannelInvite {
3037        membership_update,
3038        notifications,
3039    } = db
3040        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3041        .await?;
3042
3043    let mut connection_pool = session.connection_pool().await;
3044    if let Some(membership_update) = membership_update {
3045        notify_membership_updated(
3046            &mut connection_pool,
3047            membership_update,
3048            session.user_id(),
3049            &session.peer,
3050        );
3051    } else {
3052        let update = proto::UpdateChannels {
3053            remove_channel_invitations: vec![channel_id.to_proto()],
3054            ..Default::default()
3055        };
3056
3057        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3058            session.peer.send(connection_id, update.clone())?;
3059        }
3060    };
3061
3062    send_notifications(&connection_pool, &session.peer, notifications);
3063
3064    response.send(proto::Ack {})?;
3065
3066    Ok(())
3067}
3068
3069/// Join the channels' room
3070async fn join_channel(
3071    request: proto::JoinChannel,
3072    response: Response<proto::JoinChannel>,
3073    session: Session,
3074) -> Result<()> {
3075    let channel_id = ChannelId::from_proto(request.channel_id);
3076    join_channel_internal(channel_id, Box::new(response), session).await
3077}
3078
3079trait JoinChannelInternalResponse {
3080    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3081}
3082impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3083    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3084        Response::<proto::JoinChannel>::send(self, result)
3085    }
3086}
3087impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3088    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3089        Response::<proto::JoinRoom>::send(self, result)
3090    }
3091}
3092
3093async fn join_channel_internal(
3094    channel_id: ChannelId,
3095    response: Box<impl JoinChannelInternalResponse>,
3096    session: Session,
3097) -> Result<()> {
3098    let joined_room = {
3099        let mut db = session.db().await;
3100        // If zed quits without leaving the room, and the user re-opens zed before the
3101        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3102        // room they were in.
3103        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3104            tracing::info!(
3105                stale_connection_id = %connection,
3106                "cleaning up stale connection",
3107            );
3108            drop(db);
3109            leave_room_for_session(&session, connection).await?;
3110            db = session.db().await;
3111        }
3112
3113        let (joined_room, membership_updated, role) = db
3114            .join_channel(channel_id, session.user_id(), session.connection_id)
3115            .await?;
3116
3117        let live_kit_connection_info =
3118            session
3119                .app_state
3120                .livekit_client
3121                .as_ref()
3122                .and_then(|live_kit| {
3123                    let (can_publish, token) = if role == ChannelRole::Guest {
3124                        (
3125                            false,
3126                            live_kit
3127                                .guest_token(
3128                                    &joined_room.room.livekit_room,
3129                                    &session.user_id().to_string(),
3130                                )
3131                                .trace_err()?,
3132                        )
3133                    } else {
3134                        (
3135                            true,
3136                            live_kit
3137                                .room_token(
3138                                    &joined_room.room.livekit_room,
3139                                    &session.user_id().to_string(),
3140                                )
3141                                .trace_err()?,
3142                        )
3143                    };
3144
3145                    Some(LiveKitConnectionInfo {
3146                        server_url: live_kit.url().into(),
3147                        token,
3148                        can_publish,
3149                    })
3150                });
3151
3152        response.send(proto::JoinRoomResponse {
3153            room: Some(joined_room.room.clone()),
3154            channel_id: joined_room
3155                .channel
3156                .as_ref()
3157                .map(|channel| channel.id.to_proto()),
3158            live_kit_connection_info,
3159        })?;
3160
3161        let mut connection_pool = session.connection_pool().await;
3162        if let Some(membership_updated) = membership_updated {
3163            notify_membership_updated(
3164                &mut connection_pool,
3165                membership_updated,
3166                session.user_id(),
3167                &session.peer,
3168            );
3169        }
3170
3171        room_updated(&joined_room.room, &session.peer);
3172
3173        joined_room
3174    };
3175
3176    channel_updated(
3177        &joined_room
3178            .channel
3179            .ok_or_else(|| anyhow!("channel not returned"))?,
3180        &joined_room.room,
3181        &session.peer,
3182        &*session.connection_pool().await,
3183    );
3184
3185    update_user_contacts(session.user_id(), &session).await?;
3186    Ok(())
3187}
3188
3189/// Start editing the channel notes
3190async fn join_channel_buffer(
3191    request: proto::JoinChannelBuffer,
3192    response: Response<proto::JoinChannelBuffer>,
3193    session: Session,
3194) -> Result<()> {
3195    let db = session.db().await;
3196    let channel_id = ChannelId::from_proto(request.channel_id);
3197
3198    let open_response = db
3199        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3200        .await?;
3201
3202    let collaborators = open_response.collaborators.clone();
3203    response.send(open_response)?;
3204
3205    let update = UpdateChannelBufferCollaborators {
3206        channel_id: channel_id.to_proto(),
3207        collaborators: collaborators.clone(),
3208    };
3209    channel_buffer_updated(
3210        session.connection_id,
3211        collaborators
3212            .iter()
3213            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3214        &update,
3215        &session.peer,
3216    );
3217
3218    Ok(())
3219}
3220
3221/// Edit the channel notes
3222async fn update_channel_buffer(
3223    request: proto::UpdateChannelBuffer,
3224    session: Session,
3225) -> Result<()> {
3226    let db = session.db().await;
3227    let channel_id = ChannelId::from_proto(request.channel_id);
3228
3229    let (collaborators, epoch, version) = db
3230        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3231        .await?;
3232
3233    channel_buffer_updated(
3234        session.connection_id,
3235        collaborators.clone(),
3236        &proto::UpdateChannelBuffer {
3237            channel_id: channel_id.to_proto(),
3238            operations: request.operations,
3239        },
3240        &session.peer,
3241    );
3242
3243    let pool = &*session.connection_pool().await;
3244
3245    let non_collaborators =
3246        pool.channel_connection_ids(channel_id)
3247            .filter_map(|(connection_id, _)| {
3248                if collaborators.contains(&connection_id) {
3249                    None
3250                } else {
3251                    Some(connection_id)
3252                }
3253            });
3254
3255    broadcast(None, non_collaborators, |peer_id| {
3256        session.peer.send(
3257            peer_id,
3258            proto::UpdateChannels {
3259                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3260                    channel_id: channel_id.to_proto(),
3261                    epoch: epoch as u64,
3262                    version: version.clone(),
3263                }],
3264                ..Default::default()
3265            },
3266        )
3267    });
3268
3269    Ok(())
3270}
3271
3272/// Rejoin the channel notes after a connection blip
3273async fn rejoin_channel_buffers(
3274    request: proto::RejoinChannelBuffers,
3275    response: Response<proto::RejoinChannelBuffers>,
3276    session: Session,
3277) -> Result<()> {
3278    let db = session.db().await;
3279    let buffers = db
3280        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3281        .await?;
3282
3283    for rejoined_buffer in &buffers {
3284        let collaborators_to_notify = rejoined_buffer
3285            .buffer
3286            .collaborators
3287            .iter()
3288            .filter_map(|c| Some(c.peer_id?.into()));
3289        channel_buffer_updated(
3290            session.connection_id,
3291            collaborators_to_notify,
3292            &proto::UpdateChannelBufferCollaborators {
3293                channel_id: rejoined_buffer.buffer.channel_id,
3294                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3295            },
3296            &session.peer,
3297        );
3298    }
3299
3300    response.send(proto::RejoinChannelBuffersResponse {
3301        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3302    })?;
3303
3304    Ok(())
3305}
3306
3307/// Stop editing the channel notes
3308async fn leave_channel_buffer(
3309    request: proto::LeaveChannelBuffer,
3310    response: Response<proto::LeaveChannelBuffer>,
3311    session: Session,
3312) -> Result<()> {
3313    let db = session.db().await;
3314    let channel_id = ChannelId::from_proto(request.channel_id);
3315
3316    let left_buffer = db
3317        .leave_channel_buffer(channel_id, session.connection_id)
3318        .await?;
3319
3320    response.send(Ack {})?;
3321
3322    channel_buffer_updated(
3323        session.connection_id,
3324        left_buffer.connections,
3325        &proto::UpdateChannelBufferCollaborators {
3326            channel_id: channel_id.to_proto(),
3327            collaborators: left_buffer.collaborators,
3328        },
3329        &session.peer,
3330    );
3331
3332    Ok(())
3333}
3334
3335fn channel_buffer_updated<T: EnvelopedMessage>(
3336    sender_id: ConnectionId,
3337    collaborators: impl IntoIterator<Item = ConnectionId>,
3338    message: &T,
3339    peer: &Peer,
3340) {
3341    broadcast(Some(sender_id), collaborators, |peer_id| {
3342        peer.send(peer_id, message.clone())
3343    });
3344}
3345
3346fn send_notifications(
3347    connection_pool: &ConnectionPool,
3348    peer: &Peer,
3349    notifications: db::NotificationBatch,
3350) {
3351    for (user_id, notification) in notifications {
3352        for connection_id in connection_pool.user_connection_ids(user_id) {
3353            if let Err(error) = peer.send(
3354                connection_id,
3355                proto::AddNotification {
3356                    notification: Some(notification.clone()),
3357                },
3358            ) {
3359                tracing::error!(
3360                    "failed to send notification to {:?} {}",
3361                    connection_id,
3362                    error
3363                );
3364            }
3365        }
3366    }
3367}
3368
3369/// Send a message to the channel
3370async fn send_channel_message(
3371    request: proto::SendChannelMessage,
3372    response: Response<proto::SendChannelMessage>,
3373    session: Session,
3374) -> Result<()> {
3375    // Validate the message body.
3376    let body = request.body.trim().to_string();
3377    if body.len() > MAX_MESSAGE_LEN {
3378        return Err(anyhow!("message is too long"))?;
3379    }
3380    if body.is_empty() {
3381        return Err(anyhow!("message can't be blank"))?;
3382    }
3383
3384    // TODO: adjust mentions if body is trimmed
3385
3386    let timestamp = OffsetDateTime::now_utc();
3387    let nonce = request
3388        .nonce
3389        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3390
3391    let channel_id = ChannelId::from_proto(request.channel_id);
3392    let CreatedChannelMessage {
3393        message_id,
3394        participant_connection_ids,
3395        notifications,
3396    } = session
3397        .db()
3398        .await
3399        .create_channel_message(
3400            channel_id,
3401            session.user_id(),
3402            &body,
3403            &request.mentions,
3404            timestamp,
3405            nonce.clone().into(),
3406            request.reply_to_message_id.map(MessageId::from_proto),
3407        )
3408        .await?;
3409
3410    let message = proto::ChannelMessage {
3411        sender_id: session.user_id().to_proto(),
3412        id: message_id.to_proto(),
3413        body,
3414        mentions: request.mentions,
3415        timestamp: timestamp.unix_timestamp() as u64,
3416        nonce: Some(nonce),
3417        reply_to_message_id: request.reply_to_message_id,
3418        edited_at: None,
3419    };
3420    broadcast(
3421        Some(session.connection_id),
3422        participant_connection_ids.clone(),
3423        |connection| {
3424            session.peer.send(
3425                connection,
3426                proto::ChannelMessageSent {
3427                    channel_id: channel_id.to_proto(),
3428                    message: Some(message.clone()),
3429                },
3430            )
3431        },
3432    );
3433    response.send(proto::SendChannelMessageResponse {
3434        message: Some(message),
3435    })?;
3436
3437    let pool = &*session.connection_pool().await;
3438    let non_participants =
3439        pool.channel_connection_ids(channel_id)
3440            .filter_map(|(connection_id, _)| {
3441                if participant_connection_ids.contains(&connection_id) {
3442                    None
3443                } else {
3444                    Some(connection_id)
3445                }
3446            });
3447    broadcast(None, non_participants, |peer_id| {
3448        session.peer.send(
3449            peer_id,
3450            proto::UpdateChannels {
3451                latest_channel_message_ids: vec![proto::ChannelMessageId {
3452                    channel_id: channel_id.to_proto(),
3453                    message_id: message_id.to_proto(),
3454                }],
3455                ..Default::default()
3456            },
3457        )
3458    });
3459    send_notifications(pool, &session.peer, notifications);
3460
3461    Ok(())
3462}
3463
3464/// Delete a channel message
3465async fn remove_channel_message(
3466    request: proto::RemoveChannelMessage,
3467    response: Response<proto::RemoveChannelMessage>,
3468    session: Session,
3469) -> Result<()> {
3470    let channel_id = ChannelId::from_proto(request.channel_id);
3471    let message_id = MessageId::from_proto(request.message_id);
3472    let (connection_ids, existing_notification_ids) = session
3473        .db()
3474        .await
3475        .remove_channel_message(channel_id, message_id, session.user_id())
3476        .await?;
3477
3478    broadcast(
3479        Some(session.connection_id),
3480        connection_ids,
3481        move |connection| {
3482            session.peer.send(connection, request.clone())?;
3483
3484            for notification_id in &existing_notification_ids {
3485                session.peer.send(
3486                    connection,
3487                    proto::DeleteNotification {
3488                        notification_id: (*notification_id).to_proto(),
3489                    },
3490                )?;
3491            }
3492
3493            Ok(())
3494        },
3495    );
3496    response.send(proto::Ack {})?;
3497    Ok(())
3498}
3499
3500async fn update_channel_message(
3501    request: proto::UpdateChannelMessage,
3502    response: Response<proto::UpdateChannelMessage>,
3503    session: Session,
3504) -> Result<()> {
3505    let channel_id = ChannelId::from_proto(request.channel_id);
3506    let message_id = MessageId::from_proto(request.message_id);
3507    let updated_at = OffsetDateTime::now_utc();
3508    let UpdatedChannelMessage {
3509        message_id,
3510        participant_connection_ids,
3511        notifications,
3512        reply_to_message_id,
3513        timestamp,
3514        deleted_mention_notification_ids,
3515        updated_mention_notifications,
3516    } = session
3517        .db()
3518        .await
3519        .update_channel_message(
3520            channel_id,
3521            message_id,
3522            session.user_id(),
3523            request.body.as_str(),
3524            &request.mentions,
3525            updated_at,
3526        )
3527        .await?;
3528
3529    let nonce = request
3530        .nonce
3531        .clone()
3532        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3533
3534    let message = proto::ChannelMessage {
3535        sender_id: session.user_id().to_proto(),
3536        id: message_id.to_proto(),
3537        body: request.body.clone(),
3538        mentions: request.mentions.clone(),
3539        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3540        nonce: Some(nonce),
3541        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3542        edited_at: Some(updated_at.unix_timestamp() as u64),
3543    };
3544
3545    response.send(proto::Ack {})?;
3546
3547    let pool = &*session.connection_pool().await;
3548    broadcast(
3549        Some(session.connection_id),
3550        participant_connection_ids,
3551        |connection| {
3552            session.peer.send(
3553                connection,
3554                proto::ChannelMessageUpdate {
3555                    channel_id: channel_id.to_proto(),
3556                    message: Some(message.clone()),
3557                },
3558            )?;
3559
3560            for notification_id in &deleted_mention_notification_ids {
3561                session.peer.send(
3562                    connection,
3563                    proto::DeleteNotification {
3564                        notification_id: (*notification_id).to_proto(),
3565                    },
3566                )?;
3567            }
3568
3569            for notification in &updated_mention_notifications {
3570                session.peer.send(
3571                    connection,
3572                    proto::UpdateNotification {
3573                        notification: Some(notification.clone()),
3574                    },
3575                )?;
3576            }
3577
3578            Ok(())
3579        },
3580    );
3581
3582    send_notifications(pool, &session.peer, notifications);
3583
3584    Ok(())
3585}
3586
3587/// Mark a channel message as read
3588async fn acknowledge_channel_message(
3589    request: proto::AckChannelMessage,
3590    session: Session,
3591) -> Result<()> {
3592    let channel_id = ChannelId::from_proto(request.channel_id);
3593    let message_id = MessageId::from_proto(request.message_id);
3594    let notifications = session
3595        .db()
3596        .await
3597        .observe_channel_message(channel_id, session.user_id(), message_id)
3598        .await?;
3599    send_notifications(
3600        &*session.connection_pool().await,
3601        &session.peer,
3602        notifications,
3603    );
3604    Ok(())
3605}
3606
3607/// Mark a buffer version as synced
3608async fn acknowledge_buffer_version(
3609    request: proto::AckBufferOperation,
3610    session: Session,
3611) -> Result<()> {
3612    let buffer_id = BufferId::from_proto(request.buffer_id);
3613    session
3614        .db()
3615        .await
3616        .observe_buffer_version(
3617            buffer_id,
3618            session.user_id(),
3619            request.epoch as i32,
3620            &request.version,
3621        )
3622        .await?;
3623    Ok(())
3624}
3625
3626async fn count_language_model_tokens(
3627    request: proto::CountLanguageModelTokens,
3628    response: Response<proto::CountLanguageModelTokens>,
3629    session: Session,
3630    config: &Config,
3631) -> Result<()> {
3632    authorize_access_to_legacy_llm_endpoints(&session).await?;
3633
3634    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3635        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3636        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3637    };
3638
3639    session
3640        .app_state
3641        .rate_limiter
3642        .check(&*rate_limit, session.user_id())
3643        .await?;
3644
3645    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3646        Some(proto::LanguageModelProvider::Google) => {
3647            let api_key = config
3648                .google_ai_api_key
3649                .as_ref()
3650                .context("no Google AI API key configured on the server")?;
3651            google_ai::count_tokens(
3652                session.http_client.as_ref(),
3653                google_ai::API_URL,
3654                api_key,
3655                serde_json::from_str(&request.request)?,
3656            )
3657            .await?
3658        }
3659        _ => return Err(anyhow!("unsupported provider"))?,
3660    };
3661
3662    response.send(proto::CountLanguageModelTokensResponse {
3663        token_count: result.total_tokens as u32,
3664    })?;
3665
3666    Ok(())
3667}
3668
3669struct ZedProCountLanguageModelTokensRateLimit;
3670
3671impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3672    fn capacity(&self) -> usize {
3673        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3674            .ok()
3675            .and_then(|v| v.parse().ok())
3676            .unwrap_or(600) // Picked arbitrarily
3677    }
3678
3679    fn refill_duration(&self) -> chrono::Duration {
3680        chrono::Duration::hours(1)
3681    }
3682
3683    fn db_name(&self) -> &'static str {
3684        "zed-pro:count-language-model-tokens"
3685    }
3686}
3687
3688struct FreeCountLanguageModelTokensRateLimit;
3689
3690impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3691    fn capacity(&self) -> usize {
3692        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3693            .ok()
3694            .and_then(|v| v.parse().ok())
3695            .unwrap_or(600 / 10) // Picked arbitrarily
3696    }
3697
3698    fn refill_duration(&self) -> chrono::Duration {
3699        chrono::Duration::hours(1)
3700    }
3701
3702    fn db_name(&self) -> &'static str {
3703        "free:count-language-model-tokens"
3704    }
3705}
3706
3707struct ZedProComputeEmbeddingsRateLimit;
3708
3709impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3710    fn capacity(&self) -> usize {
3711        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3712            .ok()
3713            .and_then(|v| v.parse().ok())
3714            .unwrap_or(5000) // Picked arbitrarily
3715    }
3716
3717    fn refill_duration(&self) -> chrono::Duration {
3718        chrono::Duration::hours(1)
3719    }
3720
3721    fn db_name(&self) -> &'static str {
3722        "zed-pro:compute-embeddings"
3723    }
3724}
3725
3726struct FreeComputeEmbeddingsRateLimit;
3727
3728impl RateLimit for FreeComputeEmbeddingsRateLimit {
3729    fn capacity(&self) -> usize {
3730        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3731            .ok()
3732            .and_then(|v| v.parse().ok())
3733            .unwrap_or(5000 / 10) // Picked arbitrarily
3734    }
3735
3736    fn refill_duration(&self) -> chrono::Duration {
3737        chrono::Duration::hours(1)
3738    }
3739
3740    fn db_name(&self) -> &'static str {
3741        "free:compute-embeddings"
3742    }
3743}
3744
3745async fn compute_embeddings(
3746    request: proto::ComputeEmbeddings,
3747    response: Response<proto::ComputeEmbeddings>,
3748    session: Session,
3749    api_key: Option<Arc<str>>,
3750) -> Result<()> {
3751    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3752    authorize_access_to_legacy_llm_endpoints(&session).await?;
3753
3754    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3755        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3756        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3757    };
3758
3759    session
3760        .app_state
3761        .rate_limiter
3762        .check(&*rate_limit, session.user_id())
3763        .await?;
3764
3765    let embeddings = match request.model.as_str() {
3766        "openai/text-embedding-3-small" => {
3767            open_ai::embed(
3768                session.http_client.as_ref(),
3769                OPEN_AI_API_URL,
3770                &api_key,
3771                OpenAiEmbeddingModel::TextEmbedding3Small,
3772                request.texts.iter().map(|text| text.as_str()),
3773            )
3774            .await?
3775        }
3776        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3777    };
3778
3779    let embeddings = request
3780        .texts
3781        .iter()
3782        .map(|text| {
3783            let mut hasher = sha2::Sha256::new();
3784            hasher.update(text.as_bytes());
3785            let result = hasher.finalize();
3786            result.to_vec()
3787        })
3788        .zip(
3789            embeddings
3790                .data
3791                .into_iter()
3792                .map(|embedding| embedding.embedding),
3793        )
3794        .collect::<HashMap<_, _>>();
3795
3796    let db = session.db().await;
3797    db.save_embeddings(&request.model, &embeddings)
3798        .await
3799        .context("failed to save embeddings")
3800        .trace_err();
3801
3802    response.send(proto::ComputeEmbeddingsResponse {
3803        embeddings: embeddings
3804            .into_iter()
3805            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3806            .collect(),
3807    })?;
3808    Ok(())
3809}
3810
3811async fn get_cached_embeddings(
3812    request: proto::GetCachedEmbeddings,
3813    response: Response<proto::GetCachedEmbeddings>,
3814    session: Session,
3815) -> Result<()> {
3816    authorize_access_to_legacy_llm_endpoints(&session).await?;
3817
3818    let db = session.db().await;
3819    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3820
3821    response.send(proto::GetCachedEmbeddingsResponse {
3822        embeddings: embeddings
3823            .into_iter()
3824            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3825            .collect(),
3826    })?;
3827    Ok(())
3828}
3829
3830/// This is leftover from before the LLM service.
3831///
3832/// The endpoints protected by this check will be moved there eventually.
3833async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3834    if session.is_staff() {
3835        Ok(())
3836    } else {
3837        Err(anyhow!("permission denied"))?
3838    }
3839}
3840
3841/// Get a Supermaven API key for the user
3842async fn get_supermaven_api_key(
3843    _request: proto::GetSupermavenApiKey,
3844    response: Response<proto::GetSupermavenApiKey>,
3845    session: Session,
3846) -> Result<()> {
3847    let user_id: String = session.user_id().to_string();
3848    if !session.is_staff() {
3849        return Err(anyhow!("supermaven not enabled for this account"))?;
3850    }
3851
3852    let email = session
3853        .email()
3854        .ok_or_else(|| anyhow!("user must have an email"))?;
3855
3856    let supermaven_admin_api = session
3857        .supermaven_client
3858        .as_ref()
3859        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3860
3861    let result = supermaven_admin_api
3862        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3863        .await?;
3864
3865    response.send(proto::GetSupermavenApiKeyResponse {
3866        api_key: result.api_key,
3867    })?;
3868
3869    Ok(())
3870}
3871
3872/// Start receiving chat updates for a channel
3873async fn join_channel_chat(
3874    request: proto::JoinChannelChat,
3875    response: Response<proto::JoinChannelChat>,
3876    session: Session,
3877) -> Result<()> {
3878    let channel_id = ChannelId::from_proto(request.channel_id);
3879
3880    let db = session.db().await;
3881    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3882        .await?;
3883    let messages = db
3884        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3885        .await?;
3886    response.send(proto::JoinChannelChatResponse {
3887        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3888        messages,
3889    })?;
3890    Ok(())
3891}
3892
3893/// Stop receiving chat updates for a channel
3894async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3895    let channel_id = ChannelId::from_proto(request.channel_id);
3896    session
3897        .db()
3898        .await
3899        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3900        .await?;
3901    Ok(())
3902}
3903
3904/// Retrieve the chat history for a channel
3905async fn get_channel_messages(
3906    request: proto::GetChannelMessages,
3907    response: Response<proto::GetChannelMessages>,
3908    session: Session,
3909) -> Result<()> {
3910    let channel_id = ChannelId::from_proto(request.channel_id);
3911    let messages = session
3912        .db()
3913        .await
3914        .get_channel_messages(
3915            channel_id,
3916            session.user_id(),
3917            MESSAGE_COUNT_PER_PAGE,
3918            Some(MessageId::from_proto(request.before_message_id)),
3919        )
3920        .await?;
3921    response.send(proto::GetChannelMessagesResponse {
3922        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3923        messages,
3924    })?;
3925    Ok(())
3926}
3927
3928/// Retrieve specific chat messages
3929async fn get_channel_messages_by_id(
3930    request: proto::GetChannelMessagesById,
3931    response: Response<proto::GetChannelMessagesById>,
3932    session: Session,
3933) -> Result<()> {
3934    let message_ids = request
3935        .message_ids
3936        .iter()
3937        .map(|id| MessageId::from_proto(*id))
3938        .collect::<Vec<_>>();
3939    let messages = session
3940        .db()
3941        .await
3942        .get_channel_messages_by_id(session.user_id(), &message_ids)
3943        .await?;
3944    response.send(proto::GetChannelMessagesResponse {
3945        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3946        messages,
3947    })?;
3948    Ok(())
3949}
3950
3951/// Retrieve the current users notifications
3952async fn get_notifications(
3953    request: proto::GetNotifications,
3954    response: Response<proto::GetNotifications>,
3955    session: Session,
3956) -> Result<()> {
3957    let notifications = session
3958        .db()
3959        .await
3960        .get_notifications(
3961            session.user_id(),
3962            NOTIFICATION_COUNT_PER_PAGE,
3963            request.before_id.map(db::NotificationId::from_proto),
3964        )
3965        .await?;
3966    response.send(proto::GetNotificationsResponse {
3967        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3968        notifications,
3969    })?;
3970    Ok(())
3971}
3972
3973/// Mark notifications as read
3974async fn mark_notification_as_read(
3975    request: proto::MarkNotificationRead,
3976    response: Response<proto::MarkNotificationRead>,
3977    session: Session,
3978) -> Result<()> {
3979    let database = &session.db().await;
3980    let notifications = database
3981        .mark_notification_as_read_by_id(
3982            session.user_id(),
3983            NotificationId::from_proto(request.notification_id),
3984        )
3985        .await?;
3986    send_notifications(
3987        &*session.connection_pool().await,
3988        &session.peer,
3989        notifications,
3990    );
3991    response.send(proto::Ack {})?;
3992    Ok(())
3993}
3994
3995/// Get the current users information
3996async fn get_private_user_info(
3997    _request: proto::GetPrivateUserInfo,
3998    response: Response<proto::GetPrivateUserInfo>,
3999    session: Session,
4000) -> Result<()> {
4001    let db = session.db().await;
4002
4003    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4004    let user = db
4005        .get_user_by_id(session.user_id())
4006        .await?
4007        .ok_or_else(|| anyhow!("user not found"))?;
4008    let flags = db.get_user_flags(session.user_id()).await?;
4009
4010    response.send(proto::GetPrivateUserInfoResponse {
4011        metrics_id,
4012        staff: user.admin,
4013        flags,
4014        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4015    })?;
4016    Ok(())
4017}
4018
4019/// Accept the terms of service (tos) on behalf of the current user
4020async fn accept_terms_of_service(
4021    _request: proto::AcceptTermsOfService,
4022    response: Response<proto::AcceptTermsOfService>,
4023    session: Session,
4024) -> Result<()> {
4025    let db = session.db().await;
4026
4027    let accepted_tos_at = Utc::now();
4028    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4029        .await?;
4030
4031    response.send(proto::AcceptTermsOfServiceResponse {
4032        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4033    })?;
4034    Ok(())
4035}
4036
4037/// The minimum account age an account must have in order to use the LLM service.
4038const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4039
4040async fn get_llm_api_token(
4041    _request: proto::GetLlmToken,
4042    response: Response<proto::GetLlmToken>,
4043    session: Session,
4044) -> Result<()> {
4045    let db = session.db().await;
4046
4047    let flags = db.get_user_flags(session.user_id()).await?;
4048    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4049    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4050    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4051
4052    if !session.is_staff() && !has_language_models_feature_flag {
4053        Err(anyhow!("permission denied"))?
4054    }
4055
4056    let user_id = session.user_id();
4057    let user = db
4058        .get_user_by_id(user_id)
4059        .await?
4060        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4061
4062    if user.accepted_tos_at.is_none() {
4063        Err(anyhow!("terms of service not accepted"))?
4064    }
4065
4066    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4067
4068    let bypass_account_age_check =
4069        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4070    if !bypass_account_age_check {
4071        let mut account_created_at = user.created_at;
4072        if let Some(github_created_at) = user.github_user_created_at {
4073            account_created_at = account_created_at.min(github_created_at);
4074        }
4075        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4076            Err(anyhow!("account too young"))?
4077        }
4078    }
4079
4080    let billing_preferences = db.get_billing_preferences(user.id).await?;
4081
4082    let token = LlmTokenClaims::create(
4083        &user,
4084        session.is_staff(),
4085        billing_preferences,
4086        has_llm_closed_beta_feature_flag,
4087        has_predict_edits_feature_flag,
4088        has_llm_subscription,
4089        session.current_plan(&db).await?,
4090        session.system_id.clone(),
4091        &session.app_state.config,
4092    )?;
4093    response.send(proto::GetLlmTokenResponse { token })?;
4094    Ok(())
4095}
4096
4097fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4098    let message = match message {
4099        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4100        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4101        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4102        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4103        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4104            code: frame.code.into(),
4105            reason: frame.reason,
4106        })),
4107        // We should never receive a frame while reading the message, according
4108        // to the `tungstenite` maintainers:
4109        //
4110        // > It cannot occur when you read messages from the WebSocket, but it
4111        // > can be used when you want to send the raw frames (e.g. you want to
4112        // > send the frames to the WebSocket without composing the full message first).
4113        // >
4114        // > — https://github.com/snapview/tungstenite-rs/issues/268
4115        TungsteniteMessage::Frame(_) => {
4116            bail!("received an unexpected frame while reading the message")
4117        }
4118    };
4119
4120    Ok(message)
4121}
4122
4123fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4124    match message {
4125        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4126        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4127        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4128        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4129        AxumMessage::Close(frame) => {
4130            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4131                code: frame.code.into(),
4132                reason: frame.reason,
4133            }))
4134        }
4135    }
4136}
4137
4138fn notify_membership_updated(
4139    connection_pool: &mut ConnectionPool,
4140    result: MembershipUpdated,
4141    user_id: UserId,
4142    peer: &Peer,
4143) {
4144    for membership in &result.new_channels.channel_memberships {
4145        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4146    }
4147    for channel_id in &result.removed_channels {
4148        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4149    }
4150
4151    let user_channels_update = proto::UpdateUserChannels {
4152        channel_memberships: result
4153            .new_channels
4154            .channel_memberships
4155            .iter()
4156            .map(|cm| proto::ChannelMembership {
4157                channel_id: cm.channel_id.to_proto(),
4158                role: cm.role.into(),
4159            })
4160            .collect(),
4161        ..Default::default()
4162    };
4163
4164    let mut update = build_channels_update(result.new_channels);
4165    update.delete_channels = result
4166        .removed_channels
4167        .into_iter()
4168        .map(|id| id.to_proto())
4169        .collect();
4170    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4171
4172    for connection_id in connection_pool.user_connection_ids(user_id) {
4173        peer.send(connection_id, user_channels_update.clone())
4174            .trace_err();
4175        peer.send(connection_id, update.clone()).trace_err();
4176    }
4177}
4178
4179fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4180    proto::UpdateUserChannels {
4181        channel_memberships: channels
4182            .channel_memberships
4183            .iter()
4184            .map(|m| proto::ChannelMembership {
4185                channel_id: m.channel_id.to_proto(),
4186                role: m.role.into(),
4187            })
4188            .collect(),
4189        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4190        observed_channel_message_id: channels.observed_channel_messages.clone(),
4191    }
4192}
4193
4194fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4195    let mut update = proto::UpdateChannels::default();
4196
4197    for channel in channels.channels {
4198        update.channels.push(channel.to_proto());
4199    }
4200
4201    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4202    update.latest_channel_message_ids = channels.latest_channel_messages;
4203
4204    for (channel_id, participants) in channels.channel_participants {
4205        update
4206            .channel_participants
4207            .push(proto::ChannelParticipants {
4208                channel_id: channel_id.to_proto(),
4209                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4210            });
4211    }
4212
4213    for channel in channels.invited_channels {
4214        update.channel_invitations.push(channel.to_proto());
4215    }
4216
4217    update
4218}
4219
4220fn build_initial_contacts_update(
4221    contacts: Vec<db::Contact>,
4222    pool: &ConnectionPool,
4223) -> proto::UpdateContacts {
4224    let mut update = proto::UpdateContacts::default();
4225
4226    for contact in contacts {
4227        match contact {
4228            db::Contact::Accepted { user_id, busy } => {
4229                update.contacts.push(contact_for_user(user_id, busy, pool));
4230            }
4231            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4232            db::Contact::Incoming { user_id } => {
4233                update
4234                    .incoming_requests
4235                    .push(proto::IncomingContactRequest {
4236                        requester_id: user_id.to_proto(),
4237                    })
4238            }
4239        }
4240    }
4241
4242    update
4243}
4244
4245fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4246    proto::Contact {
4247        user_id: user_id.to_proto(),
4248        online: pool.is_user_online(user_id),
4249        busy,
4250    }
4251}
4252
4253fn room_updated(room: &proto::Room, peer: &Peer) {
4254    broadcast(
4255        None,
4256        room.participants
4257            .iter()
4258            .filter_map(|participant| Some(participant.peer_id?.into())),
4259        |peer_id| {
4260            peer.send(
4261                peer_id,
4262                proto::RoomUpdated {
4263                    room: Some(room.clone()),
4264                },
4265            )
4266        },
4267    );
4268}
4269
4270fn channel_updated(
4271    channel: &db::channel::Model,
4272    room: &proto::Room,
4273    peer: &Peer,
4274    pool: &ConnectionPool,
4275) {
4276    let participants = room
4277        .participants
4278        .iter()
4279        .map(|p| p.user_id)
4280        .collect::<Vec<_>>();
4281
4282    broadcast(
4283        None,
4284        pool.channel_connection_ids(channel.root_id())
4285            .filter_map(|(channel_id, role)| {
4286                role.can_see_channel(channel.visibility)
4287                    .then_some(channel_id)
4288            }),
4289        |peer_id| {
4290            peer.send(
4291                peer_id,
4292                proto::UpdateChannels {
4293                    channel_participants: vec![proto::ChannelParticipants {
4294                        channel_id: channel.id.to_proto(),
4295                        participant_user_ids: participants.clone(),
4296                    }],
4297                    ..Default::default()
4298                },
4299            )
4300        },
4301    );
4302}
4303
4304async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4305    let db = session.db().await;
4306
4307    let contacts = db.get_contacts(user_id).await?;
4308    let busy = db.is_user_busy(user_id).await?;
4309
4310    let pool = session.connection_pool().await;
4311    let updated_contact = contact_for_user(user_id, busy, &pool);
4312    for contact in contacts {
4313        if let db::Contact::Accepted {
4314            user_id: contact_user_id,
4315            ..
4316        } = contact
4317        {
4318            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4319                session
4320                    .peer
4321                    .send(
4322                        contact_conn_id,
4323                        proto::UpdateContacts {
4324                            contacts: vec![updated_contact.clone()],
4325                            remove_contacts: Default::default(),
4326                            incoming_requests: Default::default(),
4327                            remove_incoming_requests: Default::default(),
4328                            outgoing_requests: Default::default(),
4329                            remove_outgoing_requests: Default::default(),
4330                        },
4331                    )
4332                    .trace_err();
4333            }
4334        }
4335    }
4336    Ok(())
4337}
4338
4339async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4340    let mut contacts_to_update = HashSet::default();
4341
4342    let room_id;
4343    let canceled_calls_to_user_ids;
4344    let livekit_room;
4345    let delete_livekit_room;
4346    let room;
4347    let channel;
4348
4349    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4350        contacts_to_update.insert(session.user_id());
4351
4352        for project in left_room.left_projects.values() {
4353            project_left(project, session);
4354        }
4355
4356        room_id = RoomId::from_proto(left_room.room.id);
4357        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4358        livekit_room = mem::take(&mut left_room.room.livekit_room);
4359        delete_livekit_room = left_room.deleted;
4360        room = mem::take(&mut left_room.room);
4361        channel = mem::take(&mut left_room.channel);
4362
4363        room_updated(&room, &session.peer);
4364    } else {
4365        return Ok(());
4366    }
4367
4368    if let Some(channel) = channel {
4369        channel_updated(
4370            &channel,
4371            &room,
4372            &session.peer,
4373            &*session.connection_pool().await,
4374        );
4375    }
4376
4377    {
4378        let pool = session.connection_pool().await;
4379        for canceled_user_id in canceled_calls_to_user_ids {
4380            for connection_id in pool.user_connection_ids(canceled_user_id) {
4381                session
4382                    .peer
4383                    .send(
4384                        connection_id,
4385                        proto::CallCanceled {
4386                            room_id: room_id.to_proto(),
4387                        },
4388                    )
4389                    .trace_err();
4390            }
4391            contacts_to_update.insert(canceled_user_id);
4392        }
4393    }
4394
4395    for contact_user_id in contacts_to_update {
4396        update_user_contacts(contact_user_id, session).await?;
4397    }
4398
4399    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4400        live_kit
4401            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4402            .await
4403            .trace_err();
4404
4405        if delete_livekit_room {
4406            live_kit.delete_room(livekit_room).await.trace_err();
4407        }
4408    }
4409
4410    Ok(())
4411}
4412
4413async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4414    let left_channel_buffers = session
4415        .db()
4416        .await
4417        .leave_channel_buffers(session.connection_id)
4418        .await?;
4419
4420    for left_buffer in left_channel_buffers {
4421        channel_buffer_updated(
4422            session.connection_id,
4423            left_buffer.connections,
4424            &proto::UpdateChannelBufferCollaborators {
4425                channel_id: left_buffer.channel_id.to_proto(),
4426                collaborators: left_buffer.collaborators,
4427            },
4428            &session.peer,
4429        );
4430    }
4431
4432    Ok(())
4433}
4434
4435fn project_left(project: &db::LeftProject, session: &Session) {
4436    for connection_id in &project.connection_ids {
4437        if project.should_unshare {
4438            session
4439                .peer
4440                .send(
4441                    *connection_id,
4442                    proto::UnshareProject {
4443                        project_id: project.id.to_proto(),
4444                    },
4445                )
4446                .trace_err();
4447        } else {
4448            session
4449                .peer
4450                .send(
4451                    *connection_id,
4452                    proto::RemoveProjectCollaborator {
4453                        project_id: project.id.to_proto(),
4454                        peer_id: Some(session.connection_id.into()),
4455                    },
4456                )
4457                .trace_err();
4458        }
4459    }
4460}
4461
4462pub trait ResultExt {
4463    type Ok;
4464
4465    fn trace_err(self) -> Option<Self::Ok>;
4466}
4467
4468impl<T, E> ResultExt for Result<T, E>
4469where
4470    E: std::fmt::Debug,
4471{
4472    type Ok = T;
4473
4474    #[track_caller]
4475    fn trace_err(self) -> Option<T> {
4476        match self {
4477            Ok(value) => Some(value),
4478            Err(error) => {
4479                tracing::error!("{:?}", error);
4480                None
4481            }
4482        }
4483    }
4484}