rpc.rs

   1mod connection_pool;
   2
   3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
   4use crate::llm::LlmTokenClaims;
   5use crate::{
   6    auth,
   7    db::{
   8        self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
   9        CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
  10        NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
  11        RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
  12    },
  13    executor::Executor,
  14    AppState, Config, Error, RateLimit, Result,
  15};
  16use anyhow::{anyhow, bail, Context as _};
  17use async_tungstenite::tungstenite::{
  18    protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
  19};
  20use axum::{
  21    body::Body,
  22    extract::{
  23        ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
  24        ConnectInfo, WebSocketUpgrade,
  25    },
  26    headers::{Header, HeaderName},
  27    http::StatusCode,
  28    middleware,
  29    response::IntoResponse,
  30    routing::get,
  31    Extension, Router, TypedHeader,
  32};
  33use chrono::Utc;
  34use collections::{HashMap, HashSet};
  35pub use connection_pool::{ConnectionPool, ZedVersion};
  36use core::fmt::{self, Debug, Formatter};
  37use http_client::HttpClient;
  38use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
  39use reqwest_client::ReqwestClient;
  40use sha2::Digest;
  41use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
  42
  43use futures::{
  44    channel::oneshot, future::BoxFuture, stream::FuturesUnordered, FutureExt, SinkExt, StreamExt,
  45    TryStreamExt,
  46};
  47use prometheus::{register_int_gauge, IntGauge};
  48use rpc::{
  49    proto::{
  50        self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
  51        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
  52    },
  53    Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
  54};
  55use semantic_version::SemanticVersion;
  56use serde::{Serialize, Serializer};
  57use std::{
  58    any::TypeId,
  59    future::Future,
  60    marker::PhantomData,
  61    mem,
  62    net::SocketAddr,
  63    ops::{Deref, DerefMut},
  64    rc::Rc,
  65    sync::{
  66        atomic::{AtomicBool, Ordering::SeqCst},
  67        Arc, OnceLock,
  68    },
  69    time::{Duration, Instant},
  70};
  71use time::OffsetDateTime;
  72use tokio::sync::{watch, MutexGuard, Semaphore};
  73use tower::ServiceBuilder;
  74use tracing::{
  75    field::{self},
  76    info_span, instrument, Instrument,
  77};
  78
  79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
  80
  81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
  82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
  83
  84const MESSAGE_COUNT_PER_PAGE: usize = 100;
  85const MAX_MESSAGE_LEN: usize = 1024;
  86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
  87
  88type MessageHandler =
  89    Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
  90
  91struct Response<R> {
  92    peer: Arc<Peer>,
  93    receipt: Receipt<R>,
  94    responded: Arc<AtomicBool>,
  95}
  96
  97impl<R: RequestMessage> Response<R> {
  98    fn send(self, payload: R::Response) -> Result<()> {
  99        self.responded.store(true, SeqCst);
 100        self.peer.respond(self.receipt, payload)?;
 101        Ok(())
 102    }
 103}
 104
 105#[derive(Clone, Debug)]
 106pub enum Principal {
 107    User(User),
 108    Impersonated { user: User, admin: User },
 109}
 110
 111impl Principal {
 112    fn update_span(&self, span: &tracing::Span) {
 113        match &self {
 114            Principal::User(user) => {
 115                span.record("user_id", user.id.0);
 116                span.record("login", &user.github_login);
 117            }
 118            Principal::Impersonated { user, admin } => {
 119                span.record("user_id", user.id.0);
 120                span.record("login", &user.github_login);
 121                span.record("impersonator", &admin.github_login);
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Clone)]
 128struct Session {
 129    principal: Principal,
 130    connection_id: ConnectionId,
 131    db: Arc<tokio::sync::Mutex<DbHandle>>,
 132    peer: Arc<Peer>,
 133    connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 134    app_state: Arc<AppState>,
 135    supermaven_client: Option<Arc<SupermavenAdminApi>>,
 136    http_client: Arc<dyn HttpClient>,
 137    /// The GeoIP country code for the user.
 138    #[allow(unused)]
 139    geoip_country_code: Option<String>,
 140    system_id: Option<String>,
 141    _executor: Executor,
 142}
 143
 144impl Session {
 145    async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
 146        #[cfg(test)]
 147        tokio::task::yield_now().await;
 148        let guard = self.db.lock().await;
 149        #[cfg(test)]
 150        tokio::task::yield_now().await;
 151        guard
 152    }
 153
 154    async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
 155        #[cfg(test)]
 156        tokio::task::yield_now().await;
 157        let guard = self.connection_pool.lock();
 158        ConnectionPoolGuard {
 159            guard,
 160            _not_send: PhantomData,
 161        }
 162    }
 163
 164    fn is_staff(&self) -> bool {
 165        match &self.principal {
 166            Principal::User(user) => user.admin,
 167            Principal::Impersonated { .. } => true,
 168        }
 169    }
 170
 171    pub async fn has_llm_subscription(
 172        &self,
 173        db: &MutexGuard<'_, DbHandle>,
 174    ) -> anyhow::Result<bool> {
 175        if self.is_staff() {
 176            return Ok(true);
 177        }
 178
 179        let user_id = self.user_id();
 180
 181        Ok(db.has_active_billing_subscription(user_id).await?)
 182    }
 183
 184    pub async fn current_plan(
 185        &self,
 186        _db: &MutexGuard<'_, DbHandle>,
 187    ) -> anyhow::Result<proto::Plan> {
 188        if self.is_staff() {
 189            Ok(proto::Plan::ZedPro)
 190        } else {
 191            Ok(proto::Plan::Free)
 192        }
 193    }
 194
 195    fn user_id(&self) -> UserId {
 196        match &self.principal {
 197            Principal::User(user) => user.id,
 198            Principal::Impersonated { user, .. } => user.id,
 199        }
 200    }
 201
 202    pub fn email(&self) -> Option<String> {
 203        match &self.principal {
 204            Principal::User(user) => user.email_address.clone(),
 205            Principal::Impersonated { user, .. } => user.email_address.clone(),
 206        }
 207    }
 208}
 209
 210impl Debug for Session {
 211    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
 212        let mut result = f.debug_struct("Session");
 213        match &self.principal {
 214            Principal::User(user) => {
 215                result.field("user", &user.github_login);
 216            }
 217            Principal::Impersonated { user, admin } => {
 218                result.field("user", &user.github_login);
 219                result.field("impersonator", &admin.github_login);
 220            }
 221        }
 222        result.field("connection_id", &self.connection_id).finish()
 223    }
 224}
 225
 226struct DbHandle(Arc<Database>);
 227
 228impl Deref for DbHandle {
 229    type Target = Database;
 230
 231    fn deref(&self) -> &Self::Target {
 232        self.0.as_ref()
 233    }
 234}
 235
 236pub struct Server {
 237    id: parking_lot::Mutex<ServerId>,
 238    peer: Arc<Peer>,
 239    pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
 240    app_state: Arc<AppState>,
 241    handlers: HashMap<TypeId, MessageHandler>,
 242    teardown: watch::Sender<bool>,
 243}
 244
 245pub(crate) struct ConnectionPoolGuard<'a> {
 246    guard: parking_lot::MutexGuard<'a, ConnectionPool>,
 247    _not_send: PhantomData<Rc<()>>,
 248}
 249
 250#[derive(Serialize)]
 251pub struct ServerSnapshot<'a> {
 252    peer: &'a Peer,
 253    #[serde(serialize_with = "serialize_deref")]
 254    connection_pool: ConnectionPoolGuard<'a>,
 255}
 256
 257pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
 258where
 259    S: Serializer,
 260    T: Deref<Target = U>,
 261    U: Serialize,
 262{
 263    Serialize::serialize(value.deref(), serializer)
 264}
 265
 266impl Server {
 267    pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
 268        let mut server = Self {
 269            id: parking_lot::Mutex::new(id),
 270            peer: Peer::new(id.0 as u32),
 271            app_state: app_state.clone(),
 272            connection_pool: Default::default(),
 273            handlers: Default::default(),
 274            teardown: watch::channel(false).0,
 275        };
 276
 277        server
 278            .add_request_handler(ping)
 279            .add_request_handler(create_room)
 280            .add_request_handler(join_room)
 281            .add_request_handler(rejoin_room)
 282            .add_request_handler(leave_room)
 283            .add_request_handler(set_room_participant_role)
 284            .add_request_handler(call)
 285            .add_request_handler(cancel_call)
 286            .add_message_handler(decline_call)
 287            .add_request_handler(update_participant_location)
 288            .add_request_handler(share_project)
 289            .add_message_handler(unshare_project)
 290            .add_request_handler(join_project)
 291            .add_message_handler(leave_project)
 292            .add_request_handler(update_project)
 293            .add_request_handler(update_worktree)
 294            .add_message_handler(start_language_server)
 295            .add_message_handler(update_language_server)
 296            .add_message_handler(update_diagnostic_summary)
 297            .add_message_handler(update_worktree_settings)
 298            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
 299            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
 300            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
 301            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
 302            .add_request_handler(forward_find_search_candidates_request)
 303            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
 304            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
 305            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
 306            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
 307            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
 308            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
 309            .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
 310            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
 311            .add_request_handler(forward_read_only_project_request::<proto::GitBranches>)
 312            .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
 313            .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
 314            .add_request_handler(
 315                forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
 316            )
 317            .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
 318            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
 319            .add_request_handler(
 320                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
 321            )
 322            .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
 323            .add_request_handler(
 324                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
 325            )
 326            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
 327            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
 328            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
 329            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
 330            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
 331            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
 332            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
 333            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
 334            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
 335            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
 336            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
 337            .add_request_handler(
 338                forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
 339            )
 340            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
 341            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
 342            .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
 343            .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
 344            .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
 345            .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
 346            .add_message_handler(create_buffer_for_peer)
 347            .add_request_handler(update_buffer)
 348            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
 349            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
 350            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
 351            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
 352            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
 353            .add_request_handler(get_users)
 354            .add_request_handler(fuzzy_search_users)
 355            .add_request_handler(request_contact)
 356            .add_request_handler(remove_contact)
 357            .add_request_handler(respond_to_contact_request)
 358            .add_message_handler(subscribe_to_channels)
 359            .add_request_handler(create_channel)
 360            .add_request_handler(delete_channel)
 361            .add_request_handler(invite_channel_member)
 362            .add_request_handler(remove_channel_member)
 363            .add_request_handler(set_channel_member_role)
 364            .add_request_handler(set_channel_visibility)
 365            .add_request_handler(rename_channel)
 366            .add_request_handler(join_channel_buffer)
 367            .add_request_handler(leave_channel_buffer)
 368            .add_message_handler(update_channel_buffer)
 369            .add_request_handler(rejoin_channel_buffers)
 370            .add_request_handler(get_channel_members)
 371            .add_request_handler(respond_to_channel_invite)
 372            .add_request_handler(join_channel)
 373            .add_request_handler(join_channel_chat)
 374            .add_message_handler(leave_channel_chat)
 375            .add_request_handler(send_channel_message)
 376            .add_request_handler(remove_channel_message)
 377            .add_request_handler(update_channel_message)
 378            .add_request_handler(get_channel_messages)
 379            .add_request_handler(get_channel_messages_by_id)
 380            .add_request_handler(get_notifications)
 381            .add_request_handler(mark_notification_as_read)
 382            .add_request_handler(move_channel)
 383            .add_request_handler(follow)
 384            .add_message_handler(unfollow)
 385            .add_message_handler(update_followers)
 386            .add_request_handler(get_private_user_info)
 387            .add_request_handler(get_llm_api_token)
 388            .add_request_handler(accept_terms_of_service)
 389            .add_message_handler(acknowledge_channel_message)
 390            .add_message_handler(acknowledge_buffer_version)
 391            .add_request_handler(get_supermaven_api_key)
 392            .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
 393            .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
 394            .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
 395            .add_request_handler(forward_mutating_project_request::<proto::Stage>)
 396            .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
 397            .add_request_handler(forward_mutating_project_request::<proto::Commit>)
 398            .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
 399            .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
 400            .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
 401            .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
 402            .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
 403            .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
 404            .add_message_handler(update_context)
 405            .add_request_handler({
 406                let app_state = app_state.clone();
 407                move |request, response, session| {
 408                    let app_state = app_state.clone();
 409                    async move {
 410                        count_language_model_tokens(request, response, session, &app_state.config)
 411                            .await
 412                    }
 413                }
 414            })
 415            .add_request_handler(get_cached_embeddings)
 416            .add_request_handler({
 417                let app_state = app_state.clone();
 418                move |request, response, session| {
 419                    compute_embeddings(
 420                        request,
 421                        response,
 422                        session,
 423                        app_state.config.openai_api_key.clone(),
 424                    )
 425                }
 426            });
 427
 428        Arc::new(server)
 429    }
 430
 431    pub async fn start(&self) -> Result<()> {
 432        let server_id = *self.id.lock();
 433        let app_state = self.app_state.clone();
 434        let peer = self.peer.clone();
 435        let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
 436        let pool = self.connection_pool.clone();
 437        let livekit_client = self.app_state.livekit_client.clone();
 438
 439        let span = info_span!("start server");
 440        self.app_state.executor.spawn_detached(
 441            async move {
 442                tracing::info!("waiting for cleanup timeout");
 443                timeout.await;
 444                tracing::info!("cleanup timeout expired, retrieving stale rooms");
 445                if let Some((room_ids, channel_ids)) = app_state
 446                    .db
 447                    .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
 448                    .await
 449                    .trace_err()
 450                {
 451                    tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
 452                    tracing::info!(
 453                        stale_channel_buffer_count = channel_ids.len(),
 454                        "retrieved stale channel buffers"
 455                    );
 456
 457                    for channel_id in channel_ids {
 458                        if let Some(refreshed_channel_buffer) = app_state
 459                            .db
 460                            .clear_stale_channel_buffer_collaborators(channel_id, server_id)
 461                            .await
 462                            .trace_err()
 463                        {
 464                            for connection_id in refreshed_channel_buffer.connection_ids {
 465                                peer.send(
 466                                    connection_id,
 467                                    proto::UpdateChannelBufferCollaborators {
 468                                        channel_id: channel_id.to_proto(),
 469                                        collaborators: refreshed_channel_buffer
 470                                            .collaborators
 471                                            .clone(),
 472                                    },
 473                                )
 474                                .trace_err();
 475                            }
 476                        }
 477                    }
 478
 479                    for room_id in room_ids {
 480                        let mut contacts_to_update = HashSet::default();
 481                        let mut canceled_calls_to_user_ids = Vec::new();
 482                        let mut livekit_room = String::new();
 483                        let mut delete_livekit_room = false;
 484
 485                        if let Some(mut refreshed_room) = app_state
 486                            .db
 487                            .clear_stale_room_participants(room_id, server_id)
 488                            .await
 489                            .trace_err()
 490                        {
 491                            tracing::info!(
 492                                room_id = room_id.0,
 493                                new_participant_count = refreshed_room.room.participants.len(),
 494                                "refreshed room"
 495                            );
 496                            room_updated(&refreshed_room.room, &peer);
 497                            if let Some(channel) = refreshed_room.channel.as_ref() {
 498                                channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
 499                            }
 500                            contacts_to_update
 501                                .extend(refreshed_room.stale_participant_user_ids.iter().copied());
 502                            contacts_to_update
 503                                .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
 504                            canceled_calls_to_user_ids =
 505                                mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
 506                            livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
 507                            delete_livekit_room = refreshed_room.room.participants.is_empty();
 508                        }
 509
 510                        {
 511                            let pool = pool.lock();
 512                            for canceled_user_id in canceled_calls_to_user_ids {
 513                                for connection_id in pool.user_connection_ids(canceled_user_id) {
 514                                    peer.send(
 515                                        connection_id,
 516                                        proto::CallCanceled {
 517                                            room_id: room_id.to_proto(),
 518                                        },
 519                                    )
 520                                    .trace_err();
 521                                }
 522                            }
 523                        }
 524
 525                        for user_id in contacts_to_update {
 526                            let busy = app_state.db.is_user_busy(user_id).await.trace_err();
 527                            let contacts = app_state.db.get_contacts(user_id).await.trace_err();
 528                            if let Some((busy, contacts)) = busy.zip(contacts) {
 529                                let pool = pool.lock();
 530                                let updated_contact = contact_for_user(user_id, busy, &pool);
 531                                for contact in contacts {
 532                                    if let db::Contact::Accepted {
 533                                        user_id: contact_user_id,
 534                                        ..
 535                                    } = contact
 536                                    {
 537                                        for contact_conn_id in
 538                                            pool.user_connection_ids(contact_user_id)
 539                                        {
 540                                            peer.send(
 541                                                contact_conn_id,
 542                                                proto::UpdateContacts {
 543                                                    contacts: vec![updated_contact.clone()],
 544                                                    remove_contacts: Default::default(),
 545                                                    incoming_requests: Default::default(),
 546                                                    remove_incoming_requests: Default::default(),
 547                                                    outgoing_requests: Default::default(),
 548                                                    remove_outgoing_requests: Default::default(),
 549                                                },
 550                                            )
 551                                            .trace_err();
 552                                        }
 553                                    }
 554                                }
 555                            }
 556                        }
 557
 558                        if let Some(live_kit) = livekit_client.as_ref() {
 559                            if delete_livekit_room {
 560                                live_kit.delete_room(livekit_room).await.trace_err();
 561                            }
 562                        }
 563                    }
 564                }
 565
 566                app_state
 567                    .db
 568                    .delete_stale_servers(&app_state.config.zed_environment, server_id)
 569                    .await
 570                    .trace_err();
 571            }
 572            .instrument(span),
 573        );
 574        Ok(())
 575    }
 576
 577    pub fn teardown(&self) {
 578        self.peer.teardown();
 579        self.connection_pool.lock().reset();
 580        let _ = self.teardown.send(true);
 581    }
 582
 583    #[cfg(test)]
 584    pub fn reset(&self, id: ServerId) {
 585        self.teardown();
 586        *self.id.lock() = id;
 587        self.peer.reset(id.0 as u32);
 588        let _ = self.teardown.send(false);
 589    }
 590
 591    #[cfg(test)]
 592    pub fn id(&self) -> ServerId {
 593        *self.id.lock()
 594    }
 595
 596    fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 597    where
 598        F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
 599        Fut: 'static + Send + Future<Output = Result<()>>,
 600        M: EnvelopedMessage,
 601    {
 602        let prev_handler = self.handlers.insert(
 603            TypeId::of::<M>(),
 604            Box::new(move |envelope, session| {
 605                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
 606                let received_at = envelope.received_at;
 607                tracing::info!("message received");
 608                let start_time = Instant::now();
 609                let future = (handler)(*envelope, session);
 610                async move {
 611                    let result = future.await;
 612                    let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
 613                    let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
 614                    let queue_duration_ms = total_duration_ms - processing_duration_ms;
 615                    let payload_type = M::NAME;
 616
 617                    match result {
 618                        Err(error) => {
 619                            tracing::error!(
 620                                ?error,
 621                                total_duration_ms,
 622                                processing_duration_ms,
 623                                queue_duration_ms,
 624                                payload_type,
 625                                "error handling message"
 626                            )
 627                        }
 628                        Ok(()) => tracing::info!(
 629                            total_duration_ms,
 630                            processing_duration_ms,
 631                            queue_duration_ms,
 632                            "finished handling message"
 633                        ),
 634                    }
 635                }
 636                .boxed()
 637            }),
 638        );
 639        if prev_handler.is_some() {
 640            panic!("registered a handler for the same message twice");
 641        }
 642        self
 643    }
 644
 645    fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 646    where
 647        F: 'static + Send + Sync + Fn(M, Session) -> Fut,
 648        Fut: 'static + Send + Future<Output = Result<()>>,
 649        M: EnvelopedMessage,
 650    {
 651        self.add_handler(move |envelope, session| handler(envelope.payload, session));
 652        self
 653    }
 654
 655    fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
 656    where
 657        F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
 658        Fut: Send + Future<Output = Result<()>>,
 659        M: RequestMessage,
 660    {
 661        let handler = Arc::new(handler);
 662        self.add_handler(move |envelope, session| {
 663            let receipt = envelope.receipt();
 664            let handler = handler.clone();
 665            async move {
 666                let peer = session.peer.clone();
 667                let responded = Arc::new(AtomicBool::default());
 668                let response = Response {
 669                    peer: peer.clone(),
 670                    responded: responded.clone(),
 671                    receipt,
 672                };
 673                match (handler)(envelope.payload, response, session).await {
 674                    Ok(()) => {
 675                        if responded.load(std::sync::atomic::Ordering::SeqCst) {
 676                            Ok(())
 677                        } else {
 678                            Err(anyhow!("handler did not send a response"))?
 679                        }
 680                    }
 681                    Err(error) => {
 682                        let proto_err = match &error {
 683                            Error::Internal(err) => err.to_proto(),
 684                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
 685                        };
 686                        peer.respond_with_error(receipt, proto_err)?;
 687                        Err(error)
 688                    }
 689                }
 690            }
 691        })
 692    }
 693
 694    #[allow(clippy::too_many_arguments)]
 695    pub fn handle_connection(
 696        self: &Arc<Self>,
 697        connection: Connection,
 698        address: String,
 699        principal: Principal,
 700        zed_version: ZedVersion,
 701        geoip_country_code: Option<String>,
 702        system_id: Option<String>,
 703        send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 704        executor: Executor,
 705    ) -> impl Future<Output = ()> {
 706        let this = self.clone();
 707        let span = info_span!("handle connection", %address,
 708            connection_id=field::Empty,
 709            user_id=field::Empty,
 710            login=field::Empty,
 711            impersonator=field::Empty,
 712            geoip_country_code=field::Empty
 713        );
 714        principal.update_span(&span);
 715        if let Some(country_code) = geoip_country_code.as_ref() {
 716            span.record("geoip_country_code", country_code);
 717        }
 718
 719        let mut teardown = self.teardown.subscribe();
 720        async move {
 721            if *teardown.borrow() {
 722                tracing::error!("server is tearing down");
 723                return
 724            }
 725            let (connection_id, handle_io, mut incoming_rx) = this
 726                .peer
 727                .add_connection(connection, {
 728                    let executor = executor.clone();
 729                    move |duration| executor.sleep(duration)
 730                });
 731            tracing::Span::current().record("connection_id", format!("{}", connection_id));
 732
 733            tracing::info!("connection opened");
 734
 735            let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
 736            let http_client = match ReqwestClient::user_agent(&user_agent) {
 737                Ok(http_client) => Arc::new(http_client),
 738                Err(error) => {
 739                    tracing::error!(?error, "failed to create HTTP client");
 740                    return;
 741                }
 742            };
 743
 744            let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
 745                    supermaven_admin_api_key.to_string(),
 746                    http_client.clone(),
 747                )));
 748
 749            let session = Session {
 750                principal: principal.clone(),
 751                connection_id,
 752                db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
 753                peer: this.peer.clone(),
 754                connection_pool: this.connection_pool.clone(),
 755                app_state: this.app_state.clone(),
 756                http_client,
 757                geoip_country_code,
 758                system_id,
 759                _executor: executor.clone(),
 760                supermaven_client,
 761            };
 762
 763            if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
 764                tracing::error!(?error, "failed to send initial client update");
 765                return;
 766            }
 767
 768            let handle_io = handle_io.fuse();
 769            futures::pin_mut!(handle_io);
 770
 771            // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
 772            // This prevents deadlocks when e.g., client A performs a request to client B and
 773            // client B performs a request to client A. If both clients stop processing further
 774            // messages until their respective request completes, they won't have a chance to
 775            // respond to the other client's request and cause a deadlock.
 776            //
 777            // This arrangement ensures we will attempt to process earlier messages first, but fall
 778            // back to processing messages arrived later in the spirit of making progress.
 779            let mut foreground_message_handlers = FuturesUnordered::new();
 780            let concurrent_handlers = Arc::new(Semaphore::new(256));
 781            loop {
 782                let next_message = async {
 783                    let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
 784                    let message = incoming_rx.next().await;
 785                    (permit, message)
 786                }.fuse();
 787                futures::pin_mut!(next_message);
 788                futures::select_biased! {
 789                    _ = teardown.changed().fuse() => return,
 790                    result = handle_io => {
 791                        if let Err(error) = result {
 792                            tracing::error!(?error, "error handling I/O");
 793                        }
 794                        break;
 795                    }
 796                    _ = foreground_message_handlers.next() => {}
 797                    next_message = next_message => {
 798                        let (permit, message) = next_message;
 799                        if let Some(message) = message {
 800                            let type_name = message.payload_type_name();
 801                            // note: we copy all the fields from the parent span so we can query them in the logs.
 802                            // (https://github.com/tokio-rs/tracing/issues/2670).
 803                            let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
 804                                user_id=field::Empty,
 805                                login=field::Empty,
 806                                impersonator=field::Empty,
 807                            );
 808                            principal.update_span(&span);
 809                            let span_enter = span.enter();
 810                            if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
 811                                let is_background = message.is_background();
 812                                let handle_message = (handler)(message, session.clone());
 813                                drop(span_enter);
 814
 815                                let handle_message = async move {
 816                                    handle_message.await;
 817                                    drop(permit);
 818                                }.instrument(span);
 819                                if is_background {
 820                                    executor.spawn_detached(handle_message);
 821                                } else {
 822                                    foreground_message_handlers.push(handle_message);
 823                                }
 824                            } else {
 825                                tracing::error!("no message handler");
 826                            }
 827                        } else {
 828                            tracing::info!("connection closed");
 829                            break;
 830                        }
 831                    }
 832                }
 833            }
 834
 835            drop(foreground_message_handlers);
 836            tracing::info!("signing out");
 837            if let Err(error) = connection_lost(session, teardown, executor).await {
 838                tracing::error!(?error, "error signing out");
 839            }
 840
 841        }.instrument(span)
 842    }
 843
 844    async fn send_initial_client_update(
 845        &self,
 846        connection_id: ConnectionId,
 847        principal: &Principal,
 848        zed_version: ZedVersion,
 849        mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
 850        session: &Session,
 851    ) -> Result<()> {
 852        self.peer.send(
 853            connection_id,
 854            proto::Hello {
 855                peer_id: Some(connection_id.into()),
 856            },
 857        )?;
 858        tracing::info!("sent hello message");
 859        if let Some(send_connection_id) = send_connection_id.take() {
 860            let _ = send_connection_id.send(connection_id);
 861        }
 862
 863        match principal {
 864            Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
 865                if !user.connected_once {
 866                    self.peer.send(connection_id, proto::ShowContacts {})?;
 867                    self.app_state
 868                        .db
 869                        .set_user_connected_once(user.id, true)
 870                        .await?;
 871                }
 872
 873                update_user_plan(user.id, session).await?;
 874
 875                let contacts = self.app_state.db.get_contacts(user.id).await?;
 876
 877                {
 878                    let mut pool = self.connection_pool.lock();
 879                    pool.add_connection(connection_id, user.id, user.admin, zed_version);
 880                    self.peer.send(
 881                        connection_id,
 882                        build_initial_contacts_update(contacts, &pool),
 883                    )?;
 884                }
 885
 886                if should_auto_subscribe_to_channels(zed_version) {
 887                    subscribe_user_to_channels(user.id, session).await?;
 888                }
 889
 890                if let Some(incoming_call) =
 891                    self.app_state.db.incoming_call_for_user(user.id).await?
 892                {
 893                    self.peer.send(connection_id, incoming_call)?;
 894                }
 895
 896                update_user_contacts(user.id, session).await?;
 897            }
 898        }
 899
 900        Ok(())
 901    }
 902
 903    pub async fn invite_code_redeemed(
 904        self: &Arc<Self>,
 905        inviter_id: UserId,
 906        invitee_id: UserId,
 907    ) -> Result<()> {
 908        if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
 909            if let Some(code) = &user.invite_code {
 910                let pool = self.connection_pool.lock();
 911                let invitee_contact = contact_for_user(invitee_id, false, &pool);
 912                for connection_id in pool.user_connection_ids(inviter_id) {
 913                    self.peer.send(
 914                        connection_id,
 915                        proto::UpdateContacts {
 916                            contacts: vec![invitee_contact.clone()],
 917                            ..Default::default()
 918                        },
 919                    )?;
 920                    self.peer.send(
 921                        connection_id,
 922                        proto::UpdateInviteInfo {
 923                            url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
 924                            count: user.invite_count as u32,
 925                        },
 926                    )?;
 927                }
 928            }
 929        }
 930        Ok(())
 931    }
 932
 933    pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
 934        if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
 935            if let Some(invite_code) = &user.invite_code {
 936                let pool = self.connection_pool.lock();
 937                for connection_id in pool.user_connection_ids(user_id) {
 938                    self.peer.send(
 939                        connection_id,
 940                        proto::UpdateInviteInfo {
 941                            url: format!(
 942                                "{}{}",
 943                                self.app_state.config.invite_link_prefix, invite_code
 944                            ),
 945                            count: user.invite_count as u32,
 946                        },
 947                    )?;
 948                }
 949            }
 950        }
 951        Ok(())
 952    }
 953
 954    pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
 955        let pool = self.connection_pool.lock();
 956        for connection_id in pool.user_connection_ids(user_id) {
 957            self.peer
 958                .send(connection_id, proto::RefreshLlmToken {})
 959                .trace_err();
 960        }
 961    }
 962
 963    pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
 964        ServerSnapshot {
 965            connection_pool: ConnectionPoolGuard {
 966                guard: self.connection_pool.lock(),
 967                _not_send: PhantomData,
 968            },
 969            peer: &self.peer,
 970        }
 971    }
 972}
 973
 974impl<'a> Deref for ConnectionPoolGuard<'a> {
 975    type Target = ConnectionPool;
 976
 977    fn deref(&self) -> &Self::Target {
 978        &self.guard
 979    }
 980}
 981
 982impl<'a> DerefMut for ConnectionPoolGuard<'a> {
 983    fn deref_mut(&mut self) -> &mut Self::Target {
 984        &mut self.guard
 985    }
 986}
 987
 988impl<'a> Drop for ConnectionPoolGuard<'a> {
 989    fn drop(&mut self) {
 990        #[cfg(test)]
 991        self.check_invariants();
 992    }
 993}
 994
 995fn broadcast<F>(
 996    sender_id: Option<ConnectionId>,
 997    receiver_ids: impl IntoIterator<Item = ConnectionId>,
 998    mut f: F,
 999) where
1000    F: FnMut(ConnectionId) -> anyhow::Result<()>,
1001{
1002    for receiver_id in receiver_ids {
1003        if Some(receiver_id) != sender_id {
1004            if let Err(error) = f(receiver_id) {
1005                tracing::error!("failed to send to {:?} {}", receiver_id, error);
1006            }
1007        }
1008    }
1009}
1010
1011pub struct ProtocolVersion(u32);
1012
1013impl Header for ProtocolVersion {
1014    fn name() -> &'static HeaderName {
1015        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1016        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1017    }
1018
1019    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1020    where
1021        Self: Sized,
1022        I: Iterator<Item = &'i axum::http::HeaderValue>,
1023    {
1024        let version = values
1025            .next()
1026            .ok_or_else(axum::headers::Error::invalid)?
1027            .to_str()
1028            .map_err(|_| axum::headers::Error::invalid())?
1029            .parse()
1030            .map_err(|_| axum::headers::Error::invalid())?;
1031        Ok(Self(version))
1032    }
1033
1034    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1035        values.extend([self.0.to_string().parse().unwrap()]);
1036    }
1037}
1038
1039pub struct AppVersionHeader(SemanticVersion);
1040impl Header for AppVersionHeader {
1041    fn name() -> &'static HeaderName {
1042        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1043        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1044    }
1045
1046    fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1047    where
1048        Self: Sized,
1049        I: Iterator<Item = &'i axum::http::HeaderValue>,
1050    {
1051        let version = values
1052            .next()
1053            .ok_or_else(axum::headers::Error::invalid)?
1054            .to_str()
1055            .map_err(|_| axum::headers::Error::invalid())?
1056            .parse()
1057            .map_err(|_| axum::headers::Error::invalid())?;
1058        Ok(Self(version))
1059    }
1060
1061    fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1062        values.extend([self.0.to_string().parse().unwrap()]);
1063    }
1064}
1065
1066pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1067    Router::new()
1068        .route("/rpc", get(handle_websocket_request))
1069        .layer(
1070            ServiceBuilder::new()
1071                .layer(Extension(server.app_state.clone()))
1072                .layer(middleware::from_fn(auth::validate_header)),
1073        )
1074        .route("/metrics", get(handle_metrics))
1075        .layer(Extension(server))
1076}
1077
1078#[allow(clippy::too_many_arguments)]
1079pub async fn handle_websocket_request(
1080    TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1081    app_version_header: Option<TypedHeader<AppVersionHeader>>,
1082    ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1083    Extension(server): Extension<Arc<Server>>,
1084    Extension(principal): Extension<Principal>,
1085    country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1086    system_id_header: Option<TypedHeader<SystemIdHeader>>,
1087    ws: WebSocketUpgrade,
1088) -> axum::response::Response {
1089    if protocol_version != rpc::PROTOCOL_VERSION {
1090        return (
1091            StatusCode::UPGRADE_REQUIRED,
1092            "client must be upgraded".to_string(),
1093        )
1094            .into_response();
1095    }
1096
1097    let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1098        return (
1099            StatusCode::UPGRADE_REQUIRED,
1100            "no version header found".to_string(),
1101        )
1102            .into_response();
1103    };
1104
1105    if !version.can_collaborate() {
1106        return (
1107            StatusCode::UPGRADE_REQUIRED,
1108            "client must be upgraded".to_string(),
1109        )
1110            .into_response();
1111    }
1112
1113    let socket_address = socket_address.to_string();
1114    ws.on_upgrade(move |socket| {
1115        let socket = socket
1116            .map_ok(to_tungstenite_message)
1117            .err_into()
1118            .with(|message| async move { to_axum_message(message) });
1119        let connection = Connection::new(Box::pin(socket));
1120        async move {
1121            server
1122                .handle_connection(
1123                    connection,
1124                    socket_address,
1125                    principal,
1126                    version,
1127                    country_code_header.map(|header| header.to_string()),
1128                    system_id_header.map(|header| header.to_string()),
1129                    None,
1130                    Executor::Production,
1131                )
1132                .await;
1133        }
1134    })
1135}
1136
1137pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1138    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1139    let connections_metric = CONNECTIONS_METRIC
1140        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1141
1142    let connections = server
1143        .connection_pool
1144        .lock()
1145        .connections()
1146        .filter(|connection| !connection.admin)
1147        .count();
1148    connections_metric.set(connections as _);
1149
1150    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1151    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1152        register_int_gauge!(
1153            "shared_projects",
1154            "number of open projects with one or more guests"
1155        )
1156        .unwrap()
1157    });
1158
1159    let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1160    shared_projects_metric.set(shared_projects as _);
1161
1162    let encoder = prometheus::TextEncoder::new();
1163    let metric_families = prometheus::gather();
1164    let encoded_metrics = encoder
1165        .encode_to_string(&metric_families)
1166        .map_err(|err| anyhow!("{}", err))?;
1167    Ok(encoded_metrics)
1168}
1169
1170#[instrument(err, skip(executor))]
1171async fn connection_lost(
1172    session: Session,
1173    mut teardown: watch::Receiver<bool>,
1174    executor: Executor,
1175) -> Result<()> {
1176    session.peer.disconnect(session.connection_id);
1177    session
1178        .connection_pool()
1179        .await
1180        .remove_connection(session.connection_id)?;
1181
1182    session
1183        .db()
1184        .await
1185        .connection_lost(session.connection_id)
1186        .await
1187        .trace_err();
1188
1189    futures::select_biased! {
1190        _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1191
1192            log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1193            leave_room_for_session(&session, session.connection_id).await.trace_err();
1194            leave_channel_buffers_for_session(&session)
1195                .await
1196                .trace_err();
1197
1198            if !session
1199                .connection_pool()
1200                .await
1201                .is_user_online(session.user_id())
1202            {
1203                let db = session.db().await;
1204                if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1205                    room_updated(&room, &session.peer);
1206                }
1207            }
1208
1209            update_user_contacts(session.user_id(), &session).await?;
1210        },
1211        _ = teardown.changed().fuse() => {}
1212    }
1213
1214    Ok(())
1215}
1216
1217/// Acknowledges a ping from a client, used to keep the connection alive.
1218async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1219    response.send(proto::Ack {})?;
1220    Ok(())
1221}
1222
1223/// Creates a new room for calling (outside of channels)
1224async fn create_room(
1225    _request: proto::CreateRoom,
1226    response: Response<proto::CreateRoom>,
1227    session: Session,
1228) -> Result<()> {
1229    let livekit_room = nanoid::nanoid!(30);
1230
1231    let live_kit_connection_info = util::maybe!(async {
1232        let live_kit = session.app_state.livekit_client.as_ref();
1233        let live_kit = live_kit?;
1234        let user_id = session.user_id().to_string();
1235
1236        let token = live_kit
1237            .room_token(&livekit_room, &user_id.to_string())
1238            .trace_err()?;
1239
1240        Some(proto::LiveKitConnectionInfo {
1241            server_url: live_kit.url().into(),
1242            token,
1243            can_publish: true,
1244        })
1245    })
1246    .await;
1247
1248    let room = session
1249        .db()
1250        .await
1251        .create_room(session.user_id(), session.connection_id, &livekit_room)
1252        .await?;
1253
1254    response.send(proto::CreateRoomResponse {
1255        room: Some(room.clone()),
1256        live_kit_connection_info,
1257    })?;
1258
1259    update_user_contacts(session.user_id(), &session).await?;
1260    Ok(())
1261}
1262
1263/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1264async fn join_room(
1265    request: proto::JoinRoom,
1266    response: Response<proto::JoinRoom>,
1267    session: Session,
1268) -> Result<()> {
1269    let room_id = RoomId::from_proto(request.id);
1270
1271    let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1272
1273    if let Some(channel_id) = channel_id {
1274        return join_channel_internal(channel_id, Box::new(response), session).await;
1275    }
1276
1277    let joined_room = {
1278        let room = session
1279            .db()
1280            .await
1281            .join_room(room_id, session.user_id(), session.connection_id)
1282            .await?;
1283        room_updated(&room.room, &session.peer);
1284        room.into_inner()
1285    };
1286
1287    for connection_id in session
1288        .connection_pool()
1289        .await
1290        .user_connection_ids(session.user_id())
1291    {
1292        session
1293            .peer
1294            .send(
1295                connection_id,
1296                proto::CallCanceled {
1297                    room_id: room_id.to_proto(),
1298                },
1299            )
1300            .trace_err();
1301    }
1302
1303    let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1304    {
1305        live_kit
1306            .room_token(
1307                &joined_room.room.livekit_room,
1308                &session.user_id().to_string(),
1309            )
1310            .trace_err()
1311            .map(|token| proto::LiveKitConnectionInfo {
1312                server_url: live_kit.url().into(),
1313                token,
1314                can_publish: true,
1315            })
1316    } else {
1317        None
1318    };
1319
1320    response.send(proto::JoinRoomResponse {
1321        room: Some(joined_room.room),
1322        channel_id: None,
1323        live_kit_connection_info,
1324    })?;
1325
1326    update_user_contacts(session.user_id(), &session).await?;
1327    Ok(())
1328}
1329
1330/// Rejoin room is used to reconnect to a room after connection errors.
1331async fn rejoin_room(
1332    request: proto::RejoinRoom,
1333    response: Response<proto::RejoinRoom>,
1334    session: Session,
1335) -> Result<()> {
1336    let room;
1337    let channel;
1338    {
1339        let mut rejoined_room = session
1340            .db()
1341            .await
1342            .rejoin_room(request, session.user_id(), session.connection_id)
1343            .await?;
1344
1345        response.send(proto::RejoinRoomResponse {
1346            room: Some(rejoined_room.room.clone()),
1347            reshared_projects: rejoined_room
1348                .reshared_projects
1349                .iter()
1350                .map(|project| proto::ResharedProject {
1351                    id: project.id.to_proto(),
1352                    collaborators: project
1353                        .collaborators
1354                        .iter()
1355                        .map(|collaborator| collaborator.to_proto())
1356                        .collect(),
1357                })
1358                .collect(),
1359            rejoined_projects: rejoined_room
1360                .rejoined_projects
1361                .iter()
1362                .map(|rejoined_project| rejoined_project.to_proto())
1363                .collect(),
1364        })?;
1365        room_updated(&rejoined_room.room, &session.peer);
1366
1367        for project in &rejoined_room.reshared_projects {
1368            for collaborator in &project.collaborators {
1369                session
1370                    .peer
1371                    .send(
1372                        collaborator.connection_id,
1373                        proto::UpdateProjectCollaborator {
1374                            project_id: project.id.to_proto(),
1375                            old_peer_id: Some(project.old_connection_id.into()),
1376                            new_peer_id: Some(session.connection_id.into()),
1377                        },
1378                    )
1379                    .trace_err();
1380            }
1381
1382            broadcast(
1383                Some(session.connection_id),
1384                project
1385                    .collaborators
1386                    .iter()
1387                    .map(|collaborator| collaborator.connection_id),
1388                |connection_id| {
1389                    session.peer.forward_send(
1390                        session.connection_id,
1391                        connection_id,
1392                        proto::UpdateProject {
1393                            project_id: project.id.to_proto(),
1394                            worktrees: project.worktrees.clone(),
1395                        },
1396                    )
1397                },
1398            );
1399        }
1400
1401        notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1402
1403        let rejoined_room = rejoined_room.into_inner();
1404
1405        room = rejoined_room.room;
1406        channel = rejoined_room.channel;
1407    }
1408
1409    if let Some(channel) = channel {
1410        channel_updated(
1411            &channel,
1412            &room,
1413            &session.peer,
1414            &*session.connection_pool().await,
1415        );
1416    }
1417
1418    update_user_contacts(session.user_id(), &session).await?;
1419    Ok(())
1420}
1421
1422fn notify_rejoined_projects(
1423    rejoined_projects: &mut Vec<RejoinedProject>,
1424    session: &Session,
1425) -> Result<()> {
1426    for project in rejoined_projects.iter() {
1427        for collaborator in &project.collaborators {
1428            session
1429                .peer
1430                .send(
1431                    collaborator.connection_id,
1432                    proto::UpdateProjectCollaborator {
1433                        project_id: project.id.to_proto(),
1434                        old_peer_id: Some(project.old_connection_id.into()),
1435                        new_peer_id: Some(session.connection_id.into()),
1436                    },
1437                )
1438                .trace_err();
1439        }
1440    }
1441
1442    for project in rejoined_projects {
1443        for worktree in mem::take(&mut project.worktrees) {
1444            // Stream this worktree's entries.
1445            let message = proto::UpdateWorktree {
1446                project_id: project.id.to_proto(),
1447                worktree_id: worktree.id,
1448                abs_path: worktree.abs_path.clone(),
1449                root_name: worktree.root_name,
1450                updated_entries: worktree.updated_entries,
1451                removed_entries: worktree.removed_entries,
1452                scan_id: worktree.scan_id,
1453                is_last_update: worktree.completed_scan_id == worktree.scan_id,
1454                updated_repositories: worktree.updated_repositories,
1455                removed_repositories: worktree.removed_repositories,
1456            };
1457            for update in proto::split_worktree_update(message) {
1458                session.peer.send(session.connection_id, update.clone())?;
1459            }
1460
1461            // Stream this worktree's diagnostics.
1462            for summary in worktree.diagnostic_summaries {
1463                session.peer.send(
1464                    session.connection_id,
1465                    proto::UpdateDiagnosticSummary {
1466                        project_id: project.id.to_proto(),
1467                        worktree_id: worktree.id,
1468                        summary: Some(summary),
1469                    },
1470                )?;
1471            }
1472
1473            for settings_file in worktree.settings_files {
1474                session.peer.send(
1475                    session.connection_id,
1476                    proto::UpdateWorktreeSettings {
1477                        project_id: project.id.to_proto(),
1478                        worktree_id: worktree.id,
1479                        path: settings_file.path,
1480                        content: Some(settings_file.content),
1481                        kind: Some(settings_file.kind.to_proto().into()),
1482                    },
1483                )?;
1484            }
1485        }
1486
1487        for language_server in &project.language_servers {
1488            session.peer.send(
1489                session.connection_id,
1490                proto::UpdateLanguageServer {
1491                    project_id: project.id.to_proto(),
1492                    language_server_id: language_server.id,
1493                    variant: Some(
1494                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1495                            proto::LspDiskBasedDiagnosticsUpdated {},
1496                        ),
1497                    ),
1498                },
1499            )?;
1500        }
1501    }
1502    Ok(())
1503}
1504
1505/// leave room disconnects from the room.
1506async fn leave_room(
1507    _: proto::LeaveRoom,
1508    response: Response<proto::LeaveRoom>,
1509    session: Session,
1510) -> Result<()> {
1511    leave_room_for_session(&session, session.connection_id).await?;
1512    response.send(proto::Ack {})?;
1513    Ok(())
1514}
1515
1516/// Updates the permissions of someone else in the room.
1517async fn set_room_participant_role(
1518    request: proto::SetRoomParticipantRole,
1519    response: Response<proto::SetRoomParticipantRole>,
1520    session: Session,
1521) -> Result<()> {
1522    let user_id = UserId::from_proto(request.user_id);
1523    let role = ChannelRole::from(request.role());
1524
1525    let (livekit_room, can_publish) = {
1526        let room = session
1527            .db()
1528            .await
1529            .set_room_participant_role(
1530                session.user_id(),
1531                RoomId::from_proto(request.room_id),
1532                user_id,
1533                role,
1534            )
1535            .await?;
1536
1537        let livekit_room = room.livekit_room.clone();
1538        let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1539        room_updated(&room, &session.peer);
1540        (livekit_room, can_publish)
1541    };
1542
1543    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1544        live_kit
1545            .update_participant(
1546                livekit_room.clone(),
1547                request.user_id.to_string(),
1548                livekit_api::proto::ParticipantPermission {
1549                    can_subscribe: true,
1550                    can_publish,
1551                    can_publish_data: can_publish,
1552                    hidden: false,
1553                    recorder: false,
1554                },
1555            )
1556            .await
1557            .trace_err();
1558    }
1559
1560    response.send(proto::Ack {})?;
1561    Ok(())
1562}
1563
1564/// Call someone else into the current room
1565async fn call(
1566    request: proto::Call,
1567    response: Response<proto::Call>,
1568    session: Session,
1569) -> Result<()> {
1570    let room_id = RoomId::from_proto(request.room_id);
1571    let calling_user_id = session.user_id();
1572    let calling_connection_id = session.connection_id;
1573    let called_user_id = UserId::from_proto(request.called_user_id);
1574    let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1575    if !session
1576        .db()
1577        .await
1578        .has_contact(calling_user_id, called_user_id)
1579        .await?
1580    {
1581        return Err(anyhow!("cannot call a user who isn't a contact"))?;
1582    }
1583
1584    let incoming_call = {
1585        let (room, incoming_call) = &mut *session
1586            .db()
1587            .await
1588            .call(
1589                room_id,
1590                calling_user_id,
1591                calling_connection_id,
1592                called_user_id,
1593                initial_project_id,
1594            )
1595            .await?;
1596        room_updated(room, &session.peer);
1597        mem::take(incoming_call)
1598    };
1599    update_user_contacts(called_user_id, &session).await?;
1600
1601    let mut calls = session
1602        .connection_pool()
1603        .await
1604        .user_connection_ids(called_user_id)
1605        .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1606        .collect::<FuturesUnordered<_>>();
1607
1608    while let Some(call_response) = calls.next().await {
1609        match call_response.as_ref() {
1610            Ok(_) => {
1611                response.send(proto::Ack {})?;
1612                return Ok(());
1613            }
1614            Err(_) => {
1615                call_response.trace_err();
1616            }
1617        }
1618    }
1619
1620    {
1621        let room = session
1622            .db()
1623            .await
1624            .call_failed(room_id, called_user_id)
1625            .await?;
1626        room_updated(&room, &session.peer);
1627    }
1628    update_user_contacts(called_user_id, &session).await?;
1629
1630    Err(anyhow!("failed to ring user"))?
1631}
1632
1633/// Cancel an outgoing call.
1634async fn cancel_call(
1635    request: proto::CancelCall,
1636    response: Response<proto::CancelCall>,
1637    session: Session,
1638) -> Result<()> {
1639    let called_user_id = UserId::from_proto(request.called_user_id);
1640    let room_id = RoomId::from_proto(request.room_id);
1641    {
1642        let room = session
1643            .db()
1644            .await
1645            .cancel_call(room_id, session.connection_id, called_user_id)
1646            .await?;
1647        room_updated(&room, &session.peer);
1648    }
1649
1650    for connection_id in session
1651        .connection_pool()
1652        .await
1653        .user_connection_ids(called_user_id)
1654    {
1655        session
1656            .peer
1657            .send(
1658                connection_id,
1659                proto::CallCanceled {
1660                    room_id: room_id.to_proto(),
1661                },
1662            )
1663            .trace_err();
1664    }
1665    response.send(proto::Ack {})?;
1666
1667    update_user_contacts(called_user_id, &session).await?;
1668    Ok(())
1669}
1670
1671/// Decline an incoming call.
1672async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1673    let room_id = RoomId::from_proto(message.room_id);
1674    {
1675        let room = session
1676            .db()
1677            .await
1678            .decline_call(Some(room_id), session.user_id())
1679            .await?
1680            .ok_or_else(|| anyhow!("failed to decline call"))?;
1681        room_updated(&room, &session.peer);
1682    }
1683
1684    for connection_id in session
1685        .connection_pool()
1686        .await
1687        .user_connection_ids(session.user_id())
1688    {
1689        session
1690            .peer
1691            .send(
1692                connection_id,
1693                proto::CallCanceled {
1694                    room_id: room_id.to_proto(),
1695                },
1696            )
1697            .trace_err();
1698    }
1699    update_user_contacts(session.user_id(), &session).await?;
1700    Ok(())
1701}
1702
1703/// Updates other participants in the room with your current location.
1704async fn update_participant_location(
1705    request: proto::UpdateParticipantLocation,
1706    response: Response<proto::UpdateParticipantLocation>,
1707    session: Session,
1708) -> Result<()> {
1709    let room_id = RoomId::from_proto(request.room_id);
1710    let location = request
1711        .location
1712        .ok_or_else(|| anyhow!("invalid location"))?;
1713
1714    let db = session.db().await;
1715    let room = db
1716        .update_room_participant_location(room_id, session.connection_id, location)
1717        .await?;
1718
1719    room_updated(&room, &session.peer);
1720    response.send(proto::Ack {})?;
1721    Ok(())
1722}
1723
1724/// Share a project into the room.
1725async fn share_project(
1726    request: proto::ShareProject,
1727    response: Response<proto::ShareProject>,
1728    session: Session,
1729) -> Result<()> {
1730    let (project_id, room) = &*session
1731        .db()
1732        .await
1733        .share_project(
1734            RoomId::from_proto(request.room_id),
1735            session.connection_id,
1736            &request.worktrees,
1737            request.is_ssh_project,
1738        )
1739        .await?;
1740    response.send(proto::ShareProjectResponse {
1741        project_id: project_id.to_proto(),
1742    })?;
1743    room_updated(room, &session.peer);
1744
1745    Ok(())
1746}
1747
1748/// Unshare a project from the room.
1749async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1750    let project_id = ProjectId::from_proto(message.project_id);
1751    unshare_project_internal(project_id, session.connection_id, &session).await
1752}
1753
1754async fn unshare_project_internal(
1755    project_id: ProjectId,
1756    connection_id: ConnectionId,
1757    session: &Session,
1758) -> Result<()> {
1759    let delete = {
1760        let room_guard = session
1761            .db()
1762            .await
1763            .unshare_project(project_id, connection_id)
1764            .await?;
1765
1766        let (delete, room, guest_connection_ids) = &*room_guard;
1767
1768        let message = proto::UnshareProject {
1769            project_id: project_id.to_proto(),
1770        };
1771
1772        broadcast(
1773            Some(connection_id),
1774            guest_connection_ids.iter().copied(),
1775            |conn_id| session.peer.send(conn_id, message.clone()),
1776        );
1777        if let Some(room) = room {
1778            room_updated(room, &session.peer);
1779        }
1780
1781        *delete
1782    };
1783
1784    if delete {
1785        let db = session.db().await;
1786        db.delete_project(project_id).await?;
1787    }
1788
1789    Ok(())
1790}
1791
1792/// Join someone elses shared project.
1793async fn join_project(
1794    request: proto::JoinProject,
1795    response: Response<proto::JoinProject>,
1796    session: Session,
1797) -> Result<()> {
1798    let project_id = ProjectId::from_proto(request.project_id);
1799
1800    tracing::info!(%project_id, "join project");
1801
1802    let db = session.db().await;
1803    let (project, replica_id) = &mut *db
1804        .join_project(project_id, session.connection_id, session.user_id())
1805        .await?;
1806    drop(db);
1807    tracing::info!(%project_id, "join remote project");
1808    join_project_internal(response, session, project, replica_id)
1809}
1810
1811trait JoinProjectInternalResponse {
1812    fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1813}
1814impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1815    fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1816        Response::<proto::JoinProject>::send(self, result)
1817    }
1818}
1819
1820fn join_project_internal(
1821    response: impl JoinProjectInternalResponse,
1822    session: Session,
1823    project: &mut Project,
1824    replica_id: &ReplicaId,
1825) -> Result<()> {
1826    let collaborators = project
1827        .collaborators
1828        .iter()
1829        .filter(|collaborator| collaborator.connection_id != session.connection_id)
1830        .map(|collaborator| collaborator.to_proto())
1831        .collect::<Vec<_>>();
1832    let project_id = project.id;
1833    let guest_user_id = session.user_id();
1834
1835    let worktrees = project
1836        .worktrees
1837        .iter()
1838        .map(|(id, worktree)| proto::WorktreeMetadata {
1839            id: *id,
1840            root_name: worktree.root_name.clone(),
1841            visible: worktree.visible,
1842            abs_path: worktree.abs_path.clone(),
1843        })
1844        .collect::<Vec<_>>();
1845
1846    let add_project_collaborator = proto::AddProjectCollaborator {
1847        project_id: project_id.to_proto(),
1848        collaborator: Some(proto::Collaborator {
1849            peer_id: Some(session.connection_id.into()),
1850            replica_id: replica_id.0 as u32,
1851            user_id: guest_user_id.to_proto(),
1852            is_host: false,
1853        }),
1854    };
1855
1856    for collaborator in &collaborators {
1857        session
1858            .peer
1859            .send(
1860                collaborator.peer_id.unwrap().into(),
1861                add_project_collaborator.clone(),
1862            )
1863            .trace_err();
1864    }
1865
1866    // First, we send the metadata associated with each worktree.
1867    response.send(proto::JoinProjectResponse {
1868        project_id: project.id.0 as u64,
1869        worktrees: worktrees.clone(),
1870        replica_id: replica_id.0 as u32,
1871        collaborators: collaborators.clone(),
1872        language_servers: project.language_servers.clone(),
1873        role: project.role.into(),
1874    })?;
1875
1876    for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1877        // Stream this worktree's entries.
1878        let message = proto::UpdateWorktree {
1879            project_id: project_id.to_proto(),
1880            worktree_id,
1881            abs_path: worktree.abs_path.clone(),
1882            root_name: worktree.root_name,
1883            updated_entries: worktree.entries,
1884            removed_entries: Default::default(),
1885            scan_id: worktree.scan_id,
1886            is_last_update: worktree.scan_id == worktree.completed_scan_id,
1887            updated_repositories: worktree.repository_entries.into_values().collect(),
1888            removed_repositories: Default::default(),
1889        };
1890        for update in proto::split_worktree_update(message) {
1891            session.peer.send(session.connection_id, update.clone())?;
1892        }
1893
1894        // Stream this worktree's diagnostics.
1895        for summary in worktree.diagnostic_summaries {
1896            session.peer.send(
1897                session.connection_id,
1898                proto::UpdateDiagnosticSummary {
1899                    project_id: project_id.to_proto(),
1900                    worktree_id: worktree.id,
1901                    summary: Some(summary),
1902                },
1903            )?;
1904        }
1905
1906        for settings_file in worktree.settings_files {
1907            session.peer.send(
1908                session.connection_id,
1909                proto::UpdateWorktreeSettings {
1910                    project_id: project_id.to_proto(),
1911                    worktree_id: worktree.id,
1912                    path: settings_file.path,
1913                    content: Some(settings_file.content),
1914                    kind: Some(settings_file.kind.to_proto() as i32),
1915                },
1916            )?;
1917        }
1918    }
1919
1920    for language_server in &project.language_servers {
1921        session.peer.send(
1922            session.connection_id,
1923            proto::UpdateLanguageServer {
1924                project_id: project_id.to_proto(),
1925                language_server_id: language_server.id,
1926                variant: Some(
1927                    proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1928                        proto::LspDiskBasedDiagnosticsUpdated {},
1929                    ),
1930                ),
1931            },
1932        )?;
1933    }
1934
1935    Ok(())
1936}
1937
1938/// Leave someone elses shared project.
1939async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1940    let sender_id = session.connection_id;
1941    let project_id = ProjectId::from_proto(request.project_id);
1942    let db = session.db().await;
1943
1944    let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1945    tracing::info!(
1946        %project_id,
1947        "leave project"
1948    );
1949
1950    project_left(project, &session);
1951    if let Some(room) = room {
1952        room_updated(room, &session.peer);
1953    }
1954
1955    Ok(())
1956}
1957
1958/// Updates other participants with changes to the project
1959async fn update_project(
1960    request: proto::UpdateProject,
1961    response: Response<proto::UpdateProject>,
1962    session: Session,
1963) -> Result<()> {
1964    let project_id = ProjectId::from_proto(request.project_id);
1965    let (room, guest_connection_ids) = &*session
1966        .db()
1967        .await
1968        .update_project(project_id, session.connection_id, &request.worktrees)
1969        .await?;
1970    broadcast(
1971        Some(session.connection_id),
1972        guest_connection_ids.iter().copied(),
1973        |connection_id| {
1974            session
1975                .peer
1976                .forward_send(session.connection_id, connection_id, request.clone())
1977        },
1978    );
1979    if let Some(room) = room {
1980        room_updated(room, &session.peer);
1981    }
1982    response.send(proto::Ack {})?;
1983
1984    Ok(())
1985}
1986
1987/// Updates other participants with changes to the worktree
1988async fn update_worktree(
1989    request: proto::UpdateWorktree,
1990    response: Response<proto::UpdateWorktree>,
1991    session: Session,
1992) -> Result<()> {
1993    let guest_connection_ids = session
1994        .db()
1995        .await
1996        .update_worktree(&request, session.connection_id)
1997        .await?;
1998
1999    broadcast(
2000        Some(session.connection_id),
2001        guest_connection_ids.iter().copied(),
2002        |connection_id| {
2003            session
2004                .peer
2005                .forward_send(session.connection_id, connection_id, request.clone())
2006        },
2007    );
2008    response.send(proto::Ack {})?;
2009    Ok(())
2010}
2011
2012/// Updates other participants with changes to the diagnostics
2013async fn update_diagnostic_summary(
2014    message: proto::UpdateDiagnosticSummary,
2015    session: Session,
2016) -> Result<()> {
2017    let guest_connection_ids = session
2018        .db()
2019        .await
2020        .update_diagnostic_summary(&message, session.connection_id)
2021        .await?;
2022
2023    broadcast(
2024        Some(session.connection_id),
2025        guest_connection_ids.iter().copied(),
2026        |connection_id| {
2027            session
2028                .peer
2029                .forward_send(session.connection_id, connection_id, message.clone())
2030        },
2031    );
2032
2033    Ok(())
2034}
2035
2036/// Updates other participants with changes to the worktree settings
2037async fn update_worktree_settings(
2038    message: proto::UpdateWorktreeSettings,
2039    session: Session,
2040) -> Result<()> {
2041    let guest_connection_ids = session
2042        .db()
2043        .await
2044        .update_worktree_settings(&message, session.connection_id)
2045        .await?;
2046
2047    broadcast(
2048        Some(session.connection_id),
2049        guest_connection_ids.iter().copied(),
2050        |connection_id| {
2051            session
2052                .peer
2053                .forward_send(session.connection_id, connection_id, message.clone())
2054        },
2055    );
2056
2057    Ok(())
2058}
2059
2060/// Notify other participants that a  language server has started.
2061async fn start_language_server(
2062    request: proto::StartLanguageServer,
2063    session: Session,
2064) -> Result<()> {
2065    let guest_connection_ids = session
2066        .db()
2067        .await
2068        .start_language_server(&request, session.connection_id)
2069        .await?;
2070
2071    broadcast(
2072        Some(session.connection_id),
2073        guest_connection_ids.iter().copied(),
2074        |connection_id| {
2075            session
2076                .peer
2077                .forward_send(session.connection_id, connection_id, request.clone())
2078        },
2079    );
2080    Ok(())
2081}
2082
2083/// Notify other participants that a language server has changed.
2084async fn update_language_server(
2085    request: proto::UpdateLanguageServer,
2086    session: Session,
2087) -> Result<()> {
2088    let project_id = ProjectId::from_proto(request.project_id);
2089    let project_connection_ids = session
2090        .db()
2091        .await
2092        .project_connection_ids(project_id, session.connection_id, true)
2093        .await?;
2094    broadcast(
2095        Some(session.connection_id),
2096        project_connection_ids.iter().copied(),
2097        |connection_id| {
2098            session
2099                .peer
2100                .forward_send(session.connection_id, connection_id, request.clone())
2101        },
2102    );
2103    Ok(())
2104}
2105
2106/// forward a project request to the host. These requests should be read only
2107/// as guests are allowed to send them.
2108async fn forward_read_only_project_request<T>(
2109    request: T,
2110    response: Response<T>,
2111    session: Session,
2112) -> Result<()>
2113where
2114    T: EntityMessage + RequestMessage,
2115{
2116    let project_id = ProjectId::from_proto(request.remote_entity_id());
2117    let host_connection_id = session
2118        .db()
2119        .await
2120        .host_for_read_only_project_request(project_id, session.connection_id)
2121        .await?;
2122    let payload = session
2123        .peer
2124        .forward_request(session.connection_id, host_connection_id, request)
2125        .await?;
2126    response.send(payload)?;
2127    Ok(())
2128}
2129
2130async fn forward_find_search_candidates_request(
2131    request: proto::FindSearchCandidates,
2132    response: Response<proto::FindSearchCandidates>,
2133    session: Session,
2134) -> Result<()> {
2135    let project_id = ProjectId::from_proto(request.remote_entity_id());
2136    let host_connection_id = session
2137        .db()
2138        .await
2139        .host_for_read_only_project_request(project_id, session.connection_id)
2140        .await?;
2141    let payload = session
2142        .peer
2143        .forward_request(session.connection_id, host_connection_id, request)
2144        .await?;
2145    response.send(payload)?;
2146    Ok(())
2147}
2148
2149/// forward a project request to the host. These requests are disallowed
2150/// for guests.
2151async fn forward_mutating_project_request<T>(
2152    request: T,
2153    response: Response<T>,
2154    session: Session,
2155) -> Result<()>
2156where
2157    T: EntityMessage + RequestMessage,
2158{
2159    let project_id = ProjectId::from_proto(request.remote_entity_id());
2160
2161    let host_connection_id = session
2162        .db()
2163        .await
2164        .host_for_mutating_project_request(project_id, session.connection_id)
2165        .await?;
2166    let payload = session
2167        .peer
2168        .forward_request(session.connection_id, host_connection_id, request)
2169        .await?;
2170    response.send(payload)?;
2171    Ok(())
2172}
2173
2174/// Notify other participants that a new buffer has been created
2175async fn create_buffer_for_peer(
2176    request: proto::CreateBufferForPeer,
2177    session: Session,
2178) -> Result<()> {
2179    session
2180        .db()
2181        .await
2182        .check_user_is_project_host(
2183            ProjectId::from_proto(request.project_id),
2184            session.connection_id,
2185        )
2186        .await?;
2187    let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2188    session
2189        .peer
2190        .forward_send(session.connection_id, peer_id.into(), request)?;
2191    Ok(())
2192}
2193
2194/// Notify other participants that a buffer has been updated. This is
2195/// allowed for guests as long as the update is limited to selections.
2196async fn update_buffer(
2197    request: proto::UpdateBuffer,
2198    response: Response<proto::UpdateBuffer>,
2199    session: Session,
2200) -> Result<()> {
2201    let project_id = ProjectId::from_proto(request.project_id);
2202    let mut capability = Capability::ReadOnly;
2203
2204    for op in request.operations.iter() {
2205        match op.variant {
2206            None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2207            Some(_) => capability = Capability::ReadWrite,
2208        }
2209    }
2210
2211    let host = {
2212        let guard = session
2213            .db()
2214            .await
2215            .connections_for_buffer_update(project_id, session.connection_id, capability)
2216            .await?;
2217
2218        let (host, guests) = &*guard;
2219
2220        broadcast(
2221            Some(session.connection_id),
2222            guests.clone(),
2223            |connection_id| {
2224                session
2225                    .peer
2226                    .forward_send(session.connection_id, connection_id, request.clone())
2227            },
2228        );
2229
2230        *host
2231    };
2232
2233    if host != session.connection_id {
2234        session
2235            .peer
2236            .forward_request(session.connection_id, host, request.clone())
2237            .await?;
2238    }
2239
2240    response.send(proto::Ack {})?;
2241    Ok(())
2242}
2243
2244async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2245    let project_id = ProjectId::from_proto(message.project_id);
2246
2247    let operation = message.operation.as_ref().context("invalid operation")?;
2248    let capability = match operation.variant.as_ref() {
2249        Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2250            if let Some(buffer_op) = buffer_op.operation.as_ref() {
2251                match buffer_op.variant {
2252                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2253                        Capability::ReadOnly
2254                    }
2255                    _ => Capability::ReadWrite,
2256                }
2257            } else {
2258                Capability::ReadWrite
2259            }
2260        }
2261        Some(_) => Capability::ReadWrite,
2262        None => Capability::ReadOnly,
2263    };
2264
2265    let guard = session
2266        .db()
2267        .await
2268        .connections_for_buffer_update(project_id, session.connection_id, capability)
2269        .await?;
2270
2271    let (host, guests) = &*guard;
2272
2273    broadcast(
2274        Some(session.connection_id),
2275        guests.iter().chain([host]).copied(),
2276        |connection_id| {
2277            session
2278                .peer
2279                .forward_send(session.connection_id, connection_id, message.clone())
2280        },
2281    );
2282
2283    Ok(())
2284}
2285
2286/// Notify other participants that a project has been updated.
2287async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2288    request: T,
2289    session: Session,
2290) -> Result<()> {
2291    let project_id = ProjectId::from_proto(request.remote_entity_id());
2292    let project_connection_ids = session
2293        .db()
2294        .await
2295        .project_connection_ids(project_id, session.connection_id, false)
2296        .await?;
2297
2298    broadcast(
2299        Some(session.connection_id),
2300        project_connection_ids.iter().copied(),
2301        |connection_id| {
2302            session
2303                .peer
2304                .forward_send(session.connection_id, connection_id, request.clone())
2305        },
2306    );
2307    Ok(())
2308}
2309
2310/// Start following another user in a call.
2311async fn follow(
2312    request: proto::Follow,
2313    response: Response<proto::Follow>,
2314    session: Session,
2315) -> Result<()> {
2316    let room_id = RoomId::from_proto(request.room_id);
2317    let project_id = request.project_id.map(ProjectId::from_proto);
2318    let leader_id = request
2319        .leader_id
2320        .ok_or_else(|| anyhow!("invalid leader id"))?
2321        .into();
2322    let follower_id = session.connection_id;
2323
2324    session
2325        .db()
2326        .await
2327        .check_room_participants(room_id, leader_id, session.connection_id)
2328        .await?;
2329
2330    let response_payload = session
2331        .peer
2332        .forward_request(session.connection_id, leader_id, request)
2333        .await?;
2334    response.send(response_payload)?;
2335
2336    if let Some(project_id) = project_id {
2337        let room = session
2338            .db()
2339            .await
2340            .follow(room_id, project_id, leader_id, follower_id)
2341            .await?;
2342        room_updated(&room, &session.peer);
2343    }
2344
2345    Ok(())
2346}
2347
2348/// Stop following another user in a call.
2349async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2350    let room_id = RoomId::from_proto(request.room_id);
2351    let project_id = request.project_id.map(ProjectId::from_proto);
2352    let leader_id = request
2353        .leader_id
2354        .ok_or_else(|| anyhow!("invalid leader id"))?
2355        .into();
2356    let follower_id = session.connection_id;
2357
2358    session
2359        .db()
2360        .await
2361        .check_room_participants(room_id, leader_id, session.connection_id)
2362        .await?;
2363
2364    session
2365        .peer
2366        .forward_send(session.connection_id, leader_id, request)?;
2367
2368    if let Some(project_id) = project_id {
2369        let room = session
2370            .db()
2371            .await
2372            .unfollow(room_id, project_id, leader_id, follower_id)
2373            .await?;
2374        room_updated(&room, &session.peer);
2375    }
2376
2377    Ok(())
2378}
2379
2380/// Notify everyone following you of your current location.
2381async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2382    let room_id = RoomId::from_proto(request.room_id);
2383    let database = session.db.lock().await;
2384
2385    let connection_ids = if let Some(project_id) = request.project_id {
2386        let project_id = ProjectId::from_proto(project_id);
2387        database
2388            .project_connection_ids(project_id, session.connection_id, true)
2389            .await?
2390    } else {
2391        database
2392            .room_connection_ids(room_id, session.connection_id)
2393            .await?
2394    };
2395
2396    // For now, don't send view update messages back to that view's current leader.
2397    let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2398        proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2399        _ => None,
2400    });
2401
2402    for connection_id in connection_ids.iter().cloned() {
2403        if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2404            session
2405                .peer
2406                .forward_send(session.connection_id, connection_id, request.clone())?;
2407        }
2408    }
2409    Ok(())
2410}
2411
2412/// Get public data about users.
2413async fn get_users(
2414    request: proto::GetUsers,
2415    response: Response<proto::GetUsers>,
2416    session: Session,
2417) -> Result<()> {
2418    let user_ids = request
2419        .user_ids
2420        .into_iter()
2421        .map(UserId::from_proto)
2422        .collect();
2423    let users = session
2424        .db()
2425        .await
2426        .get_users_by_ids(user_ids)
2427        .await?
2428        .into_iter()
2429        .map(|user| proto::User {
2430            id: user.id.to_proto(),
2431            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2432            github_login: user.github_login,
2433            email: user.email_address,
2434            name: user.name,
2435        })
2436        .collect();
2437    response.send(proto::UsersResponse { users })?;
2438    Ok(())
2439}
2440
2441/// Search for users (to invite) buy Github login
2442async fn fuzzy_search_users(
2443    request: proto::FuzzySearchUsers,
2444    response: Response<proto::FuzzySearchUsers>,
2445    session: Session,
2446) -> Result<()> {
2447    let query = request.query;
2448    let users = match query.len() {
2449        0 => vec![],
2450        1 | 2 => session
2451            .db()
2452            .await
2453            .get_user_by_github_login(&query)
2454            .await?
2455            .into_iter()
2456            .collect(),
2457        _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2458    };
2459    let users = users
2460        .into_iter()
2461        .filter(|user| user.id != session.user_id())
2462        .map(|user| proto::User {
2463            id: user.id.to_proto(),
2464            avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2465            github_login: user.github_login,
2466            name: user.name,
2467            email: user.email_address,
2468        })
2469        .collect();
2470    response.send(proto::UsersResponse { users })?;
2471    Ok(())
2472}
2473
2474/// Send a contact request to another user.
2475async fn request_contact(
2476    request: proto::RequestContact,
2477    response: Response<proto::RequestContact>,
2478    session: Session,
2479) -> Result<()> {
2480    let requester_id = session.user_id();
2481    let responder_id = UserId::from_proto(request.responder_id);
2482    if requester_id == responder_id {
2483        return Err(anyhow!("cannot add yourself as a contact"))?;
2484    }
2485
2486    let notifications = session
2487        .db()
2488        .await
2489        .send_contact_request(requester_id, responder_id)
2490        .await?;
2491
2492    // Update outgoing contact requests of requester
2493    let mut update = proto::UpdateContacts::default();
2494    update.outgoing_requests.push(responder_id.to_proto());
2495    for connection_id in session
2496        .connection_pool()
2497        .await
2498        .user_connection_ids(requester_id)
2499    {
2500        session.peer.send(connection_id, update.clone())?;
2501    }
2502
2503    // Update incoming contact requests of responder
2504    let mut update = proto::UpdateContacts::default();
2505    update
2506        .incoming_requests
2507        .push(proto::IncomingContactRequest {
2508            requester_id: requester_id.to_proto(),
2509        });
2510    let connection_pool = session.connection_pool().await;
2511    for connection_id in connection_pool.user_connection_ids(responder_id) {
2512        session.peer.send(connection_id, update.clone())?;
2513    }
2514
2515    send_notifications(&connection_pool, &session.peer, notifications);
2516
2517    response.send(proto::Ack {})?;
2518    Ok(())
2519}
2520
2521/// Accept or decline a contact request
2522async fn respond_to_contact_request(
2523    request: proto::RespondToContactRequest,
2524    response: Response<proto::RespondToContactRequest>,
2525    session: Session,
2526) -> Result<()> {
2527    let responder_id = session.user_id();
2528    let requester_id = UserId::from_proto(request.requester_id);
2529    let db = session.db().await;
2530    if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2531        db.dismiss_contact_notification(responder_id, requester_id)
2532            .await?;
2533    } else {
2534        let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2535
2536        let notifications = db
2537            .respond_to_contact_request(responder_id, requester_id, accept)
2538            .await?;
2539        let requester_busy = db.is_user_busy(requester_id).await?;
2540        let responder_busy = db.is_user_busy(responder_id).await?;
2541
2542        let pool = session.connection_pool().await;
2543        // Update responder with new contact
2544        let mut update = proto::UpdateContacts::default();
2545        if accept {
2546            update
2547                .contacts
2548                .push(contact_for_user(requester_id, requester_busy, &pool));
2549        }
2550        update
2551            .remove_incoming_requests
2552            .push(requester_id.to_proto());
2553        for connection_id in pool.user_connection_ids(responder_id) {
2554            session.peer.send(connection_id, update.clone())?;
2555        }
2556
2557        // Update requester with new contact
2558        let mut update = proto::UpdateContacts::default();
2559        if accept {
2560            update
2561                .contacts
2562                .push(contact_for_user(responder_id, responder_busy, &pool));
2563        }
2564        update
2565            .remove_outgoing_requests
2566            .push(responder_id.to_proto());
2567
2568        for connection_id in pool.user_connection_ids(requester_id) {
2569            session.peer.send(connection_id, update.clone())?;
2570        }
2571
2572        send_notifications(&pool, &session.peer, notifications);
2573    }
2574
2575    response.send(proto::Ack {})?;
2576    Ok(())
2577}
2578
2579/// Remove a contact.
2580async fn remove_contact(
2581    request: proto::RemoveContact,
2582    response: Response<proto::RemoveContact>,
2583    session: Session,
2584) -> Result<()> {
2585    let requester_id = session.user_id();
2586    let responder_id = UserId::from_proto(request.user_id);
2587    let db = session.db().await;
2588    let (contact_accepted, deleted_notification_id) =
2589        db.remove_contact(requester_id, responder_id).await?;
2590
2591    let pool = session.connection_pool().await;
2592    // Update outgoing contact requests of requester
2593    let mut update = proto::UpdateContacts::default();
2594    if contact_accepted {
2595        update.remove_contacts.push(responder_id.to_proto());
2596    } else {
2597        update
2598            .remove_outgoing_requests
2599            .push(responder_id.to_proto());
2600    }
2601    for connection_id in pool.user_connection_ids(requester_id) {
2602        session.peer.send(connection_id, update.clone())?;
2603    }
2604
2605    // Update incoming contact requests of responder
2606    let mut update = proto::UpdateContacts::default();
2607    if contact_accepted {
2608        update.remove_contacts.push(requester_id.to_proto());
2609    } else {
2610        update
2611            .remove_incoming_requests
2612            .push(requester_id.to_proto());
2613    }
2614    for connection_id in pool.user_connection_ids(responder_id) {
2615        session.peer.send(connection_id, update.clone())?;
2616        if let Some(notification_id) = deleted_notification_id {
2617            session.peer.send(
2618                connection_id,
2619                proto::DeleteNotification {
2620                    notification_id: notification_id.to_proto(),
2621                },
2622            )?;
2623        }
2624    }
2625
2626    response.send(proto::Ack {})?;
2627    Ok(())
2628}
2629
2630fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2631    version.0.minor() < 139
2632}
2633
2634async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
2635    let plan = session.current_plan(&session.db().await).await?;
2636
2637    session
2638        .peer
2639        .send(
2640            session.connection_id,
2641            proto::UpdateUserPlan { plan: plan.into() },
2642        )
2643        .trace_err();
2644
2645    Ok(())
2646}
2647
2648async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2649    subscribe_user_to_channels(session.user_id(), &session).await?;
2650    Ok(())
2651}
2652
2653async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2654    let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2655    let mut pool = session.connection_pool().await;
2656    for membership in &channels_for_user.channel_memberships {
2657        pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2658    }
2659    session.peer.send(
2660        session.connection_id,
2661        build_update_user_channels(&channels_for_user),
2662    )?;
2663    session.peer.send(
2664        session.connection_id,
2665        build_channels_update(channels_for_user),
2666    )?;
2667    Ok(())
2668}
2669
2670/// Creates a new channel.
2671async fn create_channel(
2672    request: proto::CreateChannel,
2673    response: Response<proto::CreateChannel>,
2674    session: Session,
2675) -> Result<()> {
2676    let db = session.db().await;
2677
2678    let parent_id = request.parent_id.map(ChannelId::from_proto);
2679    let (channel, membership) = db
2680        .create_channel(&request.name, parent_id, session.user_id())
2681        .await?;
2682
2683    let root_id = channel.root_id();
2684    let channel = Channel::from_model(channel);
2685
2686    response.send(proto::CreateChannelResponse {
2687        channel: Some(channel.to_proto()),
2688        parent_id: request.parent_id,
2689    })?;
2690
2691    let mut connection_pool = session.connection_pool().await;
2692    if let Some(membership) = membership {
2693        connection_pool.subscribe_to_channel(
2694            membership.user_id,
2695            membership.channel_id,
2696            membership.role,
2697        );
2698        let update = proto::UpdateUserChannels {
2699            channel_memberships: vec![proto::ChannelMembership {
2700                channel_id: membership.channel_id.to_proto(),
2701                role: membership.role.into(),
2702            }],
2703            ..Default::default()
2704        };
2705        for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2706            session.peer.send(connection_id, update.clone())?;
2707        }
2708    }
2709
2710    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2711        if !role.can_see_channel(channel.visibility) {
2712            continue;
2713        }
2714
2715        let update = proto::UpdateChannels {
2716            channels: vec![channel.to_proto()],
2717            ..Default::default()
2718        };
2719        session.peer.send(connection_id, update.clone())?;
2720    }
2721
2722    Ok(())
2723}
2724
2725/// Delete a channel
2726async fn delete_channel(
2727    request: proto::DeleteChannel,
2728    response: Response<proto::DeleteChannel>,
2729    session: Session,
2730) -> Result<()> {
2731    let db = session.db().await;
2732
2733    let channel_id = request.channel_id;
2734    let (root_channel, removed_channels) = db
2735        .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2736        .await?;
2737    response.send(proto::Ack {})?;
2738
2739    // Notify members of removed channels
2740    let mut update = proto::UpdateChannels::default();
2741    update
2742        .delete_channels
2743        .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2744
2745    let connection_pool = session.connection_pool().await;
2746    for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2747        session.peer.send(connection_id, update.clone())?;
2748    }
2749
2750    Ok(())
2751}
2752
2753/// Invite someone to join a channel.
2754async fn invite_channel_member(
2755    request: proto::InviteChannelMember,
2756    response: Response<proto::InviteChannelMember>,
2757    session: Session,
2758) -> Result<()> {
2759    let db = session.db().await;
2760    let channel_id = ChannelId::from_proto(request.channel_id);
2761    let invitee_id = UserId::from_proto(request.user_id);
2762    let InviteMemberResult {
2763        channel,
2764        notifications,
2765    } = db
2766        .invite_channel_member(
2767            channel_id,
2768            invitee_id,
2769            session.user_id(),
2770            request.role().into(),
2771        )
2772        .await?;
2773
2774    let update = proto::UpdateChannels {
2775        channel_invitations: vec![channel.to_proto()],
2776        ..Default::default()
2777    };
2778
2779    let connection_pool = session.connection_pool().await;
2780    for connection_id in connection_pool.user_connection_ids(invitee_id) {
2781        session.peer.send(connection_id, update.clone())?;
2782    }
2783
2784    send_notifications(&connection_pool, &session.peer, notifications);
2785
2786    response.send(proto::Ack {})?;
2787    Ok(())
2788}
2789
2790/// remove someone from a channel
2791async fn remove_channel_member(
2792    request: proto::RemoveChannelMember,
2793    response: Response<proto::RemoveChannelMember>,
2794    session: Session,
2795) -> Result<()> {
2796    let db = session.db().await;
2797    let channel_id = ChannelId::from_proto(request.channel_id);
2798    let member_id = UserId::from_proto(request.user_id);
2799
2800    let RemoveChannelMemberResult {
2801        membership_update,
2802        notification_id,
2803    } = db
2804        .remove_channel_member(channel_id, member_id, session.user_id())
2805        .await?;
2806
2807    let mut connection_pool = session.connection_pool().await;
2808    notify_membership_updated(
2809        &mut connection_pool,
2810        membership_update,
2811        member_id,
2812        &session.peer,
2813    );
2814    for connection_id in connection_pool.user_connection_ids(member_id) {
2815        if let Some(notification_id) = notification_id {
2816            session
2817                .peer
2818                .send(
2819                    connection_id,
2820                    proto::DeleteNotification {
2821                        notification_id: notification_id.to_proto(),
2822                    },
2823                )
2824                .trace_err();
2825        }
2826    }
2827
2828    response.send(proto::Ack {})?;
2829    Ok(())
2830}
2831
2832/// Toggle the channel between public and private.
2833/// Care is taken to maintain the invariant that public channels only descend from public channels,
2834/// (though members-only channels can appear at any point in the hierarchy).
2835async fn set_channel_visibility(
2836    request: proto::SetChannelVisibility,
2837    response: Response<proto::SetChannelVisibility>,
2838    session: Session,
2839) -> Result<()> {
2840    let db = session.db().await;
2841    let channel_id = ChannelId::from_proto(request.channel_id);
2842    let visibility = request.visibility().into();
2843
2844    let channel_model = db
2845        .set_channel_visibility(channel_id, visibility, session.user_id())
2846        .await?;
2847    let root_id = channel_model.root_id();
2848    let channel = Channel::from_model(channel_model);
2849
2850    let mut connection_pool = session.connection_pool().await;
2851    for (user_id, role) in connection_pool
2852        .channel_user_ids(root_id)
2853        .collect::<Vec<_>>()
2854        .into_iter()
2855    {
2856        let update = if role.can_see_channel(channel.visibility) {
2857            connection_pool.subscribe_to_channel(user_id, channel_id, role);
2858            proto::UpdateChannels {
2859                channels: vec![channel.to_proto()],
2860                ..Default::default()
2861            }
2862        } else {
2863            connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2864            proto::UpdateChannels {
2865                delete_channels: vec![channel.id.to_proto()],
2866                ..Default::default()
2867            }
2868        };
2869
2870        for connection_id in connection_pool.user_connection_ids(user_id) {
2871            session.peer.send(connection_id, update.clone())?;
2872        }
2873    }
2874
2875    response.send(proto::Ack {})?;
2876    Ok(())
2877}
2878
2879/// Alter the role for a user in the channel.
2880async fn set_channel_member_role(
2881    request: proto::SetChannelMemberRole,
2882    response: Response<proto::SetChannelMemberRole>,
2883    session: Session,
2884) -> Result<()> {
2885    let db = session.db().await;
2886    let channel_id = ChannelId::from_proto(request.channel_id);
2887    let member_id = UserId::from_proto(request.user_id);
2888    let result = db
2889        .set_channel_member_role(
2890            channel_id,
2891            session.user_id(),
2892            member_id,
2893            request.role().into(),
2894        )
2895        .await?;
2896
2897    match result {
2898        db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2899            let mut connection_pool = session.connection_pool().await;
2900            notify_membership_updated(
2901                &mut connection_pool,
2902                membership_update,
2903                member_id,
2904                &session.peer,
2905            )
2906        }
2907        db::SetMemberRoleResult::InviteUpdated(channel) => {
2908            let update = proto::UpdateChannels {
2909                channel_invitations: vec![channel.to_proto()],
2910                ..Default::default()
2911            };
2912
2913            for connection_id in session
2914                .connection_pool()
2915                .await
2916                .user_connection_ids(member_id)
2917            {
2918                session.peer.send(connection_id, update.clone())?;
2919            }
2920        }
2921    }
2922
2923    response.send(proto::Ack {})?;
2924    Ok(())
2925}
2926
2927/// Change the name of a channel
2928async fn rename_channel(
2929    request: proto::RenameChannel,
2930    response: Response<proto::RenameChannel>,
2931    session: Session,
2932) -> Result<()> {
2933    let db = session.db().await;
2934    let channel_id = ChannelId::from_proto(request.channel_id);
2935    let channel_model = db
2936        .rename_channel(channel_id, session.user_id(), &request.name)
2937        .await?;
2938    let root_id = channel_model.root_id();
2939    let channel = Channel::from_model(channel_model);
2940
2941    response.send(proto::RenameChannelResponse {
2942        channel: Some(channel.to_proto()),
2943    })?;
2944
2945    let connection_pool = session.connection_pool().await;
2946    let update = proto::UpdateChannels {
2947        channels: vec![channel.to_proto()],
2948        ..Default::default()
2949    };
2950    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2951        if role.can_see_channel(channel.visibility) {
2952            session.peer.send(connection_id, update.clone())?;
2953        }
2954    }
2955
2956    Ok(())
2957}
2958
2959/// Move a channel to a new parent.
2960async fn move_channel(
2961    request: proto::MoveChannel,
2962    response: Response<proto::MoveChannel>,
2963    session: Session,
2964) -> Result<()> {
2965    let channel_id = ChannelId::from_proto(request.channel_id);
2966    let to = ChannelId::from_proto(request.to);
2967
2968    let (root_id, channels) = session
2969        .db()
2970        .await
2971        .move_channel(channel_id, to, session.user_id())
2972        .await?;
2973
2974    let connection_pool = session.connection_pool().await;
2975    for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2976        let channels = channels
2977            .iter()
2978            .filter_map(|channel| {
2979                if role.can_see_channel(channel.visibility) {
2980                    Some(channel.to_proto())
2981                } else {
2982                    None
2983                }
2984            })
2985            .collect::<Vec<_>>();
2986        if channels.is_empty() {
2987            continue;
2988        }
2989
2990        let update = proto::UpdateChannels {
2991            channels,
2992            ..Default::default()
2993        };
2994
2995        session.peer.send(connection_id, update.clone())?;
2996    }
2997
2998    response.send(Ack {})?;
2999    Ok(())
3000}
3001
3002/// Get the list of channel members
3003async fn get_channel_members(
3004    request: proto::GetChannelMembers,
3005    response: Response<proto::GetChannelMembers>,
3006    session: Session,
3007) -> Result<()> {
3008    let db = session.db().await;
3009    let channel_id = ChannelId::from_proto(request.channel_id);
3010    let limit = if request.limit == 0 {
3011        u16::MAX as u64
3012    } else {
3013        request.limit
3014    };
3015    let (members, users) = db
3016        .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3017        .await?;
3018    response.send(proto::GetChannelMembersResponse { members, users })?;
3019    Ok(())
3020}
3021
3022/// Accept or decline a channel invitation.
3023async fn respond_to_channel_invite(
3024    request: proto::RespondToChannelInvite,
3025    response: Response<proto::RespondToChannelInvite>,
3026    session: Session,
3027) -> Result<()> {
3028    let db = session.db().await;
3029    let channel_id = ChannelId::from_proto(request.channel_id);
3030    let RespondToChannelInvite {
3031        membership_update,
3032        notifications,
3033    } = db
3034        .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3035        .await?;
3036
3037    let mut connection_pool = session.connection_pool().await;
3038    if let Some(membership_update) = membership_update {
3039        notify_membership_updated(
3040            &mut connection_pool,
3041            membership_update,
3042            session.user_id(),
3043            &session.peer,
3044        );
3045    } else {
3046        let update = proto::UpdateChannels {
3047            remove_channel_invitations: vec![channel_id.to_proto()],
3048            ..Default::default()
3049        };
3050
3051        for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3052            session.peer.send(connection_id, update.clone())?;
3053        }
3054    };
3055
3056    send_notifications(&connection_pool, &session.peer, notifications);
3057
3058    response.send(proto::Ack {})?;
3059
3060    Ok(())
3061}
3062
3063/// Join the channels' room
3064async fn join_channel(
3065    request: proto::JoinChannel,
3066    response: Response<proto::JoinChannel>,
3067    session: Session,
3068) -> Result<()> {
3069    let channel_id = ChannelId::from_proto(request.channel_id);
3070    join_channel_internal(channel_id, Box::new(response), session).await
3071}
3072
3073trait JoinChannelInternalResponse {
3074    fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3075}
3076impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3077    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3078        Response::<proto::JoinChannel>::send(self, result)
3079    }
3080}
3081impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3082    fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3083        Response::<proto::JoinRoom>::send(self, result)
3084    }
3085}
3086
3087async fn join_channel_internal(
3088    channel_id: ChannelId,
3089    response: Box<impl JoinChannelInternalResponse>,
3090    session: Session,
3091) -> Result<()> {
3092    let joined_room = {
3093        let mut db = session.db().await;
3094        // If zed quits without leaving the room, and the user re-opens zed before the
3095        // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3096        // room they were in.
3097        if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3098            tracing::info!(
3099                stale_connection_id = %connection,
3100                "cleaning up stale connection",
3101            );
3102            drop(db);
3103            leave_room_for_session(&session, connection).await?;
3104            db = session.db().await;
3105        }
3106
3107        let (joined_room, membership_updated, role) = db
3108            .join_channel(channel_id, session.user_id(), session.connection_id)
3109            .await?;
3110
3111        let live_kit_connection_info =
3112            session
3113                .app_state
3114                .livekit_client
3115                .as_ref()
3116                .and_then(|live_kit| {
3117                    let (can_publish, token) = if role == ChannelRole::Guest {
3118                        (
3119                            false,
3120                            live_kit
3121                                .guest_token(
3122                                    &joined_room.room.livekit_room,
3123                                    &session.user_id().to_string(),
3124                                )
3125                                .trace_err()?,
3126                        )
3127                    } else {
3128                        (
3129                            true,
3130                            live_kit
3131                                .room_token(
3132                                    &joined_room.room.livekit_room,
3133                                    &session.user_id().to_string(),
3134                                )
3135                                .trace_err()?,
3136                        )
3137                    };
3138
3139                    Some(LiveKitConnectionInfo {
3140                        server_url: live_kit.url().into(),
3141                        token,
3142                        can_publish,
3143                    })
3144                });
3145
3146        response.send(proto::JoinRoomResponse {
3147            room: Some(joined_room.room.clone()),
3148            channel_id: joined_room
3149                .channel
3150                .as_ref()
3151                .map(|channel| channel.id.to_proto()),
3152            live_kit_connection_info,
3153        })?;
3154
3155        let mut connection_pool = session.connection_pool().await;
3156        if let Some(membership_updated) = membership_updated {
3157            notify_membership_updated(
3158                &mut connection_pool,
3159                membership_updated,
3160                session.user_id(),
3161                &session.peer,
3162            );
3163        }
3164
3165        room_updated(&joined_room.room, &session.peer);
3166
3167        joined_room
3168    };
3169
3170    channel_updated(
3171        &joined_room
3172            .channel
3173            .ok_or_else(|| anyhow!("channel not returned"))?,
3174        &joined_room.room,
3175        &session.peer,
3176        &*session.connection_pool().await,
3177    );
3178
3179    update_user_contacts(session.user_id(), &session).await?;
3180    Ok(())
3181}
3182
3183/// Start editing the channel notes
3184async fn join_channel_buffer(
3185    request: proto::JoinChannelBuffer,
3186    response: Response<proto::JoinChannelBuffer>,
3187    session: Session,
3188) -> Result<()> {
3189    let db = session.db().await;
3190    let channel_id = ChannelId::from_proto(request.channel_id);
3191
3192    let open_response = db
3193        .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3194        .await?;
3195
3196    let collaborators = open_response.collaborators.clone();
3197    response.send(open_response)?;
3198
3199    let update = UpdateChannelBufferCollaborators {
3200        channel_id: channel_id.to_proto(),
3201        collaborators: collaborators.clone(),
3202    };
3203    channel_buffer_updated(
3204        session.connection_id,
3205        collaborators
3206            .iter()
3207            .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3208        &update,
3209        &session.peer,
3210    );
3211
3212    Ok(())
3213}
3214
3215/// Edit the channel notes
3216async fn update_channel_buffer(
3217    request: proto::UpdateChannelBuffer,
3218    session: Session,
3219) -> Result<()> {
3220    let db = session.db().await;
3221    let channel_id = ChannelId::from_proto(request.channel_id);
3222
3223    let (collaborators, epoch, version) = db
3224        .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3225        .await?;
3226
3227    channel_buffer_updated(
3228        session.connection_id,
3229        collaborators.clone(),
3230        &proto::UpdateChannelBuffer {
3231            channel_id: channel_id.to_proto(),
3232            operations: request.operations,
3233        },
3234        &session.peer,
3235    );
3236
3237    let pool = &*session.connection_pool().await;
3238
3239    let non_collaborators =
3240        pool.channel_connection_ids(channel_id)
3241            .filter_map(|(connection_id, _)| {
3242                if collaborators.contains(&connection_id) {
3243                    None
3244                } else {
3245                    Some(connection_id)
3246                }
3247            });
3248
3249    broadcast(None, non_collaborators, |peer_id| {
3250        session.peer.send(
3251            peer_id,
3252            proto::UpdateChannels {
3253                latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3254                    channel_id: channel_id.to_proto(),
3255                    epoch: epoch as u64,
3256                    version: version.clone(),
3257                }],
3258                ..Default::default()
3259            },
3260        )
3261    });
3262
3263    Ok(())
3264}
3265
3266/// Rejoin the channel notes after a connection blip
3267async fn rejoin_channel_buffers(
3268    request: proto::RejoinChannelBuffers,
3269    response: Response<proto::RejoinChannelBuffers>,
3270    session: Session,
3271) -> Result<()> {
3272    let db = session.db().await;
3273    let buffers = db
3274        .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3275        .await?;
3276
3277    for rejoined_buffer in &buffers {
3278        let collaborators_to_notify = rejoined_buffer
3279            .buffer
3280            .collaborators
3281            .iter()
3282            .filter_map(|c| Some(c.peer_id?.into()));
3283        channel_buffer_updated(
3284            session.connection_id,
3285            collaborators_to_notify,
3286            &proto::UpdateChannelBufferCollaborators {
3287                channel_id: rejoined_buffer.buffer.channel_id,
3288                collaborators: rejoined_buffer.buffer.collaborators.clone(),
3289            },
3290            &session.peer,
3291        );
3292    }
3293
3294    response.send(proto::RejoinChannelBuffersResponse {
3295        buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3296    })?;
3297
3298    Ok(())
3299}
3300
3301/// Stop editing the channel notes
3302async fn leave_channel_buffer(
3303    request: proto::LeaveChannelBuffer,
3304    response: Response<proto::LeaveChannelBuffer>,
3305    session: Session,
3306) -> Result<()> {
3307    let db = session.db().await;
3308    let channel_id = ChannelId::from_proto(request.channel_id);
3309
3310    let left_buffer = db
3311        .leave_channel_buffer(channel_id, session.connection_id)
3312        .await?;
3313
3314    response.send(Ack {})?;
3315
3316    channel_buffer_updated(
3317        session.connection_id,
3318        left_buffer.connections,
3319        &proto::UpdateChannelBufferCollaborators {
3320            channel_id: channel_id.to_proto(),
3321            collaborators: left_buffer.collaborators,
3322        },
3323        &session.peer,
3324    );
3325
3326    Ok(())
3327}
3328
3329fn channel_buffer_updated<T: EnvelopedMessage>(
3330    sender_id: ConnectionId,
3331    collaborators: impl IntoIterator<Item = ConnectionId>,
3332    message: &T,
3333    peer: &Peer,
3334) {
3335    broadcast(Some(sender_id), collaborators, |peer_id| {
3336        peer.send(peer_id, message.clone())
3337    });
3338}
3339
3340fn send_notifications(
3341    connection_pool: &ConnectionPool,
3342    peer: &Peer,
3343    notifications: db::NotificationBatch,
3344) {
3345    for (user_id, notification) in notifications {
3346        for connection_id in connection_pool.user_connection_ids(user_id) {
3347            if let Err(error) = peer.send(
3348                connection_id,
3349                proto::AddNotification {
3350                    notification: Some(notification.clone()),
3351                },
3352            ) {
3353                tracing::error!(
3354                    "failed to send notification to {:?} {}",
3355                    connection_id,
3356                    error
3357                );
3358            }
3359        }
3360    }
3361}
3362
3363/// Send a message to the channel
3364async fn send_channel_message(
3365    request: proto::SendChannelMessage,
3366    response: Response<proto::SendChannelMessage>,
3367    session: Session,
3368) -> Result<()> {
3369    // Validate the message body.
3370    let body = request.body.trim().to_string();
3371    if body.len() > MAX_MESSAGE_LEN {
3372        return Err(anyhow!("message is too long"))?;
3373    }
3374    if body.is_empty() {
3375        return Err(anyhow!("message can't be blank"))?;
3376    }
3377
3378    // TODO: adjust mentions if body is trimmed
3379
3380    let timestamp = OffsetDateTime::now_utc();
3381    let nonce = request
3382        .nonce
3383        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3384
3385    let channel_id = ChannelId::from_proto(request.channel_id);
3386    let CreatedChannelMessage {
3387        message_id,
3388        participant_connection_ids,
3389        notifications,
3390    } = session
3391        .db()
3392        .await
3393        .create_channel_message(
3394            channel_id,
3395            session.user_id(),
3396            &body,
3397            &request.mentions,
3398            timestamp,
3399            nonce.clone().into(),
3400            request.reply_to_message_id.map(MessageId::from_proto),
3401        )
3402        .await?;
3403
3404    let message = proto::ChannelMessage {
3405        sender_id: session.user_id().to_proto(),
3406        id: message_id.to_proto(),
3407        body,
3408        mentions: request.mentions,
3409        timestamp: timestamp.unix_timestamp() as u64,
3410        nonce: Some(nonce),
3411        reply_to_message_id: request.reply_to_message_id,
3412        edited_at: None,
3413    };
3414    broadcast(
3415        Some(session.connection_id),
3416        participant_connection_ids.clone(),
3417        |connection| {
3418            session.peer.send(
3419                connection,
3420                proto::ChannelMessageSent {
3421                    channel_id: channel_id.to_proto(),
3422                    message: Some(message.clone()),
3423                },
3424            )
3425        },
3426    );
3427    response.send(proto::SendChannelMessageResponse {
3428        message: Some(message),
3429    })?;
3430
3431    let pool = &*session.connection_pool().await;
3432    let non_participants =
3433        pool.channel_connection_ids(channel_id)
3434            .filter_map(|(connection_id, _)| {
3435                if participant_connection_ids.contains(&connection_id) {
3436                    None
3437                } else {
3438                    Some(connection_id)
3439                }
3440            });
3441    broadcast(None, non_participants, |peer_id| {
3442        session.peer.send(
3443            peer_id,
3444            proto::UpdateChannels {
3445                latest_channel_message_ids: vec![proto::ChannelMessageId {
3446                    channel_id: channel_id.to_proto(),
3447                    message_id: message_id.to_proto(),
3448                }],
3449                ..Default::default()
3450            },
3451        )
3452    });
3453    send_notifications(pool, &session.peer, notifications);
3454
3455    Ok(())
3456}
3457
3458/// Delete a channel message
3459async fn remove_channel_message(
3460    request: proto::RemoveChannelMessage,
3461    response: Response<proto::RemoveChannelMessage>,
3462    session: Session,
3463) -> Result<()> {
3464    let channel_id = ChannelId::from_proto(request.channel_id);
3465    let message_id = MessageId::from_proto(request.message_id);
3466    let (connection_ids, existing_notification_ids) = session
3467        .db()
3468        .await
3469        .remove_channel_message(channel_id, message_id, session.user_id())
3470        .await?;
3471
3472    broadcast(
3473        Some(session.connection_id),
3474        connection_ids,
3475        move |connection| {
3476            session.peer.send(connection, request.clone())?;
3477
3478            for notification_id in &existing_notification_ids {
3479                session.peer.send(
3480                    connection,
3481                    proto::DeleteNotification {
3482                        notification_id: (*notification_id).to_proto(),
3483                    },
3484                )?;
3485            }
3486
3487            Ok(())
3488        },
3489    );
3490    response.send(proto::Ack {})?;
3491    Ok(())
3492}
3493
3494async fn update_channel_message(
3495    request: proto::UpdateChannelMessage,
3496    response: Response<proto::UpdateChannelMessage>,
3497    session: Session,
3498) -> Result<()> {
3499    let channel_id = ChannelId::from_proto(request.channel_id);
3500    let message_id = MessageId::from_proto(request.message_id);
3501    let updated_at = OffsetDateTime::now_utc();
3502    let UpdatedChannelMessage {
3503        message_id,
3504        participant_connection_ids,
3505        notifications,
3506        reply_to_message_id,
3507        timestamp,
3508        deleted_mention_notification_ids,
3509        updated_mention_notifications,
3510    } = session
3511        .db()
3512        .await
3513        .update_channel_message(
3514            channel_id,
3515            message_id,
3516            session.user_id(),
3517            request.body.as_str(),
3518            &request.mentions,
3519            updated_at,
3520        )
3521        .await?;
3522
3523    let nonce = request
3524        .nonce
3525        .clone()
3526        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3527
3528    let message = proto::ChannelMessage {
3529        sender_id: session.user_id().to_proto(),
3530        id: message_id.to_proto(),
3531        body: request.body.clone(),
3532        mentions: request.mentions.clone(),
3533        timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3534        nonce: Some(nonce),
3535        reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3536        edited_at: Some(updated_at.unix_timestamp() as u64),
3537    };
3538
3539    response.send(proto::Ack {})?;
3540
3541    let pool = &*session.connection_pool().await;
3542    broadcast(
3543        Some(session.connection_id),
3544        participant_connection_ids,
3545        |connection| {
3546            session.peer.send(
3547                connection,
3548                proto::ChannelMessageUpdate {
3549                    channel_id: channel_id.to_proto(),
3550                    message: Some(message.clone()),
3551                },
3552            )?;
3553
3554            for notification_id in &deleted_mention_notification_ids {
3555                session.peer.send(
3556                    connection,
3557                    proto::DeleteNotification {
3558                        notification_id: (*notification_id).to_proto(),
3559                    },
3560                )?;
3561            }
3562
3563            for notification in &updated_mention_notifications {
3564                session.peer.send(
3565                    connection,
3566                    proto::UpdateNotification {
3567                        notification: Some(notification.clone()),
3568                    },
3569                )?;
3570            }
3571
3572            Ok(())
3573        },
3574    );
3575
3576    send_notifications(pool, &session.peer, notifications);
3577
3578    Ok(())
3579}
3580
3581/// Mark a channel message as read
3582async fn acknowledge_channel_message(
3583    request: proto::AckChannelMessage,
3584    session: Session,
3585) -> Result<()> {
3586    let channel_id = ChannelId::from_proto(request.channel_id);
3587    let message_id = MessageId::from_proto(request.message_id);
3588    let notifications = session
3589        .db()
3590        .await
3591        .observe_channel_message(channel_id, session.user_id(), message_id)
3592        .await?;
3593    send_notifications(
3594        &*session.connection_pool().await,
3595        &session.peer,
3596        notifications,
3597    );
3598    Ok(())
3599}
3600
3601/// Mark a buffer version as synced
3602async fn acknowledge_buffer_version(
3603    request: proto::AckBufferOperation,
3604    session: Session,
3605) -> Result<()> {
3606    let buffer_id = BufferId::from_proto(request.buffer_id);
3607    session
3608        .db()
3609        .await
3610        .observe_buffer_version(
3611            buffer_id,
3612            session.user_id(),
3613            request.epoch as i32,
3614            &request.version,
3615        )
3616        .await?;
3617    Ok(())
3618}
3619
3620async fn count_language_model_tokens(
3621    request: proto::CountLanguageModelTokens,
3622    response: Response<proto::CountLanguageModelTokens>,
3623    session: Session,
3624    config: &Config,
3625) -> Result<()> {
3626    authorize_access_to_legacy_llm_endpoints(&session).await?;
3627
3628    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3629        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
3630        proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
3631    };
3632
3633    session
3634        .app_state
3635        .rate_limiter
3636        .check(&*rate_limit, session.user_id())
3637        .await?;
3638
3639    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
3640        Some(proto::LanguageModelProvider::Google) => {
3641            let api_key = config
3642                .google_ai_api_key
3643                .as_ref()
3644                .context("no Google AI API key configured on the server")?;
3645            google_ai::count_tokens(
3646                session.http_client.as_ref(),
3647                google_ai::API_URL,
3648                api_key,
3649                serde_json::from_str(&request.request)?,
3650            )
3651            .await?
3652        }
3653        _ => return Err(anyhow!("unsupported provider"))?,
3654    };
3655
3656    response.send(proto::CountLanguageModelTokensResponse {
3657        token_count: result.total_tokens as u32,
3658    })?;
3659
3660    Ok(())
3661}
3662
3663struct ZedProCountLanguageModelTokensRateLimit;
3664
3665impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
3666    fn capacity(&self) -> usize {
3667        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
3668            .ok()
3669            .and_then(|v| v.parse().ok())
3670            .unwrap_or(600) // Picked arbitrarily
3671    }
3672
3673    fn refill_duration(&self) -> chrono::Duration {
3674        chrono::Duration::hours(1)
3675    }
3676
3677    fn db_name(&self) -> &'static str {
3678        "zed-pro:count-language-model-tokens"
3679    }
3680}
3681
3682struct FreeCountLanguageModelTokensRateLimit;
3683
3684impl RateLimit for FreeCountLanguageModelTokensRateLimit {
3685    fn capacity(&self) -> usize {
3686        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
3687            .ok()
3688            .and_then(|v| v.parse().ok())
3689            .unwrap_or(600 / 10) // Picked arbitrarily
3690    }
3691
3692    fn refill_duration(&self) -> chrono::Duration {
3693        chrono::Duration::hours(1)
3694    }
3695
3696    fn db_name(&self) -> &'static str {
3697        "free:count-language-model-tokens"
3698    }
3699}
3700
3701struct ZedProComputeEmbeddingsRateLimit;
3702
3703impl RateLimit for ZedProComputeEmbeddingsRateLimit {
3704    fn capacity(&self) -> usize {
3705        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
3706            .ok()
3707            .and_then(|v| v.parse().ok())
3708            .unwrap_or(5000) // Picked arbitrarily
3709    }
3710
3711    fn refill_duration(&self) -> chrono::Duration {
3712        chrono::Duration::hours(1)
3713    }
3714
3715    fn db_name(&self) -> &'static str {
3716        "zed-pro:compute-embeddings"
3717    }
3718}
3719
3720struct FreeComputeEmbeddingsRateLimit;
3721
3722impl RateLimit for FreeComputeEmbeddingsRateLimit {
3723    fn capacity(&self) -> usize {
3724        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
3725            .ok()
3726            .and_then(|v| v.parse().ok())
3727            .unwrap_or(5000 / 10) // Picked arbitrarily
3728    }
3729
3730    fn refill_duration(&self) -> chrono::Duration {
3731        chrono::Duration::hours(1)
3732    }
3733
3734    fn db_name(&self) -> &'static str {
3735        "free:compute-embeddings"
3736    }
3737}
3738
3739async fn compute_embeddings(
3740    request: proto::ComputeEmbeddings,
3741    response: Response<proto::ComputeEmbeddings>,
3742    session: Session,
3743    api_key: Option<Arc<str>>,
3744) -> Result<()> {
3745    let api_key = api_key.context("no OpenAI API key configured on the server")?;
3746    authorize_access_to_legacy_llm_endpoints(&session).await?;
3747
3748    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
3749        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
3750        proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
3751    };
3752
3753    session
3754        .app_state
3755        .rate_limiter
3756        .check(&*rate_limit, session.user_id())
3757        .await?;
3758
3759    let embeddings = match request.model.as_str() {
3760        "openai/text-embedding-3-small" => {
3761            open_ai::embed(
3762                session.http_client.as_ref(),
3763                OPEN_AI_API_URL,
3764                &api_key,
3765                OpenAiEmbeddingModel::TextEmbedding3Small,
3766                request.texts.iter().map(|text| text.as_str()),
3767            )
3768            .await?
3769        }
3770        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
3771    };
3772
3773    let embeddings = request
3774        .texts
3775        .iter()
3776        .map(|text| {
3777            let mut hasher = sha2::Sha256::new();
3778            hasher.update(text.as_bytes());
3779            let result = hasher.finalize();
3780            result.to_vec()
3781        })
3782        .zip(
3783            embeddings
3784                .data
3785                .into_iter()
3786                .map(|embedding| embedding.embedding),
3787        )
3788        .collect::<HashMap<_, _>>();
3789
3790    let db = session.db().await;
3791    db.save_embeddings(&request.model, &embeddings)
3792        .await
3793        .context("failed to save embeddings")
3794        .trace_err();
3795
3796    response.send(proto::ComputeEmbeddingsResponse {
3797        embeddings: embeddings
3798            .into_iter()
3799            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3800            .collect(),
3801    })?;
3802    Ok(())
3803}
3804
3805async fn get_cached_embeddings(
3806    request: proto::GetCachedEmbeddings,
3807    response: Response<proto::GetCachedEmbeddings>,
3808    session: Session,
3809) -> Result<()> {
3810    authorize_access_to_legacy_llm_endpoints(&session).await?;
3811
3812    let db = session.db().await;
3813    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
3814
3815    response.send(proto::GetCachedEmbeddingsResponse {
3816        embeddings: embeddings
3817            .into_iter()
3818            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
3819            .collect(),
3820    })?;
3821    Ok(())
3822}
3823
3824/// This is leftover from before the LLM service.
3825///
3826/// The endpoints protected by this check will be moved there eventually.
3827async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
3828    if session.is_staff() {
3829        Ok(())
3830    } else {
3831        Err(anyhow!("permission denied"))?
3832    }
3833}
3834
3835/// Get a Supermaven API key for the user
3836async fn get_supermaven_api_key(
3837    _request: proto::GetSupermavenApiKey,
3838    response: Response<proto::GetSupermavenApiKey>,
3839    session: Session,
3840) -> Result<()> {
3841    let user_id: String = session.user_id().to_string();
3842    if !session.is_staff() {
3843        return Err(anyhow!("supermaven not enabled for this account"))?;
3844    }
3845
3846    let email = session
3847        .email()
3848        .ok_or_else(|| anyhow!("user must have an email"))?;
3849
3850    let supermaven_admin_api = session
3851        .supermaven_client
3852        .as_ref()
3853        .ok_or_else(|| anyhow!("supermaven not configured"))?;
3854
3855    let result = supermaven_admin_api
3856        .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3857        .await?;
3858
3859    response.send(proto::GetSupermavenApiKeyResponse {
3860        api_key: result.api_key,
3861    })?;
3862
3863    Ok(())
3864}
3865
3866/// Start receiving chat updates for a channel
3867async fn join_channel_chat(
3868    request: proto::JoinChannelChat,
3869    response: Response<proto::JoinChannelChat>,
3870    session: Session,
3871) -> Result<()> {
3872    let channel_id = ChannelId::from_proto(request.channel_id);
3873
3874    let db = session.db().await;
3875    db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3876        .await?;
3877    let messages = db
3878        .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3879        .await?;
3880    response.send(proto::JoinChannelChatResponse {
3881        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3882        messages,
3883    })?;
3884    Ok(())
3885}
3886
3887/// Stop receiving chat updates for a channel
3888async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3889    let channel_id = ChannelId::from_proto(request.channel_id);
3890    session
3891        .db()
3892        .await
3893        .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3894        .await?;
3895    Ok(())
3896}
3897
3898/// Retrieve the chat history for a channel
3899async fn get_channel_messages(
3900    request: proto::GetChannelMessages,
3901    response: Response<proto::GetChannelMessages>,
3902    session: Session,
3903) -> Result<()> {
3904    let channel_id = ChannelId::from_proto(request.channel_id);
3905    let messages = session
3906        .db()
3907        .await
3908        .get_channel_messages(
3909            channel_id,
3910            session.user_id(),
3911            MESSAGE_COUNT_PER_PAGE,
3912            Some(MessageId::from_proto(request.before_message_id)),
3913        )
3914        .await?;
3915    response.send(proto::GetChannelMessagesResponse {
3916        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3917        messages,
3918    })?;
3919    Ok(())
3920}
3921
3922/// Retrieve specific chat messages
3923async fn get_channel_messages_by_id(
3924    request: proto::GetChannelMessagesById,
3925    response: Response<proto::GetChannelMessagesById>,
3926    session: Session,
3927) -> Result<()> {
3928    let message_ids = request
3929        .message_ids
3930        .iter()
3931        .map(|id| MessageId::from_proto(*id))
3932        .collect::<Vec<_>>();
3933    let messages = session
3934        .db()
3935        .await
3936        .get_channel_messages_by_id(session.user_id(), &message_ids)
3937        .await?;
3938    response.send(proto::GetChannelMessagesResponse {
3939        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3940        messages,
3941    })?;
3942    Ok(())
3943}
3944
3945/// Retrieve the current users notifications
3946async fn get_notifications(
3947    request: proto::GetNotifications,
3948    response: Response<proto::GetNotifications>,
3949    session: Session,
3950) -> Result<()> {
3951    let notifications = session
3952        .db()
3953        .await
3954        .get_notifications(
3955            session.user_id(),
3956            NOTIFICATION_COUNT_PER_PAGE,
3957            request.before_id.map(db::NotificationId::from_proto),
3958        )
3959        .await?;
3960    response.send(proto::GetNotificationsResponse {
3961        done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3962        notifications,
3963    })?;
3964    Ok(())
3965}
3966
3967/// Mark notifications as read
3968async fn mark_notification_as_read(
3969    request: proto::MarkNotificationRead,
3970    response: Response<proto::MarkNotificationRead>,
3971    session: Session,
3972) -> Result<()> {
3973    let database = &session.db().await;
3974    let notifications = database
3975        .mark_notification_as_read_by_id(
3976            session.user_id(),
3977            NotificationId::from_proto(request.notification_id),
3978        )
3979        .await?;
3980    send_notifications(
3981        &*session.connection_pool().await,
3982        &session.peer,
3983        notifications,
3984    );
3985    response.send(proto::Ack {})?;
3986    Ok(())
3987}
3988
3989/// Get the current users information
3990async fn get_private_user_info(
3991    _request: proto::GetPrivateUserInfo,
3992    response: Response<proto::GetPrivateUserInfo>,
3993    session: Session,
3994) -> Result<()> {
3995    let db = session.db().await;
3996
3997    let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3998    let user = db
3999        .get_user_by_id(session.user_id())
4000        .await?
4001        .ok_or_else(|| anyhow!("user not found"))?;
4002    let flags = db.get_user_flags(session.user_id()).await?;
4003
4004    response.send(proto::GetPrivateUserInfoResponse {
4005        metrics_id,
4006        staff: user.admin,
4007        flags,
4008        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4009    })?;
4010    Ok(())
4011}
4012
4013/// Accept the terms of service (tos) on behalf of the current user
4014async fn accept_terms_of_service(
4015    _request: proto::AcceptTermsOfService,
4016    response: Response<proto::AcceptTermsOfService>,
4017    session: Session,
4018) -> Result<()> {
4019    let db = session.db().await;
4020
4021    let accepted_tos_at = Utc::now();
4022    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4023        .await?;
4024
4025    response.send(proto::AcceptTermsOfServiceResponse {
4026        accepted_tos_at: accepted_tos_at.timestamp() as u64,
4027    })?;
4028    Ok(())
4029}
4030
4031/// The minimum account age an account must have in order to use the LLM service.
4032const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4033
4034async fn get_llm_api_token(
4035    _request: proto::GetLlmToken,
4036    response: Response<proto::GetLlmToken>,
4037    session: Session,
4038) -> Result<()> {
4039    let db = session.db().await;
4040
4041    let flags = db.get_user_flags(session.user_id()).await?;
4042    let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4043    let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4044    let has_predict_edits_feature_flag = flags.iter().any(|flag| flag == "predict-edits");
4045
4046    if !session.is_staff() && !has_language_models_feature_flag {
4047        Err(anyhow!("permission denied"))?
4048    }
4049
4050    let user_id = session.user_id();
4051    let user = db
4052        .get_user_by_id(user_id)
4053        .await?
4054        .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4055
4056    if user.accepted_tos_at.is_none() {
4057        Err(anyhow!("terms of service not accepted"))?
4058    }
4059
4060    let has_llm_subscription = session.has_llm_subscription(&db).await?;
4061
4062    let bypass_account_age_check =
4063        has_llm_subscription || flags.iter().any(|flag| flag == "bypass-account-age-check");
4064    if !bypass_account_age_check {
4065        let mut account_created_at = user.created_at;
4066        if let Some(github_created_at) = user.github_user_created_at {
4067            account_created_at = account_created_at.min(github_created_at);
4068        }
4069        if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4070            Err(anyhow!("account too young"))?
4071        }
4072    }
4073
4074    let billing_preferences = db.get_billing_preferences(user.id).await?;
4075
4076    let token = LlmTokenClaims::create(
4077        &user,
4078        session.is_staff(),
4079        billing_preferences,
4080        has_llm_closed_beta_feature_flag,
4081        has_predict_edits_feature_flag,
4082        has_llm_subscription,
4083        session.current_plan(&db).await?,
4084        session.system_id.clone(),
4085        &session.app_state.config,
4086    )?;
4087    response.send(proto::GetLlmTokenResponse { token })?;
4088    Ok(())
4089}
4090
4091fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4092    let message = match message {
4093        TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4094        TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4095        TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4096        TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4097        TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4098            code: frame.code.into(),
4099            reason: frame.reason,
4100        })),
4101        // We should never receive a frame while reading the message, according
4102        // to the `tungstenite` maintainers:
4103        //
4104        // > It cannot occur when you read messages from the WebSocket, but it
4105        // > can be used when you want to send the raw frames (e.g. you want to
4106        // > send the frames to the WebSocket without composing the full message first).
4107        // >
4108        // > — https://github.com/snapview/tungstenite-rs/issues/268
4109        TungsteniteMessage::Frame(_) => {
4110            bail!("received an unexpected frame while reading the message")
4111        }
4112    };
4113
4114    Ok(message)
4115}
4116
4117fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4118    match message {
4119        AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4120        AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4121        AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4122        AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4123        AxumMessage::Close(frame) => {
4124            TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4125                code: frame.code.into(),
4126                reason: frame.reason,
4127            }))
4128        }
4129    }
4130}
4131
4132fn notify_membership_updated(
4133    connection_pool: &mut ConnectionPool,
4134    result: MembershipUpdated,
4135    user_id: UserId,
4136    peer: &Peer,
4137) {
4138    for membership in &result.new_channels.channel_memberships {
4139        connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4140    }
4141    for channel_id in &result.removed_channels {
4142        connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4143    }
4144
4145    let user_channels_update = proto::UpdateUserChannels {
4146        channel_memberships: result
4147            .new_channels
4148            .channel_memberships
4149            .iter()
4150            .map(|cm| proto::ChannelMembership {
4151                channel_id: cm.channel_id.to_proto(),
4152                role: cm.role.into(),
4153            })
4154            .collect(),
4155        ..Default::default()
4156    };
4157
4158    let mut update = build_channels_update(result.new_channels);
4159    update.delete_channels = result
4160        .removed_channels
4161        .into_iter()
4162        .map(|id| id.to_proto())
4163        .collect();
4164    update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4165
4166    for connection_id in connection_pool.user_connection_ids(user_id) {
4167        peer.send(connection_id, user_channels_update.clone())
4168            .trace_err();
4169        peer.send(connection_id, update.clone()).trace_err();
4170    }
4171}
4172
4173fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4174    proto::UpdateUserChannels {
4175        channel_memberships: channels
4176            .channel_memberships
4177            .iter()
4178            .map(|m| proto::ChannelMembership {
4179                channel_id: m.channel_id.to_proto(),
4180                role: m.role.into(),
4181            })
4182            .collect(),
4183        observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4184        observed_channel_message_id: channels.observed_channel_messages.clone(),
4185    }
4186}
4187
4188fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4189    let mut update = proto::UpdateChannels::default();
4190
4191    for channel in channels.channels {
4192        update.channels.push(channel.to_proto());
4193    }
4194
4195    update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4196    update.latest_channel_message_ids = channels.latest_channel_messages;
4197
4198    for (channel_id, participants) in channels.channel_participants {
4199        update
4200            .channel_participants
4201            .push(proto::ChannelParticipants {
4202                channel_id: channel_id.to_proto(),
4203                participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4204            });
4205    }
4206
4207    for channel in channels.invited_channels {
4208        update.channel_invitations.push(channel.to_proto());
4209    }
4210
4211    update
4212}
4213
4214fn build_initial_contacts_update(
4215    contacts: Vec<db::Contact>,
4216    pool: &ConnectionPool,
4217) -> proto::UpdateContacts {
4218    let mut update = proto::UpdateContacts::default();
4219
4220    for contact in contacts {
4221        match contact {
4222            db::Contact::Accepted { user_id, busy } => {
4223                update.contacts.push(contact_for_user(user_id, busy, pool));
4224            }
4225            db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4226            db::Contact::Incoming { user_id } => {
4227                update
4228                    .incoming_requests
4229                    .push(proto::IncomingContactRequest {
4230                        requester_id: user_id.to_proto(),
4231                    })
4232            }
4233        }
4234    }
4235
4236    update
4237}
4238
4239fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4240    proto::Contact {
4241        user_id: user_id.to_proto(),
4242        online: pool.is_user_online(user_id),
4243        busy,
4244    }
4245}
4246
4247fn room_updated(room: &proto::Room, peer: &Peer) {
4248    broadcast(
4249        None,
4250        room.participants
4251            .iter()
4252            .filter_map(|participant| Some(participant.peer_id?.into())),
4253        |peer_id| {
4254            peer.send(
4255                peer_id,
4256                proto::RoomUpdated {
4257                    room: Some(room.clone()),
4258                },
4259            )
4260        },
4261    );
4262}
4263
4264fn channel_updated(
4265    channel: &db::channel::Model,
4266    room: &proto::Room,
4267    peer: &Peer,
4268    pool: &ConnectionPool,
4269) {
4270    let participants = room
4271        .participants
4272        .iter()
4273        .map(|p| p.user_id)
4274        .collect::<Vec<_>>();
4275
4276    broadcast(
4277        None,
4278        pool.channel_connection_ids(channel.root_id())
4279            .filter_map(|(channel_id, role)| {
4280                role.can_see_channel(channel.visibility)
4281                    .then_some(channel_id)
4282            }),
4283        |peer_id| {
4284            peer.send(
4285                peer_id,
4286                proto::UpdateChannels {
4287                    channel_participants: vec![proto::ChannelParticipants {
4288                        channel_id: channel.id.to_proto(),
4289                        participant_user_ids: participants.clone(),
4290                    }],
4291                    ..Default::default()
4292                },
4293            )
4294        },
4295    );
4296}
4297
4298async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4299    let db = session.db().await;
4300
4301    let contacts = db.get_contacts(user_id).await?;
4302    let busy = db.is_user_busy(user_id).await?;
4303
4304    let pool = session.connection_pool().await;
4305    let updated_contact = contact_for_user(user_id, busy, &pool);
4306    for contact in contacts {
4307        if let db::Contact::Accepted {
4308            user_id: contact_user_id,
4309            ..
4310        } = contact
4311        {
4312            for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4313                session
4314                    .peer
4315                    .send(
4316                        contact_conn_id,
4317                        proto::UpdateContacts {
4318                            contacts: vec![updated_contact.clone()],
4319                            remove_contacts: Default::default(),
4320                            incoming_requests: Default::default(),
4321                            remove_incoming_requests: Default::default(),
4322                            outgoing_requests: Default::default(),
4323                            remove_outgoing_requests: Default::default(),
4324                        },
4325                    )
4326                    .trace_err();
4327            }
4328        }
4329    }
4330    Ok(())
4331}
4332
4333async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4334    let mut contacts_to_update = HashSet::default();
4335
4336    let room_id;
4337    let canceled_calls_to_user_ids;
4338    let livekit_room;
4339    let delete_livekit_room;
4340    let room;
4341    let channel;
4342
4343    if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4344        contacts_to_update.insert(session.user_id());
4345
4346        for project in left_room.left_projects.values() {
4347            project_left(project, session);
4348        }
4349
4350        room_id = RoomId::from_proto(left_room.room.id);
4351        canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4352        livekit_room = mem::take(&mut left_room.room.livekit_room);
4353        delete_livekit_room = left_room.deleted;
4354        room = mem::take(&mut left_room.room);
4355        channel = mem::take(&mut left_room.channel);
4356
4357        room_updated(&room, &session.peer);
4358    } else {
4359        return Ok(());
4360    }
4361
4362    if let Some(channel) = channel {
4363        channel_updated(
4364            &channel,
4365            &room,
4366            &session.peer,
4367            &*session.connection_pool().await,
4368        );
4369    }
4370
4371    {
4372        let pool = session.connection_pool().await;
4373        for canceled_user_id in canceled_calls_to_user_ids {
4374            for connection_id in pool.user_connection_ids(canceled_user_id) {
4375                session
4376                    .peer
4377                    .send(
4378                        connection_id,
4379                        proto::CallCanceled {
4380                            room_id: room_id.to_proto(),
4381                        },
4382                    )
4383                    .trace_err();
4384            }
4385            contacts_to_update.insert(canceled_user_id);
4386        }
4387    }
4388
4389    for contact_user_id in contacts_to_update {
4390        update_user_contacts(contact_user_id, session).await?;
4391    }
4392
4393    if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4394        live_kit
4395            .remove_participant(livekit_room.clone(), session.user_id().to_string())
4396            .await
4397            .trace_err();
4398
4399        if delete_livekit_room {
4400            live_kit.delete_room(livekit_room).await.trace_err();
4401        }
4402    }
4403
4404    Ok(())
4405}
4406
4407async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4408    let left_channel_buffers = session
4409        .db()
4410        .await
4411        .leave_channel_buffers(session.connection_id)
4412        .await?;
4413
4414    for left_buffer in left_channel_buffers {
4415        channel_buffer_updated(
4416            session.connection_id,
4417            left_buffer.connections,
4418            &proto::UpdateChannelBufferCollaborators {
4419                channel_id: left_buffer.channel_id.to_proto(),
4420                collaborators: left_buffer.collaborators,
4421            },
4422            &session.peer,
4423        );
4424    }
4425
4426    Ok(())
4427}
4428
4429fn project_left(project: &db::LeftProject, session: &Session) {
4430    for connection_id in &project.connection_ids {
4431        if project.should_unshare {
4432            session
4433                .peer
4434                .send(
4435                    *connection_id,
4436                    proto::UnshareProject {
4437                        project_id: project.id.to_proto(),
4438                    },
4439                )
4440                .trace_err();
4441        } else {
4442            session
4443                .peer
4444                .send(
4445                    *connection_id,
4446                    proto::RemoveProjectCollaborator {
4447                        project_id: project.id.to_proto(),
4448                        peer_id: Some(session.connection_id.into()),
4449                    },
4450                )
4451                .trace_err();
4452        }
4453    }
4454}
4455
4456pub trait ResultExt {
4457    type Ok;
4458
4459    fn trace_err(self) -> Option<Self::Ok>;
4460}
4461
4462impl<T, E> ResultExt for Result<T, E>
4463where
4464    E: std::fmt::Debug,
4465{
4466    type Ok = T;
4467
4468    #[track_caller]
4469    fn trace_err(self) -> Option<T> {
4470        match self {
4471            Ok(value) => Some(value),
4472            Err(error) => {
4473                tracing::error!("{:?}", error);
4474                None
4475            }
4476        }
4477    }
4478}